Source code for stonesoup.predictor.categorical

# -*- coding: utf-8 -*-
from ..base import Property
from ..models.transition.categorical import CategoricalTransitionModel
from ..predictor import Predictor
from ..predictor._utils import predict_lru_cache
from ..types.prediction import Prediction
from ..types.state import CategoricalState


[docs]class HMMPredictor(Predictor): r"""Models the prediction step of a Hidden Markov Model""" transition_model: CategoricalTransitionModel = Property(doc="The transition model to be used.") def _transition_function(self, prior, **kwargs): return self.transition_model.function(prior, **kwargs) def _predict_over_interval(self, prior, timestamp): """Private function to get the prediction interval (or None) Parameters ---------- prior : :class:`~.State` The prior state timestamp : :class:`datetime.datetime`, optional The (current) timestamp Returns ------- : :class:`datetime.timedelta` time interval to predict over """ # Deal with undefined timestamps if timestamp is None or prior.timestamp is None: predict_over_interval = None else: predict_over_interval = timestamp - prior.timestamp return predict_over_interval
[docs] @predict_lru_cache() def predict(self, prior, timestamp=None, **kwargs): r"""A simple matrix multiplication. The Chapman-Kolmogorov equation is: .. math:: p(x_k|z_{1:k-1}) &= \Sigma_{x_{k-1}} p(x_k|x_{k-1}) p(x_{k-1}|z_{1:k-1})\\ &= F_k p(x_{k-1}|z_{1:k-1}) where :math:`F_k` is the category-transition matrix and :math:`p(x)` is encoded in the state vector Parameters ---------- prior : :class:`~.CategoricalState` :math:`\mathbf{x}_{k-1}` timestamp : :class:`datetime.datetime`, optional :math:`k + 1` **kwargs : These are passed to the :meth:`transition_model.function` method. Returns ------- : :class:`~.CategoricalStatePrediction` :math:`\mathbf{x}_{t + \Delta t|t}`, the predicted state. Notes ----- The categorical transition model is time-invariant and the evaluated `time_interval` can be `None`. """ if not isinstance(prior, CategoricalState): raise ValueError("Prior must be a categorical state type") predict_over_interval = self._predict_over_interval(prior, timestamp) prediction = self._transition_function(prior, time_interval=predict_over_interval, **kwargs) return Prediction.from_state(prior, prediction, timestamp=timestamp, transition_model=self.transition_model)