Source code for stonesoup.updater.ensemble

from functools import lru_cache

import numpy as np
import scipy

from .kalman import KalmanUpdater
from ..base import Property
from ..types.state import State, EnsembleState
from ..types.array import StateVectors
from ..types.prediction import MeasurementPrediction
from ..types.update import Update
from ..models.measurement import MeasurementModel


[docs] class EnsembleUpdater(KalmanUpdater): r"""Ensemble Kalman Filter Updater class The EnKF is a hybrid of the Kalman updating scheme and the Monte Carlo approach of the particle filter. Deliberately structured to resemble the Vanilla Kalman Filter, :meth:`update` first calls :meth:`predict_measurement` function which proceeds by calculating the predicted measurement, innovation covariance and measurement cross-covariance. Note however, these are not propagated explicitly, they are derived from the sample covariance of the ensemble itself. Note that the EnKF equations are simpler when written in the following formalism. Note that h is not necessarily a matrix, but could be a nonlinear measurement function. .. math:: \mathbf{A}_k = \hat{X} - E(X) \mathbf{HA}_k = h(\hat{X} - E(X)) The cross covariance and measurement covariance are given by: .. math:: P_{xz} = \frac{1}{M-1} \mathbf{A}_k \mathbf{HA}_k^T P_{zz} = \frac{1}{M-1} A_k \mathbf{HA}_k^T + R The Kalman gain is then calculated via: .. math:: K_{k} = P_{xz} P_{zz}^{-1} and the posterior state mean and covariance are, .. math:: \mathbf{x}_{k|k} = \mathbf{x}_{k|k-1} + K_k (\mathbf{z}_k - H_k \mathbf{x}_{k|k-1}) This is returned as a :class:`~.EnsembleStateUpdate` object. References ---------- 1. J. Hiles, S. M. O’Rourke, R. Niu and E. P. Blasch, "Implementation of Ensemble Kalman Filters in Stone-Soup," International Conference on Information Fusion, (2021) 2. Mandel, Jan. "A brief tutorial on the ensemble Kalman filter." arXiv preprint arXiv:0901.3725 (2009). """ measurement_model: MeasurementModel = Property( default=None, doc="A measurement model. This need not be defined if a measurement " "model is provided in the measurement. If no model specified on " "construction, or in the measurement, then error will be thrown. " ) def _check_measurement_prediction(self, hypothesis, **kwargs): """Check to see if a measurement prediction exists in the hypothesis. Parameters ---------- hypothesis : :class:`~.SingleHypothesis` the prediction-measurement association hypothesis. This hypothesis may carry a predicted measurement, or a predicted state. In the latter case a predicted measurement will be calculated. Returns ------- hypothesis : :class:`~.SingleHypothesis` the prediction-measurement association hypothesis. This hypothesis may carry a predicted measurement, or a predicted state. In the latter case a predicted measurement will be calculated. """ # Get the predicted state out of the hypothesis predicted_state = hypothesis.prediction # If there is no measurement prediction in the hypothesis then do the # measurement prediction (and attach it back to the hypothesis). if hypothesis.measurement_prediction is None: # Get the measurement model out of the measurement if it's there. # If not, use the one native to the updater (which might still be # none) measurement_model = hypothesis.measurement.measurement_model measurement_model = self._check_measurement_model( measurement_model) # Attach the measurement prediction to the hypothesis hypothesis.measurement_prediction = self.predict_measurement( predicted_state, measurement_model=measurement_model, **kwargs) return hypothesis
[docs] @lru_cache() def predict_measurement(self, predicted_state, measurement_model=None, measurement_noise=True, **kwargs): r"""Predict the measurement implied by the predicted state mean Parameters ---------- predicted_state : :class:`~.State` The predicted state :math:`\mathbf{x}_{k|k-1}` measurement_model : :class:`~.MeasurementModel` The measurement model. If omitted, the model in the updater object is used measurement_noise : bool Whether to include measurement noise :math:`R` when generating ensemble. Default `True` Returns ------- : :class:`~.EnsembleMeasurementPrediction` The measurement prediction, :math:`\mathbf{z}_{k|k-1}` """ # If a measurement model is not specified then use the one that's # native to the updater measurement_model = self._check_measurement_model(measurement_model) # Propagate each vector through the measurement model. pred_meas_ensemble = measurement_model.function( predicted_state, noise=measurement_noise, **kwargs) return MeasurementPrediction.from_state(predicted_state, state_vector=pred_meas_ensemble)
[docs] def update(self, hypothesis, **kwargs): r"""The Ensemble Kalman update method. The Ensemble Kalman filter simply uses the Kalman Update scheme to evolve a set or Ensemble of state vectors as a group. This ensemble of vectors contains all the information on the system state. Parameters ---------- hypothesis : :class:`~.SingleHypothesis` the prediction-measurement association hypothesis. This hypothesis may carry a predicted measurement, or a predicted state. In the latter case a predicted measurement will be calculated. Returns ------- : :class:`~.EnsembleStateUpdate` The posterior state which contains an ensemble of state vectors and a timestamp. """ # Assigning more readible variable names hypothesis = self._check_measurement_prediction(hypothesis) pred_state = hypothesis.prediction meas_mean = hypothesis.measurement.state_vector meas_covar = self.measurement_model.covar() num_vectors = pred_state.num_vectors # Generate an ensemble of measurements based on measurement measurement_ensemble = pred_state.generate_ensemble( mean=meas_mean, covar=meas_covar, num_vectors=num_vectors) # Calculate Kalman Gain according to Dr. Jan Mandel's EnKF formalism. innovation_ensemble = pred_state.state_vector - pred_state.mean meas_innovation = ( self.measurement_model.function(pred_state, num_samples=num_vectors) - self.measurement_model.function(State(pred_state.mean))) # Calculate Kalman Gain kalman_gain = 1/(num_vectors-1) * innovation_ensemble @ meas_innovation.T @ \ scipy.linalg.inv(1/(num_vectors-1) * meas_innovation @ meas_innovation.T + meas_covar) # Calculate Posterior Ensemble posterior_ensemble = ( pred_state.state_vector + kalman_gain@(measurement_ensemble - hypothesis.measurement_prediction.state_vector)) return Update.from_state(pred_state, posterior_ensemble, timestamp=hypothesis.measurement.timestamp, hypothesis=hypothesis)
[docs] class EnsembleSqrtUpdater(EnsembleUpdater): r"""The Ensemble Square Root filter propagates the mean and square root covariance through time, and samples a new ensemble. This has the advantage of not requiring perturbation of the measurement which reduces sampling error. The posterior mean is calculated via: .. math:: \mathbf{x}_{k|k} = \mathbf{x}_{k|k-1} + K_k (\mathbf{z}_k - H_k \mathbf{x}_{k|k-1}) The Kalman gain is calculated via: .. math:: K_{k} = P_{xz} P_{zz}^{-1} The cross covariance and measurement covariance respectivley are approximated via the sample square root covariances: .. math:: P_{xz} \approx \tilde{P}_k (\tilde{Z}_k)^T P_{zz} \approx \tilde{Z}_k (\tilde{Z}_k)^T + R_k and the posterior covariance is propaged through time via: .. math:: \mathbf{P}_{k|k} = \tilde{P}^- B (\tilde{P}^- B)^T Where :math:`\tilde{P}^-` represents the prediction square root covariance and B is the matrix square root of: .. math:: B = \mathbf{I} - (\tilde{Z}_k)^T [P_{zz}]^{-1} \tilde{Z}_k The posterior mean and covariance are used to sample a new ensemble. The resulting state is returned via a :class:`~.EnsembleStateUpdate` object. References ---------- 1. J. Hiles, S. M. O’Rourke, R. Niu and E. P. Blasch, "Implementation of Ensemble Kalman Filters in Stone-Soup", International Conference on Information Fusion, (2021) 2. Livings, Dance, S. L., & Nichols, N. K. "Unbiased ensemble square root filters." Physica. D, 237(8), 1021–1028. (2008) """
[docs] def update(self, hypothesis, **kwargs): r"""The Ensemble Square Root Kalman update method. The Ensemble Square Root filter propagates the mean and square root covariance through time, and samples a new ensemble. This has the advantage of not peturbing the measurement with statistical noise, and thus is less prone to sampling error for small ensembles. Parameters ---------- hypothesis : :class:`~.SingleHypothesis` the prediction-measurement association hypothesis. This hypothesis may carry a predicted measurement, or a predicted state. In the latter case a predicted measurement will be calculated. Returns ------- : :class:`~.EnsembleStateUpdate` The posterior state which contains an ensemble of state vectors and a timestamp. """ # More readible variable names hypothesis = self._check_measurement_prediction(hypothesis) pred_state = hypothesis.prediction.mean pred_state_sqrt_covar = hypothesis.prediction.sqrt_covar pred_measurement = hypothesis.measurement_prediction.mean pred_meas_sqrt_covar = hypothesis.measurement_prediction.sqrt_covar measurement = hypothesis.measurement.state_vector meas_covar = self.measurement_model.covar() # Calculate Posterior Mean cross_covar = pred_state_sqrt_covar @ pred_meas_sqrt_covar.T innovation_covar = pred_meas_sqrt_covar @ pred_meas_sqrt_covar.T + meas_covar kalman_gain = cross_covar @ scipy.linalg.inv(innovation_covar) posterior_mean = pred_state + kalman_gain @ (measurement - pred_measurement) # Calculate Posterior Covariance. Note that B has no obvious name. B = scipy.linalg.sqrtm(np.eye(hypothesis.prediction.num_vectors) - pred_meas_sqrt_covar.T @ scipy.linalg.inv(innovation_covar)
[docs] @ pred_meas_sqrt_covar) posterior_covar = pred_state_sqrt_covar @ B @ (pred_state_sqrt_covar @ B).T posterior_ensemble = EnsembleState.generate_ensemble(posterior_mean, posterior_covar, hypothesis.prediction.num_vectors) return Update.from_state(hypothesis.prediction, posterior_ensemble, timestamp=hypothesis.measurement.timestamp, hypothesis=hypothesis)
class LinearisedEnsembleUpdater(EnsembleUpdater): """ Implementation of 'The Linearized EnKF Update' algorithm from "Ensemble Kalman Filter with Bayesian Recursive Update" by Kristen Michaelson, Andrey A. Popov and Renato Zanetti. Similar to the EnsembleUpdater, but uses a different form of Kalman gain. This alternative form of the EnKF calculates a separate kalman gain for each ensemble member. This alternative Kalman gain calculation involves linearization of the measurement. An additional step is introduced to perform inflation. References ---------- 1. K. Michaelson, A. A. Popov and R. Zanetti, "Ensemble Kalman Filter with Bayesian Recursive Update" """ inflation_factor: float = Property( default=1., doc="Parameter to control inflation")
[docs] def update(self, hypothesis, **kwargs): r"""The LinearisedEnsembleUpdater update method. This method includes an additional step over the EnsembleUpdater update step to perform inflation. Parameters ---------- hypothesis : :class:`~.SingleHypothesis` the prediction-measurement association hypothesis. This hypothesis may carry a predicted measurement, or a predicted state. In the latter case a predicted measurement will be calculated. Returns ------- : :class:`~.EnsembleStateUpdate` The posterior state which contains an ensemble of state vectors and a timestamp. """ # Extract the number of vectors from the prediction num_vectors = hypothesis.prediction.num_vectors # Assign measurement prediction if prediction is missing hypothesis = self._check_measurement_prediction(hypothesis) # Prior state vector X0 = hypothesis.prediction.state_vector # Measurement covariance R = self.measurement_model.covar() # Line 1: Compute mean m = hypothesis.prediction.mean # Line 2: Compute inflation X = StateVectors(m + self.inflation_factor * (X0 - m)) # Line 3: Recompute prior covariance P = 1/(num_vectors-1) * (X0 - m) @ (X0 - m).T states = list() # Line 5: Y_hat Y_hat = hypothesis.measurement_prediction.state_vector # Line 4 for x, y_hat in zip(X, Y_hat): # Line 6: Compute Jacobian H = self.measurement_model.jacobian(State(state_vector=x, timestamp=hypothesis.prediction.timestamp)) # Line 7: Calculate Innovation S = H @ P @ H.T + R # Line 8: Calculate Kalman gain K = P @ H.T @ scipy.linalg.inv(S) # Line 9: Recalculate X x = x + K @ (hypothesis.measurement.state_vector - y_hat) # Collect state vectors states.append(x) # Convert list of state vectors into a StateVectors class X = StateVectors(np.hstack(states)) return Update.from_state(hypothesis.prediction, X, timestamp=hypothesis.measurement.timestamp, hypothesis=hypothesis)