Source code for stonesoup.sensor.categorical

# -*- coding: utf-8 -*-
from typing import Set, Sequence

from ..base import Property
from ..models.measurement.categorical import CategoricalMeasurementModel
from ..sensor.sensor import Sensor
from ..types.detection import TrueDetection, TrueCategoricalDetection
from ..types.groundtruth import GroundTruthState, GroundTruthPath
from ..types.state import CategoricalState


[docs]class CategoricalSensor(Sensor): measurement_model: CategoricalMeasurementModel = Property( doc="Categorical measurement model used in generating measurements.") category_names: Sequence[str] = Property(default=None, doc="Measurement category names.") @property def ndim_state(self): return self.measurement_model.ndim_state @property def ndim_meas(self): return self.measurement_model.ndim_meas def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.category_names and len(self.category_names) != self.ndim_meas: raise ValueError(f"{len(self.category_names)} category names were given for a sensor " f"which returns vectors of length {self.ndim_meas}")
[docs] def measure(self, ground_truths: Set[GroundTruthState], **kwargs) -> Set[TrueDetection]: """Generate a categorical measurement for a given categorical state. Parameters ---------- ground_truths : Set[:class:`~.CategoricalGroundTruthState`] A set of :class:`~.CategoricalGroundTruthState` Returns ------- Set[:class:`~.TrueCategoricalDetection`] A set of measurements generated from the given states. The timestamps of the measurements are set equal to that of the corresponding states that they were calculated from. Each measurement stores the ground truth path that it was produced from. """ detections = set() for truth in ground_truths: wrong_type = False if isinstance(truth, GroundTruthPath): if not isinstance(truth[-1], CategoricalState): wrong_type = True elif not isinstance(truth, CategoricalState): wrong_type = True if wrong_type: raise ValueError("Categorical sensor can only observe categorical states") measurement_vector = self.measurement_model.function(truth, **kwargs) detection = TrueCategoricalDetection(state_vector=measurement_vector, timestamp=truth.timestamp, measurement_model=self.measurement_model, groundtruth_path=truth, category_names=self.category_names) detections.add(detection) return detections