Skip to content

Commit

Permalink
Fix classes bug (#22)
Browse files Browse the repository at this point in the history
* Fix classes bug

* Fix type
  • Loading branch information
nmhaddad committed Feb 25, 2024
1 parent 42b10d3 commit 7c04c24
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
## Known Issues:
- None

## v1.1.1 - Nate Haddad, 2/25/2024
- Add support for class selection (fixes bug)

## v1.1.0 - Nate Haddad, 2/25/2024
- Add `YOLOv9ONNX` to `detectors`

Expand Down
1 change: 1 addition & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def run_fast_track(input_video: str,
info="Select a supported object detector type. ONNX models are accepted for YOLOv8 and YOLOv7",
choices=[
"YOLO-NAS",
"YOLOv9",
"YOLOv8",
"YOLOv7",
],
Expand Down
4 changes: 3 additions & 1 deletion fast_track/detectors/third_party/yolov8/yolov8_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class YOLOv8ONNX(ObjectDetectorONNX):
between 0.0 and 1.0.
iou_thresh: The IoU threshold below which boxes will be filtered out during NMS. Valid values are
between 0.0 and 1.0.
classes: A list on integers corresponding to class indices to be used for detection.
agnostic:If True, the model is agnostic to the number of classes, and all classes will be considered as one.
multi_label:If True, each box may have multiple labels.
labels: A list of lists, where each inner list contains the apriori labels for a given image. The list should
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(self,
between 0.0 and 1.0.
iou_thresh: The IoU threshold below which boxes will be filtered out during NMS. Valid values are
between 0.0 and 1.0.
classes: A list on integers corresponding to class indices to be used for detection.
agnostic: If True, the model is agnostic to the number of classes, and all classes will be considered
as one.
multi_label:If True, each box may have multiple labels.
Expand Down Expand Up @@ -86,7 +88,7 @@ def postprocess(self, tensor: np.ndarray) -> Tuple[list, list, list]:
predictions = ops.non_max_suppression(torch.tensor(tensor[0]),
conf_thres=self.conf_thresh,
iou_thres=self.iou_thresh,
classes=len(self.classes),
classes=self.classes,
agnostic=self.agnostic,
multi_label=self.multi_label,
labels=self.labels,
Expand Down
4 changes: 3 additions & 1 deletion fast_track/detectors/third_party/yolov9/yolov9_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class YOLOv9ONNX(ObjectDetectorONNX):
between 0.0 and 1.0.
iou_thresh: The IoU threshold below which boxes will be filtered out during NMS. Valid values are
between 0.0 and 1.0.
classes: A list on integers corresponding to class indices to be used for detection.
agnostic:If True, the model is agnostic to the number of classes, and all classes will be considered as one.
multi_label: If True, each box may have multiple labels.
labels: A list of lists, where each inner list contains the apriori labels for a given image. The list should
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self,
between 0.0 and 1.0.
iou_thresh: The IoU threshold below which boxes will be filtered out during NMS. Valid values are
between 0.0 and 1.0.
classes: A list on integers corresponding to class indices to be used for detection.
agnostic: If True, the model is agnostic to the number of classes, and all classes will be considered
as one.
multi_label:If True, each box may have multiple labels.
Expand Down Expand Up @@ -87,7 +89,7 @@ def postprocess(self,
predictions = ops.non_max_suppression(torch.tensor(tensor[0]),
conf_thres=self.conf_thresh,
iou_thres=self.iou_thresh,
classes=len(self.classes),
classes=self.classes,
agnostic=self.agnostic,
multi_label=self.multi_label,
labels=self.labels,
Expand Down
12 changes: 6 additions & 6 deletions fast_track/detectors/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def get_detector(weights_path: str,
**detector_params)
elif detector_type.startswith("yolov7"):
return YOLOv7ONNX(weights_path=weights_path,
names=names,
image_shape=image_shape,
**detector_params)
names=names,
image_shape=image_shape,
**detector_params)
elif detector_type.startswith("yolov9"):
return YOLOv9ONNX(weights_path=weights_path,
names=names,
image_shape=image_shape,
**detector_params)
names=names,
image_shape=image_shape,
**detector_params)
else:
raise ValueError("Detector name not found.")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "fast_track"
version = "1.1.0"
version = "1.1.1"
description = "Object detection and tracking pipeline"
readme = "README.md"
keywords = [
Expand Down

0 comments on commit 7c04c24

Please sign in to comment.