Skip to content

Commit

Permalink
fix gt or dt is empty
Browse files Browse the repository at this point in the history
  • Loading branch information
MiXaiLL76 committed Jun 26, 2024
1 parent 1739926 commit 7e58139
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions csrc/mask_api/src/mask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ namespace mask_api
uint64_t m = R[0].m;
uint64_t i, a, b;

std::vector<uint> cnts(h*w+1);
std::vector<uint> cnts(h * w + 1);
for (a = 0; a < m; a++)
{
cnts[a] = R[0].cnts[a];
Expand Down Expand Up @@ -751,20 +751,22 @@ namespace mask_api
}
}

std::variant<std::vector<RLE>, std::vector<double>> _preproc(const py::object &pyobj)
std::tuple<std::variant<std::vector<RLE>, std::vector<double>>, size_t> _preproc(const py::object &pyobj)
{
std::string type = py::str(py::type::of(pyobj));
if (type == "<class 'numpy.ndarray'>")
{
return _preproc_bbox_array(pyobj);
std::vector<double> result = _preproc_bbox_array(pyobj);

return std::make_tuple(result, (size_t)(result.size() / 4));
}
else if (type == "<class 'list'>")
{
std::vector<py::object> pyobj_list = pyobj.cast<std::vector<py::object>>();

if (pyobj_list.size() == 0)
{
return std::vector<double>(0);
return std::make_tuple(std::vector<double>(0), 0);
}

bool isbox = true;
Expand Down Expand Up @@ -799,11 +801,13 @@ namespace mask_api

if (isbox)
{
return _preproc_bbox_array(pyobj);
std::vector<double> result = _preproc_bbox_array(pyobj);
return std::make_tuple(result, (size_t)(result.size() / 4));
}
else if (isrle)
{
return _frString(pyobj.cast<std::vector<py::dict>>());
std::vector<RLE> result = _frString(pyobj.cast<std::vector<py::dict>>());
return std::make_tuple(result, (size_t)result.size());
}
else
{
Expand All @@ -820,23 +824,25 @@ namespace mask_api
std::variant<py::array_t<double, py::array::f_style>, std::vector<double>> iou(const py::object &dt, const py::object &gt, const std::vector<int> &iscrowd)
{

auto _dt = _preproc(dt);
auto _gt = _preproc(gt);
auto [_dt, m] = _preproc(dt);
auto [_gt, n] = _preproc(gt);

std::size_t crowd_length = iscrowd.size();
if (m == 0 || n == 0)
{
return std::vector<double>(0);
}

if (_dt.index() != _gt.index())
{
throw std::out_of_range("The dt and gt should have the same data type, either RLEs, list or np.ndarray");
}

std::size_t crowd_length = iscrowd.size();
std::vector<double> iou;
std::size_t m, n;

if (std::holds_alternative<std::vector<double>>(_dt))
{
std::vector<double> _gt_box = std::get<std::vector<double>>(_gt);
n = (std::size_t)(_gt_box.size() / 4);
if (crowd_length > 0 && crowd_length == n)
{

Expand All @@ -848,7 +854,6 @@ namespace mask_api
else
{
std::vector<RLE> _gt_rle = std::get<std::vector<RLE>>(_gt);
n = _gt_rle.size();
if (crowd_length > 0 && crowd_length == n)
{
std::vector<RLE> _dt_rle = std::get<std::vector<RLE>>(_dt);
Expand All @@ -862,10 +867,6 @@ namespace mask_api
printf("crowd_length=%zu, n=%zu\n", crowd_length, n);
throw std::out_of_range("iscrowd must have the same length as gt");
}
if (m == 0 || n == 0)
{
return std::vector<double>(0);
}
return py::array(iou.size(), iou.data()).reshape({m, n});
}

Expand Down

0 comments on commit 7e58139

Please sign in to comment.