Skip to content

Commit

Permalink
Merge pull request #1059 from marvonlar/gospa-speedup
Browse files Browse the repository at this point in the history
GOSPAMetric: Avoid unnecessary copies and inefficient masking
  • Loading branch information
sdhiscocks committed Jul 18, 2024
2 parents 92dae99 + 59bd95b commit cb9b3ca
Show file tree
Hide file tree
Showing 2 changed files with 406 additions and 19 deletions.
71 changes: 52 additions & 19 deletions stonesoup/metricgenerator/ospametric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from itertools import chain, zip_longest
from itertools import chain, groupby, zip_longest

import numpy as np
from scipy.optimize import linear_sum_assignment
Expand Down Expand Up @@ -161,30 +161,58 @@ def compute_over_time(self, measured_states, measured_state_ids, truth_states,
exist for in the parameters. metric.value contains a list of metrics
for the GOSPA metric at each timestamp
"""
all_meas_timestamps = np.fromiter(
(state.timestamp for state in measured_states),
dtype='datetime64[us]'
)
meas_order = np.argsort(all_meas_timestamps)
all_meas_timestamps = all_meas_timestamps[meas_order]
all_meas_points = np.array(measured_states)[meas_order]
all_meas_ids = np.array(measured_state_ids)[meas_order]

all_truth_timestamps = np.fromiter(
(state.timestamp for state in truth_states),
dtype='datetime64[us]'
)
truth_order = np.argsort(all_truth_timestamps)
all_truth_timestamps = all_truth_timestamps[truth_order]
all_truth_points = np.array(truth_states)[truth_order]
all_truth_ids = np.array(truth_state_ids)[truth_order]

# Make a list of all the unique timestamps used
# Make a sorted list of all the unique timestamps used
timestamps = sorted({
state.timestamp
for state in chain(measured_states, truth_states)})
truth_iter = iter(groupby(range(len(all_truth_ids)), all_truth_timestamps.__getitem__))
meas_iter = iter(groupby(range(len(all_meas_ids)), all_meas_timestamps.__getitem__))

next_truth = next(truth_iter, None)
next_meas = next(meas_iter, None)

switching_metric = _SwitchingLoss(self.switching_penalty, self.p)
gospa_metrics = []
for timestamp in timestamps:
meas_mask = [state.timestamp == timestamp for state in measured_states]
# np.array doesn't work for ParticleState
meas_points = np.empty(len(measured_states), dtype="O")
meas_points[:] = measured_states
meas_points = meas_points[meas_mask]

meas_ids = np.array(measured_state_ids)[meas_mask]
while next_truth is not None or next_meas is not None:
timestamp = min(group[0] for group in [next_truth, next_meas] if group is not None)

truth_idxs = []

truth_mask = [state.timestamp == timestamp for state in truth_states]
truth_points = np.array(truth_states)[truth_mask]
truth_ids = np.array(truth_state_ids)[truth_mask]
if next_truth is not None and timestamp == next_truth[0]:
truth_idxs = np.fromiter(next_truth[1], dtype=int)
next_truth = next(truth_iter, None)

meas_idxs = []

if next_meas is not None and timestamp == next_meas[0]:
meas_idxs = np.fromiter(next_meas[1], dtype=int)
next_meas = next(meas_iter, None)

meas_points = all_meas_points[meas_idxs]
meas_ids = all_meas_ids[meas_idxs]

truth_points = all_truth_points[truth_idxs]
truth_ids = all_truth_ids[truth_idxs]

metric, truth_to_measured_assignment = self.compute_gospa_metric(
meas_points, truth_points)
meas_points,
truth_points
)
truth_mapping = {
truth_id: meas_ids[meas_id] if meas_id != -1 else None
for truth_id, meas_id in zip(truth_ids, truth_to_measured_assignment)}
Expand All @@ -197,13 +225,16 @@ def compute_over_time(self, measured_states, measured_state_ids, truth_states,
gospa_metrics.append(metric)

# If only one timestamp is present then return a SingleTimeMetric
if len(timestamps) == 1:
if len(gospa_metrics) == 1:
return gospa_metrics[0]
else:
start_time = np.concatenate((all_truth_timestamps[:1], all_meas_timestamps[:1])).min()
end_time = np.concatenate((all_truth_timestamps[-1:], all_meas_timestamps[-1:])).max()

return TimeRangeMetric(
title='GOSPA Metrics',
value=gospa_metrics,
time_range=TimeRange(min(timestamps), max(timestamps)),
time_range=TimeRange(start=start_time, end=end_time),
generator=self)

def compute_assignments(self, cost_matrix):
Expand Down Expand Up @@ -319,6 +350,8 @@ def compute_gospa_metric(self, measured_states, truth_states):
if num_truth_states == 0:
# When truth states are empty all measured states are false
opt_cost = -1.0 * num_measured_states * dummy_cost
if self.alpha == 2:
gospa_metric['false'] = opt_cost
elif num_measured_states == 0:
# When measured states are empty all truth
# states are missed
Expand Down
Loading

0 comments on commit cb9b3ca

Please sign in to comment.