Source code for stonesoup.proposal.simple

from typing import Union

import numpy as np
from scipy.stats import multivariate_normal as mvn

from stonesoup.base import Property
from stonesoup.models.transition import TransitionModel
from stonesoup.proposal.base import Proposal
from stonesoup.types.array import StateVector, StateVectors
from stonesoup.types.detection import Detection
from stonesoup.types.state import State, GaussianState, SqrtGaussianState
from stonesoup.types.prediction import Prediction
from stonesoup.updater.base import Updater
from stonesoup.predictor.base import Predictor
from stonesoup.predictor.kalman import SqrtKalmanPredictor
from stonesoup.types.hypothesis import SingleHypothesis


[docs] class PriorAsProposal(Proposal): """Proposal that uses the dynamics model as the importance density. This proposal uses the dynamics model to predict the next state, and then uses the predicted state as the prior for the measurement model. """ transition_model: TransitionModel = Property( doc="The transition model used to make the prediction")
[docs] def rvs(self, prior: State, measurement=None, time_interval=None, **kwargs) -> Union[StateVector, StateVectors]: """Generate samples from the proposal. Parameters ---------- state: :class:`~.State` The state to generate samples from. measurement: :class:`~.Detection` the measurement that will preferably used to get time interval if provided(the default is `None`) time_interval: :class:`datetime.time_delta` time interval of the prediction is needed to propagate the states Returns ------- : :class:`~.ParticlePrediction` State with samples drawn from the updated proposal """ if measurement is not None: timestamp = measurement.timestamp time_interval = measurement.timestamp - prior.timestamp else: timestamp = prior.timestamp + time_interval new_state_vector = self.transition_model.function(prior, time_interval=time_interval, **kwargs) return Prediction.from_state(prior, parent=prior, state_vector=new_state_vector, timestamp=timestamp, transition_model=self.transition_model, prior=prior)
[docs] class KFasProposal(Proposal): """This proposal uses the Kalman filter prediction and update steps to generate new set of particles and weights """ predictor: Predictor = Property( doc="predictor to use the various values") updater: Updater = Property( doc="Updater used for update the values")
[docs] def rvs(self, prior: State, measurement: Detection = None, time_interval=None, **kwargs): """Generate samples from the proposal. Use the Kalman filter predictor and updater to create a new distribution Parameters ---------- state: :class:`~.State` The state to generate samples from. measurement: :class:`~.Detection` the measurement that is used in the update step of the Kalman prediction, (the default is `None`) time_interval: :class:`datetime.time_delta` time interval of the prediction is needed to propagate the states Returns ------- : :class:`~.ParticlePrediction` State with samples drawn from the updated proposal """ if measurement is not None: timestamp = measurement.timestamp time_interval = measurement.timestamp - prior.timestamp else: timestamp = prior.timestamp + time_interval if time_interval.total_seconds() == 0: return Prediction.from_state(prior, parent=prior, state_vector=prior.state_vector, timestamp=prior.timestamp, transition_model=self.predictor.transition_model, prior=prior) prior_cls = GaussianState # Default if isinstance(self.predictor, SqrtKalmanPredictor): prior_cls = SqrtGaussianState # Null covariance for the particles null_covar = np.zeros_like(prior.covar) predictions = [ self.predictor.predict( prior_cls(particle_sv, null_covar, prior.timestamp), timestamp=timestamp) for particle_sv in prior.state_vector] if measurement is not None: updates = [self.updater.update(SingleHypothesis(prediction, measurement)) for prediction in predictions] else: updates = predictions # keep the prediction # Draw the samples samples = np.array([state.state_vector.reshape(-1) + mvn.rvs(cov=state.covar).T for state in updates]) # Compute the log of q(x_k|x_{k-1}, y_k) post_log_weights = np.array([mvn.logpdf(sample - update.state_vector.reshape(-1), cov=update.covar) for sample, update in zip(samples, updates)]) pred_state = Prediction.from_state(prior, parent=prior, state_vector=StateVectors(samples.T), timestamp=timestamp, transition_model=self.predictor.transition_model, prior=prior) prior_log_weights = self.predictor.transition_model.logpdf(pred_state, prior, time_interval=time_interval) pred_state.log_weight = (pred_state.log_weight + prior_log_weights - post_log_weights) return pred_state