Skip to content

Commit

Permalink
GOSPAMetric: Pre-group GTs and tracks
Browse files Browse the repository at this point in the history
  • Loading branch information
marvonlar committed Jul 15, 2024
1 parent 473322e commit c68017f
Showing 1 changed file with 39 additions and 36 deletions.
75 changes: 39 additions & 36 deletions stonesoup/metricgenerator/ospametric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from itertools import chain, zip_longest
from itertools import chain, groupby, zip_longest

import numpy as np
from scipy.optimize import linear_sum_assignment
Expand Down Expand Up @@ -161,47 +161,51 @@ def compute_over_time(self, measured_states, measured_state_ids, truth_states,
exist for in the parameters. metric.value contains a list of metrics
for the GOSPA metric at each timestamp
"""

# Make a list of all the unique timestamps used
# Make a sorted list of all the unique timestamps used
timestamps = sorted({
state.timestamp
for state in chain(measured_states, truth_states)})

all_meas_points = np.array(measured_states)
all_meas_ids = np.array(measured_state_ids)
all_meas_timestamps = np.fromiter((state.timestamp for state in measured_states), dtype='datetime64[us]')
meas_order = np.argsort(all_meas_timestamps)
all_meas_points = all_meas_points[meas_order]
all_meas_ids = all_meas_ids[meas_order]
all_meas_timestamps = all_meas_timestamps[meas_order]

all_truth_points = np.array(truth_states)
all_truth_ids = np.array(truth_state_ids)
all_truth_timestamps = np.fromiter((state.timestamp for state in truth_states), dtype='datetime64[us]')
truth_order = np.argsort(all_truth_timestamps)
all_truth_points = all_truth_points[truth_order]
all_truth_ids = all_truth_ids[truth_order]
all_truth_timestamps = all_truth_timestamps[truth_order]

truth_iter = iter(groupby(range(len(all_truth_ids)), all_truth_timestamps.__getitem__))
meas_iter = iter(groupby(range(len(all_meas_ids)), all_meas_timestamps.__getitem__))

next_truth = next(truth_iter)
next_meas = next(meas_iter)

switching_metric = _SwitchingLoss(self.switching_penalty, self.p)
gospa_metrics = []
import time
for timestamp in timestamps:
begin = time.time()
begin_ = time.time()
meas_mask = all_meas_timestamps == timestamp
end_ = time.time()
print(f'mask took {(end_ - begin_) * 1e3:.2f}ms')
# np.array doesn't work for ParticleState
begin_ = time.time()
meas_points = all_meas_points[meas_mask]
meas_ids = all_meas_ids[meas_mask]
end_ = time.time()
print(f'first took {(end_ - begin_) * 1e3:.2f}ms')

begin_ = time.time()
truth_mask = all_truth_timestamps == timestamp
end_ = time.time()
print(f'truth mask took {(end_ - begin_) * 1e3:.2f}ms')

begin_ = time.time()
truth_points = all_truth_points[truth_mask]
truth_ids = all_truth_ids[truth_mask]
end_ = time.time()
print(f'second took {(end_ - begin_) * 1e3:.2f}ms')

while next_truth is not None or next_meas is not None:
timestamp = min(next_truth[0], next_meas[0])

truth_idxs = []

if timestamp == next_truth[0]:
truth_idxs = np.fromiter(next_truth[1], dtype=int)
next_truth = next(truth_iter, None)

meas_idxs = []

if timestamp == next_meas[0]:
meas_idxs = np.fromiter(next_meas[1], dtype=int)
next_meas = next(meas_iter, None)

meas_points = all_meas_points[meas_idxs]
meas_ids = all_meas_ids[meas_idxs]

truth_points = all_truth_points[truth_idxs]
truth_ids = all_truth_ids[truth_idxs]

metric, truth_to_measured_assignment = self.compute_gospa_metric(meas_points, truth_points)
truth_mapping = {
Expand All @@ -214,17 +218,16 @@ def compute_over_time(self, measured_states, measured_state_ids, truth_states,
metric.value['switching']**self.alpha,
1.0/self.alpha)
gospa_metrics.append(metric)
end = time.time()
print(f'iter took {(end - begin) * 1e3:.2f}ms')
tmax = timestamp

# If only one timestamp is present then return a SingleTimeMetric
if len(timestamps) == 1:
if len(gospa_metrics) == 1:
return gospa_metrics[0]
else:
return TimeRangeMetric(
title='GOSPA Metrics',
value=gospa_metrics,
time_range=TimeRange(min(timestamps), max(timestamps)),
time_range=TimeRange(min(all_truth_timestamps[0], all_meas_timestamps[0]), max(all_truth_timestamps[-1], all_meas_timestamps[-1])),
generator=self)

def compute_assignments(self, cost_matrix):
Expand Down

0 comments on commit c68017f

Please sign in to comment.