from collections.abc import Collection
import numpy as np
from scipy.optimize import linear_sum_assignment
from ..base import Property
from ..dataassociator.base import Associator
from ..measures.base import BaseMeasure
from ..types.association import Association, AssociationSet
[docs]
class OneToOneAssociator(Associator):
"""
This a general one to one associator. It can be used to associate objects/values that have a
:class:`~.BaseMeasure` to compare them.
Uses :func:`~scipy.optimize.linear_sum_assignment` to find the minimum (or maximum) measure by
combination objects from two sources.
Notes
-----
As default the association threshold is set to +- a large number (1e10 was chosen arbitrarily).
Infinity can't be used, as it breaks the association algorithm.
"""
measure: BaseMeasure = Property(
doc="This will compare two objects that could be associated together and will provide an "
"indication of the separation between the objects.")
association_threshold: float = Property(
default=None,
doc="The maximum (minimum if :attr:`~.maximise_measure` is true) value from the "
":attr:`~.measure` needed to associate two objects. If the default value of `None` is "
"used then the association threshold is set to plus/minus an arbitrarily large number "
"that shouldn't limit associations.")
maximise_measure: bool = Property(
default=False, doc="Should the association algorithm attempt to maximise or minimise the "
"output of the measure.")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.association_threshold is None:
if self.maximise_measure:
self.association_threshold = -1e10
else:
self.association_threshold = 1e10
[docs]
def associate(self, objects_a: Collection, objects_b: Collection) \
-> tuple[AssociationSet, Collection, Collection]:
"""Associate two collections of objects together. Calculate the measure between each
object. :func:`~scipy.optimize.linear_sum_assignment` is used to find
the minimum (or maximum) measure by combination objects from two sources.
Parameters
----------
objects_a : collection of objects to associate to the objects in `objects_b`
objects_b : collection of objects to associate to the objects in `objects_a`
Returns
-------
AssociationSet
Contains a set of :class:`~.Association` objects
"""
if len(objects_a) == 0 or len(objects_b) == 0:
return AssociationSet(), objects_a, objects_b
distance_matrix = np.empty((len(objects_a), len(objects_b)))
list_of_as = list(objects_a)
list_of_bs = list(objects_b)
# Calculate the measure for each combination of objects
for i, a in enumerate(list_of_as):
for j, b in enumerate(list_of_bs):
distance_matrix[i, j] = self.individual_weighting(a, b)
# Use "shortest path" assignment algorithm on distance matrix
# to assign tracks to nearest detection
# Maximise flag = true for probability instance
# (converts minimisation problem to maximisation problem)
row_ind, col_ind = linear_sum_assignment(
distance_matrix, self.maximise_measure)
# Create dictionary for associations
associations = AssociationSet()
# Generate dict of key/value pairs
for i, j in zip(row_ind, col_ind):
object_a = list_of_as[i]
object_b = list_of_bs[j]
value = distance_matrix[i, j]
# Check association meets threshold
if self.maximise_measure:
if value > self.association_threshold:
# Meets threshold
associations.associations.add(Association({object_a, object_b}))
else: # Minimise measure
if value < self.association_threshold:
# Meets threshold
associations.associations.add(Association({object_a, object_b}))
associated_objects = {obj
for assoc in associations.associations
for obj in assoc.objects}
unassociated_a = set(objects_a) - associated_objects
unassociated_b = set(objects_b) - associated_objects
return associations, unassociated_a, unassociated_b
@property
def fail_value(self):
"""
For an association to be valid is must be over (or under if maximise_measure is True)
(non-inclusive). Therefore, setting the value to the association threshold will result in
the association not taking place.
"""
return self.association_threshold
[docs]
def individual_weighting(self, a, b):
""" This wrapper around the measure function allows for filtering/error checking of the
measure function. It can give an easy access point for subclasses that want to apply
additional filtering or gating."""
measure_output = self.measure(a, b)
if measure_output is None:
return self.fail_value
else:
if self.maximise_measure:
return max(measure_output, self.fail_value)
else:
return min(measure_output, self.fail_value)
[docs]
def association_dict(self, objects_a: Collection, objects_b: Collection) -> dict:
"""
This is a wrapper function around the :meth:`~.associate` function. The two collections of
objects are associated to each other. The objects are entered into a dictionary:
* The dictionary key is an object from either collection.
* The value is the object it is associated to. If the key object isn't associated to an
object then the value is `None`.
As the objects are used as dictionary keys, they must be hashable or a :class:`~.TypeError`
will be raised.
Parameters
----------
objects_a : collection of hashable objects to associate to the objects in ``objects_b``
objects_b : collection of hashable objects to associate to the objects in ``objects_a``
Returns
-------
AssociationSet
Contains a set of :class:`~.Association` objects
"""
output_dict = {}
associations, unassociated_a, unassociated_b = self.associate(objects_a, objects_b)
for assoc in associations.associations:
object_1, object_2 = assoc.objects
output_dict[object_1] = object_2
output_dict[object_2] = object_1
for obj in [*unassociated_a, *unassociated_b]:
output_dict[obj] = None
return output_dict