From b7c826b7863a4fb3c4ce0e10e5ed7400d32ed512 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Tue, 16 Aug 2022 15:19:39 -0700 Subject: [PATCH] Boolean indexing of cameras Summary: Reasonable to expect bool indexing. Reviewed By: bottler, kjchalup Differential Revision: D38741446 fbshipit-source-id: 22b607bf13110043c5624196c66ca1484fdbce6c --- pytorch3d/renderer/cameras.py | 32 +++++++++++++++++++++-------- pytorch3d/structures/meshes.py | 4 +++- pytorch3d/structures/pointclouds.py | 5 ++++- pytorch3d/structures/volumes.py | 5 ++++- pytorch3d/transforms/transform3d.py | 2 +- tests/test_cameras.py | 28 ++++++++++++++++++++----- 6 files changed, 58 insertions(+), 18 deletions(-) diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index 74caf902c..7b96609f3 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -385,31 +385,45 @@ def get_image_size(self): return self.image_size if hasattr(self, "image_size") else None def __getitem__( - self, index: Union[int, List[int], torch.LongTensor] + self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor] ) -> "CamerasBase": """ Override for the __getitem__ method in TensorProperties which needs to be refactored. Args: - index: an int/list/long tensor used to index all the fields in the cameras given by - self._FIELDS. + index: an integer index, list/tensor of integer indices, or tensor of boolean + indicators used to filter all the fields in the cameras given by self._FIELDS. Returns: - if `index` is an index int/list/long tensor return an instance of the current - cameras class with only the values at the selected index. + an instance of the current cameras class with only the values at the selected index. """ kwargs = {} - # pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`. - if not isinstance(index, (int, list, torch.LongTensor, torch.cuda.LongTensor)): - msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r" + tensor_types = { + "bool": (torch.BoolTensor, torch.cuda.BoolTensor), + "long": (torch.LongTensor, torch.cuda.LongTensor), + } + if not isinstance( + index, (int, list, *tensor_types["bool"], *tensor_types["long"]) + ) or ( + isinstance(index, list) + and not all(isinstance(i, int) and not isinstance(i, bool) for i in index) + ): + msg = ( + "Invalid index type, expected int, List[int] or Bool/LongTensor; got %r" + ) raise ValueError(msg % type(index)) if isinstance(index, int): index = [index] - if max(index) >= len(self): + if isinstance(index, tensor_types["bool"]): + if index.ndim != 1 or index.shape[0] != len(self): + raise ValueError( + f"Boolean index of shape {index.shape} does not match cameras" + ) + elif max(index) >= len(self): raise ValueError(f"Index {max(index)} is out of bounds for select cameras") for field in self._FIELDS: diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index 40ff1c3cf..7f68052a7 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -472,7 +472,9 @@ def _set_verts_normals(self, verts_normals) -> None: def __len__(self) -> int: return self._N - def __getitem__(self, index) -> "Meshes": + def __getitem__( + self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor] + ) -> "Meshes": """ Args: index: Specifying the index of the mesh to retrieve. diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index 3dd2d1266..bf4439883 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -360,7 +360,10 @@ def _parse_auxiliary_input_list( def __len__(self) -> int: return self._N - def __getitem__(self, index) -> "Pointclouds": + def __getitem__( + self, + index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor], + ) -> "Pointclouds": """ Args: index: Specifying the index of the cloud to retrieve. diff --git a/pytorch3d/structures/volumes.py b/pytorch3d/structures/volumes.py index d9cd5ad45..7f3b66f02 100644 --- a/pytorch3d/structures/volumes.py +++ b/pytorch3d/structures/volumes.py @@ -501,7 +501,10 @@ def __len__(self) -> int: return self._densities.shape[0] def __getitem__( - self, index: Union[int, List[int], Tuple[int], slice, torch.Tensor] + self, + index: Union[ + int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor + ], ) -> "Volumes": """ Args: diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index f330eb2c4..24a5663d7 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -181,7 +181,7 @@ def __len__(self) -> int: return self.get_matrix().shape[0] def __getitem__( - self, index: Union[int, List[int], slice, torch.Tensor] + self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor] ) -> "Transform3d": """ Args: diff --git a/tests/test_cameras.py b/tests/test_cameras.py index 80a77df20..5feccca81 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -884,7 +884,8 @@ def test_camera_class_init(self): self.assertTrue(new_cam.device == device) def test_getitem(self): - R_matrix = torch.randn((6, 3, 3)) + N_CAMERAS = 6 + R_matrix = torch.randn((N_CAMERAS, 3, 3)) cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix) # Check get item returns an instance of the same class @@ -908,22 +909,39 @@ def test_getitem(self): self.assertClose(c012.R, R_matrix[0:3, ...]) # Check torch.LongTensor index - index = torch.tensor([1, 3, 5], dtype=torch.int64) + SLICE = [1, 3, 5] + index = torch.tensor(SLICE, dtype=torch.int64) c135 = cam[index] self.assertEqual(len(c135), 3) self.assertClose(c135.zfar, torch.tensor([100.0] * 3)) self.assertClose(c135.znear, torch.tensor([10.0] * 3)) - self.assertClose(c135.R, R_matrix[[1, 3, 5], ...]) + self.assertClose(c135.R, R_matrix[SLICE, ...]) + + # Check torch.BoolTensor index + bool_slice = [i in SLICE for i in range(N_CAMERAS)] + index = torch.tensor(bool_slice, dtype=torch.bool) + c135 = cam[index] + self.assertEqual(len(c135), 3) + self.assertClose(c135.zfar, torch.tensor([100.0] * 3)) + self.assertClose(c135.znear, torch.tensor([10.0] * 3)) + self.assertClose(c135.R, R_matrix[SLICE, ...]) # Check errors with get item with self.assertRaisesRegex(ValueError, "out of bounds"): - cam[6] + cam[N_CAMERAS] + + with self.assertRaisesRegex(ValueError, "does not match cameras"): + index = torch.tensor([1, 0, 1], dtype=torch.bool) + cam[index] with self.assertRaisesRegex(ValueError, "Invalid index type"): cam[slice(0, 1)] with self.assertRaisesRegex(ValueError, "Invalid index type"): - index = torch.tensor([1, 3, 5], dtype=torch.float32) + cam[[True, False]] + + with self.assertRaisesRegex(ValueError, "Invalid index type"): + index = torch.tensor(SLICE, dtype=torch.float32) cam[index] def test_get_full_transform(self):