# -*- coding: utf-8 -*-
from typing import Any, Union
from ..base import Property
from .array import CovarianceMatrix
from .base import Type
from .state import (State, GaussianState, ParticleState, SqrtGaussianState,
TaggedWeightedGaussianState, StateMutableSequence)
def _from_state(
cls,
state: State,
*args: Any,
prediction_type: Union['Prediction', 'MeasurementPrediction', None] = None,
**kwargs: Any) -> Union['Prediction', 'MeasurementPrediction']:
"""Return new (Measurement)Prediction instance of suitable type using existing properties
Parameters
----------
state: State
:class:`~.State` to use existing properties from, and identify prediction type from
\\*args: Sequence
Arguments to pass to newly created prediction, replacing those with same name on ``state``
parameter.
prediction_type: :class:`~.Prediction` or :class:`~.MeasurementPrediction`, optional
Type to use for prediction, overriding one from :attr:`class_mapping`.
\\*\\*kwargs: Mapping
New property names and associate value for use in newly created prediction, replacing those
on the ``state`` parameter.
"""
# Handle being initialised with state sequence
if isinstance(state, StateMutableSequence):
state = state.state
try:
state_type = next(type_ for type_ in type(state).mro() if type_ in cls.class_mapping)
except StopIteration:
raise TypeError(f'{cls.__name__} type not defined for {type(state).__name__}')
if prediction_type is None:
prediction_type = cls.class_mapping[state_type]
args_property_names = {
name for n, name in enumerate(prediction_type.properties) if n < len(args)}
# Use current state kwargs that also properties of prediction type
new_kwargs = {
name: getattr(state, name)
for name in state_type.properties.keys() & prediction_type.properties.keys()
if name not in args_property_names}
# And replace them with any newly defined kwargs
new_kwargs.update(kwargs)
return prediction_type(*args, **new_kwargs)
[docs]class Prediction(Type):
""" Prediction type
This is the base prediction class. """
class_mapping = {}
from_state = classmethod(_from_state)
def __init_subclass__(cls, **kwargs):
state_type = cls.__bases__[-1]
Prediction.class_mapping[state_type] = cls
super().__init_subclass__(**kwargs)
[docs]class MeasurementPrediction(Type):
""" Prediction type
This is the base measurement prediction class. """
class_mapping = {}
from_state = classmethod(_from_state)
def __init_subclass__(cls, **kwargs):
state_type = cls.__bases__[-1]
MeasurementPrediction.class_mapping[state_type] = cls
super().__init_subclass__(**kwargs)
[docs]class StatePrediction(Prediction, State):
""" StatePrediction type
Most simple state prediction type, which only has time and a state vector.
"""
[docs]class StateMeasurementPrediction(MeasurementPrediction, State):
""" MeasurementPrediction type
Most simple measurement prediction type, which only has time and a state
vector.
"""
[docs]class GaussianStatePrediction(Prediction, GaussianState):
""" GaussianStatePrediction type
This is a simple Gaussian state prediction object, which, as the name
suggests, is described by a Gaussian distribution.
"""
[docs]class SqrtGaussianStatePrediction(Prediction, SqrtGaussianState):
""" SqrtGaussianStatePrediction type
This is a Gaussian state prediction object, with the covariance held
as the square root of the covariance matrix
"""
[docs]class WeightedGaussianStatePrediction(Prediction, TaggedWeightedGaussianState):
""" WeightedGaussianStatePrediction type
This is a simple Gaussian state prediction object, which, as the name
suggests, is described by a Gaussian distribution
with an associated weight.
"""
[docs]class TaggedWeightedGaussianStatePrediction(Prediction,
TaggedWeightedGaussianState):
""" TaggedWeightedGaussianStatePrediction type
This is a simple Gaussian state prediction object, which, as the name
suggests, is described by a Gaussian distribution, with an associated
weight and unique tag.
"""
[docs]class GaussianMeasurementPrediction(MeasurementPrediction, GaussianState):
""" GaussianMeasurementPrediction type
This is a simple Gaussian measurement prediction object, which, as the name
suggests, is described by a Gaussian distribution.
"""
cross_covar: CovarianceMatrix = Property(
default=None, doc="The state-measurement cross covariance matrix")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.cross_covar is not None \
and self.cross_covar.shape[1] != self.state_vector.shape[0]:
raise ValueError("cross_covar should have the same number of "
"columns as the number of rows in state_vector")
# Don't need to support Sqrt Covar for MeasurementPrediction
MeasurementPrediction.class_mapping[SqrtGaussianState] = GaussianMeasurementPrediction
[docs]class ParticleStatePrediction(Prediction, ParticleState):
"""ParticleStatePrediction type
This is a simple Particle state prediction object.
"""
[docs]class ParticleMeasurementPrediction(MeasurementPrediction, ParticleState):
"""MeasurementStatePrediction type
This is a simple Particle measurement prediction object.
"""