diff --git a/pytorch3d/csrc/iou_box3d/iou_box3d.cu b/pytorch3d/csrc/iou_box3d/iou_box3d.cu index 345b1914b..7b1679f25 100644 --- a/pytorch3d/csrc/iou_box3d/iou_box3d.cu +++ b/pytorch3d/csrc/iou_box3d/iou_box3d.cu @@ -90,7 +90,8 @@ __global__ void IoUBox3DKernel( for (int b2 = 0; b2 < box2_count; ++b2) { const bool is_coplanar = IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]); - if (is_coplanar) { + const float area = FaceArea(box1_intersect[b1]); + if ((is_coplanar) && (area > kEpsilon)) { tri2_keep[b2].keep = false; } } diff --git a/pytorch3d/csrc/iou_box3d/iou_box3d_cpu.cpp b/pytorch3d/csrc/iou_box3d/iou_box3d_cpu.cpp index 754e77709..3370097f1 100644 --- a/pytorch3d/csrc/iou_box3d/iou_box3d_cpu.cpp +++ b/pytorch3d/csrc/iou_box3d/iou_box3d_cpu.cpp @@ -81,7 +81,8 @@ std::tuple IoUBox3DCpu( for (int b2 = 0; b2 < box2_intersect.size(); ++b2) { const bool is_coplanar = IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]); - if (is_coplanar) { + const float area = FaceArea(box1_intersect[b1]); + if ((is_coplanar) && (area > kEpsilon)) { tri2_keep[b2] = 0; } } diff --git a/pytorch3d/csrc/iou_box3d/iou_utils.cuh b/pytorch3d/csrc/iou_box3d/iou_utils.cuh index e40c8ede2..960661e36 100644 --- a/pytorch3d/csrc/iou_box3d/iou_utils.cuh +++ b/pytorch3d/csrc/iou_box3d/iou_utils.cuh @@ -138,6 +138,26 @@ FaceNormal(const float3 v0, const float3 v1, const float3 v2) { return n; } +// The area of the face defined by vertices (v0, v1, v2) +// Define e0 to be the edge connecting (v1, v0) +// Define e1 to be the edge connecting (v2, v0) +// Area is the norm of the cross product of e0, e1 divided by 2.0 +// +// Args +// tri: FaceVerts of float3 coordinates of the vertices of the face +// +// Returns +// float: area for the face +// +__device__ inline float FaceArea(const FaceVerts& tri) { + // Get verts for face 1 + const float3 v0 = tri.v0; + const float3 v1 = tri.v1; + const float3 v2 = tri.v2; + const float3 n = cross(v1 - v0, v2 - v0); + return norm(n) / 2.0; +} + // The normal of a box plane defined by the verts in `plane` with // the centroid of the box given by `center`. // Args diff --git a/pytorch3d/csrc/iou_box3d/iou_utils.h b/pytorch3d/csrc/iou_box3d/iou_utils.h index a9e423951..9aea3fd44 100644 --- a/pytorch3d/csrc/iou_box3d/iou_utils.h +++ b/pytorch3d/csrc/iou_box3d/iou_utils.h @@ -145,6 +145,26 @@ inline vec3 FaceNormal(vec3 v0, vec3 v1, vec3 v2) { return n; } +// The area of the face defined by vertices (v0, v1, v2) +// Define e0 to be the edge connecting (v1, v0) +// Define e1 to be the edge connecting (v2, v0) +// Area is the norm of the cross product of e0, e1 divided by 2.0 +// +// Args +// tri: vec3 coordinates of the vertices of the face +// +// Returns +// float: area for the face +// +inline float FaceArea(const std::vector>& tri) { + // Get verts for face + const vec3 v0 = tri[0]; + const vec3 v1 = tri[1]; + const vec3 v2 = tri[2]; + const vec3 n = cross(v1 - v0, v2 - v0); + return norm(n) / 2.0; +} + // The normal of a box plane defined by the verts in `plane` with // the centroid of the box given by `center`. // Args diff --git a/pytorch3d/ops/iou_box3d.py b/pytorch3d/ops/iou_box3d.py index 22e317a1a..4401c3e2f 100644 --- a/pytorch3d/ops/iou_box3d.py +++ b/pytorch3d/ops/iou_box3d.py @@ -69,6 +69,28 @@ def _check_coplanar(boxes: torch.Tensor, eps: float = 1e-4) -> None: return +def _check_nonzero(boxes: torch.Tensor, eps: float = 1e-4) -> None: + """ + Checks that the sides of the box have a non zero area + """ + faces = torch.tensor(_box_triangles, dtype=torch.int64, device=boxes.device) + # pyre-fixme[16]: `boxes` has no attribute `index_select`. + verts = boxes.index_select(index=faces.view(-1), dim=1) + B = boxes.shape[0] + T, V = faces.shape + # (B, T, 3, 3) -> (B, T, 3) + v0, v1, v2 = verts.reshape(B, T, V, 3).unbind(2) + + normals = torch.cross(v1 - v0, v2 - v0, dim=-1) # (B, T, 3) + face_areas = normals.norm(dim=-1) / 2 + + if (face_areas < eps).any().item(): + msg = "Planes have zero areas" + raise ValueError(msg) + + return + + class _box3d_overlap(Function): """ Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations. @@ -138,6 +160,8 @@ def box3d_overlap( _check_coplanar(boxes1, eps) _check_coplanar(boxes2, eps) + _check_nonzero(boxes1, eps) + _check_nonzero(boxes2, eps) # pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`. vol, iou = _box3d_overlap.apply(boxes1, boxes2) diff --git a/tests/test_iou_box3d.py b/tests/test_iou_box3d.py index 990b1f455..921144a53 100644 --- a/tests/test_iou_box3d.py +++ b/tests/test_iou_box3d.py @@ -111,6 +111,11 @@ def _test_iou(self, overlap_fn, device): self.assertClose( vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype) ) + # symmetry + vol, iou = overlap_fn(box2[None], box1[None]) + self.assertClose( + vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype) + ) # 3rd test dd = random.random() @@ -119,6 +124,11 @@ def _test_iou(self, overlap_fn, device): self.assertClose( vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype) ) + # symmetry + vol, _ = overlap_fn(box2[None], box1[None]) + self.assertClose( + vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype) + ) # 4th test ddx, ddy, ddz = random.random(), random.random(), random.random() @@ -132,6 +142,16 @@ def _test_iou(self, overlap_fn, device): dtype=vol.dtype, ), ) + # symmetry + vol, _ = overlap_fn(box2[None], box1[None]) + self.assertClose( + vol, + torch.tensor( + [[(1 - ddx) * (1 - ddy) * (1 - ddz)]], + device=vol.device, + dtype=vol.dtype, + ), + ) # Also check IoU is 1 when computing overlap with the same shifted box vol, iou = overlap_fn(box2[None], box2[None]) @@ -152,6 +172,16 @@ def _test_iou(self, overlap_fn, device): dtype=vol.dtype, ), ) + # symmetry + vol, _ = overlap_fn(box2r[None], box1r[None]) + self.assertClose( + vol, + torch.tensor( + [[(1 - ddx) * (1 - ddy) * (1 - ddz)]], + device=vol.device, + dtype=vol.dtype, + ), + ) # 6th test ddx, ddy, ddz = random.random(), random.random(), random.random() @@ -170,6 +200,17 @@ def _test_iou(self, overlap_fn, device): ), atol=1e-7, ) + # symmetry + vol, _ = overlap_fn(box2r[None], box1r[None]) + self.assertClose( + vol, + torch.tensor( + [[(1 - ddx) * (1 - ddy) * (1 - ddz)]], + device=vol.device, + dtype=vol.dtype, + ), + atol=1e-7, + ) # 7th test: hand coded example and test with meshlab output @@ -214,6 +255,10 @@ def _test_iou(self, overlap_fn, device): vol, iou = overlap_fn(box1r[None], box2r[None]) self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1) self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1) + # symmetry + vol, iou = overlap_fn(box2r[None], box1r[None]) + self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1) + self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1) # 8th test: compare with sampling # create box1 @@ -232,7 +277,9 @@ def _test_iou(self, overlap_fn, device): iou_sampling = self._box3d_overlap_sampling_batched( box1r[None], box2r[None], num_samples=10000 ) - + self.assertClose(iou, iou_sampling, atol=1e-2) + # symmetry + vol, iou = overlap_fn(box2r[None], box1r[None]) self.assertClose(iou, iou_sampling, atol=1e-2) # 9th test: non overlapping boxes, iou = 0.0 @@ -240,6 +287,10 @@ def _test_iou(self, overlap_fn, device): vol, iou = overlap_fn(box1[None], box2[None]) self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) + # symmetry + vol, iou = overlap_fn(box2[None], box1[None]) + self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) + self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) # 10th test: Non coplanar verts in a plane box10 = box1 + torch.rand((8, 3), dtype=torch.float32, device=device) @@ -284,6 +335,56 @@ def _test_iou(self, overlap_fn, device): vols, ious = overlap_fn(box_skew_1[None], box_skew_2[None]) self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1) self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1) + # symmetry + vols, ious = overlap_fn(box_skew_2[None], box_skew_1[None]) + self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1) + self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1) + + # 12th test: Zero area bounding box (from GH issue #992) + box12a = torch.tensor( + [ + [-1.0000, -1.0000, -0.5000], + [1.0000, -1.0000, -0.5000], + [1.0000, 1.0000, -0.5000], + [-1.0000, 1.0000, -0.5000], + [-1.0000, -1.0000, 0.5000], + [1.0000, -1.0000, 0.5000], + [1.0000, 1.0000, 0.5000], + [-1.0000, 1.0000, 0.5000], + ], + device=device, + dtype=torch.float32, + ) + + box12b = torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + device=device, + dtype=torch.float32, + ) + msg = "Planes have zero areas" + with self.assertRaisesRegex(ValueError, msg): + overlap_fn(box12a[None], box12b[None]) + # symmetry + with self.assertRaisesRegex(ValueError, msg): + overlap_fn(box12b[None], box12a[None]) + + # 13th test: From GH issue #992 + # Zero area coplanar face after intersection + ctrs = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0]]) + whl = torch.tensor([[2.0, 2.0, 2.0], [2.0, 2, 2]]) + box13a = TestIoU3D.create_box(ctrs[0], whl[0]) + box13b = TestIoU3D.create_box(ctrs[1], whl[1]) + vol, iou = overlap_fn(box13a[None], box13b[None]) + self.assertClose(vol, torch.tensor([[2.0]], device=vol.device, dtype=vol.dtype)) def _test_real_boxes(self, overlap_fn, device): data_filename = "./real_boxes.pkl" @@ -577,6 +678,13 @@ def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor: msg = "Plane vertices are not coplanar" raise ValueError(msg) + # Check all faces have non zero area + area1 = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2 + area2 = torch.cross(v3 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2 + if (area1 < eps).any().item() or (area2 < eps).any().item(): + msg = "Planes have zero areas" + raise ValueError(msg) + # We can write: `ctr = v0 + a * e0 + b * e1 + c * n`, (1). # With = 0 and = 0, where <.,.> refers to the dot product, # since that e0 is orthogonal to n. Same for e1. @@ -607,6 +715,27 @@ def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor: return n +def tri_verts_area(tri_verts: torch.Tensor) -> torch.Tensor: + """ + Computes the area of the triangle faces in tri_verts + Args: + tri_verts: tensor of shape (T, 3, 3) + Returns: + areas: the area of the triangles (T, 1) + """ + add_dim = False + if tri_verts.ndim == 2: + tri_verts = tri_verts.unsqueeze(0) + add_dim = True + + v0, v1, v2 = tri_verts.unbind(1) + areas = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2.0 + + if add_dim: + areas = areas[0] + return areas + + def box_volume(box: torch.Tensor) -> torch.Tensor: """ Computes the volume of each box in boxes. @@ -988,7 +1117,10 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor): keep2 = torch.ones((tri_verts2.shape[0],), device=device, dtype=torch.bool) for i1 in range(tri_verts1.shape[0]): for i2 in range(tri_verts2.shape[0]): - if coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2]): + if ( + coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2]) + and tri_verts_area(tri_verts1[i1]) > 1e-4 + ): keep2[i2] = 0 keep2 = keep2.nonzero()[:, 0] tri_verts2 = tri_verts2[keep2]