Source code for stonesoup.updater.categorical

# -*- coding: utf-8 -*-
import numpy as np

from ..base import Property
from ..models.measurement.categorical import CategoricalMeasurementModel
from ..types.prediction import MeasurementPrediction
from ..types.state import CategoricalState
from ..types.update import Update
from ..updater import Updater


[docs]class HMMUpdater(Updater): r"""Models the update step of a hidden Markov model""" measurement_model: CategoricalMeasurementModel = Property( default=None, doc="An observation-based measurement model. Measurements are assumed to be as defined " "above. This model need not be defined if a measurement model is provided in the " "measurement. If no model specified on construction, or in the measurement, then an " "error will be thrown.") def _check_measurement_model(self, measurement_model): """Check that the measurement model passed actually exists. If not attach the one in the updater. If that one's not specified, return an error. Parameters ---------- measurement_model : :class`~.MeasurementModel` A measurement model to be checked Returns ------- : :class`~.MeasurementModel` The measurement model to be used """ if measurement_model is None: if self.measurement_model is None: raise ValueError("No measurement model specified") else: measurement_model = self.measurement_model try: measurement_model.emission_matrix except AttributeError: raise ValueError("Measurement model must be categorical. I.E. it must have an " "Emission matrix property for the HMMUpdater") return measurement_model def _get_emission_matrix(self, hypothesis): measurement_model = self._check_measurement_model(hypothesis.measurement.measurement_model) return measurement_model.emission_matrix
[docs] def predict_measurement(self, predicted_state, measurement_model=None, category_names=None, **kwargs): r"""Predict the measurement implied by the predicted state mean Parameters ---------- predicted_state : :class:`~.State` The predicted state :math:`X_{k|k-1}`. measurement_model : :class:`~.MeasurementModel` The measurement model. If omitted, the model in the updater object is used. category_names : :class:`list` List of :class:`str` measurement category names. **kwargs : various These are passed to :meth:`~.MeasurementModel.function` and :meth:`~.MeasurementModel.matrix`. Returns ------- : :class:`~.MeasurementPrediction` The measurement prediction, :math:`Y_{k|k-1}` """ measurement_model = self._check_measurement_model(measurement_model) pred_meas = measurement_model.function(predicted_state, **kwargs) return MeasurementPrediction.from_state(predicted_state, pred_meas, category_names=category_names)
[docs] def update(self, hypothesis, **kwargs): r"""The update method. Given a hypothesised association between a predicted state or predicted measurement and an actual measurement, calculate the posterior state. Bayes' rule: :math:`p(x_k|z_{1:k}) \propto p(z_k|x_k) p(x_k|z_{1:k-1})`. The likelihood is calculated as an Nx1 vector of :math:`p(z_k|x_{k|k-1})` for the :math:`z_k` actually observed. This is element-wise multiplied by the prior and normalised. Parameters ---------- hypothesis : :class:`~.SingleHypothesis` the prediction-measurement association hypothesis. This hypothesis may carry a predicted measurement, or a predicted state. In the latter case a predicted measurement will be calculated. **kwargs : various These are passed to :meth:`predict_measurement` Returns ------- : :class:`~.CategoricalStateUpdate` The posterior categorical state """ prediction = hypothesis.prediction if not isinstance(prediction, CategoricalState): raise ValueError("Prediction must be a categorical state type") # category names to be passed to measurement prediction measurement_category_names = hypothesis.measurement.category_names if hypothesis.measurement_prediction is None: measurement_model = hypothesis.measurement.measurement_model measurement_model = self._check_measurement_model(measurement_model) # Attach the measurement prediction to the hypothesis hypothesis.measurement_prediction = self.predict_measurement( prediction, measurement_model=measurement_model, category_names=measurement_category_names, **kwargs) emission_matrix = self._get_emission_matrix(hypothesis) likelihood = emission_matrix @ hypothesis.measurement.state_vector # Bayes rule posterior = np.multiply(likelihood, hypothesis.prediction.state_vector) posterior = posterior / np.sum(posterior) return Update.from_state(hypothesis.prediction, posterior, timestamp=hypothesis.measurement.timestamp, hypothesis=hypothesis)