Skip to content

Commit

Permalink
Weighted Umeyama.
Browse files Browse the repository at this point in the history
Summary:
1. Introduced weights to Umeyama implementation. This will be needed for weighted ePnP but is useful on its own.
2. Refactored to use the same code for the Pointclouds mask and passed weights.
3. Added test cases with random weights.
4. Fixed a bug in tests that calls the function with 0 points (fails randomly in Pytorch 1.3, will be fixed in the next release: pytorch/pytorch#31421 ).

Reviewed By: gkioxari

Differential Revision: D20070293

fbshipit-source-id: e9f549507ef6dcaa0688a0f17342e6d7a9a4336c
  • Loading branch information
shapovalov authored and facebook-github-bot committed Apr 3, 2020
1 parent e5b1d6d commit e37085d
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 50 deletions.
68 changes: 47 additions & 21 deletions pytorch3d/ops/points_alignment.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import warnings
from typing import Tuple, Union
from typing import List, Optional, Tuple, Union
import torch

from pytorch3d.structures.pointclouds import Pointclouds
from pytorch3d.structures import utils as strutil
from pytorch3d.ops import utils as oputil


def corresponding_points_alignment(
X: Union[torch.Tensor, Pointclouds],
Y: Union[torch.Tensor, Pointclouds],
weights: Union[torch.Tensor, List[torch.Tensor], None] = None,
estimate_scale: bool = False,
allow_reflection: bool = False,
eps: float = 1e-8,
Expand All @@ -28,9 +30,14 @@ def corresponding_points_alignment(
Args:
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
or a `Pointclouds` object.
or a `Pointclouds` object.
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
or a `Pointclouds` object.
or a `Pointclouds` object.
weights: Batch of non-negative weights of
shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
tensors that may have different shapes; in that case, the length of
i-th tensor should be equal to the number of points in X_i and Y_i.
Passing `None` means uniform weights.
estimate_scale: If `True`, also estimates a scaling component `s`
of the transformation. Otherwise assumes an identity
scale and returns a tensor of ones.
Expand Down Expand Up @@ -59,25 +66,45 @@ def corresponding_points_alignment(
"Point sets X and Y have to have the same \
number of batches, points and dimensions."
)
if weights is not None:
if isinstance(weights, list):
if any(np != w.shape[0] for np, w in zip(num_points, weights)):
raise ValueError(
"number of weights should equal to the "
+ "number of points in the point cloud."
)
weights = [w[..., None] for w in weights]
weights = strutil.list_to_padded(weights)[..., 0]

if Xt.shape[:2] != weights.shape:
raise ValueError(
"weights should have the same first two dimensions as X."
)

b, n, dim = Xt.shape

# compute the centroids of the point sets
Xmu = Xt.sum(1) / torch.clamp(num_points[:, None], 1)
Ymu = Yt.sum(1) / torch.clamp(num_points[:, None], 1)

# mean-center the point sets
Xc = Xt - Xmu[:, None]
Yc = Yt - Ymu[:, None]

if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
mask = (
torch.arange(n, dtype=torch.int64, device=Xc.device)[None]
torch.arange(n, dtype=torch.int64, device=Xt.device)[None]
< num_points[:, None]
).type_as(Xc)
Xc *= mask[:, :, None]
Yc *= mask[:, :, None]
).type_as(Xt)
weights = mask if weights is None else mask * weights.type_as(Xt)

# compute the centroids of the point sets
Xmu = oputil.wmean(Xt, weights, eps=eps)
Ymu = oputil.wmean(Yt, weights, eps=eps)

# mean-center the point sets
Xc = Xt - Xmu
Yc = Yt - Ymu

total_weight = torch.clamp(num_points, 1)
# special handling for heterogeneous point clouds and/or input weights
if weights is not None:
Xc *= weights[:, :, None]
Yc *= weights[:, :, None]
total_weight = torch.clamp(weights.sum(1), eps)

if (num_points < (dim + 1)).any():
warnings.warn(
Expand All @@ -87,7 +114,7 @@ def corresponding_points_alignment(

# compute the covariance XYcov between the point sets Xc, Yc
XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
XYcov = XYcov / torch.clamp(num_points[:, None, None], 1)
XYcov = XYcov / total_weight[:, None, None]

# decompose the covariance matrix XYcov
U, S, V = torch.svd(XYcov)
Expand All @@ -111,17 +138,16 @@ def corresponding_points_alignment(
if estimate_scale:
# estimate the scaling component of the transformation
trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
Xcov = (Xc * Xc).sum((1, 2)) / torch.clamp(num_points, 1)
Xcov = (Xc * Xc).sum((1, 2)) / total_weight

# the scaling component
s = trace_ES / torch.clamp(Xcov, eps)

# translation component
T = Ymu - s[:, None] * torch.bmm(Xmu[:, None], R)[:, 0, :]

T = Ymu[:, 0, :] - s[:, None] * torch.bmm(Xmu, R)[:, 0, :]
else:
# translation component
T = Ymu - torch.bmm(Xmu[:, None], R)[:, 0]
T = Ymu[:, 0, :] - torch.bmm(Xmu, R)[:, 0, :]

# unit scaling since we do not estimate scale
s = T.new_ones(b)
Expand Down
44 changes: 44 additions & 0 deletions pytorch3d/ops/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional, Tuple, Union

import torch


def wmean(
x: torch.Tensor,
weight: Optional[torch.Tensor] = None,
dim: Union[int, Tuple[int]] = -2,
keepdim: bool = True,
eps: float = 1e-9,
) -> torch.Tensor:
"""
Finds the mean of the input tensor across the specified dimension.
If the `weight` argument is provided, computes weighted mean.
Args:
x: tensor of shape `(*, D)`, where D is assumed to be spatial;
weights: if given, non-negative tensor of shape `(*,)`. It must be
broadcastable to `x.shape[:-1]`. Note that the weights for
the last (spatial) dimension are assumed same;
dim: dimension(s) in `x` to average over;
keepdim: tells whether to keep the resulting singleton dimension.
eps: minumum clamping value in the denominator.
Returns:
the mean tensor:
* if `weights` is None => `mean(x, dim)`,
* otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`.
"""
args = dict(dim=dim, keepdim=keepdim)

if weight is None:
return x.mean(**args)

if any(
xd != wd and xd != 1 and wd != 1
for xd, wd in zip(x.shape[-2::-1], weight.shape[::-1])
):
raise ValueError("wmean: weights are not compatible with the tensor")

return (
(x * weight[..., None]).sum(**args)
/ weight[..., None].sum(**args).clamp(eps)
)
1 change: 1 addition & 0 deletions tests/bm_points_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def bm_corresponding_points_alignment() -> None:
"dim": [3, 20],
"estimate_scale": [True, False],
"n_points": [100, 10000],
"random_weights": [False, True],
"use_pointclouds": [False],
}

Expand Down
11 changes: 8 additions & 3 deletions tests/common_testing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from typing import Optional

import unittest

Expand Down Expand Up @@ -35,13 +36,15 @@ def assertClose(
*,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False
equal_nan: bool = False,
msg: Optional[str] = None,
) -> None:
"""
Verify that two tensors or arrays are the same shape and close.
Args:
input, other: two tensors or two arrays.
rtol, atol, equal_nan: as for torch.allclose.
msg: message in case the assertion is violated.
Note:
Optional arguments here are all keyword-only, to avoid confusion
with msg arguments on other assert functions.
Expand All @@ -54,5 +57,7 @@ def assertClose(
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
)
else:
close = np.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
self.assertTrue(close)
close = np.allclose(
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
)
self.assertTrue(close, msg)
75 changes: 75 additions & 0 deletions tests/test_ops_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest

import numpy as np
import torch

from common_testing import TestCaseMixin

from pytorch3d.ops import utils as oputil

class TestOpsUtils(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)

def test_wmean(self):
device = torch.device("cuda:0")
n_points = 20

x = torch.rand(n_points, 3, device=device)
weight = torch.rand(n_points, device=device)
x_np = x.cpu().data.numpy()
weight_np = weight.cpu().data.numpy()

# test unweighted
mean = oputil.wmean(x, keepdim=False)
mean_gt = np.average(x_np, axis=-2)
self.assertClose(mean.cpu().data.numpy(), mean_gt)

# test weighted
mean = oputil.wmean(x, weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)

# test keepdim
mean = oputil.wmean(x, weight=weight, keepdim=True)
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)

# test binary weigths
mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False)
mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5)
self.assertClose(mean.cpu().data.numpy(), mean_gt)

# test broadcasting
x = torch.rand(10, n_points, 3, device=device)
x_np = x.cpu().data.numpy()
mean = oputil.wmean(x, weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)

weight = weight[None, None, :].repeat(3, 1, 1)
mean = oputil.wmean(x, weight=weight, keepdim=False)
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)

# test failing broadcasting
weight = torch.rand(x.shape[0], device=device)
with self.assertRaises(ValueError) as context:
oputil.wmean(x, weight=weight, keepdim=False)
self.assertTrue("weights are not compatible" in str(context.exception))

# test dim
weight = torch.rand(x.shape[0], n_points, device=device)
weight_np = np.tile(
weight[:, :, None].cpu().data.numpy(),
(1, 1, x_np.shape[-1]),
)
mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=0, weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)

# test dim tuple
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)
Loading

0 comments on commit e37085d

Please sign in to comment.