Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GOSPAMetric: Avoid unnecessary copies and inefficient masking #1059

Merged
merged 10 commits into from
Jul 18, 2024
71 changes: 52 additions & 19 deletions stonesoup/metricgenerator/ospametric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from itertools import chain, zip_longest
from itertools import chain, groupby, zip_longest

import numpy as np
from scipy.optimize import linear_sum_assignment
Expand Down Expand Up @@ -161,30 +161,58 @@ def compute_over_time(self, measured_states, measured_state_ids, truth_states,
exist for in the parameters. metric.value contains a list of metrics
for the GOSPA metric at each timestamp
"""
all_meas_timestamps = np.fromiter(
(state.timestamp for state in measured_states),
dtype='datetime64[us]'
)
meas_order = np.argsort(all_meas_timestamps)
all_meas_timestamps = all_meas_timestamps[meas_order]
all_meas_points = np.array(measured_states)[meas_order]
all_meas_ids = np.array(measured_state_ids)[meas_order]

all_truth_timestamps = np.fromiter(
(state.timestamp for state in truth_states),
dtype='datetime64[us]'
)
truth_order = np.argsort(all_truth_timestamps)
all_truth_timestamps = all_truth_timestamps[truth_order]
all_truth_points = np.array(truth_states)[truth_order]
all_truth_ids = np.array(truth_state_ids)[truth_order]

# Make a list of all the unique timestamps used
# Make a sorted list of all the unique timestamps used
timestamps = sorted({
state.timestamp
for state in chain(measured_states, truth_states)})
truth_iter = iter(groupby(range(len(all_truth_ids)), all_truth_timestamps.__getitem__))
meas_iter = iter(groupby(range(len(all_meas_ids)), all_meas_timestamps.__getitem__))

next_truth = next(truth_iter, None)
next_meas = next(meas_iter, None)

switching_metric = _SwitchingLoss(self.switching_penalty, self.p)
gospa_metrics = []
for timestamp in timestamps:
meas_mask = [state.timestamp == timestamp for state in measured_states]
# np.array doesn't work for ParticleState
meas_points = np.empty(len(measured_states), dtype="O")
meas_points[:] = measured_states
meas_points = meas_points[meas_mask]

meas_ids = np.array(measured_state_ids)[meas_mask]
while next_truth is not None or next_meas is not None:
timestamp = min(group[0] for group in [next_truth, next_meas] if group is not None)

truth_idxs = []

truth_mask = [state.timestamp == timestamp for state in truth_states]
truth_points = np.array(truth_states)[truth_mask]
truth_ids = np.array(truth_state_ids)[truth_mask]
if next_truth is not None and timestamp == next_truth[0]:
truth_idxs = np.fromiter(next_truth[1], dtype=int)
next_truth = next(truth_iter, None)

meas_idxs = []

if next_meas is not None and timestamp == next_meas[0]:
meas_idxs = np.fromiter(next_meas[1], dtype=int)
next_meas = next(meas_iter, None)

meas_points = all_meas_points[meas_idxs]
meas_ids = all_meas_ids[meas_idxs]

truth_points = all_truth_points[truth_idxs]
truth_ids = all_truth_ids[truth_idxs]

metric, truth_to_measured_assignment = self.compute_gospa_metric(
meas_points, truth_points)
meas_points,
truth_points
)
truth_mapping = {
truth_id: meas_ids[meas_id] if meas_id != -1 else None
for truth_id, meas_id in zip(truth_ids, truth_to_measured_assignment)}
Expand All @@ -197,13 +225,16 @@ def compute_over_time(self, measured_states, measured_state_ids, truth_states,
gospa_metrics.append(metric)

# If only one timestamp is present then return a SingleTimeMetric
if len(timestamps) == 1:
if len(gospa_metrics) == 1:
return gospa_metrics[0]
else:
start_time = np.concatenate((all_truth_timestamps[:1], all_meas_timestamps[:1])).min()
end_time = np.concatenate((all_truth_timestamps[-1:], all_meas_timestamps[-1:])).max()

return TimeRangeMetric(
title='GOSPA Metrics',
value=gospa_metrics,
time_range=TimeRange(min(timestamps), max(timestamps)),
time_range=TimeRange(start=start_time, end=end_time),
generator=self)

def compute_assignments(self, cost_matrix):
Expand Down Expand Up @@ -319,6 +350,8 @@ def compute_gospa_metric(self, measured_states, truth_states):
if num_truth_states == 0:
# When truth states are empty all measured states are false
opt_cost = -1.0 * num_measured_states * dummy_cost
if self.alpha == 2:
gospa_metric['false'] = opt_cost
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was not part of my first patch, but came up as I added some tests to provide coverage of my original changes.

It seems reasonable that GOSPA should report tracks as false also in the case where there are 0 ground truth states, right?

I don't fully understand why we test for self.alpha == 2, but this is what is done everywhere else (see for instance line 369).

elif num_measured_states == 0:
# When measured states are empty all truth
# states are missed
Expand Down
Loading