diff --git a/pytorch3d/loss/chamfer.py b/pytorch3d/loss/chamfer.py index 49690ec39..8ec828ec1 100644 --- a/pytorch3d/loss/chamfer.py +++ b/pytorch3d/loss/chamfer.py @@ -68,74 +68,28 @@ def _handle_pointcloud_input( return X, lengths, normals -def chamfer_distance( +def _chamfer_distance_single_direction( x, y, - x_lengths=None, - y_lengths=None, - x_normals=None, - y_normals=None, - weights=None, - batch_reduction: Union[str, None] = "mean", - point_reduction: str = "mean", - norm: int = 2, + x_lengths, + y_lengths, + x_normals, + y_normals, + weights, + batch_reduction: Union[str, None], + point_reduction: str, + norm: int, + abs_cosine: bool, ): - """ - Chamfer distance between two pointclouds x and y. - - Args: - x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing - a batch of point clouds with at most P1 points in each batch element, - batch size N and feature dimension D. - y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing - a batch of point clouds with at most P2 points in each batch element, - batch size N and feature dimension D. - x_lengths: Optional LongTensor of shape (N,) giving the number of points in each - cloud in x. - y_lengths: Optional LongTensor of shape (N,) giving the number of points in each - cloud in y. - x_normals: Optional FloatTensor of shape (N, P1, D). - y_normals: Optional FloatTensor of shape (N, P2, D). - weights: Optional FloatTensor of shape (N,) giving weights for - batch elements for reduction operation. - batch_reduction: Reduction operation to apply for the loss across the - batch, can be one of ["mean", "sum"] or None. - point_reduction: Reduction operation to apply for the loss across the - points, can be one of ["mean", "sum"]. - norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2. - - Returns: - 2-element tuple containing - - - **loss**: Tensor giving the reduced distance between the pointclouds - in x and the pointclouds in y. - - **loss_normals**: Tensor giving the reduced cosine distance of normals - between pointclouds in x and pointclouds in y. Returns None if - x_normals and y_normals are None. - """ - _validate_chamfer_reduction_inputs(batch_reduction, point_reduction) - - if not ((norm == 1) or (norm == 2)): - raise ValueError("Support for 1 or 2 norm.") - - x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) - y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) - return_normals = x_normals is not None and y_normals is not None N, P1, D = x.shape - P2 = y.shape[1] # Check if inputs are heterogeneous and create a lengths mask. is_x_heterogeneous = (x_lengths != P1).any() - is_y_heterogeneous = (y_lengths != P2).any() x_mask = ( torch.arange(P1, device=x.device)[None] >= x_lengths[:, None] ) # shape [N, P1] - y_mask = ( - torch.arange(P2, device=y.device)[None] >= y_lengths[:, None] - ) # shape [N, P2] - if y.shape[0] != N or y.shape[2] != D: raise ValueError("y does not have the correct shape.") if weights is not None: @@ -153,75 +107,148 @@ def chamfer_distance( return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0) cham_norm_x = x.new_zeros(()) - cham_norm_y = x.new_zeros(()) x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1) - y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, norm=norm, K=1) - cham_x = x_nn.dists[..., 0] # (N, P1) - cham_y = y_nn.dists[..., 0] # (N, P2) if is_x_heterogeneous: cham_x[x_mask] = 0.0 - if is_y_heterogeneous: - cham_y[y_mask] = 0.0 if weights is not None: cham_x *= weights.view(N, 1) - cham_y *= weights.view(N, 1) if return_normals: # Gather the normals using the indices and keep only value for k=0 x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :] - y_normals_near = knn_gather(x_normals, y_nn.idx, x_lengths)[..., 0, :] - cham_norm_x = 1 - torch.abs( - F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6) - ) - cham_norm_y = 1 - torch.abs( - F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6) - ) + cosine_sim = F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6) + # If abs_cosine, ignore orientation and take the absolute value of the cosine sim. + cham_norm_x = 1 - (torch.abs(cosine_sim) if abs_cosine else cosine_sim) if is_x_heterogeneous: cham_norm_x[x_mask] = 0.0 - if is_y_heterogeneous: - cham_norm_y[y_mask] = 0.0 if weights is not None: cham_norm_x *= weights.view(N, 1) - cham_norm_y *= weights.view(N, 1) + cham_norm_x = cham_norm_x.sum(1) # (N,) # Apply point reduction cham_x = cham_x.sum(1) # (N,) - cham_y = cham_y.sum(1) # (N,) - if return_normals: - cham_norm_x = cham_norm_x.sum(1) # (N,) - cham_norm_y = cham_norm_y.sum(1) # (N,) if point_reduction == "mean": x_lengths_clamped = x_lengths.clamp(min=1) - y_lengths_clamped = y_lengths.clamp(min=1) cham_x /= x_lengths_clamped - cham_y /= y_lengths_clamped if return_normals: cham_norm_x /= x_lengths_clamped - cham_norm_y /= y_lengths_clamped if batch_reduction is not None: # batch_reduction == "sum" cham_x = cham_x.sum() - cham_y = cham_y.sum() if return_normals: cham_norm_x = cham_norm_x.sum() - cham_norm_y = cham_norm_y.sum() if batch_reduction == "mean": div = weights.sum() if weights is not None else max(N, 1) cham_x /= div - cham_y /= div if return_normals: cham_norm_x /= div - cham_norm_y /= div - - cham_dist = cham_x + cham_y - cham_normals = cham_norm_x + cham_norm_y if return_normals else None + cham_dist = cham_x + cham_normals = cham_norm_x if return_normals else None return cham_dist, cham_normals + + +def chamfer_distance( + x, + y, + x_lengths=None, + y_lengths=None, + x_normals=None, + y_normals=None, + weights=None, + batch_reduction: Union[str, None] = "mean", + point_reduction: str = "mean", + norm: int = 2, + single_directional: bool = False, + abs_cosine: bool = True, +): + """ + Chamfer distance between two pointclouds x and y. + + Args: + x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing + a batch of point clouds with at most P1 points in each batch element, + batch size N and feature dimension D. + y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing + a batch of point clouds with at most P2 points in each batch element, + batch size N and feature dimension D. + x_lengths: Optional LongTensor of shape (N,) giving the number of points in each + cloud in x. + y_lengths: Optional LongTensor of shape (N,) giving the number of points in each + cloud in y. + x_normals: Optional FloatTensor of shape (N, P1, D). + y_normals: Optional FloatTensor of shape (N, P2, D). + weights: Optional FloatTensor of shape (N,) giving weights for + batch elements for reduction operation. + batch_reduction: Reduction operation to apply for the loss across the + batch, can be one of ["mean", "sum"] or None. + point_reduction: Reduction operation to apply for the loss across the + points, can be one of ["mean", "sum"]. + norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2. + single_directional: If False (default), loss comes from both the distance between + each point in x and its nearest neighbor in y and each point in y and its nearest + neighbor in x. If True, loss is the distance between each point in x and its + nearest neighbor in y. + abs_cosine: If False, loss_normals is from one minus the cosine similarity. + If True (default), loss_normals is from one minus the absolute value of the + cosine similarity, which means that exactly opposite normals are considered + equivalent to exactly matching normals, i.e. sign does not matter. + + Returns: + 2-element tuple containing + + - **loss**: Tensor giving the reduced distance between the pointclouds + in x and the pointclouds in y. + - **loss_normals**: Tensor giving the reduced cosine distance of normals + between pointclouds in x and pointclouds in y. Returns None if + x_normals and y_normals are None. + + """ + _validate_chamfer_reduction_inputs(batch_reduction, point_reduction) + + if not ((norm == 1) or (norm == 2)): + raise ValueError("Support for 1 or 2 norm.") + x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) + y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) + + cham_x, cham_norm_x = _chamfer_distance_single_direction( + x, + y, + x_lengths, + y_lengths, + x_normals, + y_normals, + weights, + batch_reduction, + point_reduction, + norm, + abs_cosine, + ) + if single_directional: + return cham_x, cham_norm_x + else: + cham_y, cham_norm_y = _chamfer_distance_single_direction( + y, + x, + y_lengths, + x_lengths, + y_normals, + x_normals, + weights, + batch_reduction, + point_reduction, + norm, + abs_cosine, + ) + return ( + cham_x + cham_y, + (cham_norm_x + cham_norm_y) if cham_norm_x is not None else None, + ) diff --git a/tests/test_chamfer.py b/tests/test_chamfer.py index 964a9fab9..e6a7897d9 100644 --- a/tests/test_chamfer.py +++ b/tests/test_chamfer.py @@ -88,7 +88,9 @@ def init_pointclouds( ) @staticmethod - def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"): + def chamfer_distance_naive_pointclouds( + p1, p2, norm: int = 2, device="cpu", abs_cosine=True + ): """ Naive iterative implementation of nearest neighbor and chamfer distance. x and y are assumed to be pointclouds objects with points and optionally normals. @@ -146,17 +148,20 @@ def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"): if return_normals: x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3) y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3) - lnorm1 = 1 - torch.abs( - F.cosine_similarity( - x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6 - ) + cosine_sim1 = F.cosine_similarity( + x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6 ) - lnorm2 = 1 - torch.abs( - F.cosine_similarity( - y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6 - ) + cosine_sim2 = F.cosine_similarity( + y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6 ) + if abs_cosine: + lnorm1 = 1 - torch.abs(cosine_sim1) + lnorm2 = 1 - torch.abs(cosine_sim2) + else: + lnorm1 = 1 - cosine_sim1 + lnorm2 = 1 - cosine_sim2 + if is_x_heterogeneous: lnorm1[x_mask] = 0.0 if is_y_heterogeneous: @@ -167,7 +172,9 @@ def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"): return loss, lnorm @staticmethod - def chamfer_distance_naive(x, y, x_normals=None, y_normals=None, norm: int = 2): + def chamfer_distance_naive( + x, y, x_normals=None, y_normals=None, norm: int = 2, abs_cosine=True + ): """ Naive iterative implementation of nearest neighbor and chamfer distance. Returns lists of the unreduced loss and loss_normals. This naive @@ -200,16 +207,21 @@ def chamfer_distance_naive(x, y, x_normals=None, y_normals=None, norm: int = 2): if return_normals: x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3) y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3) - lnorm1 = 1 - torch.abs( - F.cosine_similarity( - x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6 - ) + + cosine_sim1 = F.cosine_similarity( + x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6 ) - lnorm2 = 1 - torch.abs( - F.cosine_similarity( - y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6 - ) + cosine_sim2 = F.cosine_similarity( + y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6 ) + + if abs_cosine: + lnorm1 = 1 - torch.abs(cosine_sim1) + lnorm2 = 1 - torch.abs(cosine_sim2) + else: + lnorm1 = 1 - cosine_sim1 + lnorm2 = 1 - cosine_sim2 + lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)] return loss, lnorm @@ -323,6 +335,80 @@ def test_chamfer_vs_naive_pointcloud(self): y_lengths, ) + def test_single_directional_chamfer_vs_naive_pointcloud(self): + """ + Test the single directional settings for chamfer_distance + (point reduction = "mean" and batch_reduction="mean") but with heterogeneous + pointclouds as input. Compare with the naive implementation of chamfer + which supports heterogeneous pointcloud objects. + """ + N, max_P1, max_P2 = 3, 70, 70 + device = get_random_cuda_device() + + for norm in [1, 2]: + for abs_cosine in [True, False]: + points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) + weights = points_normals.weights + x_lengths = points_normals.p1_lengths + y_lengths = points_normals.p2_lengths + + # Chamfer with tensors as input for heterogeneous pointclouds. + cham_tensor, norm_tensor = chamfer_distance( + points_normals.p1, + points_normals.p2, + x_normals=points_normals.n1, + y_normals=points_normals.n2, + x_lengths=points_normals.p1_lengths, + y_lengths=points_normals.p2_lengths, + weights=weights, + norm=norm, + single_directional=True, + abs_cosine=abs_cosine, + ) + + # Chamfer with pointclouds as input. + ( + pred_loss, + pred_norm_loss, + ) = TestChamfer.chamfer_distance_naive_pointclouds( + points_normals.cloud1, + points_normals.cloud2, + norm=norm, + device=device, + abs_cosine=abs_cosine, + ) + + # Mean reduction point loss. + pred_loss[0] *= weights.view(N, 1) + pred_loss_mean = pred_loss[0].sum(1) / x_lengths + pred_loss_mean = pred_loss_mean.sum() + pred_loss_mean /= weights.sum() + + # Mean reduction norm loss. + pred_norm_loss[0] *= weights.view(N, 1) + pred_norm_loss_mean = pred_norm_loss[0].sum(1) / x_lengths + pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum() + + self.assertClose(pred_loss_mean, cham_tensor) + self.assertClose(pred_norm_loss_mean, norm_tensor) + + self._check_gradients( + cham_tensor, + norm_tensor, + pred_loss_mean, + pred_norm_loss_mean, + points_normals.cloud1.points_list(), + points_normals.p1, + points_normals.cloud2.points_list(), + points_normals.p2, + points_normals.cloud1.normals_list(), + points_normals.n1, + points_normals.cloud2.normals_list(), + points_normals.n2, + x_lengths, + y_lengths, + ) + def test_chamfer_pointcloud_object_withnormals(self): N = 5 P1, P2 = 100, 100 @@ -485,6 +571,53 @@ def test_chamfer_point_reduction_mean(self): loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22 ) + def test_single_direction_chamfer_point_reduction_mean(self): + """ + Compare output of vectorized chamfer loss with naive implementation + for point_reduction = "mean" and batch_reduction = None. + """ + N, max_P1, max_P2 = 7, 10, 18 + device = get_random_cuda_device() + points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + p1_normals = points_normals.n1 + p2_normals = points_normals.n2 + weights = points_normals.weights + p11 = p1.detach().clone() + p22 = p2.detach().clone() + p11.requires_grad = True + p22.requires_grad = True + P1 = p1.shape[1] + + pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive( + p1, p2, x_normals=p1_normals, y_normals=p2_normals + ) + + # point_reduction = "mean". + loss, loss_norm = chamfer_distance( + p11, + p22, + x_normals=p1_normals, + y_normals=p2_normals, + weights=weights, + batch_reduction=None, + point_reduction="mean", + single_directional=True, + ) + pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss_mean *= weights + self.assertClose(loss, pred_loss_mean) + + pred_loss_norm_mean = pred_loss_norm[0].sum(1) / P1 + pred_loss_norm_mean *= weights + self.assertClose(loss_norm, pred_loss_norm_mean) + + # Check gradients + self._check_gradients( + loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22 + ) + def test_chamfer_point_reduction_sum(self): """ Compare output of vectorized chamfer loss with naive implementation @@ -529,6 +662,51 @@ def test_chamfer_point_reduction_sum(self): loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22 ) + def test_single_directional_chamfer_point_reduction_sum(self): + """ + Compare output of vectorized single directional chamfer loss with naive implementation + for point_reduction = "sum" and batch_reduction = None. + """ + N, P1, P2 = 7, 10, 18 + device = get_random_cuda_device() + points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + p1_normals = points_normals.n1 + p2_normals = points_normals.n2 + weights = points_normals.weights + p11 = p1.detach().clone() + p22 = p2.detach().clone() + p11.requires_grad = True + p22.requires_grad = True + + pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive( + p1, p2, x_normals=p1_normals, y_normals=p2_normals + ) + + loss, loss_norm = chamfer_distance( + p11, + p22, + x_normals=p1_normals, + y_normals=p2_normals, + weights=weights, + batch_reduction=None, + point_reduction="sum", + single_directional=True, + ) + pred_loss_sum = pred_loss[0].sum(1) + pred_loss_sum *= weights + self.assertClose(loss, pred_loss_sum) + + pred_loss_norm_sum = pred_loss_norm[0].sum(1) + pred_loss_norm_sum *= weights + self.assertClose(loss_norm, pred_loss_norm_sum) + + # Check gradients + self._check_gradients( + loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22 + ) + def _check_gradients( self, loss,