diff --git a/test/test_models.py b/test/test_models.py index faa14f8250e..7e1f0fb3b12 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -71,6 +71,9 @@ def get_available_video_models(): "keypointrcnn_resnet50_fpn": { 'unwrapper': lambda x: x[1] }, + "retinanet_resnet50_fpn": { + 'unwrapper': lambda x: x[1] + } } diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index c124bed79c8..4095dc7f7c7 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -565,7 +565,7 @@ def forward(self, images, targets=None): if not self._has_warned: warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting") self._has_warned = True - return (losses, detections) + return losses, detections return self.eager_outputs(losses, detections)