Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize line zone #1228

Merged
merged 10 commits into from
Jun 5, 2024
52 changes: 35 additions & 17 deletions supervision/detection/line_zone.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,28 +158,29 @@ def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]:
]
)

cross_products_1 = self._cross_product(all_anchors, self.limits[0])
cross_products_2 = self._cross_product(all_anchors, self.limits[1])
# anchor is in limits if it's on the same side of both limit vectors
in_limits = ~np.logical_xor(cross_products_1 > 0, cross_products_2 > 0)
LinasKo marked this conversation as resolved.
Show resolved Hide resolved
# Reduce array to find out if all anchors for a detection are within limits
in_limits = np.min(in_limits, axis=0)
LinasKo marked this conversation as resolved.
Show resolved Hide resolved

# Calculate which anchors lie to the left of the line
LinasKo marked this conversation as resolved.
Show resolved Hide resolved
triggers = self._cross_product(all_anchors, self.vector) < 0
# Reduce to find out if all anchors for a
# detection lie to the left (or right) of the line
max_triggers = np.max(triggers, axis=0)
SkalskiP marked this conversation as resolved.
Show resolved Hide resolved
LinasKo marked this conversation as resolved.
Show resolved Hide resolved
min_triggers = np.min(triggers, axis=0)
for i, tracker_id in enumerate(detections.tracker_id):
box_anchors = [Point(x=x, y=y) for x, y in all_anchors[:, i, :]]

in_limits = all(
[
self.is_point_in_limits(point=anchor, limits=self.limits)
for anchor in box_anchors
]
)

if not in_limits:
if not in_limits[i]:
continue

triggers = [
self.vector.cross_product(point=anchor) < 0 for anchor in box_anchors
]

if len(set(triggers)) == 2:
if min_triggers[i] != max_triggers[i]:
# One anchor lies to the left of the line
# whilst another lies to the right
LinasKo marked this conversation as resolved.
Show resolved Hide resolved
continue

tracker_state = triggers[0]

tracker_state = max_triggers[i]
if tracker_id not in self.tracker_state:
self.tracker_state[tracker_id] = tracker_state
continue
Expand All @@ -197,6 +198,23 @@ def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]:

return crossed_in, crossed_out

@staticmethod
def _cross_product(anchors: np.ndarray, vector: Vector) -> np.ndarray:
LinasKo marked this conversation as resolved.
Show resolved Hide resolved
"""
Get array of cross products of each anchor with a vector.
Args:
anchors: Array of anchors of shape (number of anchors, detections, 2)
vector: Vector to calculate cross product with

Returns:
Array of cross products of shape (number of anchors, detections)
"""
vector_at_zero = np.array(
[vector.end.x - vector.start.x, vector.end.y - vector.start.y]
)
vector_start = np.array([vector.start.x, vector.start.y])
return np.cross(vector_at_zero, anchors - vector_start)


class LineZoneAnnotator:
def __init__(
Expand Down