Source code for stonesoup.mixturereducer.gaussianmixture

# -*- coding: utf-8 -*-
import uuid

import numpy as np
from ordered_set import OrderedSet
from scipy.spatial import KDTree

from ..base import Property
from .base import MixtureReducer
from ..types.state import TaggedWeightedGaussianState, WeightedGaussianState
from ..measures import Mahalanobis
from operator import attrgetter


[docs]class GaussianMixtureReducer(MixtureReducer): """ Gaussian Mixture Reducer class: Reduces the number of components in a Gaussian mixture to increase computational efficiency. See [1] for details. Achieved in three ways: pruning, merging, and truncating. Pruning is the act of removing low weight components from the mixture that fall below a pruning threshold. Merging is the act of combining similar components in the mixture that fall with a distance threshold into a single component. Truncating is the act of removing low weight components from the mixture so that the number of components in the mixture stays below a given threshold. Truncating is performed after the pruning and merging. References ---------- [1] B.-N. Vo and W.-K. Ma, “The Gaussian Mixture Probability Hypothesis Density Filter,” Signal Processing,IEEE Transactions on, vol. 54, no. 11, pp. 4091–4104, 2006.. """ prune_threshold: float = Property(default=1e-9, doc='Threshold for pruning') merge_threshold: float = Property(default=16, doc='Threshold for merging') max_number_components: int = Property(default=np.iinfo(np.int64).max, doc='Maximum number of components to keep ' 'in the Gaussian mixture') merging: bool = Property(default=True, doc='Flag for merging') pruning: bool = Property(default=True, doc='Flag for pruning components whose weight is below ' ':attr:`prune_threshold`') truncating: bool = Property(default=True, doc='Flag for truncating components, keeping a maximum ' 'of :attr:`max_number_components` components') kdtree_max_distance: float = Property( default=None, doc="This defines the max Euclidean search distance for a kd-tree, " "used as part of the merge process as a coarse gate. Default " "`None` where tree isn't used and all components are checked " "against the merge threshold.")
[docs] def reduce(self, components_list): """ Reduce the components of Gaussian Mixture :class:`list` through pruning, merging, and truncating Parameters ---------- components_list : :class:`~.list` The components of Gaussian Mixture Returns ------- :class:`~.list` Reduced components """ if len(components_list) > 0: if self.pruning: components_list = self.prune(components_list) if len(components_list) > 1 and self.merging: components_list = self.merge(components_list) if len(components_list) > self.max_number_components and self.truncating: components_list = self.truncate(components_list) return components_list
[docs] def prune(self, components_list): """ Pruning is the act of removing low weight components from the mixture that fall below a pruning threshold :attr:`prune_threshold`. Parameters ---------- components_list : :class:`~.list` The components of Gaussian Mixture to be pruned Returns ------- remaining_components : :class:`~.GaussianMixtureState` Components that remain after pruning """ # Prune low weight components pruned_weight_sum = 0 for component in components_list: if component.weight < self.prune_threshold: pruned_weight_sum += component.weight remaining_components = [component for component in components_list if component.weight >= self.prune_threshold] # Distribute pruned weights across remaining components for component in remaining_components: component.weight += \ pruned_weight_sum / len(remaining_components) return remaining_components
[docs] def merge_components(self, component_1, component_2): """ Merge two similar components Parameters ---------- component_1 : :class:`~.WeightedGaussianState` First component to be merged component_2 : :class:`~.WeightedGaussianState` Second component to be merged Returns ------- merged_component : :class:`~.WeightedGaussianState` Merged Gaussian component """ weight_sum = component_1.weight+component_2.weight w1 = component_1.weight / weight_sum w2 = component_2.weight / weight_sum merged_mean = component_1.mean*w1 + component_2.mean*w2 merged_covar = component_1.covar*w1 + component_2.covar*w2 mu1_minus_m2 = component_1.mean - component_2.mean merged_covar = merged_covar + \ mu1_minus_m2*mu1_minus_m2.T*w1*w2 merged_weight = component_1.weight + component_2.weight if merged_weight > 1: merged_weight = 1 if isinstance(component_1, TaggedWeightedGaussianState): merged_component = TaggedWeightedGaussianState( state_vector=merged_mean, covar=merged_covar, weight=merged_weight, tag=component_1.tag, timestamp=component_1.timestamp ) elif isinstance(component_1, WeightedGaussianState): merged_component = WeightedGaussianState( state_vector=merged_mean, covar=merged_covar, weight=merged_weight, timestamp=component_1.timestamp ) return merged_component
[docs] def merge(self, components_list): """ Merging is the act of combining similar components in the mixture that fall with a distance threshold :attr:`merge_threshold` into a single component. Parameters ---------- components_list : :class:`~.list` Components of the Gaussian Mixture to be merged Returns ------- :class:`~.list` Merged components """ if self.kdtree_max_distance is not None: tree = KDTree( np.vstack([component.state_vector[:, 0] for component in components_list])) else: tree = None # Sort components by weight remaining_components = OrderedSet(sorted( components_list, key=attrgetter('weight'))) merged_components = [] final_merged_components = [] measure = Mahalanobis() while remaining_components: # Get highest weighted component best_component = remaining_components.pop() # If kdtree_max_distance set, use this as gate if tree: indexes = tree.query_ball_point( best_component.state_vector.ravel(), r=self.kdtree_max_distance) matched_components = {components_list[i] for i in indexes if components_list[i] in remaining_components} else: # Modifying list in loop, so copy used matched_components = remaining_components.copy() # Check for similar components against threshold for component in matched_components: # Calculate distance between component and best component distance = (measure(state1=component, state2=best_component))**2 # Merge if similar if distance < self.merge_threshold: remaining_components.remove(component) best_component = self.merge_components( best_component, component ) # Add potentially merged component to new mixture merged_components.append(best_component) if all(isinstance(component, TaggedWeightedGaussianState) for component in merged_components): # Check for duplicate tags components_tags = set(component.tag for component in merged_components) if len(components_tags) != len(merged_components): # There are duplicatze tags so assign # new tags to the lower weighted shared ones for shared_tag in components_tags: shared_components = sorted( (component for component in merged_components if component.tag == shared_tag), key=attrgetter('weight'), reverse=True) final_merged_components.append(shared_components[0]) for component in shared_components[1:]: # Assign a new uuid component.tag = str(uuid.uuid4()) final_merged_components.append(component) else: # No duplicates final_merged_components.extend(merged_components) else: # Just weighted components (no tags) final_merged_components.extend(merged_components) # Assign merged components to the mixture return final_merged_components
[docs] def truncate(self, components_list): """ Truncating is the act of removing low-weight components from the mixture so that the size of the mixture (number of components) stays within the given threshold :attr:`max_number_components`. Parameters ---------- components_list : :class:`~.list` Components of the Gaussian Mixture to be truncated Returns ------- :class:`~.list` The :attr:`max_number_components` components with the highest weights """ # Sort components by weight from highest to lowest all_components = sorted( components_list, key=attrgetter('weight'), reverse=True) # Make list of truncated components. This function is called only when # len(components_list) > self.max_number_components, so the next line # will never give an index error truncated_components = all_components[self.max_number_components:] truncated_weight_sum = sum([component.weight for component in truncated_components]) # Distribute truncated weights across remaining components remaining_components = all_components[:self.max_number_components] for component in remaining_components: component.weight += \ truncated_weight_sum / self.max_number_components return remaining_components