From 7c04c24482976b2c187151710cead473c5e4c4ad Mon Sep 17 00:00:00 2001 From: Nate Haddad Date: Sun, 25 Feb 2024 12:00:54 -0500 Subject: [PATCH] Fix classes bug (#22) * Fix classes bug * Fix type --- CHANGELOG.md | 3 +++ app.py | 1 + .../detectors/third_party/yolov8/yolov8_onnx.py | 4 +++- .../detectors/third_party/yolov9/yolov9_onnx.py | 4 +++- fast_track/detectors/util.py | 12 ++++++------ pyproject.toml | 2 +- 6 files changed, 17 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 19e398f..973f9fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,9 @@ ## Known Issues: - None +## v1.1.1 - Nate Haddad, 2/25/2024 +- Add support for class selection (fixes bug) + ## v1.1.0 - Nate Haddad, 2/25/2024 - Add `YOLOv9ONNX` to `detectors` diff --git a/app.py b/app.py index bf69d6f..5925fcc 100644 --- a/app.py +++ b/app.py @@ -89,6 +89,7 @@ def run_fast_track(input_video: str, info="Select a supported object detector type. ONNX models are accepted for YOLOv8 and YOLOv7", choices=[ "YOLO-NAS", + "YOLOv9", "YOLOv8", "YOLOv7", ], diff --git a/fast_track/detectors/third_party/yolov8/yolov8_onnx.py b/fast_track/detectors/third_party/yolov8/yolov8_onnx.py index d6f83ac..3c515e1 100644 --- a/fast_track/detectors/third_party/yolov8/yolov8_onnx.py +++ b/fast_track/detectors/third_party/yolov8/yolov8_onnx.py @@ -25,6 +25,7 @@ class YOLOv8ONNX(ObjectDetectorONNX): between 0.0 and 1.0. iou_thresh: The IoU threshold below which boxes will be filtered out during NMS. Valid values are between 0.0 and 1.0. + classes: A list on integers corresponding to class indices to be used for detection. agnostic:If True, the model is agnostic to the number of classes, and all classes will be considered as one. multi_label:If True, each box may have multiple labels. labels: A list of lists, where each inner list contains the apriori labels for a given image. The list should @@ -57,6 +58,7 @@ def __init__(self, between 0.0 and 1.0. iou_thresh: The IoU threshold below which boxes will be filtered out during NMS. Valid values are between 0.0 and 1.0. + classes: A list on integers corresponding to class indices to be used for detection. agnostic: If True, the model is agnostic to the number of classes, and all classes will be considered as one. multi_label:If True, each box may have multiple labels. @@ -86,7 +88,7 @@ def postprocess(self, tensor: np.ndarray) -> Tuple[list, list, list]: predictions = ops.non_max_suppression(torch.tensor(tensor[0]), conf_thres=self.conf_thresh, iou_thres=self.iou_thresh, - classes=len(self.classes), + classes=self.classes, agnostic=self.agnostic, multi_label=self.multi_label, labels=self.labels, diff --git a/fast_track/detectors/third_party/yolov9/yolov9_onnx.py b/fast_track/detectors/third_party/yolov9/yolov9_onnx.py index b1648cd..c9f3d16 100644 --- a/fast_track/detectors/third_party/yolov9/yolov9_onnx.py +++ b/fast_track/detectors/third_party/yolov9/yolov9_onnx.py @@ -25,6 +25,7 @@ class YOLOv9ONNX(ObjectDetectorONNX): between 0.0 and 1.0. iou_thresh: The IoU threshold below which boxes will be filtered out during NMS. Valid values are between 0.0 and 1.0. + classes: A list on integers corresponding to class indices to be used for detection. agnostic:If True, the model is agnostic to the number of classes, and all classes will be considered as one. multi_label: If True, each box may have multiple labels. labels: A list of lists, where each inner list contains the apriori labels for a given image. The list should @@ -56,6 +57,7 @@ def __init__(self, between 0.0 and 1.0. iou_thresh: The IoU threshold below which boxes will be filtered out during NMS. Valid values are between 0.0 and 1.0. + classes: A list on integers corresponding to class indices to be used for detection. agnostic: If True, the model is agnostic to the number of classes, and all classes will be considered as one. multi_label:If True, each box may have multiple labels. @@ -87,7 +89,7 @@ def postprocess(self, predictions = ops.non_max_suppression(torch.tensor(tensor[0]), conf_thres=self.conf_thresh, iou_thres=self.iou_thresh, - classes=len(self.classes), + classes=self.classes, agnostic=self.agnostic, multi_label=self.multi_label, labels=self.labels, diff --git a/fast_track/detectors/util.py b/fast_track/detectors/util.py index 597fe40..7dc253c 100644 --- a/fast_track/detectors/util.py +++ b/fast_track/detectors/util.py @@ -64,13 +64,13 @@ def get_detector(weights_path: str, **detector_params) elif detector_type.startswith("yolov7"): return YOLOv7ONNX(weights_path=weights_path, - names=names, - image_shape=image_shape, - **detector_params) + names=names, + image_shape=image_shape, + **detector_params) elif detector_type.startswith("yolov9"): return YOLOv9ONNX(weights_path=weights_path, - names=names, - image_shape=image_shape, - **detector_params) + names=names, + image_shape=image_shape, + **detector_params) else: raise ValueError("Detector name not found.") diff --git a/pyproject.toml b/pyproject.toml index 037d50b..fc3c11f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "fast_track" -version = "1.1.0" +version = "1.1.1" description = "Object detection and tracking pipeline" readme = "README.md" keywords = [