-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Copy the sample_pdf operation from the NeRF project in to PyTorch3D, in preparation for optimizing it. Reviewed By: gkioxari Differential Revision: D27117930 fbshipit-source-id: 20286b007f589a4c4d53ed818c4bc5f2abd22833
- Loading branch information
1 parent
b481cfb
commit 7d7d00f
Showing
3 changed files
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import torch | ||
|
||
|
||
def sample_pdf_python( | ||
bins: torch.Tensor, | ||
weights: torch.Tensor, | ||
N_samples: int, | ||
det: bool = False, | ||
eps: float = 1e-5, | ||
) -> torch.Tensor: | ||
""" | ||
Samples probability density functions defined by bin edges `bins` and | ||
the non-negative per-bin probabilities `weights`. | ||
Note: This is a direct conversion of the TensorFlow function from the original | ||
release [1] to PyTorch. | ||
Args: | ||
bins: Tensor of shape `(..., n_bins+1)` denoting the edges of the sampling bins. | ||
weights: Tensor of shape `(..., n_bins)` containing non-negative numbers | ||
representing the probability of sampling the corresponding bin. | ||
N_samples: The number of samples to draw from each set of bins. | ||
det: If `False`, the sampling is random. `True` yields deterministic | ||
uniformly-spaced sampling from the inverse cumulative density function. | ||
eps: A constant preventing division by zero in case empty bins are present. | ||
Returns: | ||
samples: Tensor of shape `(..., N_samples)` containing `N_samples` samples | ||
drawn from each probability distribution. | ||
Refs: | ||
[1] https://github.com/bmild/nerf/blob/55d8b00244d7b5178f4d003526ab6667683c9da9/run_nerf_helpers.py#L183 # noqa E501 | ||
""" | ||
|
||
# Get pdf | ||
weights = weights + eps # prevent nans | ||
if weights.min() <= 0: | ||
raise ValueError("Negative weights provided.") | ||
pdf = weights / weights.sum(dim=-1, keepdim=True) | ||
cdf = torch.cumsum(pdf, -1) | ||
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) | ||
|
||
# Take uniform samples u of shape (..., N_samples) | ||
if det: | ||
u = torch.linspace(0.0, 1.0, N_samples, device=cdf.device, dtype=cdf.dtype) | ||
u = u.expand(list(cdf.shape[:-1]) + [N_samples]).contiguous() | ||
else: | ||
u = torch.rand( | ||
list(cdf.shape[:-1]) + [N_samples], device=cdf.device, dtype=cdf.dtype | ||
) | ||
|
||
# Invert CDF | ||
inds = torch.searchsorted(cdf, u, right=True) | ||
# inds has shape (..., N_samples) identifying the bin of each sample. | ||
below = (inds - 1).clamp(0) | ||
above = inds.clamp(max=cdf.shape[-1] - 1) | ||
# Below and above are of shape (..., N_samples), identifying the bin | ||
# edges surrounding each sample. | ||
|
||
inds_g = torch.stack([below, above], -1).view( | ||
*below.shape[:-1], below.shape[-1] * 2 | ||
) | ||
cdf_g = torch.gather(cdf, -1, inds_g).view(*below.shape, 2) | ||
bins_g = torch.gather(bins, -1, inds_g).view(*below.shape, 2) | ||
# cdf_g and bins_g are of shape (..., N_samples, 2) and identify | ||
# the cdf and the index of the two bin edges surrounding each sample. | ||
|
||
denom = cdf_g[..., 1] - cdf_g[..., 0] | ||
denom = torch.where(denom < eps, torch.ones_like(denom), denom) | ||
t = (u - cdf_g[..., 0]) / denom | ||
# t is of shape (..., N_samples) and identifies how far through | ||
# each sample is in its bin. | ||
|
||
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) | ||
|
||
return samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from itertools import product | ||
|
||
from fvcore.common.benchmark import benchmark | ||
from test_sample_pdf import TestSamplePDF | ||
|
||
|
||
def bm_sample_pdf() -> None: | ||
|
||
backends = ["python_cuda", "python_cpu"] | ||
|
||
kwargs_list = [] | ||
sample_counts = [64] | ||
batch_sizes = [1024, 10240] | ||
bin_counts = [62, 600] | ||
test_cases = product(backends, sample_counts, batch_sizes, bin_counts) | ||
for case in test_cases: | ||
backend, n_samples, batch_size, n_bins = case | ||
kwargs_list.append( | ||
{ | ||
"backend": backend, | ||
"n_samples": n_samples, | ||
"batch_size": batch_size, | ||
"n_bins": n_bins, | ||
} | ||
) | ||
|
||
benchmark(TestSamplePDF.bm_fn, "SAMPLE_PDF", kwargs_list, warmup_iters=1) | ||
|
||
|
||
if __name__ == "__main__": | ||
bm_sample_pdf() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
from common_testing import TestCaseMixin | ||
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf_python | ||
|
||
|
||
class TestSamplePDF(TestCaseMixin, unittest.TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
torch.manual_seed(1) | ||
|
||
def test_single_bin(self): | ||
bins = torch.arange(2).expand(5, 2) + 17 | ||
weights = torch.ones(5, 1) | ||
output = sample_pdf_python(bins, weights, 100, True) | ||
calc = torch.linspace(17, 18, 100).expand(5, -1) | ||
self.assertClose(output, calc) | ||
|
||
@staticmethod | ||
def bm_fn(*, backend: str, n_samples, batch_size, n_bins): | ||
f = sample_pdf_python | ||
weights = torch.rand(size=(batch_size, n_bins)) | ||
bins = torch.cumsum(torch.rand(size=(batch_size, n_bins + 1)), dim=-1) | ||
|
||
if "cuda" in backend: | ||
weights = weights.cuda() | ||
bins = bins.cuda() | ||
|
||
torch.cuda.synchronize() | ||
|
||
def output(): | ||
f(bins, weights, n_samples) | ||
torch.cuda.synchronize() | ||
|
||
return output |