# Source code for stonesoup.updater.ensemble

from functools import lru_cache

import numpy as np
import scipy

from .kalman import KalmanUpdater
from ..base import Property
from ..types.state import State, EnsembleState
from ..types.array import StateVectors
from ..types.prediction import MeasurementPrediction
from ..types.update import Update
from ..models.measurement import MeasurementModel

[docs]
class EnsembleUpdater(KalmanUpdater):
r"""Ensemble Kalman Filter Updater class
The EnKF is a hybrid of the Kalman updating scheme and the
Monte Carlo approach of the particle filter.

Deliberately structured to resemble the Vanilla Kalman Filter,
:meth:update first calls :meth:predict_measurement function which
proceeds by calculating the predicted measurement, innovation covariance
and measurement cross-covariance. Note however, these are not propagated
explicitly, they are derived from the sample covariance of the ensemble
itself.

Note that the EnKF equations are simpler when written in the following
formalism. Note that h is not necessarily a matrix, but could be a
nonlinear measurement function.

.. math::

\mathbf{A}_k = \hat{X} - E(X)
\mathbf{HA}_k = h(\hat{X} - E(X))

The cross covariance and measurement covariance are given by:

.. math::

P_{xz} = \frac{1}{M-1} \mathbf{A}_k \mathbf{HA}_k^T
P_{zz} = \frac{1}{M-1} A_k \mathbf{HA}_k^T + R

The Kalman gain is then calculated via:

.. math::

K_{k} = P_{xz} P_{zz}^{-1}

and the posterior state mean and covariance are,

.. math::

\mathbf{x}_{k|k} = \mathbf{x}_{k|k-1} + K_k (\mathbf{z}_k - H_k
\mathbf{x}_{k|k-1})

This is returned as a :class:~.EnsembleStateUpdate object.

References
----------

1. J. Hiles, S. M. O’Rourke, R. Niu and E. P. Blasch,
"Implementation of Ensemble Kalman Filters in Stone-Soup,"
International Conference on Information Fusion, (2021)

2. Mandel, Jan. "A brief tutorial on the ensemble Kalman filter."
arXiv preprint arXiv:0901.3725 (2009).
"""

measurement_model: MeasurementModel = Property(
default=None,
doc="A measurement model. This need not be defined if a measurement "
"model is provided in the measurement. If no model specified on "
"construction, or in the measurement, then error will be thrown. "
)

def _check_measurement_prediction(self, hypothesis, **kwargs):
"""Check to see if a measurement prediction exists in the hypothesis.

Parameters
----------
hypothesis : :class:~.SingleHypothesis
the prediction-measurement association hypothesis. This hypothesis
may carry a predicted measurement, or a predicted state. In the
latter case a predicted measurement will be calculated.
Returns
-------
hypothesis : :class:~.SingleHypothesis
the prediction-measurement association hypothesis. This hypothesis
may carry a predicted measurement, or a predicted state. In the
latter case a predicted measurement will be calculated.
"""

# Get the predicted state out of the hypothesis
predicted_state = hypothesis.prediction
# If there is no measurement prediction in the hypothesis then do the
# measurement prediction (and attach it back to the hypothesis).
if hypothesis.measurement_prediction is None:
# Get the measurement model out of the measurement if it's there.
# If not, use the one native to the updater (which might still be
# none)
measurement_model = hypothesis.measurement.measurement_model
measurement_model = self._check_measurement_model(
measurement_model)

# Attach the measurement prediction to the hypothesis
hypothesis.measurement_prediction = self.predict_measurement(
predicted_state, measurement_model=measurement_model, **kwargs)
return hypothesis

[docs]
@lru_cache()
def predict_measurement(self, predicted_state, measurement_model=None,
**kwargs):
r"""Predict the measurement implied by the predicted state mean

Parameters
----------
predicted_state : :class:~.State
The predicted state :math:\mathbf{x}_{k|k-1}
measurement_model : :class:~.MeasurementModel
The measurement model. If omitted, the model in the updater object
is used

Returns
-------
: :class:~.EnsembleMeasurementPrediction

The measurement prediction, :math:\mathbf{z}_{k|k-1}
"""
# If a measurement model is not specified then use the one that's
# native to the updater
measurement_model = self._check_measurement_model(measurement_model)

# Propagate each vector through the measurement model.
pred_meas_ensemble = measurement_model.function(predicted_state, noise=True)

return MeasurementPrediction.from_state(
predicted_state, pred_meas_ensemble)

[docs]
def update(self, hypothesis, **kwargs):
r"""The Ensemble Kalman update method. The Ensemble Kalman filter
simply uses the Kalman Update scheme
to evolve a set or Ensemble
of state vectors as a group. This ensemble of vectors contains all the
information on the system state.

Parameters
----------
hypothesis : :class:~.SingleHypothesis
the prediction-measurement association hypothesis. This hypothesis
may carry a predicted measurement, or a predicted state. In the
latter case a predicted measurement will be calculated.

Returns
-------
: :class:~.EnsembleStateUpdate
The posterior state which contains an ensemble of state vectors
and a timestamp.
"""
# Assigning more readible variable names
hypothesis = self._check_measurement_prediction(hypothesis)
pred_state = hypothesis.prediction
meas_mean = hypothesis.measurement.state_vector
meas_covar = self.measurement_model.covar()
num_vectors = pred_state.num_vectors

# Generate an ensemble of measurements based on measurement
measurement_ensemble = pred_state.generate_ensemble(
mean=meas_mean,
covar=meas_covar,
num_vectors=num_vectors)

# Calculate Kalman Gain according to Dr. Jan Mandel's EnKF formalism.
innovation_ensemble = pred_state.state_vector - pred_state.mean

meas_innovation = (
self.measurement_model.function(pred_state, num_samples=num_vectors)
- self.measurement_model.function(State(pred_state.mean)))

# Calculate Kalman Gain
kalman_gain = 1/(num_vectors-1) * innovation_ensemble @ meas_innovation.T @ \
scipy.linalg.inv(1/(num_vectors-1) * meas_innovation @ meas_innovation.T + meas_covar)

# Calculate Posterior Ensemble
posterior_ensemble = (
pred_state.state_vector
+ kalman_gain@(measurement_ensemble - hypothesis.measurement_prediction.state_vector))

return Update.from_state(pred_state,
posterior_ensemble,
timestamp=hypothesis.measurement.timestamp,
hypothesis=hypothesis)

[docs]
class EnsembleSqrtUpdater(EnsembleUpdater):
r"""The Ensemble Square Root filter propagates the mean and square root
covariance through time, and samples a new ensemble.
This has the advantage of not requiring perturbation of the measurement
which reduces sampling error.
The posterior mean is calculated via:

.. math::

\mathbf{x}_{k|k} = \mathbf{x}_{k|k-1} + K_k (\mathbf{z}_k - H_k
\mathbf{x}_{k|k-1})

The Kalman gain is calculated via:

.. math::

K_{k} = P_{xz} P_{zz}^{-1}

The cross covariance and measurement covariance respectivley are approximated
via the sample square root covariances:

.. math::

P_{xz} \approx \tilde{P}_k (\tilde{Z}_k)^T

P_{zz} \approx \tilde{Z}_k (\tilde{Z}_k)^T + R_k

and the posterior covariance is propaged through time via:

.. math::

\mathbf{P}_{k|k} = \tilde{P}^- B (\tilde{P}^- B)^T

Where :math:\tilde{P}^- represents the prediction square root covariance
and B is the matrix square root of:

.. math::

B = \mathbf{I} - (\tilde{Z}_k)^T [P_{zz}]^{-1} \tilde{Z}_k

The posterior mean and covariance are used to sample a new ensemble.
The resulting state is returned via a :class:~.EnsembleStateUpdate object.

References
----------
1. J. Hiles, S. M. O’Rourke, R. Niu and E. P. Blasch,
"Implementation of Ensemble Kalman Filters in Stone-Soup",
International Conference on Information Fusion, (2021)

2. Livings, Dance, S. L., & Nichols, N. K.
"Unbiased ensemble square root filters."
Physica. D, 237(8), 1021–1028.  (2008)
"""

[docs]
def update(self, hypothesis, **kwargs):
r"""The Ensemble Square Root Kalman update method. The Ensemble Square
Root filter propagates the mean and square root covariance through time,
and samples a new ensemble. This has the advantage of not peturbing the
measurement with statistical noise, and thus is less prone to sampling
error for small ensembles.

Parameters
----------
hypothesis : :class:~.SingleHypothesis
the prediction-measurement association hypothesis. This hypothesis
may carry a predicted measurement, or a predicted state. In the
latter case a predicted measurement will be calculated.
Returns
-------
: :class:~.EnsembleStateUpdate
The posterior state which contains an ensemble of state vectors
and a timestamp.
"""
hypothesis = self._check_measurement_prediction(hypothesis)
pred_state = hypothesis.prediction.mean
pred_state_sqrt_covar = hypothesis.prediction.sqrt_covar
pred_measurement = hypothesis.measurement_prediction.mean
pred_meas_sqrt_covar = hypothesis.measurement_prediction.sqrt_covar
measurement = hypothesis.measurement.state_vector
meas_covar = self.measurement_model.covar()

# Calculate Posterior Mean
cross_covar = pred_state_sqrt_covar @ pred_meas_sqrt_covar.T
innovation_covar = pred_meas_sqrt_covar @ pred_meas_sqrt_covar.T + meas_covar
kalman_gain = cross_covar @ scipy.linalg.inv(innovation_covar)
posterior_mean = pred_state + kalman_gain @ (measurement - pred_measurement)

# Calculate Posterior Covariance. Note that B has no obvious name.
B = scipy.linalg.sqrtm(np.eye(hypothesis.prediction.num_vectors) -
pred_meas_sqrt_covar.T @
scipy.linalg.inv(innovation_covar)

[docs]
@ pred_meas_sqrt_covar)
posterior_covar = pred_state_sqrt_covar @ B @ (pred_state_sqrt_covar @ B).T
posterior_ensemble = EnsembleState.generate_ensemble(posterior_mean,
posterior_covar,
hypothesis.prediction.num_vectors)

return Update.from_state(hypothesis.prediction,
posterior_ensemble, timestamp=hypothesis.measurement.timestamp,
hypothesis=hypothesis)

class LinearisedEnsembleUpdater(EnsembleUpdater):
"""
Implementation of 'The Linearized EnKF Update' algorithm from "Ensemble Kalman Filter with
Bayesian Recursive Update" by Kristen Michaelson, Andrey A. Popov and Renato Zanetti.
Similar to the EnsembleUpdater, but uses a different form of Kalman gain. This alternative form
of the EnKF calculates a separate kalman gain for each ensemble member. This alternative
Kalman gain calculation involves linearization of the measurement. An additional step is
introduced to perform inflation.

References
----------

1. K. Michaelson, A. A. Popov and R. Zanetti,
"Ensemble Kalman Filter with Bayesian Recursive Update"
"""
inflation_factor: float = Property(
default=1.,
doc="Parameter to control inflation")

[docs]
def update(self, hypothesis, **kwargs):
r"""The LinearisedEnsembleUpdater update method. This method includes an additional step
over the EnsembleUpdater update step to perform inflation.

Parameters
----------
hypothesis : :class:~.SingleHypothesis
the prediction-measurement association hypothesis. This hypothesis
may carry a predicted measurement, or a predicted state. In the
latter case a predicted measurement will be calculated.

Returns
-------
: :class:~.EnsembleStateUpdate
The posterior state which contains an ensemble of state vectors
and a timestamp.
"""
# Extract the number of vectors from the prediction
num_vectors = hypothesis.prediction.num_vectors

# Assign measurement prediction if prediction is missing
hypothesis = self._check_measurement_prediction(hypothesis)

# Prior state vector
X0 = hypothesis.prediction.state_vector

# Measurement covariance
R = self.measurement_model.covar()

# Line 1: Compute mean
m = hypothesis.prediction.mean

# Line 2: Compute inflation
X = StateVectors(m + self.inflation_factor * (X0 - m))

# Line 3: Recompute prior covariance
P = 1/(num_vectors-1) * (X0 - m) @ (X0 - m).T

states = list()

# Line 5: Y_hat
Y_hat = hypothesis.measurement_prediction.state_vector

# Line 4
for x, y_hat in zip(X, Y_hat):

# Line 6: Compute Jacobian
H = self.measurement_model.jacobian(State(state_vector=x,
timestamp=hypothesis.prediction.timestamp))

# Line 7: Calculate Innovation
S = H @ P @ H.T + R

# Line 8: Calculate Kalman gain
K = P @ H.T @ scipy.linalg.inv(S)

# Line 9: Recalculate X
x = x + K @ (hypothesis.measurement.state_vector - y_hat)

# Collect state vectors
states.append(x)

# Convert list of state vectors into a StateVectors class
X = StateVectors(np.hstack(states))

return Update.from_state(hypothesis.prediction,
X,
timestamp=hypothesis.measurement.timestamp,
hypothesis=hypothesis)