-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: 1. Introduced weights to Umeyama implementation. This will be needed for weighted ePnP but is useful on its own. 2. Refactored to use the same code for the Pointclouds mask and passed weights. 3. Added test cases with random weights. 4. Fixed a bug in tests that calls the function with 0 points (fails randomly in Pytorch 1.3, will be fixed in the next release: pytorch/pytorch#31421 ). Reviewed By: gkioxari Differential Revision: D20070293 fbshipit-source-id: e9f549507ef6dcaa0688a0f17342e6d7a9a4336c
- Loading branch information
1 parent
e5b1d6d
commit e37085d
Showing
6 changed files
with
278 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
from typing import Optional, Tuple, Union | ||
|
||
import torch | ||
|
||
|
||
def wmean( | ||
x: torch.Tensor, | ||
weight: Optional[torch.Tensor] = None, | ||
dim: Union[int, Tuple[int]] = -2, | ||
keepdim: bool = True, | ||
eps: float = 1e-9, | ||
) -> torch.Tensor: | ||
""" | ||
Finds the mean of the input tensor across the specified dimension. | ||
If the `weight` argument is provided, computes weighted mean. | ||
Args: | ||
x: tensor of shape `(*, D)`, where D is assumed to be spatial; | ||
weights: if given, non-negative tensor of shape `(*,)`. It must be | ||
broadcastable to `x.shape[:-1]`. Note that the weights for | ||
the last (spatial) dimension are assumed same; | ||
dim: dimension(s) in `x` to average over; | ||
keepdim: tells whether to keep the resulting singleton dimension. | ||
eps: minumum clamping value in the denominator. | ||
Returns: | ||
the mean tensor: | ||
* if `weights` is None => `mean(x, dim)`, | ||
* otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`. | ||
""" | ||
args = dict(dim=dim, keepdim=keepdim) | ||
|
||
if weight is None: | ||
return x.mean(**args) | ||
|
||
if any( | ||
xd != wd and xd != 1 and wd != 1 | ||
for xd, wd in zip(x.shape[-2::-1], weight.shape[::-1]) | ||
): | ||
raise ValueError("wmean: weights are not compatible with the tensor") | ||
|
||
return ( | ||
(x * weight[..., None]).sum(**args) | ||
/ weight[..., None].sum(**args).clamp(eps) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
import unittest | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from common_testing import TestCaseMixin | ||
|
||
from pytorch3d.ops import utils as oputil | ||
|
||
class TestOpsUtils(TestCaseMixin, unittest.TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
torch.manual_seed(42) | ||
np.random.seed(42) | ||
|
||
def test_wmean(self): | ||
device = torch.device("cuda:0") | ||
n_points = 20 | ||
|
||
x = torch.rand(n_points, 3, device=device) | ||
weight = torch.rand(n_points, device=device) | ||
x_np = x.cpu().data.numpy() | ||
weight_np = weight.cpu().data.numpy() | ||
|
||
# test unweighted | ||
mean = oputil.wmean(x, keepdim=False) | ||
mean_gt = np.average(x_np, axis=-2) | ||
self.assertClose(mean.cpu().data.numpy(), mean_gt) | ||
|
||
# test weighted | ||
mean = oputil.wmean(x, weight=weight, keepdim=False) | ||
mean_gt = np.average(x_np, axis=-2, weights=weight_np) | ||
self.assertClose(mean.cpu().data.numpy(), mean_gt) | ||
|
||
# test keepdim | ||
mean = oputil.wmean(x, weight=weight, keepdim=True) | ||
self.assertClose(mean[0].cpu().data.numpy(), mean_gt) | ||
|
||
# test binary weigths | ||
mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False) | ||
mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5) | ||
self.assertClose(mean.cpu().data.numpy(), mean_gt) | ||
|
||
# test broadcasting | ||
x = torch.rand(10, n_points, 3, device=device) | ||
x_np = x.cpu().data.numpy() | ||
mean = oputil.wmean(x, weight=weight, keepdim=False) | ||
mean_gt = np.average(x_np, axis=-2, weights=weight_np) | ||
self.assertClose(mean.cpu().data.numpy(), mean_gt) | ||
|
||
weight = weight[None, None, :].repeat(3, 1, 1) | ||
mean = oputil.wmean(x, weight=weight, keepdim=False) | ||
self.assertClose(mean[0].cpu().data.numpy(), mean_gt) | ||
|
||
# test failing broadcasting | ||
weight = torch.rand(x.shape[0], device=device) | ||
with self.assertRaises(ValueError) as context: | ||
oputil.wmean(x, weight=weight, keepdim=False) | ||
self.assertTrue("weights are not compatible" in str(context.exception)) | ||
|
||
# test dim | ||
weight = torch.rand(x.shape[0], n_points, device=device) | ||
weight_np = np.tile( | ||
weight[:, :, None].cpu().data.numpy(), | ||
(1, 1, x_np.shape[-1]), | ||
) | ||
mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False) | ||
mean_gt = np.average(x_np, axis=0, weights=weight_np) | ||
self.assertClose(mean.cpu().data.numpy(), mean_gt) | ||
|
||
# test dim tuple | ||
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False) | ||
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np) | ||
self.assertClose(mean.cpu().data.numpy(), mean_gt) |
Oops, something went wrong.