Source code for stonesoup.predictor.composite

from collections.abc import Sequence

from ..base import Property
from ..predictor import Predictor
from ..predictor._utils import predict_lru_cache
from ..types.prediction import CompositePrediction
from ..types.state import CompositeState


[docs] class CompositePredictor(Predictor): """Composite predictor type A composition of ordered sub-predictors (:class:`~.Predictor`). Independently predicts each sub-state of a :class:`CompositeState` forward using a corresponding sub-predictor. """ sub_predictors: Sequence[Predictor] = Property( doc="Sequence of sub-predictors comprising the composite predictor. Must not be empty.") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not isinstance(self.sub_predictors, Sequence): raise ValueError(f"Sub-predictors must be defined as an ordered list, not " f"{type(self.sub_predictors)}") if len(self.sub_predictors) == 0: raise ValueError("Cannot create an empty composite predictor") if any(not isinstance(sub_predictor, Predictor) for sub_predictor in self.sub_predictors): raise ValueError("All sub-predictors must be a Predictor type") @property def transition_model(self): raise NotImplementedError("A composition of predictors has no defined transition model")
[docs] @predict_lru_cache() def predict(self, prior, timestamp=None, **kwargs): r"""The predict function Parameters ---------- prior : :class:`~.CompositeState` The composite state of an object to be predicted forwards timestamp : :class:`datetime.datetime`, optional :math:`k` **kwargs : These are passed to each sub-predictor's prediction method Returns ------- : :class:`~.CompositeState` The predicted composite state """ if not isinstance(prior, CompositeState): raise ValueError("CompositePredictor can only predict forward CompositeState types") if len(prior) != len(self): raise ValueError(f"Mismatch in number of prior sub-states {len(prior)} and number " f"of sub-predictors {len(self)}") prediction_sub_states = [] for sub_predictor, sub_state in zip(self.sub_predictors, prior.sub_states): sub_prediction = sub_predictor.predict(prior=sub_state, timestamp=timestamp, **kwargs) prediction_sub_states.append(sub_prediction) return CompositePrediction(sub_states=prediction_sub_states)
def __contains__(self, item): return self.sub_predictors.__contains__(item) def __getitem__(self, index): """Can be indexed as a list, or sliced, in which case a new composite predictor will be created from the sub-list of sub-predictors.""" if isinstance(index, slice): return self.__class__(self.sub_predictors.__getitem__(index)) return self.sub_predictors.__getitem__(index) def __iter__(self): return self.sub_predictors.__iter__() def __len__(self): return self.sub_predictors.__len__()