diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index db5252b5c..74f401c88 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -266,30 +266,8 @@ def _parse_auxiliary_input( aux_input_C = None if isinstance(aux_input, list): - if len(aux_input) != self._N: - raise ValueError("Points and auxiliary input must be the same length.") - for p, d in zip(self._num_points_per_cloud, aux_input): - if p != d.shape[0]: - raise ValueError( - "A cloud has mismatched numbers of points and inputs" - ) - if d.device != self.device: - raise ValueError( - "All auxiliary inputs must be on the same device as the points." - ) - if p > 0: - if d.dim() != 2: - raise ValueError( - "A cloud auxiliary input must be of shape PxC or empty" - ) - if aux_input_C is None: - aux_input_C = d.shape[1] - if aux_input_C != d.shape[1]: - raise ValueError( - "The clouds must have the same number of channels" - ) - return aux_input, None, aux_input_C - elif torch.is_tensor(aux_input): + return self._parse_auxiliary_input_list(aux_input) + if torch.is_tensor(aux_input): if aux_input.dim() != 3: raise ValueError("Auxiliary input tensor has incorrect dimensions.") if self._N != aux_input.shape[0]: @@ -312,6 +290,72 @@ def _parse_auxiliary_input( points in a cloud." ) + def _parse_auxiliary_input_list( + self, aux_input: list + ) -> Tuple[Optional[List[torch.Tensor]], None, Optional[int]]: + """ + Interpret the auxiliary inputs (normals, features) given to __init__, + if a list. + + Args: + aux_input: + - List where each element is a tensor of shape (num_points, C) + containing the features for the points in the cloud. + For normals, C = 3 + + Returns: + 3-element tuple of list, padded=None, num_channels. + If aux_input is list, then padded is None. If aux_input is a tensor, + then list is None. + """ + aux_input_C = None + good_empty = None + needs_fixing = False + + if len(aux_input) != self._N: + raise ValueError("Points and auxiliary input must be the same length.") + for p, d in zip(self._num_points_per_cloud, aux_input): + valid_but_empty = p == 0 and d is not None and d.ndim == 2 + if p > 0 or valid_but_empty: + if p != d.shape[0]: + raise ValueError( + "A cloud has mismatched numbers of points and inputs" + ) + if d.dim() != 2: + raise ValueError( + "A cloud auxiliary input must be of shape PxC or empty" + ) + if aux_input_C is None: + aux_input_C = d.shape[1] + elif aux_input_C != d.shape[1]: + raise ValueError("The clouds must have the same number of channels") + if d.device != self.device: + raise ValueError( + "All auxiliary inputs must be on the same device as the points." + ) + else: + needs_fixing = True + + if aux_input_C is None: + # We found nothing useful + return None, None, None + + # If we have empty but "wrong" inputs we want to store "fixed" versions. + if needs_fixing: + if good_empty is None: + good_empty = torch.zeros((0, aux_input_C), device=self.device) + aux_input_out = [] + for p, d in zip(self._num_points_per_cloud, aux_input): + valid_but_empty = p == 0 and d is not None and d.ndim == 2 + if p > 0 or valid_but_empty: + aux_input_out.append(d) + else: + aux_input_out.append(good_empty) + else: + aux_input_out = aux_input + + return aux_input_out, None, aux_input_C + def __len__(self) -> int: return self._N diff --git a/tests/test_pointclouds.py b/tests/test_pointclouds.py index f1705e154..fa37368f4 100644 --- a/tests/test_pointclouds.py +++ b/tests/test_pointclouds.py @@ -383,6 +383,43 @@ def test_empty(self): self.assertTrue(features_padded[n, p:, :].eq(0).all()) self.assertTrue(points_per_cloud[n] == p) + def test_list_someempty(self): + # We want + # point_cloud = Pointclouds( + # [pcl.points_packed() for pcl in point_clouds], + # features=[pcl.features_packed() for pcl in point_clouds], + # ) + # to work if point_clouds is a list of pointclouds with some empty and some not. + points_list = [torch.rand(30, 3), torch.zeros(0, 3)] + features_list = [torch.rand(30, 3), None] + pcls = Pointclouds(points=points_list, features=features_list) + self.assertEqual(len(pcls), 2) + self.assertClose( + pcls.points_padded(), + torch.stack([points_list[0], torch.zeros_like(points_list[0])]), + ) + self.assertClose(pcls.points_packed(), points_list[0]) + self.assertClose( + pcls.features_padded(), + torch.stack([features_list[0], torch.zeros_like(points_list[0])]), + ) + self.assertClose(pcls.features_packed(), features_list[0]) + + points_list = [torch.zeros(0, 3), torch.rand(30, 3)] + features_list = [None, torch.rand(30, 3)] + pcls = Pointclouds(points=points_list, features=features_list) + self.assertEqual(len(pcls), 2) + self.assertClose( + pcls.points_padded(), + torch.stack([torch.zeros_like(points_list[1]), points_list[1]]), + ) + self.assertClose(pcls.points_packed(), points_list[1]) + self.assertClose( + pcls.features_padded(), + torch.stack([torch.zeros_like(points_list[1]), features_list[1]]), + ) + self.assertClose(pcls.features_packed(), features_list[1]) + def test_clone_list(self): N = 5 clouds = self.init_cloud(N, 100, 5)