Skip to content

Commit

Permalink
GOSPAMetric: Fixed formatting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
marvonlar committed Jul 15, 2024
1 parent c68017f commit 270b87a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
21 changes: 16 additions & 5 deletions stonesoup/metricgenerator/ospametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,21 @@ def compute_over_time(self, measured_states, measured_state_ids, truth_states,
"""
all_meas_points = np.array(measured_states)
all_meas_ids = np.array(measured_state_ids)
all_meas_timestamps = np.fromiter((state.timestamp for state in measured_states), dtype='datetime64[us]')
all_meas_timestamps = np.fromiter(
(state.timestamp for state in measured_states),
dtype='datetime64[us]'
)
meas_order = np.argsort(all_meas_timestamps)
all_meas_points = all_meas_points[meas_order]
all_meas_ids = all_meas_ids[meas_order]
all_meas_timestamps = all_meas_timestamps[meas_order]

all_truth_points = np.array(truth_states)
all_truth_ids = np.array(truth_state_ids)
all_truth_timestamps = np.fromiter((state.timestamp for state in truth_states), dtype='datetime64[us]')
all_truth_timestamps = np.fromiter(
(state.timestamp for state in truth_states),
dtype='datetime64[us]'
)
truth_order = np.argsort(all_truth_timestamps)
all_truth_points = all_truth_points[truth_order]
all_truth_ids = all_truth_ids[truth_order]
Expand Down Expand Up @@ -207,7 +213,10 @@ def compute_over_time(self, measured_states, measured_state_ids, truth_states,
truth_points = all_truth_points[truth_idxs]
truth_ids = all_truth_ids[truth_idxs]

metric, truth_to_measured_assignment = self.compute_gospa_metric(meas_points, truth_points)
metric, truth_to_measured_assignment = self.compute_gospa_metric(
meas_points,
truth_points
)
truth_mapping = {
truth_id: meas_ids[meas_id] if meas_id != -1 else None
for truth_id, meas_id in zip(truth_ids, truth_to_measured_assignment)}
Expand All @@ -218,7 +227,6 @@ def compute_over_time(self, measured_states, measured_state_ids, truth_states,
metric.value['switching']**self.alpha,
1.0/self.alpha)
gospa_metrics.append(metric)
tmax = timestamp

# If only one timestamp is present then return a SingleTimeMetric
if len(gospa_metrics) == 1:
Expand All @@ -227,7 +235,10 @@ def compute_over_time(self, measured_states, measured_state_ids, truth_states,
return TimeRangeMetric(
title='GOSPA Metrics',
value=gospa_metrics,
time_range=TimeRange(min(all_truth_timestamps[0], all_meas_timestamps[0]), max(all_truth_timestamps[-1], all_meas_timestamps[-1])),
time_range=TimeRange(
start=min(all_truth_timestamps[0], all_meas_timestamps[0]),
end=max(all_truth_timestamps[-1], all_meas_timestamps[-1])
),
generator=self)

def compute_assignments(self, cost_matrix):
Expand Down
19 changes: 10 additions & 9 deletions stonesoup/metricgenerator/tests/test_ospametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,11 @@ def test_gospametric_speed():

track_states = np.vstack(
(
gt_states[:num_true_positives] + np.random.normal(0, track_position_sigma, (num_true_positives, num_timesteps, 2)),
5*gt_states[:num_false_positives] + np.random.normal(0, track_position_sigma, (num_false_positives, num_timesteps, 2))
gt_states[:num_true_positives],
5*gt_states[:num_false_positives]
)
)
track_states += np.random.normal(0, track_position_sigma, track_states.shape)

time = datetime.datetime.now()
# Multiple tracks and truths present at two timesteps
Expand All @@ -420,13 +421,13 @@ def test_gospametric_speed():
}
truths = {
GroundTruthPath(
states=[
State(
state_vector=gt_states[gt_idx, state_idx],
timestamp=time + datetime.timedelta(ts[state_idx])
)
for state_idx in range(gt_states.shape[1])
])
states=[
State(
state_vector=gt_states[gt_idx, state_idx],
timestamp=time + datetime.timedelta(ts[state_idx])
)
for state_idx in range(gt_states.shape[1])
])
for gt_idx in range(gt_states.shape[0])
}

Expand Down

0 comments on commit 270b87a

Please sign in to comment.