diff --git a/csrc/mask_api/src/mask.cpp b/csrc/mask_api/src/mask.cpp index 495da2f..bb20b15 100644 --- a/csrc/mask_api/src/mask.cpp +++ b/csrc/mask_api/src/mask.cpp @@ -461,7 +461,7 @@ namespace mask_api uint64_t m = R[0].m; uint64_t i, a, b; - std::vector cnts(h*w+1); + std::vector cnts(h * w + 1); for (a = 0; a < m; a++) { cnts[a] = R[0].cnts[a]; @@ -751,12 +751,14 @@ namespace mask_api } } - std::variant, std::vector> _preproc(const py::object &pyobj) + std::tuple, std::vector>, size_t> _preproc(const py::object &pyobj) { std::string type = py::str(py::type::of(pyobj)); if (type == "") { - return _preproc_bbox_array(pyobj); + std::vector result = _preproc_bbox_array(pyobj); + + return std::make_tuple(result, (size_t)(result.size() / 4)); } else if (type == "") { @@ -764,7 +766,7 @@ namespace mask_api if (pyobj_list.size() == 0) { - return std::vector(0); + return std::make_tuple(std::vector(0), 0); } bool isbox = true; @@ -799,11 +801,13 @@ namespace mask_api if (isbox) { - return _preproc_bbox_array(pyobj); + std::vector result = _preproc_bbox_array(pyobj); + return std::make_tuple(result, (size_t)(result.size() / 4)); } else if (isrle) { - return _frString(pyobj.cast>()); + std::vector result = _frString(pyobj.cast>()); + return std::make_tuple(result, (size_t)result.size()); } else { @@ -820,23 +824,25 @@ namespace mask_api std::variant, std::vector> iou(const py::object &dt, const py::object >, const std::vector &iscrowd) { - auto _dt = _preproc(dt); - auto _gt = _preproc(gt); + auto [_dt, m] = _preproc(dt); + auto [_gt, n] = _preproc(gt); - std::size_t crowd_length = iscrowd.size(); + if (m == 0 || n == 0) + { + return std::vector(0); + } if (_dt.index() != _gt.index()) { throw std::out_of_range("The dt and gt should have the same data type, either RLEs, list or np.ndarray"); } + std::size_t crowd_length = iscrowd.size(); std::vector iou; - std::size_t m, n; if (std::holds_alternative>(_dt)) { std::vector _gt_box = std::get>(_gt); - n = (std::size_t)(_gt_box.size() / 4); if (crowd_length > 0 && crowd_length == n) { @@ -848,7 +854,6 @@ namespace mask_api else { std::vector _gt_rle = std::get>(_gt); - n = _gt_rle.size(); if (crowd_length > 0 && crowd_length == n) { std::vector _dt_rle = std::get>(_dt); @@ -862,10 +867,6 @@ namespace mask_api printf("crowd_length=%zu, n=%zu\n", crowd_length, n); throw std::out_of_range("iscrowd must have the same length as gt"); } - if (m == 0 || n == 0) - { - return std::vector(0); - } return py::array(iou.size(), iou.data()).reshape({m, n}); }