Skip to content

Commit

Permalink
simple warning for bin overflow
Browse files Browse the repository at this point in the history
Summary: Since coarse rasterization on cuda can overflow bins, we detect when this happens for memory safety. See #348 . Also try to print a warning.

Reviewed By: patricklabatut

Differential Revision: D33065604

fbshipit-source-id: 99b3c576d01b78e6d77776cf1a3e95984506c93a
  • Loading branch information
bottler authored and facebook-github-bot committed Jan 6, 2022
1 parent d6a12af commit 6726500
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,22 @@ __global__ void RasterizeCoarseCudaKernel(
// this effectively allocates space in the bin_faces array for the
// elems in the current chunk that fall into this bin.
const int start = atomicAdd(elems_per_bin + elems_per_bin_idx, count);
if (start + count > M) {
// The number of elems in this bin is so big that they won't fit.
// We print a warning using CUDA's printf. This may be invisible
// to notebook users, but apparent to others. It would be nice to
// also have a Python-friendly warning, but it is not obvious
// how to do this without slowing down the normal case.
const char* warning =
"Bin size was too small in the coarse rasterization phase. "
"This caused an overflow, meaning output may be incomplete. "
"To solve, "
"try increasing max_faces_per_bin / max_points_per_bin, "
"decreasing bin_size, "
"or setting bin_size to -1 to use the naive rasterization.";
printf(warning);
continue;
}

// Now loop over the binmask and write the active bits for this bin
// out to bin_faces.
Expand Down

0 comments on commit 6726500

Please sign in to comment.