Skip to content

Commit

Permalink
Merge pull request #1228 from tc360950/vectorize-line-zone
Browse files Browse the repository at this point in the history
Vectorize line zone
  • Loading branch information
LinasKo committed Jun 5, 2024
2 parents cfbb668 + c27a282 commit 59c5ab5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 17 deletions.
30 changes: 13 additions & 17 deletions supervision/detection/line_zone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from supervision.detection.core import Detections
from supervision.detection.utils import cross_product
from supervision.draw.color import Color
from supervision.draw.utils import draw_text
from supervision.geometry.core import Point, Position, Vector
Expand Down Expand Up @@ -158,28 +159,23 @@ def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]:
]
)

for i, tracker_id in enumerate(detections.tracker_id):
box_anchors = [Point(x=x, y=y) for x, y in all_anchors[:, i, :]]

in_limits = all(
[
self.is_point_in_limits(point=anchor, limits=self.limits)
for anchor in box_anchors
]
)
cross_products_1 = cross_product(all_anchors, self.limits[0])
cross_products_2 = cross_product(all_anchors, self.limits[1])
in_limits = (cross_products_1 > 0) == (cross_products_2 > 0)
in_limits = np.all(in_limits, axis=0)

if not in_limits:
triggers = cross_product(all_anchors, self.vector) < 0
has_any_left_trigger = np.any(triggers, axis=0)
has_any_right_trigger = np.any(~triggers, axis=0)
is_uniformly_triggered = ~(has_any_left_trigger & has_any_right_trigger)
for i, tracker_id in enumerate(detections.tracker_id):
if not in_limits[i]:
continue

triggers = [
self.vector.cross_product(point=anchor) < 0 for anchor in box_anchors
]

if len(set(triggers)) == 2:
if not is_uniformly_triggered[i]:
continue

tracker_state = triggers[0]

tracker_state = has_any_left_trigger[i]
if tracker_id not in self.tracker_state:
self.tracker_state[tracker_id] = tracker_state
continue
Expand Down
18 changes: 18 additions & 0 deletions supervision/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy.typing as npt

from supervision.config import CLASS_NAME_DATA_FIELD
from supervision.geometry.core import Vector

MIN_POLYGON_POINT_COUNT = 3

Expand Down Expand Up @@ -833,3 +834,20 @@ def contains_multiple_segments(
mask_uint8, labels, connectivity=connectivity
)
return number_of_labels > 2


def cross_product(anchors: np.ndarray, vector: Vector) -> np.ndarray:
"""
Get array of cross products of each anchor with a vector.
Args:
anchors: Array of anchors of shape (number of anchors, detections, 2)
vector: Vector to calculate cross product with
Returns:
Array of cross products of shape (number of anchors, detections)
"""
vector_at_zero = np.array(
[vector.end.x - vector.start.x, vector.end.y - vector.start.y]
)
vector_start = np.array([vector.start.x, vector.start.y])
return np.cross(vector_at_zero, anchors - vector_start)

0 comments on commit 59c5ab5

Please sign in to comment.