Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add track stitching class and example. #764

Merged
merged 20 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 29 additions & 234 deletions docs/examples/Track_Stitcher_Example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@
# Introduction
# ------------
# Track Stitching considers a set of broken fragments of track (which we call tracklets), and aims to identify which
# fragments should be stitched (joined) together to form a track. This is done by considering the state of a tracked
# fragments should be stitched (joined) together to form one track. This is done by considering the state of a tracked
# object and predicting its state at a future (or past) time. This example generates a set of `tracklets` , before
# applying track stitching. The figure below visualizes the aim of track stitching: taking a set of tracklets
# (left, black) and producing a set of tracks (right, blue/red).

# %%
# .. image:: ../_static/track_stitching_basic_example.png
# :width: 500
# :alt: Image showing NN association of two tracks
# :alt: Image showing basic example of track stitching

# %%
# Track Stitching Method
# ^^^^^^^^^^^^^^^^^^^^^^
# Consider the following scenario: We have a bunch of sections of track that are all disconnected from eachother. We
# Consider the following scenario: We have a bunch of sections of track that are all disconnected from each other. We
# aim to stitch the track sections together into full tracks. We can use the known states of tracklets at known times
# to predict where the tracked object would be at a different time. We can use this information to associate tracklets
# with eachother. Methods of doing this are explained below the following figure.
# with each other. Methods of associating tracklets are explained below.
#
# Predicting forward
# ^^^^^^^^^^^^^^^^^^
Expand All @@ -38,14 +38,14 @@
#
# Predicting backward
# ^^^^^^^^^^^^^^^^^^^
# Similarly to predicting forward, we can consider the state at the start point of a track section, call this at time
# Similarly to predicting forward, we can consider the state at the start point of a track section, call this time
# :math:`k`, and predict what the state would have been at time :math:`k - \delta k`. We can then associate and stitch tracks
# together as before. This method is used in the function `backward_predict`.
#
# Using both predictions
# ^^^^^^^^^^^^^^^^^^^^^^
# We can use both methods at the same time to calculate the probability that two track sections are part of the same
# track. The track stitcher in this example uses the `KalmanPredictor` to make predictions on which tracklets should
# track. The track stitcher in this example uses the `KalmanPredictor` to make predictions about which tracklets should
# be stitched into the same track.

# %%
Expand All @@ -57,7 +57,8 @@
from stonesoup.types.track import Track
from stonesoup.types.state import GaussianState
from stonesoup.plotter import Plotter, Dimension
from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, ConstantVelocity,OrnsteinUhlenbeck
from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, ConstantVelocity, \
OrnsteinUhlenbeck
from stonesoup.predictor.kalman import KalmanPredictor
from stonesoup.updater.kalman import KalmanUpdater
from stonesoup.hypothesiser.distance import DistanceHypothesiser
Expand All @@ -79,7 +80,7 @@
# -------------------
# Set Variables for Scenario Generation
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# The following cell contains parameters used to generate input truths.
# The code below contains parameters used to generate input truth paths.
#
# The `number_of_targets` is the total number of truth paths generated in the initial simulation.
#
Expand All @@ -88,16 +89,16 @@
# Each truth object is split into a number of segments chosen randomly from the range (1, `max_segments`).
#
# You can define the minimum and maximum length that segments can be, by setting `min_segment_length` and
# `max_segment_length` respectively.
# `max_segment_length`, respectively.
#
# Similarly, the length of disjoint sections can be bounded by `min_disjoint_length` and `max_disjoint_length`.
#
# The start time of each truthpath is bounded between :math:`t` = 0 and :math:`t` = `max_track_start`.
# The start time of each truth path is bounded between :math:`t` = 0 and :math:`t` = `max_track_start`.
#
# The simulation will run for any number of spacial dimensions, given by `n_spacial_dimensions`.
#
# Finally, the transition model can be set by setting `TM` to either "CV" or "KTR" as indicated in the comments in the
# cell below.
# code below.
start_time = datetime.now().replace(second=0, microsecond=0)
np.random.seed(100)

Expand Down Expand Up @@ -153,28 +154,7 @@
# %%
# Define Tracker function
# ^^^^^^^^^^^^^^^^^^^^^^^


def tracker(all_measurements, initiator, deleter, data_associator, hypothesiser, predictor, updater):
tracks = set()
historic_tracks = set()
for n, measurements in enumerate(all_measurements):
hypotheses = data_associator.associate(tracks, measurements, start_time + timedelta(seconds=n))
associated_measurements = set()
for track in tracks:
hypothesis = hypotheses[track]
if hypothesis.measurement:
post = updater.update(hypothesis)
track.append(post)
associated_measurements.add(hypothesis.measurement)
else:
track.append(hypothesis.prediction)
del_tracks = deleter.delete_tracks(tracks)
tracks -= del_tracks
tracks |= initiator.initiate(measurements - associated_measurements, start_time + timedelta(seconds=n))
historic_tracks |= del_tracks
historic_tracks |= tracks
return historic_tracks
from stonesoup.stitcher import tracker
spike-dstl marked this conversation as resolved.
Show resolved Hide resolved

# %%
# Generate ground truths and truthlets
Expand All @@ -200,18 +180,25 @@ def tracker(all_measurements, initiator, deleter, data_associator, hypothesiser,
for i in range(number_of_targets):
# Sets number of segments from range of random numbers
number_of_segments = int(np.random.choice(range(1, max_segments), 1))

# Set length of first truthlet segment
truthlet0_length = np.random.choice(range(max_track_start), 1)

# Set lengths of each of the truthlet segments
truthlet_lengths = np.random.choice(range(min_segment_length, max_segment_length), number_of_segments)

# Set lengths of each disjoint section
disjoint_lengths = np.random.choice(range(min_disjoint_length, max_disjoint_length), number_of_segments)

# Sum pairs of truthlets and disjoints, and set the start-point of the truth path
segment_pair_lengths = np.insert(truthlet_lengths + disjoint_lengths, 0, truthlet0_length, axis=0)

# Cumulative sum of segments, giving the start point of each truth segment
truthlet_startpoints = np.cumsum(segment_pair_lengths)

# Sum truth segments length to start point, giving end point for each segment
truthlet_endpoints = truthlet_startpoints + np.append(truthlet_lengths, 0)

# Set start and end points for each segment
starts = truthlet_startpoints[:number_of_segments]
stops = truthlet_endpoints[:number_of_segments]
Expand Down Expand Up @@ -247,7 +234,7 @@ def tracker(all_measurements, initiator, deleter, data_associator, hypothesiser,
groundtruth_path=truthlet)
measurementlet.append({m0})
all_measurements.append({m0})
tracklet = tracker(measurementlet, initiator, deleter, data_associator, hypothesiser, predictor, updater)
tracklet = tracker(measurementlet, initiator, deleter, data_associator, hypothesiser, predictor, updater, start_time)
for t in tracklet:
all_tracks.add(t)

Expand Down Expand Up @@ -283,199 +270,7 @@ def tracker(all_measurements, initiator, deleter, data_associator, hypothesiser,
# be stitched together. The function `stitch` uses `forward_predict` and `backward_predict` to pair and 'stitch' track
# sections together.
spike-dstl marked this conversation as resolved.
Show resolved Hide resolved

class TrackStitcher():
def __init__(self, forward_hypothesiser=None, backward_hypothesiser=None):
self.forward_hypothesiser = forward_hypothesiser
self.backward_hypothesiser = backward_hypothesiser

@staticmethod
def __extract_detection(track, backward=False):
state = track[0]
if backward:
state = track[-1]
return Detection(state_vector=state.state_vector,
timestamp=state.timestamp,
measurement_model=LinearGaussian(
ndim_state=2 * n_spacial_dimensions,
mapping=list(range(2 * n_spacial_dimensions)),
noise_covar=state.covar),
metadata=track.id)

@staticmethod
def __get_track(track_id, tracks):
for track in tracks:
if track.id == track_id:
return track

@staticmethod
def __merge(a, b):
if a[-1] == b[0]:
a.pop(-1)
return a + b
else:
return a

def forward_predict(self, tracks, start_time):
x_forward = {track.id: [] for track in tracks}
poss_pairs = []
for n in range(int((min([track[0].timestamp for track in tracks]) - start_time).total_seconds()),
int((max([track[-1].timestamp for track in tracks]) - start_time).total_seconds())):
poss_tracks = []
poss_detections = set()
for track in tracks:
if track[-1].timestamp < start_time + timedelta(seconds=n):
poss_tracks.append(track)
if track[0].timestamp == start_time + timedelta(seconds=n):
poss_detections.add(self.__extract_detection(track))
if len(poss_tracks) > 0 and len(poss_detections) > 0:
for track in poss_tracks:
a = self.forward_hypothesiser.hypothesise(track, poss_detections, start_time + timedelta(seconds=n))
if a[0].measurement.metadata == {}:
continue
else:
x_forward[track.id].append((a[0].measurement.metadata, a[0].distance))
return x_forward

def backward_predict(self, tracks, start_time):
x_backward = {track.id: [] for track in tracks}
poss_pairs = []
for n in range(int((max([track[-1].timestamp for track in tracks]) - start_time).total_seconds()),
int((min([track[0].timestamp for track in tracks]) - start_time).total_seconds()),
-1):
poss_tracks = []
poss_detections = set()
for track in tracks:
if track[0].timestamp > start_time + timedelta(seconds=n):
poss_tracks.append(track)
if track[-1].timestamp == start_time + timedelta(seconds=n):
poss_detections.add(self.__extract_detection(track, backward=True))
if len(poss_tracks) > 0 and len(poss_detections) > 0:
for track in poss_tracks:
a = self.backward_hypothesiser.hypothesise(track, poss_detections,
start_time + timedelta(seconds=n))
if a[0].measurement.metadata == {}:
continue
else:
x_backward[a[0].measurement.metadata].append((track.id, a[0].distance))
return x_backward

def stitch(self, tracks, start_time):
forward, backward = False, False
if self.forward_hypothesiser != None:
forward = True
if self.backward_hypothesiser != None:
backward = True

tracks = list(tracks)
x = {track.id: [] for track in tracks}
if forward:
x_forward = self.forward_predict(tracks, start_time)
if backward:
x_backward = self.backward_predict(tracks, start_time)

if forward and not (backward):
x = x_forward
elif not (forward) and backward:
x = x_backward
else:
for key in x:
if x_forward[key] == [] and x_backward[key] == []:
x[key] = []
elif x_forward[key] == [] and x_backward[key] != []:
x[key] = x_backward[key]
elif x_forward[key] != [] and x_backward[key] == []:
x[key] = x_forward[key]
else:
arr = []
for f_val in x_forward[key]:
for b_val in x_backward[key]:
if f_val[0] == b_val[0]:
arr.append((f_val[0], f_val[1] + b_val[1]))
for f_val in x_forward[key]:
in_arr = False
for a_val in arr:
if f_val[0] == a_val[0]:
in_arr = True
if not (in_arr):
arr.append((f_val[0], f_val[1] + 300))
for b_val in x_backward[key]:
in_arr = False
for a_val in arr:
if b_val[0] == a_val[0]:
in_arr = True
if not (in_arr):
arr.append((b_val[0], b_val[1] + 300))
x[key] = arr

matrix_val = [[300] * len(tracks) for i in range(len(tracks))]
matrix_track = [[None] * len(tracks) for i in range(len(tracks))]
for i in range(len(tracks)):
for j in range(len(tracks)):
if tracks[i].id in x:
if tracks[j].id in [combo[0] for combo in x[tracks[i].id]]:
matrix_val[i][j] = [tup[1] for tup in x[tracks[i].id] if tup[0] == tracks[j].id][0]
matrix_track[i][j] = (tracks[i].id, tracks[j].id)
else:
matrix_track[i][j] = (tracks[i].id, None)

row_ind, col_ind = linear_sum_assignment(matrix_val)

for i in range(len(col_ind)):
start_track = matrix_track[row_ind[i]][col_ind[i]][0]
end_track = matrix_track[row_ind[i]][col_ind[i]][1]
if end_track == None:
x[start_track] = None
else:
x[start_track] = end_track

combo = []
for key in x:
if x[key] is None:
continue
elif len(combo) == 0 or not (
any(key in sublist for sublist in combo) or any(x[key] in sublist for sublist in combo)):
combo.append([key, x[key]])
elif any(x[key] in sublist for sublist in combo):
for track_list in combo:
if x[key] in track_list:
track_list.insert(track_list.index(x[key]), key)
else:
for track_list in combo:
if key in track_list:
track_list.insert(track_list.index(key) + 1, x[key])

i = 0
count = 0
while i != len(combo):
id1 = combo[i]
id2 = combo[count]
new_list1 = self.__merge(deepcopy(id1), deepcopy(id2))
new_list2 = self.__merge(deepcopy(id2), deepcopy(id1))
if len(new_list1) == len(id1) and len(new_list2) == len(id2):
count += 1
else:
combo.remove(id1)
combo.remove(id2)
count = 0
i = 0
if len(new_list1) > len(id1):
combo.append(new_list1)
else:
combo.append(new_list2)
if count == len(combo):
count = 0
i += 1
continue
tracks = set(tracks)
for ids in combo:
x = []
for a in ids:
track = self.__get_track(a, tracks)
x = x + track.states
tracks.remove(track)
tracks.add(Track(x))

return tracks
from stonesoup.stitcher import TrackStitcher

# %%
# Applying the Track Stitcher
Expand All @@ -491,7 +286,7 @@ def stitch(self, tracks, start_time):
hypothesiser = DistanceHypothesiser(predictor, updater, Mahalanobis(), missed_distance=300)
stitcher = TrackStitcher(forward_hypothesiser=hypothesiser)

stitched_tracks = stitcher.stitch(all_tracks, start_time)
stitched_tracks, _ = stitcher.stitch(all_tracks, start_time)

for pair in dim_pairs:
plotter = Plotter();
Expand Down Expand Up @@ -527,7 +322,7 @@ def StitcherCorrectness(StitchedTracks):
total += 1
if id1[0] == id2[0] and id1[1] == (id2[1] - 1):
count += 1
return (count / total * 100)
return count / total * 100


print("Tracklets stitched correctly: ", StitcherCorrectness(stitched_tracks), "%")
Expand All @@ -536,12 +331,12 @@ def StitcherCorrectness(StitchedTracks):
# SIAP Metrics
# ^^^^^^^^^^^^
# The following cell calculates and records a range of SIAP (Single Integrated Air Picture) metrics to assess the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# The following cell calculates and records a range of SIAP (Single Integrated Air Picture) metrics to assess the
# The code below calculates and records a range of SIAP (Single Integrated Air Picture) metrics to assess the

# accuracy of the stitcher. The value of 'association_threshold' should be adjusted to represent the acceptable
# distance for association for the scenaio that is being considered. For example, associating with a threshold of 50
# metres may be acceptable if tracking a large ship, but not so useful for tracking cell movement.
# accuracy of the stitcher. The value of math:`association_threshold` should be adjusted to represent the acceptable
# distance for association for the scenario that is being considered. For example, associating with a threshold of 50
# metres may be acceptable if tracking a large ship, but not so useful for tracking biological cell movement.
#
# SIAP Ambiguity: Number of tracks assigned to a true object. Important as a value not equal to 1 suggests that the
# stitcher is not stitching whole tracks together, or stitching multiple tracks into one.
# SIAP Ambiguity: Number of tracks assigned to a single true object. Important as a value not equal to 1 suggests that
# the stitcher is not stitching whole tracks together, or stitching multiple tracks into one.
#
# SIAP Completeness: Fraction of true objects being tracked. Not a valuable metric for track stitching evaluation as we
# are only tracking fractions of the true objects - metric value is scaled by the ratio of truthlets to disjoint
spike-dstl marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading