Source code for stonesoup.dataassociator.tree

# -*- coding: utf-8 -*-
import datetime
from collections import defaultdict
from operator import attrgetter
from typing import Sequence

import numpy as np
import scipy as sp
from scipy.spatial import KDTree
    import rtree
except (ImportError, AttributeError, OSError) as err:
    # AttributeError or OSError raised when libspatialindex missing or unable to load.
    import warnings
    warnings.warn(f"Failed to import 'rtree': {err!r}")
    rtree = None

from .base import DataAssociator
from ..base import Property
from ..models.base import LinearModel
from ..models.measurement import MeasurementModel
from ..predictor import Predictor
from ..types.update import Update
from ..updater import Updater

[docs]class DetectionKDTreeMixIn(DataAssociator): """Detection kd-tree based mixin Construct a kd-tree from detections and then use a :class:`~.Predictor` and :class:`~.Updater` to get prediction of track in measurement space. This is then queried against the kd-tree, and only matching detections are passed to the :attr:`hypothesiser`. Notes ----- This is only suitable where measurements are in same space as each other and at the same timestamp. """ predictor: Predictor = Property( doc="Predict tracks to detection times") updater: Updater = Property( doc="Updater used to get measurement prediction") number_of_neighbours: int = Property( default=None, doc="Number of neighbours to find. Default `None`, which means all " "points within the :attr:`max_distance` are returned.") max_distance: float = Property( default=np.inf, doc="Max distance to return points. Default `inf`") def generate_hypotheses(self, tracks, detections, timestamp, **kwargs): # No need for tree here. if not tracks: return {} if not detections: return {track: self.hypothesiser.hypothesise( track, detections, timestamp, **kwargs) for track in tracks} detections_list = list(detections) tree = KDTree( np.vstack([detection.state_vector[:, 0] for detection in detections_list])) track_detections = defaultdict(set) for track in tracks: prediction = self.predictor.predict(track.state, timestamp) meas_pred = self.updater.predict_measurement(prediction) if self.number_of_neighbours is None: indexes = tree.query_ball_point( meas_pred.state_vector.ravel(), r=self.max_distance) else: _, indexes = tree.query( meas_pred.state_vector.ravel(), k=self.number_of_neighbours, distance_upper_bound=self.max_distance) for index in np.atleast_1d(indexes): # Index is equal to length of detections when no neighbours found if index != len(detections_list): track_detections[track].add(detections_list[index]) return {track: self.hypothesiser.hypothesise( track, track_detections[track], timestamp, **kwargs) for track in tracks}
[docs]class TPRTreeMixIn(DataAssociator): """Detection TPR tree based mixin Construct a TPR-tree. """ measurement_model: MeasurementModel = Property( doc="Measurement model used within the TPR tree") horizon_time: datetime.timedelta = Property( doc="How far the TPR tree should look into the future") pos_mapping: Sequence[int] = Property( default=None, doc="Mapping for position coordinates. Default `None`, which uses the measurement model" "mapping") vel_mapping: Sequence[int] = Property( default=None, doc="Mapping for velocity coordinates. Default `None`, which uses the position mapping " "adding offset of 1 to each") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.pos_mapping is None: self.pos_mapping = self.measurement_model.mapping # if no vel_mapping take position mapping and plus 1 to each dimension # e.g. 0,2 would become 1,3 if self.vel_mapping is None: self.vel_mapping = [i + 1 for i in self.pos_mapping] # Create tree tree_property = rtree.index.Property( type=rtree.index.RT_TPRTree, tpr_horizon=self.horizon_time.total_seconds(), dimension=len(self.pos_mapping)) self._tree = rtree.index.RtreeContainer(properties=tree_property) self._coords = dict() def _track_tree_coordinates(self, track): state_vector = track.state_vector[self.pos_mapping, :] state_delta = 3 * np.sqrt( np.diag(track.covar)[self.pos_mapping].reshape(-1, 1)) vel_vector = track.state_vector[self.vel_mapping, :] vel_delta = 3 * np.sqrt( np.diag(track.covar)[self.vel_mapping].reshape(-1, 1)) min_pos = (state_vector - state_delta).ravel() max_pos = (state_vector + state_delta).ravel() min_vel = (vel_vector - vel_delta).ravel() max_vel = (vel_vector + vel_delta).ravel() return ((*min_pos, *max_pos), (*min_vel, *max_vel), track.timestamp.astimezone(datetime.timezone.utc).timestamp()) def generate_hypotheses(self, tracks, detections, timestamp, **kwargs): # No need for tree here. if not tracks: return dict() # Update the tree in this first section sorted_tracks = sorted(tracks.union(self._tree), key=attrgetter('timestamp')) # Get initial starting time from earliest track c_time = sorted_tracks[0].timestamp for track in sorted_tracks: if track not in self._tree: # track not in tree, so insert it self._coords[track] = self._track_tree_coordinates(track) self._tree.insert(track, self._coords[track]) elif track not in tracks: # track in tree, but not in tracks now; remove it from tree coords = self._coords[track][:-1] \ + ((self._coords[track][-1] - 1e-6, c_time.astimezone(datetime.timezone.utc).timestamp()),) self._tree.delete(track, coords) del self._coords[track] elif isinstance(track.state, Update): # Track in tree, and updated; so update it. coords = self._coords[track][:-1] \ + ((self._coords[track][-1]-1e-6, c_time.astimezone(datetime.timezone.utc).timestamp()),) self._tree.delete(track, coords) self._coords[track] = self._track_tree_coordinates(track) self._tree.insert(track, self._coords[track]) # Set current tree to tracks timestamp c_time = track.timestamp # With tree up to date, find tracks that intersect with detections track_detections = defaultdict(set) for detection in sorted(detections, key=attrgetter('timestamp')): if detection.measurement_model is not None: model = detection.measurement_model else: model = self.measurement_model # Convert detection to track state space if isinstance(model, LinearModel): model_matrix = model.matrix(**kwargs) inv_model_matrix = sp.linalg.pinv(model_matrix) state_meas = (inv_model_matrix @ detection.state_vector)[self.pos_mapping, :] else: state_meas = model.inverse_function( detection, **kwargs)[self.pos_mapping, :] # Find intersections det_time = detection.timestamp.astimezone(datetime.timezone.utc).timestamp() intersected_tracks = self._tree.intersection(( (*state_meas.ravel(), *state_meas.ravel()), (0, 0)*len(self.pos_mapping), (det_time, det_time + 1e-3))) for track in intersected_tracks: track_detections[track].add(detection) return {track: self.hypothesiser.hypothesise( track, track_detections[track], timestamp, **kwargs) for track in tracks}