-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Implements a backprop-safe version of `torch.acos` that linearly extrapolates the function outside bounds. Below is a plot of the extrapolated acos for different bounds: {F611339485} Reviewed By: bottler, nikhilaravi Differential Revision: D27945714 fbshipit-source-id: fa2e2385b56d6fe534338d5192447c4a3aec540c
- Loading branch information
1 parent
88f5d79
commit dd45123
Showing
4 changed files
with
246 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
import math | ||
from typing import Tuple, Union | ||
|
||
import torch | ||
|
||
|
||
def acos_linear_extrapolation( | ||
x: torch.Tensor, | ||
bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4, | ||
) -> torch.Tensor: | ||
""" | ||
Implements `arccos(x)` which is linearly extrapolated outside `x`'s original | ||
domain of `(-1, 1)`. This allows for stable backpropagation in case `x` | ||
is not guaranteed to be strictly within `(-1, 1)`. | ||
More specifically: | ||
``` | ||
if -bound <= x <= bound: | ||
acos_linear_extrapolation(x) = acos(x) | ||
elif x <= -bound: # 1st order Taylor approximation | ||
acos_linear_extrapolation(x) = acos(-bound) + dacos/dx(-bound) * (x - (-bound)) | ||
else: # x >= bound | ||
acos_linear_extrapolation(x) = acos(bound) + dacos/dx(bound) * (x - bound) | ||
``` | ||
Note that `bound` can be made more specific with setting | ||
`bound=[lower_bound, upper_bound]` as detailed below. | ||
Args: | ||
x: Input `Tensor`. | ||
bound: A float constant or a float 2-tuple defining the region for the | ||
linear extrapolation of `acos`. | ||
If `bound` is a float scalar, linearly interpolates acos for | ||
`x <= -bound` or `bound <= x`. | ||
If `bound` is a 2-tuple, the first/second element of `bound` | ||
describes the lower/upper bound that defines the lower/upper | ||
extrapolation region, i.e. the region where | ||
`x <= bound[0]`/`bound[1] <= x`. | ||
Note that all elements of `bound` have to be within (-1, 1). | ||
Returns: | ||
acos_linear_extrapolation: `Tensor` containing the extrapolated `arccos(x)`. | ||
""" | ||
|
||
if isinstance(bound, float): | ||
upper_bound = bound | ||
lower_bound = -bound | ||
else: | ||
lower_bound, upper_bound = bound | ||
|
||
if lower_bound > upper_bound: | ||
raise ValueError("lower bound has to be smaller or equal to upper bound.") | ||
|
||
if lower_bound <= -1.0 or upper_bound >= 1.0: | ||
raise ValueError("Both lower bound and upper bound have to be within (-1, 1).") | ||
|
||
# init an empty tensor and define the domain sets | ||
acos_extrap = torch.empty_like(x) | ||
x_upper = x >= upper_bound | ||
x_lower = x <= lower_bound | ||
x_mid = (~x_upper) & (~x_lower) | ||
|
||
# acos calculation for upper_bound < x < lower_bound | ||
acos_extrap[x_mid] = torch.acos(x[x_mid]) | ||
# the linear extrapolation for x >= upper_bound | ||
acos_extrap[x_upper] = _acos_linear_approximation(x[x_upper], upper_bound) | ||
# the linear extrapolation for x <= lower_bound | ||
acos_extrap[x_lower] = _acos_linear_approximation(x[x_lower], lower_bound) | ||
|
||
return acos_extrap | ||
|
||
|
||
def _acos_linear_approximation(x: torch.Tensor, x0: float) -> torch.Tensor: | ||
""" | ||
Calculates the 1st order Taylor expansion of `arccos(x)` around `x0`. | ||
""" | ||
return (x - x0) * _dacos_dx(x0) + math.acos(x0) | ||
|
||
|
||
def _dacos_dx(x: float) -> float: | ||
""" | ||
Calculates the derivative of `arccos(x)` w.r.t. `x`. | ||
""" | ||
return (-1.0) / math.sqrt(1.0 - x * x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
from fvcore.common.benchmark import benchmark | ||
from test_acos_linear_extrapolation import TestAcosLinearExtrapolation | ||
|
||
|
||
def bm_acos_linear_extrapolation() -> None: | ||
kwargs_list = [ | ||
{"batch_size": 1}, | ||
{"batch_size": 100}, | ||
{"batch_size": 10000}, | ||
{"batch_size": 1000000}, | ||
] | ||
benchmark( | ||
TestAcosLinearExtrapolation.acos_linear_extrapolation, | ||
"ACOS_LINEAR_EXTRAPOLATION", | ||
kwargs_list, | ||
warmup_iters=1, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
bm_acos_linear_extrapolation() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
|
||
import unittest | ||
|
||
import numpy as np | ||
import torch | ||
from common_testing import TestCaseMixin | ||
from pytorch3d.transforms import acos_linear_extrapolation | ||
|
||
|
||
class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
torch.manual_seed(42) | ||
np.random.seed(42) | ||
|
||
@staticmethod | ||
def init_acos_boundary_values(batch_size: int = 10000): | ||
""" | ||
Initialize a tensor containing values close to the bounds of the | ||
domain of `acos`, i.e. close to -1 or 1; and random values between (-1, 1). | ||
""" | ||
device = torch.device("cuda:0") | ||
# one quarter are random values between -1 and 1 | ||
x_rand = 2 * torch.rand(batch_size // 4, dtype=torch.float32, device=device) - 1 | ||
x = [x_rand] | ||
for bound in [-1, 1]: | ||
for above_bound in [True, False]: | ||
for noise_std in [1e-4, 1e-2]: | ||
n_generate = (batch_size - batch_size // 4) // 8 | ||
x_add = ( | ||
bound | ||
+ (2 * float(above_bound) - 1) | ||
* torch.randn( | ||
n_generate, device=device, dtype=torch.float32 | ||
).abs() | ||
* noise_std | ||
) | ||
x.append(x_add) | ||
x = torch.cat(x) | ||
return x | ||
|
||
@staticmethod | ||
def acos_linear_extrapolation(batch_size: int): | ||
x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size) | ||
torch.cuda.synchronize() | ||
|
||
def compute_acos(): | ||
acos_linear_extrapolation(x) | ||
torch.cuda.synchronize() | ||
|
||
return compute_acos | ||
|
||
def _test_acos_outside_bounds(self, x, y, dydx, bound): | ||
""" | ||
Check that `acos_linear_extrapolation` yields points on a line with correct | ||
slope, and that the function is continuous around `bound`. | ||
""" | ||
bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype) | ||
# fit a line: slope * x + bias = y | ||
x_1 = torch.stack([x, torch.ones_like(x)], dim=-1) | ||
solution = torch.linalg.lstsq(x_1, y[:, None]).solution | ||
slope, bias = solution.view(-1)[:2] | ||
desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t ** 2) | ||
# test that the desired slope is the same as the fitted one | ||
self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2) | ||
# test that the autograd's slope is the same as the desired one | ||
self.assertClose(desired_slope.expand_as(dydx), dydx, atol=1e-2) | ||
# test that the value of the fitted line at x=bound equals | ||
# arccos(x), i.e. the function is continuous around the bound | ||
y_bound_lin = (slope * bound_t + bias).view(1) | ||
y_bound_acos = bound_t.acos().view(1) | ||
self.assertClose(y_bound_lin, y_bound_acos, atol=1e-2) | ||
|
||
def _one_acos_test(self, x: torch.Tensor, lower_bound: float, upper_bound: float): | ||
""" | ||
Test that `acos_linear_extrapolation` returns correct values for | ||
`x` between/above/below `lower_bound`/`upper_bound`. | ||
""" | ||
x.requires_grad = True | ||
x.grad = None | ||
y = acos_linear_extrapolation(x, [lower_bound, upper_bound]) | ||
# compute the gradient of the acos w.r.t. x | ||
y.backward(torch.ones_like(y)) | ||
dacos_dx = x.grad | ||
x_lower = x <= lower_bound | ||
x_upper = x >= upper_bound | ||
x_mid = (~x_lower) & (~x_upper) | ||
# test that between bounds, the function returns plain acos | ||
self.assertClose(x[x_mid].acos(), y[x_mid]) | ||
# test that outside the bounds, the function is linear with the right | ||
# slope and continuous around the bound | ||
self._test_acos_outside_bounds( | ||
x[x_upper], y[x_upper], dacos_dx[x_upper], upper_bound | ||
) | ||
self._test_acos_outside_bounds( | ||
x[x_lower], y[x_lower], dacos_dx[x_lower], lower_bound | ||
) | ||
if abs(upper_bound + lower_bound) <= 1e-5: # lower_bound==-upper_bound | ||
# check that passing bounds=upper_bound gives the same | ||
# resut as bounds=[lower_bound, upper_bound] | ||
y_one_bound = acos_linear_extrapolation(x, upper_bound) | ||
self.assertClose(y_one_bound, y) | ||
|
||
def test_acos(self, batch_size: int = 10000): | ||
""" | ||
Tests whether the function returns correct outputs | ||
inside/outside the bounds. | ||
""" | ||
x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size) | ||
bounds = 1 - 10.0 ** torch.linspace(-1, -5, 5) | ||
for lower_bound in -bounds: | ||
for upper_bound in bounds: | ||
if upper_bound < lower_bound: | ||
continue | ||
self._one_acos_test(x, float(lower_bound), float(upper_bound)) | ||
|
||
def test_finite_gradient(self, batch_size: int = 10000): | ||
""" | ||
Tests whether gradients stay finite close to the bounds. | ||
""" | ||
x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size) | ||
x.requires_grad = True | ||
bounds = 1 - 10.0 ** torch.linspace(-1, -5, 5) | ||
for lower_bound in -bounds: | ||
for upper_bound in bounds: | ||
if upper_bound < lower_bound: | ||
continue | ||
x.grad = None | ||
y = acos_linear_extrapolation( | ||
x, | ||
[float(lower_bound), float(upper_bound)], | ||
) | ||
self.assertTrue(torch.isfinite(y).all()) | ||
loss = y.mean() | ||
loss.backward() | ||
self.assertIsNotNone(x.grad) | ||
self.assertTrue(torch.isfinite(x.grad).all()) |