Source code for stonesoup.types.prediction

# -*- coding: utf-8 -*-
from typing import Any, Union

from ..base import Property
from .array import CovarianceMatrix
from .base import Type
from .state import (State, GaussianState, ParticleState, SqrtGaussianState,
                    TaggedWeightedGaussianState, StateMutableSequence)


def _from_state(
        cls,
        state: State,
        *args: Any,
        prediction_type: Union['Prediction', 'MeasurementPrediction', None] = None,
        **kwargs: Any) -> Union['Prediction', 'MeasurementPrediction']:
    """Return new (Measurement)Prediction instance of suitable type using existing properties

    Parameters
    ----------
    state: State
        :class:`~.State` to use existing properties from, and identify prediction type from
    \\*args: Sequence
        Arguments to pass to newly created prediction, replacing those with same name on ``state``
        parameter.
    prediction_type: :class:`~.Prediction` or :class:`~.MeasurementPrediction`, optional
        Type to use for prediction, overriding one from :attr:`class_mapping`.
    \\*\\*kwargs: Mapping
        New property names and associate value for use in newly created prediction, replacing those
        on the ``state`` parameter.
    """
    # Handle being initialised with state sequence
    if isinstance(state, StateMutableSequence):
        state = state.state
    try:
        state_type = next(type_ for type_ in type(state).mro() if type_ in cls.class_mapping)
    except StopIteration:
        raise TypeError(f'{cls.__name__} type not defined for {type(state).__name__}')
    if prediction_type is None:
        prediction_type = cls.class_mapping[state_type]

    args_property_names = {
        name for n, name in enumerate(prediction_type.properties) if n < len(args)}
    # Use current state kwargs that also properties of prediction type
    new_kwargs = {
        name: getattr(state, name)
        for name in state_type.properties.keys() & prediction_type.properties.keys()
        if name not in args_property_names}
    # And replace them with any newly defined kwargs
    new_kwargs.update(kwargs)

    return prediction_type(*args, **new_kwargs)


[docs]class Prediction(Type): """ Prediction type This is the base prediction class. """ class_mapping = {} from_state = classmethod(_from_state) def __init_subclass__(cls, **kwargs): state_type = cls.__bases__[-1] Prediction.class_mapping[state_type] = cls super().__init_subclass__(**kwargs)
[docs]class MeasurementPrediction(Type): """ Prediction type This is the base measurement prediction class. """ class_mapping = {} from_state = classmethod(_from_state) def __init_subclass__(cls, **kwargs): state_type = cls.__bases__[-1] MeasurementPrediction.class_mapping[state_type] = cls super().__init_subclass__(**kwargs)
[docs]class StatePrediction(Prediction, State): """ StatePrediction type Most simple state prediction type, which only has time and a state vector. """
[docs]class StateMeasurementPrediction(MeasurementPrediction, State): """ MeasurementPrediction type Most simple measurement prediction type, which only has time and a state vector. """
[docs]class GaussianStatePrediction(Prediction, GaussianState): """ GaussianStatePrediction type This is a simple Gaussian state prediction object, which, as the name suggests, is described by a Gaussian distribution. """
[docs]class SqrtGaussianStatePrediction(Prediction, SqrtGaussianState): """ SqrtGaussianStatePrediction type This is a Gaussian state prediction object, with the covariance held as the square root of the covariance matrix """
[docs]class WeightedGaussianStatePrediction(Prediction, TaggedWeightedGaussianState): """ WeightedGaussianStatePrediction type This is a simple Gaussian state prediction object, which, as the name suggests, is described by a Gaussian distribution with an associated weight. """
[docs]class TaggedWeightedGaussianStatePrediction(Prediction, TaggedWeightedGaussianState): """ TaggedWeightedGaussianStatePrediction type This is a simple Gaussian state prediction object, which, as the name suggests, is described by a Gaussian distribution, with an associated weight and unique tag. """
[docs]class GaussianMeasurementPrediction(MeasurementPrediction, GaussianState): """ GaussianMeasurementPrediction type This is a simple Gaussian measurement prediction object, which, as the name suggests, is described by a Gaussian distribution. """ cross_covar: CovarianceMatrix = Property( default=None, doc="The state-measurement cross covariance matrix") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.cross_covar is not None \ and self.cross_covar.shape[1] != self.state_vector.shape[0]: raise ValueError("cross_covar should have the same number of " "columns as the number of rows in state_vector")
# Don't need to support Sqrt Covar for MeasurementPrediction MeasurementPrediction.class_mapping[SqrtGaussianState] = GaussianMeasurementPrediction
[docs]class ParticleStatePrediction(Prediction, ParticleState): """ParticleStatePrediction type This is a simple Particle state prediction object. """
[docs]class ParticleMeasurementPrediction(MeasurementPrediction, ParticleState): """MeasurementStatePrediction type This is a simple Particle measurement prediction object. """