Source code for stonesoup.metricgenerator.manager

from collections.abc import Sequence, Iterable
from itertools import chain
from typing import Union

from ..base import Property
from ..dataassociator import Associator
from ..platform import Platform
from ..types.detection import Detection
from ..types.groundtruth import GroundTruthPath
from ..types.track import Track
from .base import MetricGenerator, MetricManager


[docs] class MultiManager(MetricManager): """MultiManager class for metric management :class:`~.MetricManager` for the generation of metrics on multiple sets of :class:`~.Track`, :class:`~.Detection` and :class:`~.GroundTruthPath` objects passed in as dictionaries. """ generators: Sequence[MetricGenerator] = Property(doc='List of generators to use', default=None) associator: Associator = Property(doc="Associator to combine tracks and truth", default=None) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.states_sets = dict() self.association_set = None self.metrics = None
[docs] def add_data(self, metric_data: dict = None, overwrite=True): """Adds data to the metric generator Parameters ---------- metric_data : dict of lists or dict of sets of :class:`~.GroundTruthPath`, \ :class:`~.Track`, and/or :class:`~.Detection` Ground truth paths, Tracks, and/or detections to be added to the manager. overwrite: bool Declaring whether pre-existing data will be overwritten. Note that overwriting one key-value pair (e.g. 'tracks') does not affect the others. """ self._add(overwrite, metric_data=metric_data)
def _add(self, overwrite, metric_data): if overwrite: for key, value in metric_data.items(): self.states_sets[key] = set(value) else: for key, value in metric_data.items(): if key not in self.states_sets.keys(): self.states_sets[key] = set(value) else: self.states_sets[key].update(value)
[docs] def associate_tracks(self, generator): """Associate tracks to truth using the associator to produce an :class:`~.AssociationSet` Parameters ---------- generator : :class:`~.MetricGenerator` :class:`~.MetricGenerator` containing `tracks_key` and `truths_key` to extract tracks and truths from :class:`~.MetricManager` for association. """ self.association_set = self.associator.associate_tracks( self.states_sets[generator.tracks_key], self.states_sets[generator.truths_key])
def _get_metrics(self): return self.metrics
[docs] def generate_metrics(self): """Generate metrics using the generators and data that has been added Returns ------- : nested dict of :class:`~.Metric` Metrics generated """ metrics: dict[dict] = {} generators = self.generators if isinstance(self.generators, list) else [self.generators] for generator in generators: if self.associator is not None and \ hasattr(generator, 'tracks_key') and hasattr(generator, 'truths_key'): self.associate_tracks(generator) metric_list = generator.compute_metric(self) if not isinstance(metric_list, list): # If not already a list, force it to be one metric_list = [metric_list] for metric in metric_list: if generator.generator_name not in metrics.keys(): metrics[generator.generator_name] = {metric.title: metric} else: metrics[generator.generator_name][metric.title] = metric self.metrics = metrics return self._get_metrics()
[docs] def list_timestamps(self, generator=None): """List all the unique timestamps used in the tracks and truth associated with a given generator, in order Parameters ---------- generator : :class:`~.MetricGenerator` :class:`~.MetricGenerator` containing `tracks_key` and `truths_key` to extract tracks and truths from :class:`~.MetricManager` to extract timestamps from. Default None to take tracks and truths values from first :class:`~.MetricGenerator` in `self.generators`. Returns ------- : list of :class:`datetime.datetime` unique timestamps present in the internal tracks and truths. """ if generator is None: generator = self.generators[0] timestamps = {state.timestamp for sequence in chain(self.states_sets[generator.tracks_key], self.states_sets[generator.truths_key]) for state in sequence} return sorted(timestamps)
[docs] class SimpleManager(MultiManager): """SimpleManager class for metric management Simple :class:`~.MetricManager` for the generation of metrics on multiple :class:`~.Track`, :class:`~.Detection` and :class:`~.GroundTruthPath` objects. """ generators: Sequence[MetricGenerator] = Property(doc='List of generators to use', default=None) associator: Associator = Property(doc="Associator to combine tracks and truth", default=None) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.states_sets = dict() self.association_set = None self.metrics = None
[docs] def add_data(self, groundtruth_paths: Iterable[Union[GroundTruthPath, Platform]] = None, tracks: Iterable[Track] = None, detections: Iterable[Detection] = None, overwrite=True): """Adds data to the metric generator Parameters ---------- groundtruth_paths : list or set of :class:`~.GroundTruthPath` Ground truth paths to be added to the manager. tracks : list or set of :class:`~.Track` Tracks objects to be added to the manager. detections : list or set of :class:`~.Detection` Detections to be added to the manager. overwrite: bool declaring whether pre-existing data will be overwritten. Note that overwriting one field (e.g. tracks) does not affect the others """ self._add(overwrite, groundtruth_paths=groundtruth_paths, tracks=tracks, detections=detections)
def _add(self, overwrite, **kwargs): if overwrite: for key, value in kwargs.items(): if value is not None: self.states_sets[key] = set(value) else: for key, value in kwargs.items(): if value is not None: if key not in self.states_sets.keys(): self.states_sets[key] = set(value) else: self.states_sets[key].update(value) def _get_metrics(self): metrics = {} for key, value in self.metrics.items(): metrics.update(value) return metrics