Skip to content

Commit

Permalink
Changed api for devices (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Nov 25, 2023
1 parent a846b58 commit d162caa
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 28 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ Use the wrapper to quickly deploy face detection in your projects:
from yolo5face.get_model import get_model
import cv2

model = get_model("yolov5n", gpu=-1, target_size=512, min_face=24)
model = get_model("yolov5n", device=-1, target_size=512, min_face=24)

image = cv2.imread(<IMAGE_PATH>)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

boxes, key_points, scores = model(image)
```

* **gpu**: Specify the GPU number, `-1` or `cpu` for CPU usage.
* **device**: Specify device `cpu`, `cuda`, `mps` or integer for the number of cuda device.
* **target_size**: The minimum size of the target image for detection.
* **min_face**: The minimum face size in pixels. Faces smaller than this value will be ignored.

Expand All @@ -41,7 +41,7 @@ To use this feature:
```bash
from yolo5face.get_model import get_model

model = get_model("yolov5n", gpu=-1, target_size=[320, 640, 1280], min_face=24)
model = get_model("yolov5n", device=-1, target_size=[320, 640, 1280], min_face=24)

# Aggregate detections over the specified target sizes
boxes, key_points, scores = aggregator(image)
Expand Down
2 changes: 1 addition & 1 deletion tests/aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tests.conftest import test_images as images
from yolo5face.get_model import get_model

model = get_model("yolov5n", gpu=-1, target_size=[512, 1024])
model = get_model("yolov5n", device="cpu", target_size=[512, 1024])


@mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tests.conftest import test_images as images
from yolo5face.get_model import get_model

model = get_model("yolov5n", gpu=-1, target_size=512)
model = get_model("yolov5n", device="cpu", target_size=512)


@mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion yolo5face/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.7"
__version__ = "0.0.8"
14 changes: 12 additions & 2 deletions yolo5face/get_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import NamedTuple

import torch
from torch.hub import download_url_to_file

from yolo5face.yoloface.YoloDetectorAggregator import YoloDetectorAggregator
Expand All @@ -26,7 +27,7 @@ def get_file_name(url: str) -> str:

def get_model(
model_name: str,
gpu: int,
device: str,
target_size: int,
min_face: int = 24,
weights_path: str = "~/.torch/models",
Expand All @@ -46,10 +47,19 @@ def get_model(
if not config_file_path.exists():
download_url_to_file(config_name, config_file_path.as_posix(), progress=True)

if (
(torch.backends.mps.is_available() and device == "mps")
or (device == "cuda" or isinstance(device, int))
and torch.cuda.is_available()
):
device = torch.device(device)
else:
device = torch.device("cpu")

return YoloDetectorAggregator(
target_sizes=target_size,
min_face=min_face,
gpu=gpu,
device=device,
weights_name=weight_file_path,
config_name=config_file_path,
)
25 changes: 5 additions & 20 deletions yolo5face/yoloface/face_detector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import os
from pathlib import Path

import cv2
Expand All @@ -21,10 +20,10 @@
class YoloDetector:
def __init__(
self,
weights_name: str = "yolov5n_state_dict.pt",
config_name: str = "yolov5n.yaml",
gpu: int | str = 0,
min_face: int = 100,
weights_name: str,
config_name: str,
device: torch.device,
min_face: int,
target_size: int | None = None,
):
"""
Expand All @@ -37,27 +36,13 @@ def __init__(
"""
self._class_path = Path(__file__).parent.absolute()
self.gpu = gpu
self.device = device
self.target_size = target_size
self.min_face = min_face

self.detector = self.init_detector(weights_name, config_name)

def init_detector(self, weights_name: str, config_name: str) -> nn.Module:
# Check for MPS availability (specific to macOS with Apple Silicon)
if torch.backends.mps.is_available():
print("Using MPS (Apple Metal Performance Shaders)")
self.device = torch.device("mps")
# Check for CUDA availability
elif isinstance(self.gpu, int) and self.gpu >= 0 and torch.cuda.is_available():
print("Using CUDA")
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu)
self.device = torch.device("cuda:0")
else:
print("Using CPU")
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
self.device = torch.device("cpu")

state_dict = torch.load(weights_name)
detector = Model(cfg=config_name)
detector.load_state_dict(state_dict)
Expand Down

0 comments on commit d162caa

Please sign in to comment.