import warnings
from abc import ABC, abstractmethod
from collections.abc import Collection, Iterable
from datetime import datetime, timedelta
from enum import IntEnum
from itertools import chain
from typing import Optional, Union
import numpy as np
from matplotlib import animation as animation
from matplotlib import pyplot as plt
from matplotlib.legend_handler import HandlerPatch
from matplotlib.lines import Line2D
from matplotlib.patches import Ellipse
from mergedeep import merge
from scipy.integrate import quad
from scipy.optimize import brentq
from scipy.stats import kde
try:
from plotly import colors
except ImportError:
colors = None
try:
import plotly.graph_objects as go
except ImportError:
go = None
from .base import Base, Property
from .models.base import LinearModel, Model
from .types import detection
from .types.array import StateVector
from .types.groundtruth import GroundTruthPath
from .types.metric import SingleTimeMetric
from .types.state import State, StateMutableSequence
from .types.update import Update
[docs]
class Dimension(IntEnum):
"""Dimension Enum class for specifying plotting parameters in the Plotter class.
Used to sanitize inputs for the dimension attribute of Plotter().
Attributes
----------
TWO: int
Specifies 2D plotting for Plotter object
THREE: int
Specifies 3D plotting for Plotter object
"""
ONE = 1 # 1D plotting mode (plot state over time in Plotterly)
TWO = 2 # 2D plotting mode (original plotter.py functionality)
THREE = 3 # 3D plotting mode
class _Plotter(ABC):
@abstractmethod
def plot_ground_truths(self, truths, mapping, label="Ground Truth", **kwargs):
raise NotImplementedError
@abstractmethod
def plot_measurements(self, measurements, mapping, measurement_model=None,
label="Measurements", **kwargs):
raise NotImplementedError
@abstractmethod
def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, label="Tracks",
**kwargs):
raise NotImplementedError
@abstractmethod
def plot_sensors(self, sensors, mapping, label="Sensors", **kwargs):
raise NotImplementedError
def _conv_measurements(self, measurements, mapping, measurement_model=None,
convert_measurements=True) -> \
tuple[dict[detection.Detection, StateVector], dict[detection.Clutter, StateVector]]:
conv_detections = {}
conv_clutter = {}
for state in measurements:
meas_model = state.measurement_model # measurement_model from detections
if meas_model is None:
meas_model = measurement_model # measurement_model from input
if not convert_measurements:
state_vec = state.state_vector[mapping, :]
elif isinstance(meas_model, LinearModel):
model_matrix = meas_model.matrix()
inv_model_matrix = np.linalg.pinv(model_matrix)
state_vec = (inv_model_matrix @ state.state_vector)[mapping, :]
elif isinstance(meas_model, Model):
try:
state_vec = meas_model.inverse_function(state)[mapping, :]
except (NotImplementedError, AttributeError):
warnings.warn('Nonlinear measurement model used with no inverse '
'function available')
continue
else:
warnings.warn('Measurement model type not specified for all detections')
continue
if isinstance(state, detection.Clutter):
# Plot clutter
conv_clutter[state] = (*state_vec, )
elif isinstance(state, detection.Detection):
# Plot detections
conv_detections[state] = (*state_vec, )
else:
warnings.warn(f'Unknown type {type(state)}')
continue
return conv_detections, conv_clutter
[docs]
class Plotter(_Plotter):
"""Plotting class for building graphs of Stone Soup simulations using matplotlib
A plotting class which is used to simplify the process of plotting ground truths,
measurements, clutter and tracks. Tracks can be plotted with uncertainty ellipses or
particles if required. Legends are automatically generated with each plot.
Three dimensional plots can be created using the optional dimension parameter.
Parameters
----------
dimension: enum \'Dimension\'
Optional parameter to specify 2D or 3D plotting. Default is 2D plotting.
plot_timeseries: bool
Specify whether data to be plotted is time series data. Default False
\\*\\*kwargs: dict
Additional arguments to be passed to plot function. For example, figsize (Default is
(10, 6)).
Attributes
----------
fig: matplotlib.figure.Figure
Generated figure for graphs to be plotted on
ax: matplotlib.axes.Axes
Generated axes for graphs to be plotted on
legend_dict: dict
Dictionary of legend handles as :class:`matplotlib.legend_handler.HandlerBase`
and labels as str
"""
def __init__(self, dimension=Dimension.TWO, **kwargs):
figure_kwargs = {"figsize": (10, 6)}
figure_kwargs.update(kwargs)
if isinstance(dimension, type(Dimension.TWO)):
self.dimension = dimension
elif isinstance(dimension, int):
self.dimension = Dimension(dimension)
else:
raise TypeError("%s is an unsupported type for \'dimension\'; "
"expected type %s" % (type(dimension), type(Dimension.TWO)))
# Generate plot axes
self.fig = plt.figure(**figure_kwargs)
if self.dimension is Dimension.TWO: # 2D axes
self.ax = self.fig.add_subplot(1, 1, 1)
self.ax.axis('equal')
else: # 3D axes
self.ax = self.fig.add_subplot(111, projection='3d')
self.ax.axis('auto')
self.ax.set_zlabel("$z$")
self.ax.set_xlabel("$x$")
self.ax.set_ylabel("$y$")
# Create empty dictionary for legend handles and labels - dict used to
# prevent multiple entries with the same label from displaying on legend
# This is new compared to plotter.py
self.legend_dict = {} # create an empty dictionary to hold legend entries
[docs]
def plot_ground_truths(self, truths, mapping, label="Ground Truth", **kwargs):
"""Plots ground truth(s)
Plots each ground truth path passed in to :attr:`truths` and generates a legend
automatically. Ground truths are plotted as dashed lines with default colors.
Users can change linestyle, color and marker using keyword arguments. Any changes
will apply to all ground truths.
Parameters
----------
truths : Collection of :class:`~.GroundTruthPath`
Collection of ground truths which will be plotted. If not a collection and instead a
single :class:`~.GroundTruthPath` type, the argument is modified to be a set to allow
for iteration.
mapping: list
List of items specifying the mapping of the position components of the state space.
label: str
Label for truth data. Default is "Ground Truth"
\\*\\*kwargs: dict
Additional arguments to be passed to plot function. Default is ``linestyle="--"``.
Returns
-------
: list of :class:`matplotlib.artist.Artist`
List of artists that have been added to the axis.
.. deprecated:: 1.5
``label`` has replaced ``truths_label``. In the current implementation ``truths_label``
overrides ``label``. However, use of ``truths_label`` may be removed in the future.
"""
label = kwargs.pop('truths_label', None) or label
truths_kwargs = dict(linestyle="--")
truths_kwargs.update(kwargs)
if not isinstance(truths, Collection) or isinstance(truths, StateMutableSequence):
truths = {truths} # Make a set of length 1
artists = []
for truth in truths:
if self.dimension is Dimension.TWO: # plots the ground truths in xy
artists.extend(
self.ax.plot([state.state_vector[mapping[0]] for state in truth],
[state.state_vector[mapping[1]] for state in truth],
**truths_kwargs))
elif self.dimension is Dimension.THREE: # plots the ground truths in xyz
artists.extend(
self.ax.plot3D([state.state_vector[mapping[0]] for state in truth],
[state.state_vector[mapping[1]] for state in truth],
[state.state_vector[mapping[2]] for state in truth],
**truths_kwargs))
else:
raise NotImplementedError('Unsupported dimension type for truth plotting')
# Generate legend items
if "color" in kwargs:
colour = kwargs["color"]
else:
colour = "black"
truths_handle = Line2D([], [], linestyle=truths_kwargs['linestyle'], color=colour)
self.legend_dict[label] = truths_handle
# Generate legend
artists.append(self.ax.legend(handles=self.legend_dict.values(),
labels=self.legend_dict.keys()))
return artists
[docs]
def plot_measurements(self, measurements, mapping, measurement_model=None,
label="Measurements", convert_measurements=True, **kwargs):
"""Plots measurements
Plots detections and clutter, generating a legend automatically. Detections are plotted as
blue circles by default unless the detection type is clutter.
If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker.
Users can change the color and marker of detections using keyword arguments but not for
clutter detections.
Parameters
----------
measurements : Collection of :class:`~.Detection`
Detections which will be plotted. If measurements is a set of lists it is flattened.
mapping: list
List of items specifying the mapping of the position components of the state space.
measurement_model : :class:`~.Model`, optional
User-defined measurement model to be used in finding measurement state inverses if
they cannot be found from the measurements themselves.
label : str
Label for the measurements. Default is "Measurements".
convert_measurements : bool
Should the measurements be converted from measurement space to state space before
being plotted. Default is True
\\*\\*kwargs: dict
Additional arguments to be passed to plot function for detections. Defaults are
``marker='o'`` and ``color='b'``.
Returns
-------
: list of :class:`matplotlib.artist.Artist`
List of artists that have been added to the axis.
.. deprecated:: 1.5
``label`` has replaced ``measurements_label``. In the current implementation
``measurements_label`` overrides ``label``. However, use of ``measurements_label``
may be removed in the future.
"""
label = kwargs.pop('measurements_label', None) or label
measurement_kwargs = dict(marker='o', color='b')
measurement_kwargs.update(kwargs)
if not isinstance(measurements, Collection):
measurements = {measurements} # Make a set of length 1
if any(isinstance(item, set) for item in measurements):
measurements_set = chain.from_iterable(measurements) # Flatten into one set
else:
measurements_set = measurements
plot_detections, plot_clutter = self._conv_measurements(measurements_set,
mapping,
measurement_model,
convert_measurements)
artists = []
if plot_detections:
detection_array = np.array(list(plot_detections.values()))
# *detection_array.T unpacks detection_array by columns
# (same as passing in detection_array[:,0], detection_array[:,1], etc...)
artists.append(self.ax.scatter(*detection_array.T, **measurement_kwargs))
measurements_handle = Line2D([], [], linestyle='', **measurement_kwargs)
# Generate legend items for measurements
if plot_clutter:
name = label + "\n(Detections)"
else:
name = label
self.legend_dict[name] = measurements_handle
if plot_clutter:
clutter_kwargs = kwargs.copy()
clutter_kwargs.update(dict(marker='2'))
clutter_array = np.array(list(plot_clutter.values()))
artists.append(self.ax.scatter(*clutter_array.T, **clutter_kwargs))
clutter_handle = Line2D([], [], linestyle='', **clutter_kwargs)
# Generate legend items for clutter
name = label + "\n(Clutter)"
self.legend_dict[name] = clutter_handle
# Generate legend
artists.append(self.ax.legend(handles=self.legend_dict.values(),
labels=self.legend_dict.keys()))
return artists
[docs]
def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, label="Tracks",
err_freq=1, same_color=False, **kwargs):
"""Plots track(s)
Plots each track generated, generating a legend automatically. If ``uncertainty=True``
and is being plotted in 2D, error ellipses are plotted. If being plotted in
3D, uncertainty bars are plotted every :attr:`err_freq` measurement, default
plots uncertainty bars at every track step. Tracks are plotted as solid
lines with point markers and default colors. Uncertainty bars are plotted
with a default color which is the same for all tracks.
Users can change linestyle, color and marker using keyword arguments. Uncertainty metrics
will also be plotted with the user defined colour and any changes will apply to all tracks.
Parameters
----------
tracks : Collection of :class:`~.Track`
Collection of tracks which will be plotted. If not a collection, and instead a single
:class:`~.Track` type, the argument is modified to be a set to allow for iteration.
mapping: list
List of items specifying the mapping of the position
components of the state space.
uncertainty : bool
If True, function plots uncertainty ellipses or bars.
particle : bool
If True, function plots particles.
label: str
Label to apply to all tracks for legend.
err_freq: int
Frequency of error bar plotting on tracks. Default value is 1, meaning
error bars are plotted at every track step.
same_color: bool
Should all the tracks have the same color. Default False
\\*\\*kwargs: dict
Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``,
``marker='s'`` for :class:`~.Update` and ``marker='o'`` for other states.
Returns
-------
: list of :class:`matplotlib.artist.Artist`
List of artists that have been added to the axis.
.. deprecated:: 1.5
``label`` has replaced ``track_label``. In the current implementation
``track_label`` overrides ``label``. However, use of ``track_label``
may be removed in the future.
"""
label = kwargs.pop('track_label', None) or label
tracks_kwargs = dict(linestyle='-', marker="s", color=None)
tracks_kwargs.update(kwargs)
if not isinstance(tracks, Collection) or isinstance(tracks, StateMutableSequence):
tracks = {tracks} # Make a set of length 1
# Plot tracks
artists = []
track_colors = {}
for track in tracks:
# Get indexes for Update and non-Update states for styling markers
update_indexes = []
not_update_indexes = []
for n, state in enumerate(track):
if isinstance(state, Update):
update_indexes.append(n)
else:
not_update_indexes.append(n)
data = np.concatenate(
[(getattr(state, 'mean', state.state_vector)[mapping, :])
for state in track],
axis=1)
line = self.ax.plot(
*data,
markevery=update_indexes,
**tracks_kwargs)
artists.extend(line)
if not_update_indexes:
artists.extend(self.ax.plot(
*data[:, not_update_indexes],
marker="o" if "marker" not in kwargs else kwargs['marker'],
linestyle='',
color=plt.getp(line[0], 'color')))
track_colors[track] = plt.getp(line[0], 'color')
if same_color:
tracks_kwargs['color'] = plt.getp(line[0], 'color')
if tracks: # If no tracks `line` won't be defined
# Assuming a single track or all plotted as the same colour then the following will
# work. Otherwise will just render the final track colour.
tracks_kwargs['color'] = plt.getp(line[0], 'color')
# Generate legend items for track
track_handle = Line2D([], [], linestyle=tracks_kwargs['linestyle'],
marker=tracks_kwargs['marker'], color=tracks_kwargs['color'])
self.legend_dict[label] = track_handle
if uncertainty:
if self.dimension is Dimension.TWO:
# Plot uncertainty ellipses
for track in tracks:
HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix
check = err_freq - 1 # plot the first one
for state in track:
check += 1
if check % err_freq:
continue
w, v = np.linalg.eig(HH @ state.covar @ HH.T)
if np.iscomplexobj(w) or np.iscomplexobj(v):
warnings.warn("Can not plot uncertainty for all states due to complex "
"eigenvalues or eigenvectors", UserWarning)
continue
max_ind = np.argmax(w)
min_ind = np.argmin(w)
orient = np.arctan2(v[1, max_ind], v[0, max_ind])
ellipse = Ellipse(xy=state.mean[mapping[:2], 0],
width=2 * np.sqrt(w[max_ind]),
height=2 * np.sqrt(w[min_ind]),
angle=np.rad2deg(orient), alpha=0.2,
color=track_colors[track])
self.ax.add_artist(ellipse)
artists.append(ellipse)
# Generate legend items for uncertainty ellipses
ellipse_handle = Ellipse((0.5, 0.5), 0.5, 0.5, alpha=0.2,
color=tracks_kwargs['color'])
ellipse_label = "Uncertainty"
self.legend_dict[ellipse_label] = ellipse_handle
# Generate legend
artists.append(self.ax.legend(handles=self.legend_dict.values(),
labels=self.legend_dict.keys(),
handler_map={Ellipse: _HandlerEllipse()}))
else:
# Plot 3D error bars on tracks
for track in tracks:
HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix
check = err_freq
for state in track:
if not check % err_freq:
w, v = np.linalg.eig(HH @ state.covar @ HH.T)
xl = state.state_vector[mapping[0]]
yl = state.state_vector[mapping[1]]
zl = state.state_vector[mapping[2]]
x_err = w[0]
y_err = w[1]
z_err = w[2]
artists.extend(
self.ax.plot3D([xl+x_err, xl-x_err], [yl, yl], [zl, zl],
marker="_", color=tracks_kwargs['color']))
artists.extend(
self.ax.plot3D([xl, xl], [yl+y_err, yl-y_err], [zl, zl],
marker="_", color=tracks_kwargs['color']))
artists.extend(
self.ax.plot3D([xl, xl], [yl, yl], [zl+z_err, zl-z_err],
marker="_", color=tracks_kwargs['color']))
check += 1
if particle:
if self.dimension is Dimension.TWO:
# Plot particles
for track in tracks:
for state in track:
data = state.state_vector[mapping[:2], :]
artists.extend(self.ax.plot(data[0], data[1], linestyle='', marker=".",
markersize=1, alpha=0.5))
# Generate legend items for particles
particle_handle = Line2D([], [], linestyle='', color="black", marker='.',
markersize=1)
particle_label = "Particles"
self.legend_dict[particle_label] = particle_handle
# Generate legend
artists.append(self.ax.legend(handles=self.legend_dict.values(),
labels=self.legend_dict.keys()))
else:
raise NotImplementedError("""Particle plotting is not currently supported for
3D visualization""")
else:
artists.append(self.ax.legend(handles=self.legend_dict.values(),
labels=self.legend_dict.keys()))
return artists
[docs]
def plot_sensors(self, sensors, mapping=None, label="Sensors", **kwargs):
"""Plots sensor(s)
Plots sensors. Users can change the color and marker of sensors using keyword
arguments. Default is a black 'x' marker.
Parameters
----------
sensors : Collection of :class:`~.Sensor`
Sensors to plot
mapping: list
List of items specifying the mapping of the position components of the
sensor's position. Default is either [0, 1] or [0, 1, 2] depending on `self.dimension`
label: str
Label to apply to all sensors for legend.
\\*\\*kwargs: dict
Additional arguments to be passed to plot function for sensors. Defaults are
``marker='x'`` and ``color='black'``.
Returns
-------
: list of :class:`matplotlib.artist.Artist`
List of artists that have been added to the axis.
.. deprecated:: 1.5
``label`` has replaced ``sensor_label``. In the current implementation
``sensor_label`` overrides ``label``. However, use of ``sensor_label``
may be removed in the future.
"""
label = kwargs.pop('sensor_label', None) or label
sensor_kwargs = dict(marker='x', color='black')
sensor_kwargs.update(kwargs)
if not isinstance(sensors, Collection):
sensors = {sensors} # Make a set of length 1
if mapping is None:
mapping = list(range(self.dimension))
artists = []
for sensor in sensors:
if self.dimension is Dimension.TWO: # plots the sensors in xy
artists.append(self.ax.scatter(sensor.position[mapping[0]],
sensor.position[mapping[1]],
**sensor_kwargs))
elif self.dimension is Dimension.THREE: # plots the sensors in xyz
artists.extend(self.ax.plot3D(sensor.position[mapping[0]],
sensor.position[mapping[1]],
sensor.position[mapping[2]],
**sensor_kwargs))
else:
raise NotImplementedError('Unsupported dimension type for sensor plotting')
self.legend_dict[label] = Line2D([], [], linestyle='', **sensor_kwargs)
artists.append(self.ax.legend(handles=self.legend_dict.values(),
labels=self.legend_dict.keys()))
return artists
[docs]
def set_equal_3daxis(self, axes=None):
"""Plots minimum/maximum points with no linestyle to increase the plotting region to
simulate `.ax.axis('equal')` from matplotlib 2d plots which is not possible using 3d
projection.
Parameters
----------
axes: list
List of dimension index specifying the equal axes, equal x and y = [0,1].
Default is x,y [0,1].
"""
if not axes:
axes = [0, 1]
if self.dimension is Dimension.THREE:
min_xyz = [0, 0, 0]
max_xyz = [0, 0, 0]
for n in range(3):
for line in self.ax.lines:
min_xyz[n] = np.min([min_xyz[n], *line.get_data_3d()[n]])
max_xyz[n] = np.max([max_xyz[n], *line.get_data_3d()[n]])
extremes = np.max([x - y for x, y in zip(max_xyz, min_xyz)])
equal_axes = [0, 0, 0]
for i in axes:
equal_axes[i] = 1
lower = ([np.mean([x, y]) for x, y in zip(max_xyz, min_xyz)] - extremes/2) * equal_axes
upper = ([np.mean([x, y]) for x, y in zip(max_xyz, min_xyz)] + extremes/2) * equal_axes
ghosts = GroundTruthPath(states=[State(state_vector=lower),
State(state_vector=upper)])
self.ax.plot3D([state.state_vector[0] for state in ghosts],
[state.state_vector[1] for state in ghosts],
[state.state_vector[2] for state in ghosts],
linestyle="")
[docs]
def plot_density(self, state_sequences: Collection[StateMutableSequence],
index: Union[int, None] = -1,
mapping=(0, 2), n_bins=300, **kwargs):
"""
Parameters
----------
state_sequences : a collection of :class:`~.StateMutableSequence`
Set of tracks which will be plotted. If not a set, and instead a single
:class:`~.Track` type, the argument is modified to be a set to allow for iteration.
index: int
Which index of the StateMutableSequences should be plotted.
Default value is '-1' which is the last state in the sequences.
index can be set to None if all indices of the sequence should be included in the plot
mapping: list
List of 2 items specifying the mapping of the x and y components of the state space.
n_bins : int
Size of the bins used to group the data
\\*\\*kwargs: dict
Additional arguments to be passed to pcolormesh function.
"""
if len(state_sequences) == 0:
raise ValueError("Skipping plotting density due to state_sequences being empty.")
if index is None: # Plot all states in the sequence
x = np.array([a_state.state_vector[mapping[0]]
for a_state_sequence in state_sequences
for a_state in a_state_sequence])
y = np.array([a_state.state_vector[mapping[1]]
for a_state_sequence in state_sequences
for a_state in a_state_sequence])
else: # Only plot one state out of the sequences
x = np.array([a_state_sequence.states[index].state_vector[mapping[0]]
for a_state_sequence in state_sequences])
y = np.array([a_state_sequence.states[index].state_vector[mapping[1]]
for a_state_sequence in state_sequences])
if np.allclose(x, y, atol=1e-10):
raise ValueError("Skipping plotting density due to x and y values are the same. "
"This leads to a singular matrix in the kde function.")
# Evaluate a gaussian kde on a regular grid of n_bins x n_bins over data extents
k = kde.gaussian_kde([x, y])
xi, yi = np.mgrid[x.min():x.max():n_bins * 1j, y.min():y.max():n_bins * 1j]
zi = k(np.vstack([xi.flatten(), yi.flatten()]))
# Make the plot
self.ax.pcolormesh(xi, yi, zi.reshape(xi.shape), shading='auto', **kwargs)
# Ellipse legend patch (used in Tutorial 3)
[docs]
@staticmethod
def ellipse_legend(ax, label_list, color_list, **kwargs):
"""Adds an ellipse patch to the legend on the axes. One patch added for each item in
`label_list` with the corresponding color from `color_list`.
Parameters
----------
ax : matplotlib.axes.Axes
Looks at the plot axes defined
label_list : list of str
Takes in list of strings intended to label ellipses in legend
color_list : list of str
Takes in list of colors corresponding to string/label
Must be the same length as label_list
\\*\\*kwargs: dict
Additional arguments to be passed to plot function. Default is ``alpha=0.2``.
"""
ellipse_kwargs = dict(alpha=0.2)
ellipse_kwargs.update(kwargs)
legend = ax.legend(handler_map={Ellipse: _HandlerEllipse()})
handles, labels = ax.get_legend_handles_labels()
for color in color_list:
handle = Ellipse((0.5, 0.5), 0.5, 0.5, color=color, **ellipse_kwargs)
handles.append(handle)
for label in label_list:
labels.append(label)
legend._legend_box = None
legend._init_legend_box(handles, labels)
legend._set_loc(legend._loc)
legend.set_title(legend.get_title().get_text())
class _HandlerEllipse(HandlerPatch):
def create_artists(self, legend, orig_handle,
xdescent, ydescent, width, height, fontsize, trans):
center = 0.5*width - 0.5*xdescent, 0.5*height - 0.5*ydescent
p = Ellipse(xy=center, width=width + xdescent,
height=height + ydescent)
self.update_prop(p, orig_handle, legend)
p.set_transform(trans)
return [p]
[docs]
class MetricPlotter(ABC):
"""Class for plotting Stone Soup metrics using matplotlib
A plotting class which is used to simplify the process of plotting metrics.
Legends are automatically generated with each plot.
"""
def __init__(self):
self.fig = None
self.axes = None
self.plottable_metrics = list()
[docs]
def plot_metrics(self, metrics, generator_names=None, metric_names=None,
combine_plots=True, **kwargs):
"""Plots metrics
Plots each plottable metric passed in to :attr:`metrics` across a series of subplots
and generates legend(s) automatically. Metrics are plotted as lines with default colors.
Users can change linestyle, color and marker or other features using keyword arguments.
Any changes will apply to all metrics.
Parameters
----------
metrics : dict of :class:`~.Metric`
Dictionary of generated metrics to be plotted.
generator_names: list of str
Generator(s) to extract specific metrics from :attr:`metrics` for plotting.
Default None to take all metrics.
metric_names: list of str
Specific metric(s) to extract from :class:`~.MetricGenerator` for plotting.
Default None to take all metrics in generators.
combine_plots: bool
Plot metrics of same type on the same subplot. Default True.
\\*\\*kwargs: dict
Additional arguments to be passed to plot function. Default is ``linestyle="-"``.
Returns
-------
: :class:`matplotlib.pyplot.figure`
Figure containing subplots displaying all plottable metrics.
"""
for metric_dict in metrics.values():
for metric_name, metric in metric_dict.items():
if isinstance(metric.value, list) \
and all(isinstance(x, SingleTimeMetric) for x in metric.value):
self.plottable_metrics.append(metric_name)
metrics_kwargs = dict(linestyle="-")
metrics_kwargs.update(kwargs)
generator_names = list(metrics.keys()) if generator_names is None else generator_names
# warning for user input metrics that will not be plotted
if metric_names is not None:
for metric_name in metric_names:
if metric_name not in self.plottable_metrics:
warnings.warn(f"{metric_name} "
f"is not a plottable metric and will not be plotted")
else:
metric_names = self.extract_metric_types(metrics)
metrics_to_plot = self._extract_plottable_metrics(metrics, generator_names, metric_names)
if combine_plots:
self.combine_plots(metrics_to_plot, metrics_kwargs)
else:
self.plot_separately(metrics_to_plot, metrics_kwargs)
def _extract_plottable_metrics(self, metrics, generator_names, metric_names):
"""
Extract all plottable metrics from dict of generated metrics.
Parameters
----------
metrics: dict of :class:`~.Metric`
Dictionary of generated metrics.
generator_names: list of str
Generator(s) to extract specific metrics from :attr:`metrics` for plotting.
metric_names: list of str
Specific metric(s) to extract from :class:`~.MetricGenerator` for plotting.
Returns
-------
: dict
Dict of all plottable metrics.
"""
metrics_dict = dict()
for generator_name in generator_names:
for metric_name in metric_names:
if metric_name in metrics[generator_name].keys() and \
metric_name in self.plottable_metrics:
if generator_name not in metrics_dict.keys():
metrics_dict[generator_name] = \
{metric_name: metrics[generator_name][metric_name]}
else:
metrics_dict[generator_name][metric_name] = \
metrics[generator_name][metric_name]
return metrics_dict
def _count_subplots(self, metrics_to_plot, combine_plots):
"""
Calculate number of subplots needed to plot all metrics.
Parameters
----------
metrics_to_plot: dict of :class:`~.Metric`
Dictionary of metrics to be plotted.
combine_plots: bool
Specifies whether same metric types should be plotted on same subplot.
Returns
-------
: int
Number of subplots to generate.
"""
if combine_plots:
metric_types = self.extract_metric_types(metrics_to_plot)
number_of_subplots = len(metric_types)
else:
number_of_subplots = 0
for generator in metrics_to_plot.keys():
number_of_subplots += len(metrics_to_plot[generator])
return number_of_subplots
[docs]
def combine_plots(self, metrics_to_plot, metrics_kwargs):
"""
Generates one subplot for each different metric type and plots metrics of the same
type on same subplot. Metrics are plotted over time.
Parameters
----------
metrics_to_plot: dict of :class:`~.Metric`
Dictionary of metrics to plot.
metrics_kwargs: dict
Keyword arguments to be passed to plot function.
Returns
-------
: :class:`matplotlib.pyplot.figure`
Figure containing subplots displaying metrics.
"""
# determine how many plots required - equal to number of metric types
number_of_subplots = self._count_subplots(metrics_to_plot, True)
# initialise each subplot
self.fig, axes = plt.subplots(number_of_subplots, figsize=(10, 6*number_of_subplots))
self.fig.subplots_adjust(hspace=0.3)
# extract data for each subplot and plot it
metric_types = self.extract_metric_types(metrics_to_plot)
self.axes = axes if isinstance(axes, Iterable) else [axes]
# generate colour map for lines to be plotted
if 'color' not in metrics_kwargs.keys():
colour_map = plt.cm.rainbow(np.linspace(0, 1, len(metrics_to_plot.keys())))
else:
colour_map = metrics_kwargs['color']
metrics_kwargs.pop('color')
for metric_type, axis in zip(list(metric_types), self.axes):
artists = []
legend_dict = {}
colour_map_copy = iter(colour_map.copy())
for generator in metrics_to_plot.keys():
for metric in metrics_to_plot[generator].keys():
if metric == metric_type:
colour = next(colour_map_copy)
metric_values = metrics_to_plot[generator][metric].value
artists.extend(axis.plot([_.timestamp for _ in metric_values],
[_.value for _ in metric_values],
color=colour,
**metrics_kwargs))
metric_handle = Line2D([], [], linestyle=metrics_kwargs['linestyle'],
color=colour)
legend_dict[generator] = metric_handle
# Generate legend
artists.append(axis.legend(handles=legend_dict.values(),
labels=legend_dict.keys()))
y_label = metric_type.split(' at times')[0]
artists.extend(axis.set(title=metric_type.split(' at times')[0],
xlabel="Time", ylabel=y_label))
[docs]
def plot_separately(self, metrics_to_plot, metrics_kwargs):
"""
Generates one subplot for each different individual metric and plots metric
values over time.
Parameters
----------
metrics_to_plot: dict of :class:`~.Metric`
Dictionary of metrics to plot.
metrics_kwargs: dict
Keyword arguments to be passed to plot function.
Returns
-------
: :class:`matplotlib.pyplot.figure`
Figure containing subplots displaying metrics.
"""
metrics_kwargs['color'] = metrics_kwargs['color'] if \
'color' in metrics_kwargs.keys() else 'blue'
# determine how many plots required - equal to number of metrics within the generators
number_of_subplots = self._count_subplots(metrics_to_plot, False)
# initialise each plot
self.fig, axes = plt.subplots(number_of_subplots, figsize=(10, 6*number_of_subplots))
self.fig.subplots_adjust(hspace=0.3)
# extract data for each plot and plot it
all_metrics = {}
for generator in metrics_to_plot.keys():
for metric in list(metrics_to_plot[generator].keys()):
all_metrics[f'{generator}: {metric}'] = metrics_to_plot[generator][metric]
self.axes = axes if isinstance(axes, Iterable) else [axes]
for metric, axis in zip(all_metrics.keys(), self.axes):
y_label = str(all_metrics[metric].title).split(' at times')[0]
axis.set(title=str(all_metrics[metric].title), xlabel='Time', ylabel=y_label)
metric_values = all_metrics[metric].value
axis.plot([_.timestamp for _ in metric_values],
[_.value for _ in metric_values],
**metrics_kwargs)
# Generate legend
metric_handle = Line2D([], [], linestyle=metrics_kwargs['linestyle'],
color=metrics_kwargs['color'])
axis.legend(handles=[metric_handle],
labels=[metric.split(' at times')[0]])
[docs]
def set_fig_title(self, title):
"""
Set title for the figure.
Parameters
----------
title: str
Figure title text.
Returns
-------
Text instance of figure title.
"""
self.fig.suptitle(t=title)
[docs]
def set_ax_title(self, titles):
"""
Set axis titles for each axis in figure.
Parameters
----------
titles: list of str
List of strings for title text for each axis.
Returns
-------
Text instance of axis titles.
"""
for axis, title in zip(self.axes, titles):
axis.set(title=title)
[docs]
class Plotterly(_Plotter):
"""Plotting class for building graphs of Stone Soup simulations using plotly
A plotting class which is used to simplify the process of plotting ground truths,
measurements, clutter and tracks. Tracks can be plotted with uncertainty ellipses or
particles if required. Legends are automatically generated with each plot.
Three-dimensional plots can be created using the optional dimension parameter.
Parameters
----------
dimension: enum \'Dimension\'
Optional parameter to specify 1D, 2D, or 3D plotting.
axis_labels: list
Optional parameter to specify the axis labels for non-xy dimensions. Default None, i.e.,
"x" and "y".
\\*\\*kwargs: dict
Additional arguments to be passed to the Plotly.graph_objects Figure.
Attributes
----------
fig: plotly.graph_objects.Figure
Generated figure to display graphs.
"""
def __init__(self, dimension=Dimension.TWO, axis_labels=None, **kwargs):
if dimension != Dimension.ONE:
if not axis_labels:
axis_labels = ["x", "y"]
else:
if axis_labels:
if len(axis_labels) == 1:
axis_labels = ["Time", axis_labels[0]]
else:
axis_labels = ["Time", "x"]
if go is None:
raise RuntimeError("Usage of Plotterly plotter requires installation of `plotly`")
self.dimension = Dimension(dimension) # allows 1, 2, 3,
# Dimension(1), Dimension(2) or Dimension(3)
from plotly import colors
layout_kwargs = dict(
xaxis_title=axis_labels[0],
yaxis_title=axis_labels[1],
colorway=colors.qualitative.Plotly, # Needed to match colours later.
)
if self.dimension == 3:
layout_kwargs.update(dict(scene_aspectmode='data')) # auto shapes fig to fit data well
merge(layout_kwargs, kwargs)
# Generate plot axes
self.fig = go.Figure(layout=layout_kwargs)
@staticmethod
def _format_state_text(state):
text = []
text.append(type(state).__name__)
text.append(getattr(state, 'mean', state.state_vector))
text.append(state.timestamp)
text.extend([f"{key}: {value}" for key, value in getattr(state, 'metadata', {}).items()])
return "<br>".join((str(t) for t in text))
def _check_mapping(self, mapping):
if len(mapping) == 0:
raise ValueError("No indices provided in mapping.")
elif len(mapping) != self.dimension:
raise TypeError("Plotter dimension is not same as the mapping dimension.")
[docs]
def plot_ground_truths(self, truths, mapping, label="Ground Truth", **kwargs):
"""Plots ground truth(s)
Plots each ground truth path passed in to :attr:`truths` and generates a legend
automatically. Ground truths are plotted as dashed lines with default colors.
Users can change line style, color and marker using keyword arguments. Any changes
will apply to all ground truths.
Parameters
----------
truths : Collection of :class:`~.GroundTruthPath`
Collection of ground truths which will be plotted. If not a collection,
and instead a single :class:`~.GroundTruthPath` type, the argument is modified to be a
set to allow for iteration.
mapping: list
List of items specifying the mapping of the position components of the state space.
label: str
Label for truth data. Default is "Ground Truth"
\\*\\*kwargs: dict
Additional arguments to be passed to scatter function. Default is
``line=dict(dash="dash")``.
.. deprecated:: 1.5
``label`` has replaced ``truths_label``. In the current implementation
``truths_label`` overrides ``label``. However, use of ``truths_label``
may be removed in the future.
"""
label = kwargs.pop('truths_label', None) or label
if not isinstance(truths, Collection) or isinstance(truths, StateMutableSequence):
truths = {truths}
self._check_mapping(mapping) # ensure mapping is compatible with plotter dimension
truths_kwargs = dict(
mode="lines", line=dict(dash="dash"), legendgroup=label, legendrank=100,
name=label)
if self.dimension == 3: # make ground truth line thicker so easier to see in 3d plot
truths_kwargs.update(dict(line=dict(width=8, dash="longdashdot")))
merge(truths_kwargs, kwargs)
add_legend = truths_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}
for truth in truths:
scatter_kwargs = truths_kwargs.copy()
if add_legend:
scatter_kwargs['showlegend'] = True
add_legend = False
else:
scatter_kwargs['showlegend'] = False
if self.dimension == 1:
self.fig.add_scatter(
x=[state.timestamp for state in truth],
y=[state.state_vector[mapping[0]] for state in truth],
text=[self._format_state_text(state) for state in truth],
**scatter_kwargs)
elif self.dimension == 2:
self.fig.add_scatter(
x=[state.state_vector[mapping[0]] for state in truth],
y=[state.state_vector[mapping[1]] for state in truth],
text=[self._format_state_text(state) for state in truth],
**scatter_kwargs)
elif self.dimension == 3:
self.fig.add_scatter3d(
x=[state.state_vector[mapping[0]] for state in truth],
y=[state.state_vector[mapping[1]] for state in truth],
z=[state.state_vector[mapping[2]] for state in truth],
text=[self._format_state_text(state) for state in truth],
**scatter_kwargs)
[docs]
def plot_measurements(self, measurements, mapping, measurement_model=None,
label="Measurements", convert_measurements=True, **kwargs):
"""Plots measurements
Plots detections and clutter, generating a legend automatically. Detections are plotted as
blue circles by default unless the detection type is clutter.
If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker.
Users can change the color and marker of detections using keyword arguments but not for
clutter detections.
Parameters
----------
measurements : Collection of :class:`~.Detection`
Detections which will be plotted. If measurements is a set of lists it is flattened.
mapping: list
List of items specifying the mapping of the position components of the state space.
measurement_model : :class:`~.Model`, optional
User-defined measurement model to be used in finding measurement state inverses if
they cannot be found from the measurements themselves.
label : str
Label for the measurements. Default is "Measurements".
convert_measurements: bool
Should the measurements be converted from measurement space to state space before
being plotted. Default is True
\\*\\*kwargs: dict
Additional arguments to be passed to scatter function for detections. Defaults are
``marker=dict(color="#636EFA")``.
.. deprecated:: 1.5
``label`` has replaced ``measurements_label``. In the current implementation
``measurements_label`` overrides ``label``. However, use of ``measurements_label``
may be removed in the future.
"""
label = kwargs.pop('measurements_label', None) or label
if not isinstance(measurements, Collection):
measurements = {measurements}
if any(isinstance(item, set) for item in measurements):
measurements_set = chain.from_iterable(measurements) # Flatten into one set
else:
measurements_set = set(measurements)
self._check_mapping(mapping)
plot_detections, plot_clutter = self._conv_measurements(measurements_set,
mapping,
measurement_model,
convert_measurements)
if plot_detections:
if plot_clutter:
name = label + "<br>(Detections)"
else:
name = label
measurement_kwargs = dict(
mode='markers', marker=dict(color='#636EFA'),
name=name, legendgroup=name, legendrank=200)
if self.dimension == 3: # make markers smaller in 3d plot
measurement_kwargs.update(dict(marker=dict(size=4, color='#636EFA')))
merge(measurement_kwargs, kwargs)
if measurement_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}:
measurement_kwargs['showlegend'] = True
else:
measurement_kwargs['showlegend'] = False
detection_array = np.asarray(list(plot_detections.values()), dtype=np.float64)
if self.dimension == 1:
self.fig.add_scatter(
x=[state.timestamp for state in plot_detections.keys()],
y=detection_array[:, 0],
text=[self._format_state_text(state) for state in plot_detections.keys()],
**measurement_kwargs,
)
elif self.dimension == 2:
self.fig.add_scatter(
x=detection_array[:, 0],
y=detection_array[:, 1],
text=[self._format_state_text(state) for state in plot_detections.keys()],
**measurement_kwargs,
)
elif self.dimension == 3:
self.fig.add_scatter3d(
x=detection_array[:, 0],
y=detection_array[:, 1],
z=detection_array[:, 2],
text=[self._format_state_text(state) for state in plot_detections.keys()],
**measurement_kwargs,
)
if plot_clutter:
name = label + "<br>(Clutter)"
clutter_kwargs = dict(
mode='markers', marker=dict(symbol="star-triangle-up", color='#FECB52'),
name=name, legendgroup=name, legendrank=210)
if self.dimension == 3: # update - star-triangle-up not in 3d plotly
clutter_kwargs.update(dict(marker=dict(size=4, symbol="diamond",
color='#FECB52')))
merge(clutter_kwargs, kwargs)
if clutter_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}:
clutter_kwargs['showlegend'] = True
else:
clutter_kwargs['showlegend'] = False
clutter_array = np.asarray(list(plot_clutter.values()), dtype=np.float64)
if self.dimension == 1:
self.fig.add_scatter(
x=[state.timestamp for state in plot_clutter.keys()],
y=clutter_array[:, 0],
text=[self._format_state_text(state) for state in plot_clutter.keys()],
**clutter_kwargs,
)
elif self.dimension == 2:
self.fig.add_scatter(
x=clutter_array[:, 0],
y=clutter_array[:, 1],
text=[self._format_state_text(state) for state in plot_clutter.keys()],
**clutter_kwargs,
)
elif self.dimension == 3:
self.fig.add_scatter3d(
x=clutter_array[:, 0],
y=clutter_array[:, 1],
z=clutter_array[:, 2],
text=[self._format_state_text(state) for state in plot_clutter.keys()],
**clutter_kwargs,
)
[docs]
def get_next_color(self):
"""
Find the colour of the next plot. This approach to getting colour isn't ideal, but should
work in most cases...
Returns
-------
dist : str
Hex string for a colour
"""
# Find how many sequences have been plotted so far. The current plot has already been added
# to fig.data, so -1 is needed
figure_index = len(self.fig.data) - 1
# Get the list of colours used for plotting
colorway = self.fig.layout.colorway
max_index = len(colorway)
# Use the modulo operator to limit the colour index to limits of the colorway.
# If figure_index > max_index then colours will be reused
color_index = figure_index % max_index
return colorway[color_index]
[docs]
def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, label="Tracks",
ellipse_points=30, err_freq=1, same_color=False, **kwargs):
"""Plots track(s)
Plots each track generated, generating a legend automatically. If ``uncertainty=True``
error ellipses are plotted.
Tracks are plotted as solid lines with point markers and default colors.
Users can change line style, color and marker using keyword arguments.
Parameters
----------
tracks : Collection of :class:`~.Track`
Collection of tracks which will be plotted. If not a collection, and instead a single
:class:`~.Track` type, the argument is modified to be a set to allow for iteration.
mapping: list
List of items specifying the mapping of the position
components of the state space.
uncertainty : bool
If True, function plots uncertainty ellipses.
particle : bool
If True, function plots particles.
label: str
Label to apply to all tracks for legend.
ellipse_points: int
Number of points for polygon approximating ellipse shape
err_freq: int
Frequency of error bar plotting on tracks. Default value is 1, meaning
error bars are plotted at every track step.
same_color: bool
Should all the tracks have the same colour. Default False
\\*\\*kwargs: dict
Additional arguments to be passed to scatter function. Defaults are
``marker=dict(symbol='square')`` for :class:`~.Update` and
``marker=dict(symbol='circle')`` for other states.
.. deprecated:: 1.5
``label`` has replaced ``track_label``. In the current implementation
``track_label`` overrides ``label``. However, use of ``track_label``
may be removed in the future.
"""
label = kwargs.pop('track_label', None) or label
if not isinstance(tracks, Collection) or isinstance(tracks, StateMutableSequence):
tracks = {tracks} # Make a set of length 1
self._check_mapping(mapping) # check size of mapping against dimension of plotter
# Plot tracks
track_colors = {}
track_kwargs = dict(mode='markers+lines', legendgroup=label, legendrank=300)
if self.dimension == 3: # change visuals to work well in 3d
track_kwargs.update(dict(line=dict(width=7)), marker=dict(size=4))
merge(track_kwargs, kwargs)
add_legend = track_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}
if same_color:
color = track_kwargs.get('marker', {}).get('color') or \
track_kwargs.get('line', {}).get('color')
# Set the colour if it hasn't already been set
if color is None:
track_kwargs['marker'] = track_kwargs.get('marker', {})
track_kwargs['marker']['color'] = self.get_next_color()
for track in tracks:
scatter_kwargs = track_kwargs.copy()
scatter_kwargs['name'] = track.id
if add_legend:
scatter_kwargs['name'] = label
scatter_kwargs['showlegend'] = True
add_legend = False
else:
scatter_kwargs['showlegend'] = False
scatter_kwargs['marker'] = scatter_kwargs.get('marker', {}).copy()
if 'symbol' not in scatter_kwargs['marker']:
scatter_kwargs['marker']['symbol'] = [
'square' if isinstance(state, Update) else 'circle' for state in track]
if len(self.fig.data) > 0:
track_colors[track] = (self.fig.data[-1].line.color
or self.fig.data[-1].marker.color
or self.get_next_color())
else:
track_colors[track] = self.get_next_color()
if self.dimension == 1: # plot 1D tracks
if uncertainty or particle:
raise NotImplementedError
self.fig.add_scatter(
x=[state.timestamp for state in track],
y=[float(getattr(state, 'mean', state.state_vector)[mapping[0]])
for state in track],
text=[self._format_state_text(state) for state in track],
**scatter_kwargs)
elif self.dimension == 2: # plot 2D tracks
self.fig.add_scatter(
x=[float(getattr(state, 'mean', state.state_vector)[mapping[0]])
for state in track],
y=[float(getattr(state, 'mean', state.state_vector)[mapping[1]])
for state in track],
text=[self._format_state_text(state) for state in track],
**scatter_kwargs)
elif self.dimension == 3: # plot 3D tracks
if particle:
raise NotImplementedError
# create empty error arrays
err_x = np.array([np.nan for _ in range(len(track))], dtype=float)
err_y = np.array([np.nan for _ in range(len(track))], dtype=float)
err_z = np.array([np.nan for _ in range(len(track))], dtype=float)
if uncertainty: # find x,y,z error bars for relevant states
for count, state in enumerate(track):
if not count % err_freq: # ie count % err_freq = 0
HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix
cov = HH @ state.covar @ HH.T
err_x[count] = np.sqrt(cov[0, 0])
err_y[count] = np.sqrt(cov[1, 1])
err_z[count] = np.sqrt(cov[2, 2])
self.fig.add_scatter3d(
x=[float(getattr(state, 'mean', state.state_vector)[mapping[0]])
for state in track],
error_x=dict(type='data', thickness=10, width=3, array=err_x),
y=[float(getattr(state, 'mean', state.state_vector)[mapping[1]])
for state in track],
error_y=dict(type='data', thickness=10, width=3, array=err_y),
z=[float(getattr(state, 'mean', state.state_vector)[mapping[2]])
for state in track],
error_z=dict(type='data', thickness=10, width=3, array=err_z),
# note that 3D error thickness seems to be broken in Plotly
text=[self._format_state_text(state) for state in track],
**scatter_kwargs)
track_colors[track] = (self.fig.data[-1].line.color
or self.fig.data[-1].marker.color
or self.get_next_color())
# earlier checking means this only applies to 2D.
if uncertainty and self.dimension == 2:
name = track_kwargs['legendgroup'] + "<br>(Ellipses)"
add_legend = name not in {trace.legendgroup for trace in self.fig.data}
for track in tracks:
ellipse_kwargs = dict(
mode='none', fill='toself', fillcolor=track_colors[track],
opacity=0.2, hoverinfo='skip',
legendgroup=name, name=name,
legendrank=track_kwargs['legendrank'] + 10)
for state in track:
points = self._generate_ellipse_points(state, mapping, ellipse_points)
if add_legend:
ellipse_kwargs['showlegend'] = True
add_legend = False
else:
ellipse_kwargs['showlegend'] = False
self.fig.add_scatter(x=points[0, :], y=points[1, :], **ellipse_kwargs)
if particle and self.dimension == 2:
name = track_kwargs['legendgroup'] + "<br>(Particles)"
add_legend = name not in {trace.legendgroup for trace in self.fig.data}
for track in tracks:
for state in track:
particle_kwargs = dict(
mode='markers', marker=dict(size=2),
opacity=0.4, hoverinfo='skip',
legendgroup=name, name=name,
legendrank=track_kwargs['legendrank'] + 20)
if add_legend:
particle_kwargs['showlegend'] = True
add_legend = False
else:
particle_kwargs['showlegend'] = False
data = state.state_vector[mapping[:2], :]
self.fig.add_scattergl(x=data[0], y=data[1], **particle_kwargs)
@staticmethod
def _generate_ellipse_points(state, mapping, n_points=30):
"""Generate error ellipse points for given state and mapping"""
HH = np.eye(state.ndim)[mapping, :] # Get position mapping matrix
w, v = np.linalg.eig(HH @ state.covar @ HH.T)
max_ind = np.argmax(w)
min_ind = np.argmin(w)
orient = np.arctan2(v[1, max_ind], v[0, max_ind])
a = np.sqrt(w[max_ind])
b = np.sqrt(w[min_ind])
m = 1 - (b**2 / a**2)
def func(x):
return np.sqrt(1 - (m**2 * np.sin(x)**2))
def func2(z):
return quad(func, 0, z)[0]
c = 4 * a * func2(np.pi / 2)
points = []
for n in range(n_points):
def func3(x):
return n/n_points*c - a*func2(x)
points.append((brentq(func3, 0, 2 * np.pi, xtol=1e-4)))
c, s = np.cos(orient), np.sin(orient)
rotational_matrix = np.array(((c, -s), (s, c)))
points.append(points[0])
points = np.array([[a * np.sin(i), b * np.cos(i)] for i in points])
points = rotational_matrix @ points.T
return points + state.mean[mapping[:2], :]
[docs]
def plot_sensors(self, sensors, mapping=[0, 1], label="Sensors", **kwargs):
"""Plots sensor(s)
Plots sensors. Users can change the color and marker of sensors using keyword
arguments. Default is a black 'x' marker.
Parameters
----------
sensors : Collection of :class:`~.Sensor`
Sensors to plot
mapping: list
List of items specifying the mapping of the position
components of the sensor's position.
label: str
Label to apply to all sensors for legend.
\\*\\*kwargs: dict
Additional arguments to be passed to scatter function for sensors. Defaults are
``marker=dict(symbol='x', color='black')``.
.. deprecated:: 1.5
``label`` has replaced ``sensor_label``. In the current implementation
``sensor_label`` overrides ``label``. However, use of ``sensor_label``
may be removed in the future.
"""
label = kwargs.pop('sensor_label', None) or label
if not isinstance(sensors, Collection):
sensors = {sensors}
self._check_mapping(mapping) # ensure mapping is compatible with plotter dimension
if self.dimension == 1 or self.dimension == 3:
raise NotImplementedError
sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'),
legendgroup=label, legendrank=50)
merge(sensor_kwargs, kwargs)
sensor_kwargs['name'] = label
if sensor_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}:
sensor_kwargs['showlegend'] = True
else:
sensor_kwargs['showlegend'] = True
sensor_xy = np.array([sensor.position[mapping, 0] for sensor in sensors])
self.fig.add_scatter(x=sensor_xy[:, 0], y=sensor_xy[:, 1], **sensor_kwargs)
[docs]
def hide_plot_traces(self, items_to_hide=None):
"""Hide Plot Traces
This function allows plotting items to be invisible as default. Users can toggle the plot
trace to visible.
Parameters
----------
items_to_hide : Iterable[str]
The legend label (`legendgroups`) for the plot traces that should be invisible as
default. If left as ``None`` no traces will be shown.
"""
for fig_data in self.fig.data:
if items_to_hide is None or fig_data.legendgroup in items_to_hide:
fig_data.visible = "legendonly"
else:
fig_data.visible = None
[docs]
def show_plot_traces(self, items_to_show=None):
"""Show Plot Traces
This function allows specific plotting items to be shown as default. All labels not
mentioned in `items_to_show` will be invisible and can be manually toggled on.
Parameters
----------
items_to_show : Iterable[str]
The legend label (`legendgroups`) for the plot traces that should be shown as
default. If left as ``None`` all traces will be shown.
"""
for fig_data in self.fig.data:
if items_to_show is None or fig_data.legendgroup in items_to_show:
fig_data.visible = None
else:
fig_data.visible = "legendonly"
[docs]
class PolarPlotterly(_Plotter):
def __init__(self, dimension=Dimension.TWO, **kwargs):
if go is None:
raise RuntimeError("Usage of Plotterly plotter requires installation of `plotly`")
if isinstance(dimension, type(Dimension.TWO)):
self.dimension = dimension
elif isinstance(dimension, int):
self.dimension = Dimension(dimension)
else:
raise TypeError("%s is an unsupported type for \'dimension\'; "
"expected type %s" % (type(dimension), type(Dimension.TWO)))
if self.dimension != dimension.TWO:
raise TypeError("Only 2D plotting currently supported")
layout_kwargs = dict()
layout_kwargs.update(kwargs)
# Generate plot axes
self.fig = go.Figure(layout=layout_kwargs)
[docs]
def plot_state_sequence(self, state_sequences, angle_mapping: int, range_mapping: int = None,
label="", **kwargs):
"""Plots state sequence(s)
Plots each state sequence passed in to :attr:`state_sequences` and generates a legend
automatically.
Users can change line style, color and marker using keyword arguments. Any changes
will apply to all ground truths.
Parameters
----------
state_sequences : Collection of :class:`~.StateMutableSequence`
Collection of state sequences which will be plotted. If not a collection,
and instead a single :class:`~.StateMutableSequence` type, the argument is modified
to be a set to allow for iteration.
angle_mapping: int
Specifying the mapping of the angular component of the state space to be plotted.
range_mapping: int
Specifying the mapping of the range component of the state space to be plotted. If
`None`, the angular component will be plotted against time.
label: str
Label for truth data.
\\*\\*kwargs: dict
Additional arguments to be passed to scatter function. Default is
``mode=marker``.
The default unit for the angular component is radians. This can be changed to degrees
with the keyword argument ``thetaunit='degrees'``.
"""
if not isinstance(state_sequences, Collection) \
or isinstance(state_sequences, StateMutableSequence):
state_sequences = {state_sequences}
plotting_kwargs = dict(
mode="markers", legendgroup=label, legendrank=200,
name=label, thetaunit="radians")
merge(plotting_kwargs, kwargs)
add_legend = plotting_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}
for state_sequence in state_sequences:
if range_mapping is None:
r = [state.timestamp for state in state_sequence]
else:
r = [float(state.state_vector[range_mapping]) for state in state_sequence]
bearings = [float(state.state_vector[angle_mapping]) for state in state_sequence]
scatter_kwargs = plotting_kwargs.copy()
if add_legend:
scatter_kwargs['showlegend'] = True
add_legend = False
else:
scatter_kwargs['showlegend'] = False
polar_plot = go.Scatterpolar(
r=r,
theta=bearings, **scatter_kwargs)
self.fig.add_trace(polar_plot)
[docs]
def plot_ground_truths(self, truths, mapping, label="Ground Truth", **kwargs):
"""Plots ground truth(s)
Plots each ground truth path passed in to :attr:`truths` and generates a legend
automatically. Ground truths are plotted as dashed lines with default colors.
Users can change line style, color and marker using keyword arguments. Any changes
will apply to all ground truths.
Parameters
----------
truths : Collection of :class:`~.GroundTruthPath`
Collection of ground truths which will be plotted. If not a collection,
and instead a single :class:`~.GroundTruthPath` type, the argument is modified to be a
set to allow for iteration.
mapping: list
List of items specifying the mapping of the position components of the state space.
label: str
Label for truth data. Default is "Ground Truth".
\\*\\*kwargs: dict
Additional arguments to be passed to scatter function. Default is
``line=dict(dash="dash")``.
.. deprecated:: 1.5
``label`` has replaced ``truths_label``. In the current implementation
``truths_label`` overrides ``label``. However, use of ``truths_label``
may be removed in the future.
"""
label = kwargs.pop('truths_label', None) or label
truths_kwargs = dict(mode="lines", line=dict(dash="dash"), legendrank=100)
merge(truths_kwargs, kwargs)
angle_mapping = mapping[0]
if len(mapping) > 1:
range_mapping = mapping[1]
else:
range_mapping = None
self.plot_state_sequence(state_sequences=truths, angle_mapping=angle_mapping,
range_mapping=range_mapping, label=label, **truths_kwargs)
[docs]
def plot_measurements(self, measurements, mapping, measurement_model=None,
label="Measurements", convert_measurements=True, **kwargs):
"""Plots measurements
Plots detections and clutter, generating a legend automatically. Detections are plotted as
blue circles by default unless the detection type is clutter.
If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker.
Users can change the color and marker of detections using keyword arguments but not for
clutter detections.
Parameters
----------
measurements : Collection of :class:`~.Detection`
Detections which will be plotted. If measurements is a set of lists it is flattened.
mapping: list
List of items specifying the mapping of the position components of the state space.
measurement_model : :class:`~.Model`, optional
User-defined measurement model to be used in finding measurement state inverses if
they cannot be found from the measurements themselves.
label : str
Label for the measurements. Default is "Measurements".
convert_measurements: bool
Should the measurements be converted before being plotted. Default is True.
\\*\\*kwargs: dict
Additional arguments to be passed to scatter function for detections. Defaults are
``marker=dict(color="#636EFA")``.
.. deprecated:: 1.5
``label`` has replaced ``measurements_label``. In the current implementation
``measurements_label`` overrides ``label``. However, use of ``measurements_label``
may be removed in the future.
"""
label = kwargs.pop('measurements_label', None) or label
if not isinstance(measurements, Collection):
measurements = {measurements}
if any(isinstance(item, set) for item in measurements):
measurements_set = chain.from_iterable(measurements) # Flatten into one set
else:
measurements_set = set(measurements)
plot_detections, plot_clutter = self._conv_measurements(measurements_set,
mapping,
measurement_model,
convert_measurements)
angle_mapping = 0
if len(mapping) > 1:
range_mapping = 1
else:
range_mapping = None
if plot_detections:
if plot_clutter:
name = label + "<br>(Detections)"
else:
name = label
measurement_kwargs = dict(mode='markers', marker=dict(color='#636EFA'), legendrank=200)
merge(measurement_kwargs, kwargs)
plotting_data = [State(state_vector=plotting_state_vector,
timestamp=det.timestamp)
for det, plotting_state_vector in plot_detections.items()]
self.plot_state_sequence(state_sequences=[plotting_data], angle_mapping=angle_mapping,
range_mapping=range_mapping, label=name,
**measurement_kwargs)
if plot_clutter:
name = label + "<br>(Clutter)"
clutter_kwargs = dict(mode='markers', legendrank=210,
marker=dict(symbol="star-triangle-up", color='#FECB52'))
merge(clutter_kwargs, kwargs)
plotting_data = [State(state_vector=plotting_state_vector,
timestamp=det.timestamp)
for det, plotting_state_vector in plot_clutter.items()]
self.plot_state_sequence(state_sequences=[plotting_data], angle_mapping=angle_mapping,
range_mapping=range_mapping, label=name,
**clutter_kwargs)
[docs]
def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, label="Tracks",
**kwargs):
"""Plots track(s)
Plots each track generated, generating a legend automatically. If ``uncertainty=True``
error ellipses are plotted.
Tracks are plotted as solid lines with point markers and default colors.
Users can change line style, color and marker using keyword arguments.
Parameters
----------
tracks : Collection of :class:`~.Track`
Collection of tracks which will be plotted. If not a collection, and instead a single
:class:`~.Track` type, the argument is modified to be a set to allow for iteration.
mapping: list
List of items specifying the mapping of the position
components of the state space.
uncertainty : bool
If True, function plots uncertainty ellipses.
particle : bool
If True, function plots particles.
label: str
Label to apply to all tracks for legend.
\\*\\*kwargs: dict
Additional arguments to be passed to scatter function. Defaults are
``mode='markers+lines'``.
.. deprecated:: 1.5
``label`` has replaced ``track_label``. In the current implementation
``track_label`` overrides ``label``. However, use of ``track_label``
may be removed in the future.
"""
label = kwargs.pop('track_label', None) or label
if uncertainty or particle:
raise NotImplementedError
track_kwargs = dict(mode='markers+lines', legendrank=300)
merge(track_kwargs, kwargs)
angle_mapping = mapping[0]
if len(mapping) > 1:
range_mapping = mapping[1]
else:
range_mapping = None
self.plot_state_sequence(state_sequences=tracks, angle_mapping=angle_mapping,
range_mapping=range_mapping, label=label, **track_kwargs)
def plot_sensors(self, sensors, label="Sensors", **kwargs):
raise NotImplementedError
class _AnimationPlotterDataClass(Base):
plotting_data: Iterable[State] = Property()
plotting_label: str = Property()
plotting_keyword_arguments: dict = Property()
[docs]
class AnimationPlotter(_Plotter):
def __init__(self, dimension=Dimension.TWO, x_label: str = "$x$", y_label: str = "$y$",
title: str = None, legend_kwargs: dict = None, **kwargs):
self.figure_kwargs = {"figsize": (10, 6)}
self.figure_kwargs.update(kwargs)
if dimension != Dimension.TWO:
raise NotImplementedError
self.legend_kwargs = dict()
if legend_kwargs is not None:
self.legend_kwargs.update(legend_kwargs)
self.x_label: str = x_label
self.y_label: str = y_label
if title:
title += "\n"
self.title: str = title
self.plotting_data: list[_AnimationPlotterDataClass] = []
self.animation_output: animation.FuncAnimation = None
[docs]
def run(self,
times_to_plot: list[datetime] = None,
plot_item_expiry: Optional[timedelta] = None,
**kwargs):
"""Run the animation
Parameters
----------
times_to_plot : List of :class:`~.datetime`
List of datetime objects of when to refresh and draw the animation. Default `None`,
where unique timestamps of data will be used.
plot_item_expiry: :class:`~.timedelta`, Optional
Describes how long states will remain present in the figure. Default value of None
means data is shown indefinitely
\\*\\*kwargs: dict
Additional arguments to be passed to the animation.FuncAnimation function
"""
if times_to_plot is None:
times_to_plot = sorted({
state.timestamp
for plotting_data in self.plotting_data
for state in plotting_data.plotting_data})
self.animation_output = self.run_animation(
times_to_plot=times_to_plot,
data=self.plotting_data,
plot_item_expiry=plot_item_expiry,
x_label=self.x_label,
y_label=self.y_label,
figure_kwargs=self.figure_kwargs,
legend_kwargs=self.legend_kwargs,
animation_input_kwargs=kwargs,
plot_title=self.title
)
return self.animation_output
[docs]
def save(self, filename='example.mp4', **kwargs):
"""Save the animation
Parameters
----------
filename : str
filename of animation file
\\*\\*kwargs: dict
Additional arguments to be passed to the animation.save function
"""
if self.animation_output is None:
raise ValueError("Animation hasn't been run yet. Therefore there is no animation to "
"save")
self.animation_output.save(filename, **kwargs)
[docs]
def plot_ground_truths(self, truths, mapping: list[int], label: str = "Ground Truth",
**kwargs):
"""Plots ground truth(s)
Plots each ground truth path passed in to :attr:`truths` and generates a legend
automatically. Ground truths are plotted as dashed lines with default colors.
Users can change linestyle, color and marker using keyword arguments. Any changes
will apply to all ground truths.
Parameters
----------
truths : Collection of :class:`~.GroundTruthPath`
Collection of ground truths which will be plotted. If not a collection and instead a
single :class:`~.GroundTruthPath` type, the argument is modified to be a set to allow
for iteration.
mapping: list
List of items specifying the mapping of the position components of the state space.
label: str
Label for truth data. Default is "Ground Truth"
\\*\\*kwargs: dict
Additional arguments to be passed to plot function. Default is ``linestyle="--"``.
.. deprecated:: 1.5
``label`` has replaced ``truths_label``. In the current implementation
``truths_label`` overrides ``label``. However, use of ``truths_label``
may be removed in the future.
"""
label = kwargs.pop('truths_label', None) or label
truths_kwargs = dict(linestyle="--")
truths_kwargs.update(kwargs)
self.plot_state_mutable_sequence(truths, mapping, label, **truths_kwargs)
[docs]
def plot_tracks(self, tracks, mapping: list[int], uncertainty=False, particle=False,
label="Tracks", **kwargs):
"""Plots track(s)
Plots each track generated, generating a legend automatically. Tracks are plotted as solid
lines with point markers and default colors. Users can change linestyle, color and marker
using keyword arguments.
Parameters
----------
tracks : Collection of :class:`~.Track`
Collection of tracks which will be plotted. If not a collection, and instead a single
:class:`~.Track` type, the argument is modified to be a set to allow for iteration.
mapping: list
List of items specifying the mapping of the position
components of the state space.
uncertainty : bool
Currently not implemented. If True, an error is raised
particle : bool
Currently not implemented. If True, an error is raised
label: str
Label to apply to all tracks for legend.
\\*\\*kwargs: dict
Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``,
``marker='s'`` for :class:`~.Update` and ``marker='o'`` for other states.
.. deprecated:: 1.5
``label`` has replaced ``track_label``. In the current implementation
``track_label`` overrides ``label``. However, use of ``track_label``
may be removed in the future.
"""
label = kwargs.pop('track_label', None) or label
if uncertainty or particle:
raise NotImplementedError
tracks_kwargs = dict(linestyle='-', marker="s", color=None)
tracks_kwargs.update(kwargs)
self.plot_state_mutable_sequence(tracks, mapping, label, **tracks_kwargs)
[docs]
def plot_state_mutable_sequence(self, state_mutable_sequences, mapping: list[int], label: str,
**plotting_kwargs):
"""Plots State Mutable Sequence
Parameters
----------
state_mutable_sequences : Collection of :class:`~.StateMutableSequence`
Collection of states to be plotted
mapping: list
List of items specifying the mapping of the position components of the state space.
label : str
User-defined measurement model to be used in finding measurement state inverses if
they cannot be found from the measurements themselves.
\\*\\*kwargs: dict
Additional arguments to be passed to plot function for states.
"""
if not isinstance(state_mutable_sequences, Collection) or \
isinstance(state_mutable_sequences, StateMutableSequence):
state_mutable_sequences = {state_mutable_sequences} # Make a set of length 1
for idx, state_mutable_sequence in enumerate(state_mutable_sequences):
if idx == 0:
this_plotting_label = label
else:
this_plotting_label = None
self.plotting_data.append(_AnimationPlotterDataClass(
plotting_data=[State(state_vector=[state.state_vector[mapping[0]],
state.state_vector[mapping[1]]],
timestamp=state.timestamp)
for state in state_mutable_sequence],
plotting_label=this_plotting_label,
plotting_keyword_arguments=plotting_kwargs
))
[docs]
def plot_measurements(self, measurements, mapping, measurement_model=None,
label="Measurements", convert_measurements=True, **kwargs):
"""Plots measurements
Plots detections and clutter, generating a legend automatically. Detections are plotted as
blue circles by default unless the detection type is clutter.
If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker.
Users can change the color and marker of detections using keyword arguments but not for
clutter detections.
Parameters
----------
measurements : Collection of :class:`~.Detection`
Detections which will be plotted. If measurements is a set of lists it is flattened.
mapping: list
List of items specifying the mapping of the position components of the state space.
measurement_model : :class:`~.Model`, optional
User-defined measurement model to be used in finding measurement state inverses if
they cannot be found from the measurements themselves.
label: str
Label for measurements. Default is "Detections".
convert_measurements: bool
Should the measurements be converted from measurement space to state space before
being plotted. Default is True
\\*\\*kwargs: dict
Additional arguments to be passed to plot function for detections. Defaults are
``marker='o'`` and ``color='b'``.
.. deprecated:: 1.5
``label`` has replaced ``measurements_label``. In the current implementation
``measurements_label`` overrides ``label``. However, use of ``measurements_label``
may be removed in the future.
"""
label = kwargs.pop('measurements_label', None) or label
measurement_kwargs = dict(marker='o', color='b')
measurement_kwargs.update(kwargs)
if not isinstance(measurements, Collection):
measurements = {measurements} # Make a set of length 1
if any(isinstance(item, set) for item in measurements):
measurements_set = chain.from_iterable(measurements) # Flatten into one set
else:
measurements_set = measurements
plot_detections, plot_clutter = self._conv_measurements(measurements_set,
mapping,
measurement_model,
convert_measurements)
if plot_detections:
if plot_clutter:
name = label + "\n(Detections)"
else:
name = label
detection_kwargs = dict(linestyle='', marker='o', color='b')
detection_kwargs.update(kwargs)
self.plotting_data.append(_AnimationPlotterDataClass(
plotting_data=[State(state_vector=plotting_state_vector,
timestamp=detection.timestamp)
for detection, plotting_state_vector in plot_detections.items()],
plotting_label=name,
plotting_keyword_arguments=detection_kwargs
))
if plot_clutter:
clutter_kwargs = dict(linestyle='', marker='2', color='y')
clutter_kwargs.update(kwargs)
self.plotting_data.append(_AnimationPlotterDataClass(
plotting_data=[State(state_vector=plotting_state_vector,
timestamp=detection.timestamp)
for detection, plotting_state_vector in plot_clutter.items()],
plotting_label=label + "\n(Clutter)",
plotting_keyword_arguments=clutter_kwargs
))
def plot_sensors(self, sensors, label="Sensors", **kwargs):
raise NotImplementedError
[docs]
@classmethod
def run_animation(cls,
times_to_plot: list[datetime],
data: Iterable[_AnimationPlotterDataClass],
plot_item_expiry: Optional[timedelta] = None,
axis_padding: float = 0.1,
figure_kwargs: dict = None,
animation_input_kwargs: dict = None,
legend_kwargs: dict = None,
x_label: str = "$x$",
y_label: str = "$y$",
plot_title: str = None
) -> animation.FuncAnimation:
"""
Parameters
----------
times_to_plot : Iterable[datetime]
All the times that the plotter should plot
data : Iterable[datetime]
All the data that should be plotted
plot_item_expiry: timedelta
How long a state should be displayed for. Default value of None
means data is shown indefinitely
axis_padding: float
How much extra space should be given around the edge of the plot
figure_kwargs: dict
Keyword arguments for the pyplot figure function. See matplotlib.pyplot.figure for more
details
animation_input_kwargs: dict
Keyword arguments for FuncAnimation class. See matplotlib.animation.FuncAnimation for
more details. Default values are: blit=False, repeat=False, interval=50
legend_kwargs: dict
Keyword arguments for the pyplot legend function. See matplotlib.pyplot.legend for more
details
x_label: str
Label for the x axis
y_label: str
Label for the y axis
plot_title: str
Title for the plot
Returns
-------
: animation.FuncAnimation
Animation object
"""
animation_kwargs = dict(blit=False, repeat=False, interval=50) # milliseconds
if animation_input_kwargs is None:
animation_input_kwargs = dict()
animation_kwargs.update(animation_input_kwargs)
if figure_kwargs is None:
figure_kwargs = dict()
fig1 = plt.figure(**figure_kwargs)
the_lines = []
plotting_data = []
legends_key = []
for a_plot_object in data:
if a_plot_object.plotting_data is not None:
the_data = np.array(
[a_state.state_vector for a_state in a_plot_object.plotting_data])
if len(the_data) == 0:
continue
the_lines.append(
plt.plot([], # the_data[:1, 0],
[], # the_data[:1, 1],
**a_plot_object.plotting_keyword_arguments)[0])
legends_key.append(a_plot_object.plotting_label)
plotting_data.append(a_plot_object.plotting_data)
if axis_padding:
[x_limits, y_limits] = [
[min(state.state_vector[idx] for line in data for state in line.plotting_data),
max(state.state_vector[idx] for line in data for state in line.plotting_data)]
for idx in [0, 1]]
for axis_limits in [x_limits, y_limits]:
limit_padding = axis_padding * (axis_limits[1] - axis_limits[0])
# The casting to float to ensure the limits contain do not contain angle classes
axis_limits[0] = float(axis_limits[0] - limit_padding)
axis_limits[1] = float(axis_limits[1] + limit_padding)
plt.xlim(x_limits)
plt.ylim(y_limits)
else:
plt.axis('equal')
plt.xlabel(x_label)
plt.ylabel(y_label)
lines_with_legend = [line for line, label in zip(the_lines, legends_key)
if label is not None]
if legend_kwargs is None:
legend_kwargs = dict()
plt.legend(lines_with_legend, [label for label in legends_key if label is not None],
**legend_kwargs)
if plot_item_expiry is None:
min_plot_time = min(state.timestamp
for line in data
for state in line.plotting_data)
min_plot_times = [min_plot_time] * len(times_to_plot)
else:
min_plot_times = [time - plot_item_expiry for time in times_to_plot]
line_ani = animation.FuncAnimation(fig1, cls.update_animation,
frames=len(times_to_plot),
fargs=(the_lines, plotting_data, min_plot_times,
times_to_plot, plot_title),
**animation_kwargs)
plt.draw()
return line_ani
[docs]
@staticmethod
def update_animation(index: int, lines: list[Line2D], data_list: list[list[State]],
start_times: list[datetime], end_times: list[datetime], title: str):
"""
Parameters
----------
index : int
Which index of the start_times and end_times should be used
lines : List[Line2D]
The data that will be plotted, to be plotted.
data_list : List[List[State]]
All the data that should be plotted
start_times : List[datetime]
lowest (earliest) time for an item to be plotted
end_times : List[datetime]
highest (latest) time for an item to be plotted
title: str
Title for the plot
Returns
-------
: List[Line2D]
The data that will be plotted
"""
min_time = start_times[index]
max_time = end_times[index]
if title is None:
title = ""
plt.title(title + str(max_time))
for i, data_source in enumerate(data_list):
if data_source is not None:
the_data = np.array([a_state.state_vector for a_state in data_source
if min_time <= a_state.timestamp <= max_time])
if the_data.size > 0:
lines[i].set_data(the_data[:, 0],
the_data[:, 1])
else:
lines[i].set_data([],
[])
return lines
[docs]
class AnimatedPlotterly(_Plotter):
"""
Class for a 2D animated plotter that uses Plotly graph objects rather than matplotlib.
This gives the user the ability to see how tracking works through time, while being
able to interact with tracks, truths, etc, in the same way that is enabled by
Plotly static plots.
Simplifies the process of plotting ground truths, measurements, clutter, and tracks.
Tracks can be plotted with uncertainty ellipses or particles if required. Legends
are automatically generated with each plot.
Parameters
----------
timesteps: Collection
Collection of equally-spaced timesteps. Each animation frame is a timestep.
tail_length: float
Percentage of sim time for which previous values will still be displayed for.
Value can be between 0 and 1. Default is 0.3.
equal_size: bool
Makes x and y axes equal when figure is resized. Default is False.
sim_duration: int
Time taken to run animation (s). Default is 6
\\*\\*kwargs
Additional arguments to be passed in the initialisation.
Attributes
----------
"""
def __init__(self, timesteps, tail_length=0.3, equal_size=False,
sim_duration=6, **kwargs):
"""
Initialise the figure and checks that inputs are correctly formatted.
Creates an empty frame for each timestep, and configures
the buttons and slider.
"""
if go is None or colors is None:
raise RuntimeError("Usage of Plotterly plotter requires installation of `plotly`")
self.equal_size = equal_size
# checking that there are multiple timesteps
if len(timesteps) < 2:
raise ValueError("Must be at least 2 timesteps for animation.")
# checking that timesteps are evenly spaced
time_spaces = np.unique(np.diff(timesteps))
# gives the unique values of time gaps between timesteps. If this contains more than
# one value, then timesteps are not all evenly spaced which is an issue.
if len(time_spaces) != 1:
warnings.warn("Timesteps are not equally spaced, so the passage of time is not linear")
self.timesteps = timesteps
# checking input to tail_length
if tail_length > 1 or tail_length < 0:
raise ValueError("Tail length should be between 0 and 1")
self.tail_length = tail_length
# checking sim_duration
if sim_duration <= 0:
raise ValueError("Simulation duration must be positive")
# time window is calculated as sim_length * tail_length. This is
# the window of time for which past plots are still visible
self.time_window = (timesteps[-1] - timesteps[0]) * tail_length
self.colorway = colors.qualitative.Plotly[1:] # plotting colours
self.all_masks = dict() # dictionary to be filled up later
self.plotting_function_called = False # keeps track if anything has been plotted or not
# so that only the first data plotted will override the default axis max and mins.
self.fig = go.Figure()
layout_kwargs = dict(
xaxis=dict(title=dict(text="<i>x</i>")),
yaxis=dict(title=dict(text="<i>y</i>")),
colorway=self.colorway, # Needed to match colours later.
height=550,
autosize=True
)
# layout_kwargs.update(kwargs)
self.fig.update_layout(layout_kwargs)
# initialise frames according to simulation timesteps
self.fig.frames = [dict(
name=str(time),
data=[],
traces=[]
) for time in timesteps]
self.fig.update_xaxes(range=[0, 10])
self.fig.update_yaxes(range=[0, 10])
frame_duration = sim_duration * 1000 / len(self.fig.frames)
# if the gap between timesteps is greater than a day, it isn't necessary
# to display hour and minute information, so remove this to give a cleaner display.
# a and b are used in the slider steps label later
if time_spaces[0] >= timedelta(days=1):
start_cut_off = None
end_cut_off = 10
# if the simulation is over a day long, display all information which
# looks clunky but is necessary
elif timesteps[-1] - timesteps[0] > timedelta(days=1):
start_cut_off = None
end_cut_off = None
# otherwise, remove day information and just show
# hours, mins, etc. which is cleaner to look at
else:
start_cut_off = 11
end_cut_off = None
# create button and slider
updatemenus = [dict(type='buttons',
buttons=[{
"args": [None,
{"frame": {"duration": frame_duration, "redraw": True},
"fromcurrent": True, "transition": {"duration": 0}}],
"label": "Play",
"method": "animate"
}, {
"args": [[None], {"frame": {"duration": 0, "redraw": True},
"mode": "immediate",
"transition": {"duration": 0}}],
"label": "Stop",
"method": "animate"
}],
direction='left',
pad=dict(r=10, t=75),
showactive=True, x=0.1, y=0, xanchor='right', yanchor='top')
]
sliders = [{'yanchor': 'top',
'xanchor': 'left',
'currentvalue': {'font': {'size': 16}, 'prefix': 'Time: ', 'visible': True,
'xanchor': 'right'},
'transition': {'duration': frame_duration, 'easing': 'linear'},
'pad': {'b': 10, 't': 50},
'len': 0.9, 'x': 0.1, 'y': 0,
'steps': [{'args': [[frame.name], {
'frame': {'duration': 1.0, 'easing': 'linear', 'redraw': True},
'transition': {'duration': 0, 'easing': 'linear'}}],
'label': frame.name[start_cut_off: end_cut_off],
'method': 'animate'} for frame in
self.fig.frames
]}]
self.fig.update_layout(updatemenus=updatemenus, sliders=sliders)
self.fig.update_layout(kwargs)
[docs]
def show(self):
"""
Display the animation.
"""
return self.fig
def _resize(self, data, type="track"):
"""
Reshape figure so that everything is in view.
Parameters
----------
data:
Collection of values that are being added to the figure.
Will be a list if coming from plot_ground_Truths or
plot_tracks, but will be a dictionary if coming from plot_measurements.
"""
# fill in all data. If there is no data, fill all_x, all_y with current axis limits
if not data:
all_x = list(self.fig.layout.xaxis.range)
all_y = list(self.fig.layout.xaxis.range)
else:
all_x = list()
all_y = list()
# fill in data
if type == "measurements":
for key, item in data.items():
all_x.extend(data[key]["x"])
all_y.extend(data[key]["y"])
elif type in ("ground_truth", "tracks"):
for n, _ in enumerate(data):
all_x.extend(data[n]["x"])
all_y.extend(data[n]["y"])
elif type == "sensor":
sensor_xy = np.array([sensor.position[[0, 1], 0] for sensor in data])
all_x.extend(sensor_xy[:, 0])
all_y.extend(sensor_xy[:, 1])
elif type == "particle_or_uncertainty":
# data comes in format of list of dictionaries. Each dictionary contains 'x' and 'y',
# which are a list of lists.
for dictionary in data:
for x_values in dictionary["x"]:
all_x.extend([np.nanmax(x_values), np.nanmin(x_values)])
for y_values in dictionary["y"]:
all_y.extend([np.nanmax(y_values), np.nanmin(y_values)])
xmax = max(all_x)
ymax = max(all_y)
xmin = min(all_x)
ymin = min(all_y)
if self.equal_size:
xmax = ymax = max(xmax, ymax)
xmin = ymin = min(xmin, ymin)
# if it's first time plotting data, want to ensure plotter is bound to that data
# and not the default values. Issues arise if the initial plotted data is much
# smaller than the default 0 to 10 values.
if not self.plotting_function_called:
self.fig.update_xaxes(range=[xmin, xmax])
self.fig.update_yaxes(range=[ymin, ymax])
# need to check if it's actually necessary to resize or not
if xmax >= self.fig.layout.xaxis.range[1] or xmin <= self.fig.layout.xaxis.range[0]:
xmax = max(xmax, self.fig.layout.xaxis.range[1])
xmin = min(xmin, self.fig.layout.xaxis.range[0])
xrange = xmax - xmin
# update figure while adding a small buffer to the mins and maxes
self.fig.update_xaxes(range=[xmin - xrange / 20, xmax + xrange / 20])
if ymax >= self.fig.layout.yaxis.range[1] or ymin <= self.fig.layout.yaxis.range[0]:
ymax = max(ymax, self.fig.layout.yaxis.range[1])
ymin = min(ymin, self.fig.layout.yaxis.range[0])
yrange = ymax - ymin
self.fig.update_yaxes(range=[ymin - yrange / 20, ymax + yrange / 20])
[docs]
def plot_ground_truths(self, truths, mapping, label="Ground Truth",
resize=True, **kwargs):
"""Plots ground truth(s)
Plots each ground truth path passed in to :attr:`truths` and generates a legend
automatically. Ground truths are plotted as dashed lines with default colors.
Users can change linestyle, color and marker using keyword arguments. Any changes
will apply to all ground truths.
Parameters
----------
truths : Collection of :class:`~.GroundTruthPath`
Collection of ground truths which will be plotted. If not a collection and instead a
single :class:`~.GroundTruthPath` type, the argument is modified to be a set to allow
for iteration.
mapping: list
List of items specifying the mapping of the position components of the state space.
label: str
Name of ground truths in legend/plot
resize: bool
if True, will resize figure to ensure that ground truths are in view
\\*\\*kwargs: dict
Additional arguments to be passed to plot function. Default is ``linestyle="--"``.
.. deprecated:: 1.5
``label`` has replaced ``truths_label``. In the current implementation
``truths_label`` overrides ``label``. However, use of ``truths_label``
may be removed in the future.
"""
label = kwargs.pop('truths_label', None) or label
if not isinstance(truths, Collection) or isinstance(truths, StateMutableSequence):
truths = {truths} # Make a set of length 1
data = [dict() for _ in truths] # put all data into one place for later plotting
for n, truth in enumerate(truths):
# initialise arrays that go inside the dictionary
data[n].update(x=np.zeros(len(truth)),
y=np.zeros(len(truth)),
time=np.array([0 for _ in range(len(truth))], dtype=object),
time_str=np.array([0 for _ in range(len(truth))], dtype=object),
type=np.array([0 for _ in range(len(truth))], dtype=object))
for k, state in enumerate(truth):
# fill the arrays here
data[n]["x"][k] = state.state_vector[mapping[0]]
data[n]["y"][k] = state.state_vector[mapping[1]]
data[n]["time"][k] = state.timestamp
data[n]["time_str"][k] = str(state.timestamp)
data[n]["type"][k] = type(state).__name__
trace_base = len(self.fig.data) # number of traces currently in the animation
# add a trace that keeps the legend up for the entire simulation (will remain
# even if no truths are present), then add a trace for each truth in the simulation.
# initialise keyword arguments, then add them to the traces
truth_kwargs = dict(x=[], y=[], mode="lines", hoverinfo='none', legendgroup=label,
line=dict(dash="dash", color=self.colorway[0]), legendrank=100,
name=label, showlegend=True)
merge(truth_kwargs, kwargs)
# legend dummy trace
self.fig.add_trace(go.Scatter(truth_kwargs))
# we don't want the legend for any of the actual traces
truth_kwargs.update({"showlegend": False})
for n, _ in enumerate(truths):
# change the colour of each truth and include n in its name
merge(truth_kwargs, dict(line=dict(color=self.colorway[n % len(self.colorway)])))
merge(truth_kwargs, kwargs)
self.fig.add_trace(go.Scatter(truth_kwargs)) # add to traces
for frame in self.fig.frames:
# get current fig data and traces
data_ = list(frame.data)
traces_ = list(frame.traces)
# convert string to datetime object
frame_time = datetime.fromisoformat(frame.name)
cutoff_time = (frame_time - self.time_window)
# for the legend
data_.append(go.Scatter(x=[0, 0], y=[0, 0]))
traces_.append(trace_base)
for n, truth in enumerate(truths):
# all truth points that come at or before the frame time
t_upper = [data[n]["time"] <= frame_time]
# only select detections that come after the time cut-off
t_lower = [data[n]["time"] >= cutoff_time]
# put together
mask = np.logical_and(t_upper, t_lower)
# find x, y, time, and type
truth_x = data[n]["x"][tuple(mask)]
# add in np.inf to ensure traces are present for every timestep
truth_x = np.append(truth_x, [np.inf])
truth_y = data[n]["y"][tuple(mask)]
truth_y = np.append(truth_y, [np.inf])
times = data[n]["time_str"][tuple(mask)]
data_.append(go.Scatter(x=truth_x,
y=truth_y,
meta=times,
hovertemplate='GroundTruthState' +
'<br>(%{x}, %{y})' +
'<br>Time: %{meta}'))
traces_.append(trace_base + n + 1) # append data to correct trace
frame.data = data_
frame.traces = traces_
if resize:
self._resize(data, type="ground_truth")
# we have called a plotting function so update flag (gets used in _resize)
self.plotting_function_called = True
[docs]
def plot_measurements(self, measurements, mapping, measurement_model=None,
resize=True, label="Measurements",
convert_measurements=True, **kwargs):
"""Plots measurements
Plots detections and clutter, generating a legend automatically. Detections are plotted as
blue circles by default unless the detection type is clutter.
If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker.
Users can change the color and marker of detections using keyword arguments but not for
clutter detections.
Parameters
----------
measurements : Collection of :class:`~.Detection`
Detections which will be plotted. If measurements is a set of lists it is flattened.
mapping: list
List of items specifying the mapping of the position components of the state space.
measurement_model : :class:`~.Model`, optional
User-defined measurement model to be used in finding measurement state inverses if
they cannot be found from the measurements themselves.
resize: bool
If True, will resize figure to ensure measurements are in view
label : str
Label for the measurements. Default is "Measurements".
convert_measurements : bool
Should the measurements be converted from measurement space to state space before
being plotted. Default is True
\\*\\*kwargs: dict
Additional arguments to be passed to scatter function for detections. Defaults are
``marker=dict(color="#636EFA")``.
.. deprecated:: 1.5
``label`` has replaced ``measurements_label``. In the current implementation
``measurements_label`` overrides ``label``. However, use of ``measurements_label``
may be removed in the future.
"""
label = kwargs.pop('measurements_label', None) or label
if not isinstance(measurements, Collection):
measurements = {measurements} # Make a set of length 1
if any(isinstance(item, set) for item in measurements):
measurements_set = chain.from_iterable(measurements) # Flatten into one set
else:
measurements_set = measurements
plot_detections, plot_clutter = self._conv_measurements(measurements_set,
mapping,
measurement_model,
convert_measurements)
plot_combined = {'Detection': plot_detections,
'Clutter': plot_clutter} # for later reference
# this dictionary will store all the plotting data that we need
# from the detections and clutter into numpy arrays that we can easily
# access to plot
combined_data = dict()
# only add clutter or detections to plot if necessary
if plot_detections:
combined_data.update(dict(Detection=dict()))
if plot_clutter:
combined_data.update(dict(Clutter=dict()))
# initialise combined_data
for key in combined_data.keys():
length = len(plot_combined[key])
combined_data[key].update({
"x": np.zeros(length),
"y": np.zeros(length),
"time": np.array([0 for _ in range(length)], dtype=object),
"time_str": np.array([0 for _ in range(length)], dtype=object),
"type": np.array([0 for _ in range(length)], dtype=object)})
# and now fill in the data
for key in combined_data.keys():
for n, det in enumerate(plot_combined[key]):
x, y = list(plot_combined[key].values())[n]
combined_data[key]["x"][n] = x
combined_data[key]["y"][n] = y
combined_data[key]["time"][n] = det.timestamp
combined_data[key]["time_str"][n] = str(det.timestamp)
combined_data[key]["type"][n] = type(det).__name__
# get number of traces currently in fig
trace_base = len(self.fig.data)
if plot_detections:
# initialise detections
if plot_clutter:
name = label + "<br>(Detections)"
else:
name = label
measurement_kwargs = dict(x=[], y=[], mode='markers',
name=name,
legendgroup=name,
legendrank=200, showlegend=True,
marker=dict(color="#636EFA"), hoverinfo='none')
merge(measurement_kwargs, kwargs)
self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for legend
measurement_kwargs.update({"showlegend": False})
self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for plotting
if plot_clutter:
# change necessary kwargs to initialise clutter trace
name = label + "<br>(Clutter)"
clutter_kwargs = dict(x=[], y=[], mode='markers',
name=name,
legendgroup=name,
legendrank=300, showlegend=True,
marker=dict(symbol="star-triangle-up", color='#FECB52'),
hoverinfo='none')
merge(clutter_kwargs, kwargs)
self.fig.add_trace(go.Scatter(clutter_kwargs)) # trace for plotting clutter
# add data to frames
for frame in self.fig.frames:
data_ = list(frame.data)
traces_ = list(frame.traces)
# add blank data to ensure detection legend stays in place
data_.append(go.Scatter(x=[-np.inf, np.inf], y=[-np.inf, np.inf]))
traces_.append(trace_base) # ensure data is added to correct trace
frame_time = datetime.fromisoformat(frame.name) # convert string to datetime object
# time at which dets will disappear from the fig
cutoff_time = (frame_time - self.time_window)
for j, key in enumerate(combined_data.keys()):
# only select measurements that arrive by the time of the current frame
t_upper = [combined_data[key]["time"] <= frame_time]
# only select detections that come after the time cut-off
t_lower = [combined_data[key]["time"] >= cutoff_time]
# put them together to create the final mask
mask = np.logical_and(t_upper, t_lower)
# find x and y points for true detections and clutter
det_x = combined_data[key]["x"][tuple(mask)]
det_x = np.append(det_x, [np.inf])
det_y = combined_data[key]["y"][tuple(mask)]
det_y = np.append(det_y, [np.inf])
det_times = combined_data[key]["time_str"][tuple(mask)]
data_.append(go.Scatter(x=det_x,
y=det_y,
meta=det_times,
hovertemplate=f'{key}' +
'<br>(%{x}, %{y})' +
'<br>Time: %{meta}'))
traces_.append(trace_base + j + 1)
frame.data = data_ # update the figure
frame.traces = traces_
if resize:
self._resize(combined_data, "measurements")
# we have called a plotting function so update flag (gets used in resize)
self.plotting_function_called = True
[docs]
def plot_tracks(self, tracks, mapping, uncertainty=False, resize=True,
particle=False, plot_history=False, ellipse_points=30,
label="Tracks", **kwargs):
"""
Plots each track generated, generating a legend automatically. If 'uncertainty=True',
error ellipses are plotted. Tracks are plotted as solid lines with point markers
and default colours.
Users can change linestyle, color, and marker using keyword arguments. Uncertainty metrics
will also be plotted with the user defined colour and any changes will apply to all tracks.
Parameters
----------
tracks: Collection of :class '~Track'
Collection of tracks which will be plotted. If not a collection, and instead a single
:class:'~Track' type, the argument is modified to be a set to allow for iteration
mapping: list
List of items specifying the mapping of the position
components of the state space
uncertainty: bool
If True, function plots uncertainty ellipses
resize: bool
If True, plotter will change bounds so that tracks are in view
particle: bool
If True, function plots particles
plot_history: bool
If true, plots all particles and uncertainty ellipses up to current time step
ellipse_points: int
Number of points for polygon approximating ellipse shape
label: str
Label to apply to all tracks for legend
\\*\\*kwargs: dict
Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``,
``marker='s'`` for :class:`~.Update` and ``marker='o'`` for other states.
.. deprecated:: 1.5
``label`` has replaced ``track_label``. In the current implementation
``track_label`` overrides ``label``. However, use of ``track_label``
may be removed in the future.
"""
label = kwargs.pop('track_label', None) or label
if not isinstance(tracks, Collection) or isinstance(tracks, StateMutableSequence):
tracks = {tracks} # Make a set of length 1
# So that we can plot tracks for both the current time and for some previous times,
# we put plotting data for each track into a dictionary so that it can be easily
# accessed later.
data = [dict() for _ in tracks]
for n, track in enumerate(tracks): # sum up means - accounts for particle filter
xydata = np.concatenate(
[(getattr(state, 'mean', state.state_vector)[mapping, :])
for state in track],
axis=1)
# initialise arrays that go inside the dictionary
data[n].update(x=xydata[0],
y=xydata[1],
time=np.array([0 for _ in range(len(track))], dtype=object),
time_str=np.array([0 for _ in range(len(track))], dtype=object),
type=np.array([0 for _ in range(len(track))], dtype=object))
for k, state in enumerate(track):
# fill the arrays here
data[n]["time"][k] = state.timestamp
data[n]["time_str"][k] = str(state.timestamp)
data[n]["type"][k] = type(state).__name__
trace_base = len(self.fig.data) # number of traces
# add dummy trace for legend for track
track_kwargs = dict(x=[], y=[], mode="markers+lines", line=dict(color=self.colorway[2]),
legendgroup=label, legendrank=400, name=label,
showlegend=True)
track_kwargs.update(kwargs)
self.fig.add_trace(go.Scatter(track_kwargs))
# and initialise traces for every track. Need to change a few kwargs:
track_kwargs.update({'showlegend': False})
for k, _ in enumerate(tracks):
# update track colours
track_kwargs.update({'line': dict(color=self.colorway[(k + 2) % len(self.colorway)])})
track_kwargs.update(kwargs)
self.fig.add_trace(go.Scatter(track_kwargs))
for frame in self.fig.frames:
# get current fig data and traces
data_ = list(frame.data)
traces_ = list(frame.traces)
# convert string to datetime object
frame_time = datetime.fromisoformat(frame.name)
self.all_masks[frame_time] = dict() # save mask for later use
cutoff_time = (frame_time - self.time_window)
# add blank data to ensure legend stays in place
data_.append(go.Scatter(x=[-np.inf, np.inf], y=[-np.inf, np.inf]))
traces_.append(trace_base) # ensure data is added to correct trace
for n, track in enumerate(tracks):
# all track points that come at or before the frame time
t_upper = [data[n]["time"] <= frame_time]
# only select detections that come after the time cut-off
t_lower = [data[n]["time"] >= cutoff_time]
# put together
mask = np.logical_and(t_upper, t_lower)
# put into dictionary for later use
if plot_history:
self.all_masks[frame_time][n] = np.logical_and(t_upper, t_lower)
else:
self.all_masks[frame_time][n] = [data[n]["time"] == frame_time]
# find x, y, time, and type
track_x = data[n]["x"][tuple(mask)]
# add np.inf to plot so that the traces are present for entire simulation
track_x = np.append(track_x, [np.inf])
# repeat for y
track_y = data[n]["y"][tuple(mask)]
track_y = np.append(track_y, [np.inf])
track_type = data[n]["type"][tuple(mask)]
times = data[n]["time_str"][tuple(mask)]
data_.append(go.Scatter(x=track_x, # plot track
y=track_y,
meta=track_type,
customdata=times,
hovertemplate='%{meta}' +
'<br>(%{x}, %{y})' +
'<br>Time: %{customdata}'))
traces_.append(trace_base + n + 1) # add to correct trace
frame.data = data_
frame.traces = traces_
if resize:
self._resize(data, "tracks")
if uncertainty: # plot ellipses
name = f'{label}<br>Uncertainty'
uncertainty_kwargs = dict(x=[], y=[], legendgroup=name, fill='toself',
fillcolor=self.colorway[2],
opacity=0.2, legendrank=500, name=name,
hoverinfo='skip',
mode='none', showlegend=True)
uncertainty_kwargs.update(kwargs)
# dummy trace for legend for uncertainty
self.fig.add_trace(go.Scatter(uncertainty_kwargs))
# and an uncertainty ellipse trace for each track
uncertainty_kwargs.update({'showlegend': False})
for k, _ in enumerate(tracks):
uncertainty_kwargs.update(
{'fillcolor': self.colorway[(k + 2) % len(self.colorway)]})
uncertainty_kwargs.update(kwargs)
self.fig.add_trace(go.Scatter(uncertainty_kwargs))
# following function finds uncertainty data points and plots them
self._plot_particles_and_ellipses(tracks, mapping, resize, method="uncertainty")
if particle: # plot particles
# initialise traces. One for legend and one per track
name = f'{label}<br>Particles'
particle_kwargs = dict(mode='markers', marker=dict(size=2, color=self.colorway[2]),
opacity=0.4,
hoverinfo='skip', legendgroup=name, name=name,
legendrank=520, showlegend=True)
# apply any keyword arguments
particle_kwargs.update(kwargs)
self.fig.add_trace(go.Scatter(particle_kwargs)) # legend trace
particle_kwargs.update({"showlegend": False})
for k, track in enumerate(tracks): # trace for each track
particle_kwargs.update(
{'marker': dict(size=2, color=self.colorway[(k + 2) % len(self.colorway)])})
particle_kwargs.update(kwargs)
self.fig.add_trace(go.Scatter(particle_kwargs))
self._plot_particles_and_ellipses(tracks, mapping, resize, method="particles")
# we have called a plotting function so update flag
self.plotting_function_called = True
def _plot_particles_and_ellipses(self, tracks, mapping, resize, method="uncertainty"):
"""
The logic for plotting uncertainty ellipses and particles is nearly identical,
so it is put into one function.
Parameters
----------
tracks: Collection of :class '~Track'
Collection of tracks which will be plotted. If not a collection, and instead a single
:class:'~Track' type, the argument is modified to be a set to allow for iteration
mapping: list
List of items specifying the mapping of the position components of the state space.
method: str
Can either be "uncertainty" or "particles". Depends on what the function is plotting.
"""
data = [dict() for _ in tracks]
trace_base = len(self.fig.data)
for n, track in enumerate(tracks):
# initialise arrays that store particle/ellipse for later plotting
data[n].update(x=np.array([0 for _ in range(len(track))], dtype=object),
y=np.array([0 for _ in range(len(track))], dtype=object))
for k, state in enumerate(track):
# find data points
if method == "uncertainty":
data_x, data_y = Plotterly._generate_ellipse_points(state, mapping)
data_x = list(data_x)
data_y = list(data_y)
data_x.append(np.nan) # necessary to draw multiple ellipses at once
data_y.append(np.nan)
data[n]["x"][k] = data_x
data[n]["y"][k] = data_y
elif method == "particles":
data_xy = state.state_vector[mapping[:2], :]
data[n]["x"][k] = data_xy[0]
data[n]["y"][k] = data_xy[1]
else:
raise ValueError("Should be 'uncertainty' or 'particles'")
for frame in self.fig.frames:
frame_time = datetime.fromisoformat(frame.name)
data_ = list(frame.data) # current data in frame
traces_ = list(frame.traces) # current traces in frame
data_.append(go.Scatter(x=[-np.inf], y=[np.inf])) # add empty data for legend trace
traces_.append(trace_base - len(tracks) - 1) # ensure correct trace
for n, track in enumerate(tracks):
# now plot the data
_x = list(chain(*data[n]["x"][tuple(self.all_masks[frame_time][n])]))
_y = list(chain(*data[n]["y"][tuple(self.all_masks[frame_time][n])]))
_x.append(np.inf)
_y.append(np.inf)
data_.append(go.Scatter(x=_x, y=_y))
traces_.append(trace_base - len(tracks) + n)
frame.data = data_
frame.traces = traces_
if resize:
self._resize(data, type="particle_or_uncertainty")
[docs]
def plot_sensors(self, sensors, label="Sensors", resize=True, **kwargs):
"""Plots sensor(s)
Plots sensors. Users can change the color and marker of detections using keyword
arguments. Default is a black 'x' marker. Currently only works for stationary
sensors.
Parameters
----------
sensors : Collection of :class:`~.Sensor`
Sensors to plot
label: str
Label to apply to all tracks for legend.
\\*\\*kwargs: dict
Additional arguments to be passed to scatter function for detections. Defaults are
``marker=dict(symbol='x', color='black')``.
.. deprecated:: 1.5
``label`` has replaced ``sensor_label``. In the current implementation
``sensor_label`` overrides ``label``. However, use of ``sensor_label``
may be removed in the future.
"""
label = kwargs.pop('sensor_label', None) or label
if not isinstance(sensors, Collection):
sensors = {sensors}
# don't run any of this if there is no data input
if sensors:
trace_base = len(self.fig.data) # number of traces currently in figure
sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'),
legendgroup=label, legendrank=50,
name=label, showlegend=True)
merge(sensor_kwargs, kwargs)
self.fig.add_trace(go.Scatter(sensor_kwargs)) # initialises trace
# sensor position
sensor_xy = np.array([sensor.position[[0, 1], 0] for sensor in sensors])
if resize:
self._resize(sensors, "sensor")
for frame in self.fig.frames: # the plotting bit
traces_ = list(frame.traces)
data_ = list(frame.data)
data_.append(go.Scatter(x=sensor_xy[:, 0], y=sensor_xy[:, 1]))
traces_.append(trace_base)
frame.traces = traces_
frame.data = data_
# we have called a plotting function so update flag (used in _resize)
self.plotting_function_called = True