diff --git a/pytorch3d/ops/points_alignment.py b/pytorch3d/ops/points_alignment.py index 60d0afd0d..15b39d783 100644 --- a/pytorch3d/ops/points_alignment.py +++ b/pytorch3d/ops/points_alignment.py @@ -1,16 +1,18 @@ -#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import warnings -from typing import Tuple, Union +from typing import List, Optional, Tuple, Union import torch from pytorch3d.structures.pointclouds import Pointclouds +from pytorch3d.structures import utils as strutil +from pytorch3d.ops import utils as oputil def corresponding_points_alignment( X: Union[torch.Tensor, Pointclouds], Y: Union[torch.Tensor, Pointclouds], + weights: Union[torch.Tensor, List[torch.Tensor], None] = None, estimate_scale: bool = False, allow_reflection: bool = False, eps: float = 1e-8, @@ -28,9 +30,14 @@ def corresponding_points_alignment( Args: X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` - or a `Pointclouds` object. + or a `Pointclouds` object. Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` - or a `Pointclouds` object. + or a `Pointclouds` object. + weights: Batch of non-negative weights of + shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional + tensors that may have different shapes; in that case, the length of + i-th tensor should be equal to the number of points in X_i and Y_i. + Passing `None` means uniform weights. estimate_scale: If `True`, also estimates a scaling component `s` of the transformation. Otherwise assumes an identity scale and returns a tensor of ones. @@ -59,25 +66,45 @@ def corresponding_points_alignment( "Point sets X and Y have to have the same \ number of batches, points and dimensions." ) + if weights is not None: + if isinstance(weights, list): + if any(np != w.shape[0] for np, w in zip(num_points, weights)): + raise ValueError( + "number of weights should equal to the " + + "number of points in the point cloud." + ) + weights = [w[..., None] for w in weights] + weights = strutil.list_to_padded(weights)[..., 0] + + if Xt.shape[:2] != weights.shape: + raise ValueError( + "weights should have the same first two dimensions as X." + ) b, n, dim = Xt.shape - # compute the centroids of the point sets - Xmu = Xt.sum(1) / torch.clamp(num_points[:, None], 1) - Ymu = Yt.sum(1) / torch.clamp(num_points[:, None], 1) - - # mean-center the point sets - Xc = Xt - Xmu[:, None] - Yc = Yt - Ymu[:, None] - if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any(): # in case we got Pointclouds as input, mask the unused entries in Xc, Yc mask = ( - torch.arange(n, dtype=torch.int64, device=Xc.device)[None] + torch.arange(n, dtype=torch.int64, device=Xt.device)[None] < num_points[:, None] - ).type_as(Xc) - Xc *= mask[:, :, None] - Yc *= mask[:, :, None] + ).type_as(Xt) + weights = mask if weights is None else mask * weights.type_as(Xt) + + # compute the centroids of the point sets + Xmu = oputil.wmean(Xt, weights, eps=eps) + Ymu = oputil.wmean(Yt, weights, eps=eps) + + # mean-center the point sets + Xc = Xt - Xmu + Yc = Yt - Ymu + + total_weight = torch.clamp(num_points, 1) + # special handling for heterogeneous point clouds and/or input weights + if weights is not None: + Xc *= weights[:, :, None] + Yc *= weights[:, :, None] + total_weight = torch.clamp(weights.sum(1), eps) if (num_points < (dim + 1)).any(): warnings.warn( @@ -87,7 +114,7 @@ def corresponding_points_alignment( # compute the covariance XYcov between the point sets Xc, Yc XYcov = torch.bmm(Xc.transpose(2, 1), Yc) - XYcov = XYcov / torch.clamp(num_points[:, None, None], 1) + XYcov = XYcov / total_weight[:, None, None] # decompose the covariance matrix XYcov U, S, V = torch.svd(XYcov) @@ -111,17 +138,16 @@ def corresponding_points_alignment( if estimate_scale: # estimate the scaling component of the transformation trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1) - Xcov = (Xc * Xc).sum((1, 2)) / torch.clamp(num_points, 1) + Xcov = (Xc * Xc).sum((1, 2)) / total_weight # the scaling component s = trace_ES / torch.clamp(Xcov, eps) # translation component - T = Ymu - s[:, None] * torch.bmm(Xmu[:, None], R)[:, 0, :] - + T = Ymu[:, 0, :] - s[:, None] * torch.bmm(Xmu, R)[:, 0, :] else: # translation component - T = Ymu - torch.bmm(Xmu[:, None], R)[:, 0] + T = Ymu[:, 0, :] - torch.bmm(Xmu, R)[:, 0, :] # unit scaling since we do not estimate scale s = T.new_ones(b) diff --git a/pytorch3d/ops/utils.py b/pytorch3d/ops/utils.py new file mode 100644 index 000000000..6813288c2 --- /dev/null +++ b/pytorch3d/ops/utils.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import Optional, Tuple, Union + +import torch + + +def wmean( + x: torch.Tensor, + weight: Optional[torch.Tensor] = None, + dim: Union[int, Tuple[int]] = -2, + keepdim: bool = True, + eps: float = 1e-9, +) -> torch.Tensor: + """ + Finds the mean of the input tensor across the specified dimension. + If the `weight` argument is provided, computes weighted mean. + Args: + x: tensor of shape `(*, D)`, where D is assumed to be spatial; + weights: if given, non-negative tensor of shape `(*,)`. It must be + broadcastable to `x.shape[:-1]`. Note that the weights for + the last (spatial) dimension are assumed same; + dim: dimension(s) in `x` to average over; + keepdim: tells whether to keep the resulting singleton dimension. + eps: minumum clamping value in the denominator. + Returns: + the mean tensor: + * if `weights` is None => `mean(x, dim)`, + * otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`. + """ + args = dict(dim=dim, keepdim=keepdim) + + if weight is None: + return x.mean(**args) + + if any( + xd != wd and xd != 1 and wd != 1 + for xd, wd in zip(x.shape[-2::-1], weight.shape[::-1]) + ): + raise ValueError("wmean: weights are not compatible with the tensor") + + return ( + (x * weight[..., None]).sum(**args) + / weight[..., None].sum(**args).clamp(eps) + ) diff --git a/tests/bm_points_alignment.py b/tests/bm_points_alignment.py index f823b0cca..75464602f 100644 --- a/tests/bm_points_alignment.py +++ b/tests/bm_points_alignment.py @@ -16,6 +16,7 @@ def bm_corresponding_points_alignment() -> None: "dim": [3, 20], "estimate_scale": [True, False], "n_points": [100, 10000], + "random_weights": [False, True], "use_pointclouds": [False], } diff --git a/tests/common_testing.py b/tests/common_testing.py index 4549aae19..9bbcb1874 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import Optional import unittest @@ -35,13 +36,15 @@ def assertClose( *, rtol: float = 1e-05, atol: float = 1e-08, - equal_nan: bool = False + equal_nan: bool = False, + msg: Optional[str] = None, ) -> None: """ Verify that two tensors or arrays are the same shape and close. Args: input, other: two tensors or two arrays. rtol, atol, equal_nan: as for torch.allclose. + msg: message in case the assertion is violated. Note: Optional arguments here are all keyword-only, to avoid confusion with msg arguments on other assert functions. @@ -54,5 +57,7 @@ def assertClose( input, other, rtol=rtol, atol=atol, equal_nan=equal_nan ) else: - close = np.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan) - self.assertTrue(close) + close = np.allclose( + input, other, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + self.assertTrue(close, msg) diff --git a/tests/test_ops_utils.py b/tests/test_ops_utils.py new file mode 100644 index 000000000..81099bdc2 --- /dev/null +++ b/tests/test_ops_utils.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import unittest + +import numpy as np +import torch + +from common_testing import TestCaseMixin + +from pytorch3d.ops import utils as oputil + +class TestOpsUtils(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + np.random.seed(42) + + def test_wmean(self): + device = torch.device("cuda:0") + n_points = 20 + + x = torch.rand(n_points, 3, device=device) + weight = torch.rand(n_points, device=device) + x_np = x.cpu().data.numpy() + weight_np = weight.cpu().data.numpy() + + # test unweighted + mean = oputil.wmean(x, keepdim=False) + mean_gt = np.average(x_np, axis=-2) + self.assertClose(mean.cpu().data.numpy(), mean_gt) + + # test weighted + mean = oputil.wmean(x, weight=weight, keepdim=False) + mean_gt = np.average(x_np, axis=-2, weights=weight_np) + self.assertClose(mean.cpu().data.numpy(), mean_gt) + + # test keepdim + mean = oputil.wmean(x, weight=weight, keepdim=True) + self.assertClose(mean[0].cpu().data.numpy(), mean_gt) + + # test binary weigths + mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False) + mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5) + self.assertClose(mean.cpu().data.numpy(), mean_gt) + + # test broadcasting + x = torch.rand(10, n_points, 3, device=device) + x_np = x.cpu().data.numpy() + mean = oputil.wmean(x, weight=weight, keepdim=False) + mean_gt = np.average(x_np, axis=-2, weights=weight_np) + self.assertClose(mean.cpu().data.numpy(), mean_gt) + + weight = weight[None, None, :].repeat(3, 1, 1) + mean = oputil.wmean(x, weight=weight, keepdim=False) + self.assertClose(mean[0].cpu().data.numpy(), mean_gt) + + # test failing broadcasting + weight = torch.rand(x.shape[0], device=device) + with self.assertRaises(ValueError) as context: + oputil.wmean(x, weight=weight, keepdim=False) + self.assertTrue("weights are not compatible" in str(context.exception)) + + # test dim + weight = torch.rand(x.shape[0], n_points, device=device) + weight_np = np.tile( + weight[:, :, None].cpu().data.numpy(), + (1, 1, x_np.shape[-1]), + ) + mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False) + mean_gt = np.average(x_np, axis=0, weights=weight_np) + self.assertClose(mean.cpu().data.numpy(), mean_gt) + + # test dim tuple + mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False) + mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np) + self.assertClose(mean.cpu().data.numpy(), mean_gt) diff --git a/tests/test_points_alignment.py b/tests/test_points_alignment.py index fc2a6d9f9..823f78ccc 100644 --- a/tests/test_points_alignment.py +++ b/tests/test_points_alignment.py @@ -6,6 +6,8 @@ import unittest import torch +from common_testing import TestCaseMixin + from pytorch3d.ops import points_alignment from pytorch3d.structures.pointclouds import Pointclouds from pytorch3d.transforms import rotation_conversions @@ -35,7 +37,7 @@ def _apply_pcl_transformation(X, R, T, s=None): return X_t -class TestCorrespondingPointsAlignment(unittest.TestCase): +class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() torch.manual_seed(42) @@ -171,6 +173,7 @@ def corresponding_points_alignment( estimate_scale=False, allow_reflection=False, reflect=False, + random_weights=False, ): device = torch.device("cuda:0") @@ -198,12 +201,27 @@ def corresponding_points_alignment( # point cloud X X_t = _apply_pcl_transformation(X, R, T, s=s) + weights = None + if random_weights: + template = X.points_padded() if use_pointclouds else X + weights = torch.rand_like(template[:, :, 0]) + weights = weights / weights.sum(dim=1, keepdim=True) + # zero out some weights as zero weights are a common use case + # this guarantees there are no zero weight + weights *= (weights * template.size()[1] > 0.3).to(weights) + if use_pointclouds: # convert to List[Tensor] + weights = [ + w[:npts] + for w, npts in zip(weights, X.num_points_per_cloud()) + ] + torch.cuda.synchronize() def run_corresponding_points_alignment(): points_alignment.corresponding_points_alignment( X, X_t, + weights, allow_reflection=allow_reflection, estimate_scale=estimate_scale, ) @@ -230,26 +248,28 @@ def test_corresponding_points_alignment(self, batch_size=10): """ # run this for several different point cloud sizes - for n_points in (100, 3, 2, 1, 0): + for n_points in (100, 3, 2, 1): # run this for several different dimensionalities - for dim in torch.arange(2, 10): + for dim in range(2, 10): # switches whether we should use the Pointclouds inputs use_point_clouds_cases = ( (True, False) if dim == 3 and n_points > 3 else (False,) ) - for use_pointclouds in use_point_clouds_cases: - for estimate_scale in (False, True): - for reflect in (False, True): - for allow_reflection in (False, True): - self._test_single_corresponding_points_alignment( - batch_size=10, - n_points=n_points, - dim=int(dim), - use_pointclouds=use_pointclouds, - estimate_scale=estimate_scale, - reflect=reflect, - allow_reflection=allow_reflection, - ) + for random_weights in (False, True,): + for use_pointclouds in use_point_clouds_cases: + for estimate_scale in (False, True): + for reflect in (False, True): + for allow_reflection in (False, True): + self._test_single_corresponding_points_alignment( + batch_size=10, + n_points=n_points, + dim=dim, + use_pointclouds=use_pointclouds, + estimate_scale=estimate_scale, + reflect=reflect, + allow_reflection=allow_reflection, + random_weights=random_weights, + ) def _test_single_corresponding_points_alignment( self, @@ -260,6 +280,7 @@ def _test_single_corresponding_points_alignment( estimate_scale=False, reflect=False, allow_reflection=False, + random_weights=False, ): """ Executes a single test for `corresponding_points_alignment` for a @@ -294,6 +315,20 @@ def _test_single_corresponding_points_alignment( ) R = torch.bmm(M, R) + weights = None + if random_weights: + template = X.points_padded() if use_pointclouds else X + weights = torch.rand_like(template[:, :, 0]) + weights = weights / weights.sum(dim=1, keepdim=True) + # zero out some weights as zero weights are a common use case + # this guarantees there are no zero weight + weights *= (weights * template.size()[1] > 0.3).to(weights) + if use_pointclouds: # convert to List[Tensor] + weights = [ + w[:npts] + for w, npts in zip(weights, X.num_points_per_cloud()) + ] + # apply the generated transformation to the generated # point cloud X X_t = _apply_pcl_transformation(X, R, T, s=s) @@ -302,6 +337,7 @@ def _test_single_corresponding_points_alignment( R_est, T_est, s_est = points_alignment.corresponding_points_alignment( X, X_t, + weights, allow_reflection=allow_reflection, estimate_scale=estimate_scale, ) @@ -313,9 +349,40 @@ def _test_single_corresponding_points_alignment( f"use_pointclouds={use_pointclouds}, " f"estimate_scale={estimate_scale}, " f"reflect={reflect}, " - f"allow_reflection={allow_reflection}." + f"allow_reflection={allow_reflection}," + f"random_weights={random_weights}." ) + # if we test the weighted case, check that weights help with noise + if random_weights and not use_pointclouds and n_points >= (dim + 10): + # add noise to 20% points with smallest weight + X_noisy = X_t.clone() + _, mink_idx = torch.topk(-weights, int(n_points * 0.2), dim=1) + mink_idx = mink_idx[:, :, None].expand(-1, -1, X_t.shape[-1]) + X_noisy.scatter_add_( + 1, mink_idx, 0.3 * torch.randn_like(mink_idx, dtype=X_t.dtype) + ) + + def align_and_get_mse(weights_): + R_n, T_n, s_n = points_alignment.corresponding_points_alignment( + X_noisy, + X_t, + weights_, + allow_reflection=allow_reflection, + estimate_scale=estimate_scale, + ) + + X_t_est = _apply_pcl_transformation(X_noisy, R_n, T_n, s=s_n) + + return ( + ((X_t_est - X_t) * weights[..., None]) ** 2 + ).sum(dim=(1, 2)) / weights.sum(dim=-1) + + # check that using weights leads to lower weighted_MSE(X_noisy, X_t) + self.assertTrue( + torch.all(align_and_get_mse(weights) <= align_and_get_mse(None)) + ) + if reflect and not allow_reflection: # check that all rotations have det=1 self._assert_all_close( @@ -325,34 +392,44 @@ def _test_single_corresponding_points_alignment( ) else: + # mask out inputs with too few non-degenerate points for assertions + w = ( + torch.ones_like(R_est[:, 0, 0]) + if weights is None or n_points >= dim + 10 + else (weights > 0.0).all(dim=1).to(R_est) + ) # check that the estimated tranformation is the same # as the ground truth if n_points >= (dim + 1): # the checks on transforms apply only when # the problem setup is unambiguous - self._assert_all_close(R_est, R, assert_error_message) - self._assert_all_close(T_est, T, assert_error_message) - self._assert_all_close(s_est, s, assert_error_message) + msg = assert_error_message + self._assert_all_close(R_est, R, msg, w[:, None, None], atol=1e-5) + self._assert_all_close(T_est, T, msg, w[:, None]) + self._assert_all_close(s_est, s, msg, w) # check that the orthonormal part of the # transformation has a correct determinant (+1/-1) desired_det = R_est.new_ones(batch_size) if reflect: desired_det *= -1.0 - self._assert_all_close( - torch.det(R_est), desired_det, assert_error_message - ) + self._assert_all_close(torch.det(R_est), desired_det, msg, w) # check that the transformed point cloud # X matches X_t X_t_est = _apply_pcl_transformation(X, R_est, T_est, s=s_est) self._assert_all_close( - X_t, X_t_est, assert_error_message, atol=1e-5 + X_t, X_t_est, assert_error_message, w[:, None, None], atol=1e-5 ) - def _assert_all_close(self, a_, b_, err_message, atol=1e-6): + def _assert_all_close(self, a_, b_, err_message, weights=None, atol=1e-6): if isinstance(a_, Pointclouds): a_ = a_.points_packed() if isinstance(b_, Pointclouds): b_ = b_.points_packed() - self.assertTrue(torch.allclose(a_, b_, atol=atol), err_message) + if weights is None: + self.assertClose(a_, b_, atol=atol, msg=err_message) + else: + self.assertClose( + a_ * weights, b_ * weights, atol=atol, msg=err_message + )