Skip to content

Commit

Permalink
Add blurpool following MIPNerf paper.
Browse files Browse the repository at this point in the history
Summary:
Add blurpool has defined in [MIP-NeRF](https://arxiv.org/abs/2103.13415).
It has been added has an option for RayPointRefiner.

Reviewed By: shapovalov

Differential Revision: D46356189

fbshipit-source-id: ad841bad86d2b591a68e1cb885d4f781cf26c111
  • Loading branch information
EmGarr authored and facebook-github-bot committed Jul 6, 2023
1 parent ccf860f commit 5910d81
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 5 deletions.
4 changes: 4 additions & 0 deletions projects/implicitron_trainer/tests/experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ model_factory_ImplicitronModelFactory_args:
append_coarse_samples_to_fine: true
density_noise_std_train: 0.0
return_weights: false
blurpool_weights: false
sample_pdf_eps: 1.0e-05
raymarcher_CumsumRaymarcher_args:
surface_thickness: 1
bg_color:
Expand Down Expand Up @@ -679,6 +681,8 @@ model_factory_ImplicitronModelFactory_args:
append_coarse_samples_to_fine: true
density_noise_std_train: 0.0
return_weights: false
blurpool_weights: false
sample_pdf_eps: 1.0e-05
raymarcher_CumsumRaymarcher_args:
surface_thickness: 1
bg_color:
Expand Down
11 changes: 11 additions & 0 deletions pytorch3d/implicitron/models/renderer/multipass_ea.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
opacity field.
return_weights: Enables returning the rendering weights of the EA raymarcher.
Setting to `True` can lead to a prohibitivelly large memory consumption.
blurpool_weights: Use blurpool defined in [3], on the input weights of
each implicit_function except the first (implicit_functions[0]).
sample_pdf_eps: Padding applied to the weights (alpha in equation 18 of [3]).
raymarcher_class_type: The type of self.raymarcher corresponding to
a child of `RaymarcherBase` in the registry.
raymarcher: The raymarcher object used to convert per-point features
Expand All @@ -75,6 +78,8 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
Fields for View Synthesis." ECCV 2020.
[2] Lombardi, Stephen, et al. "Neural Volumes: Learning Dynamic Renderable
Volumes from Images." SIGGRAPH 2019.
[3] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
"""

Expand All @@ -88,18 +93,24 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
append_coarse_samples_to_fine: bool = True
density_noise_std_train: float = 0.0
return_weights: bool = False
blurpool_weights: bool = False
sample_pdf_eps: float = 1e-5

def __post_init__(self):
self._refiners = {
EvaluationMode.TRAINING: RayPointRefiner(
n_pts_per_ray=self.n_pts_per_ray_fine_training,
random_sampling=self.stratified_sampling_coarse_training,
add_input_samples=self.append_coarse_samples_to_fine,
blurpool_weights=self.blurpool_weights,
sample_pdf_eps=self.sample_pdf_eps,
),
EvaluationMode.EVALUATION: RayPointRefiner(
n_pts_per_ray=self.n_pts_per_ray_fine_evaluation,
random_sampling=self.stratified_sampling_coarse_evaluation,
add_input_samples=self.append_coarse_samples_to_fine,
blurpool_weights=self.blurpool_weights,
sample_pdf_eps=self.sample_pdf_eps,
),
}
run_auto_creation(self)
Expand Down
55 changes: 52 additions & 3 deletions pytorch3d/implicitron/models/renderer/ray_point_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,45 +32,66 @@ class RayPointRefiner(Configurable, torch.nn.Module):
sampling from that distribution.
add_input_samples: Concatenates and returns the sampled values
together with the input samples.
blurpool_weights: Use blurpool defined in [1], on the input weights.
sample_pdf_eps: A constant preventing division by zero in case empty bins
are present.
References:
[1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
"""

n_pts_per_ray: int
random_sampling: bool
add_input_samples: bool = True
blurpool_weights: bool = False
sample_pdf_eps: float = 1e-5

def forward(
self,
input_ray_bundle: ImplicitronRayBundle,
ray_weights: torch.Tensor,
blurpool_weights: bool = False,
sample_pdf_padding: float = 1e-5,
**kwargs,
) -> ImplicitronRayBundle:
"""
Args:
input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the
source rays for sampling of the probability distribution.
ray_weights: A tensor of shape
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
`(..., input_ray_bundle.lengths.shape[-1])` with non-negative
elements defining the probability distribution to sample
ray points from.
blurpool_weights: Use blurpool defined in [1], on the input weights.
sample_pdf_padding: A constant preventing division by zero in case empty bins
are present.
Returns:
ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray
points together with `n_pts_per_ray` additionally sampled
points per ray. For each ray, the lengths are sorted.
References:
[1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
"""

z_vals = input_ray_bundle.lengths
with torch.no_grad():
if self.blurpool_weights:
ray_weights = apply_blurpool_on_weights(ray_weights)

z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
z_samples = sample_pdf(
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
self.n_pts_per_ray,
det=not self.random_sampling,
eps=self.sample_pdf_eps,
).view(*z_vals.shape[:-1], self.n_pts_per_ray)

if self.add_input_samples:
# Add the new samples to the input ones.
z_vals = torch.cat((z_vals, z_samples), dim=-1)
else:
z_vals = z_samples
Expand All @@ -80,3 +101,31 @@ def forward(
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
new_bundle.lengths = z_vals
return new_bundle


def apply_blurpool_on_weights(weights) -> torch.Tensor:
"""
Filter weights with a 2-tap max filters followed by a 2-tap blur filter,
which produces a wide and smooth upper envelope on the weights.
Args:
weights: Tensor of shape `(..., dim)`
Returns:
blured_weights: Tensor of shape `(..., dim)`
"""
weights_pad = torch.concatenate(
[
weights[..., :1],
weights,
weights[..., -1:],
],
dim=-1,
)

weights_max = torch.nn.functional.max_pool1d(
weights_pad.flatten(end_dim=-2), 2, stride=1
)
return torch.lerp(weights_max[..., :-1], weights_max[..., 1:], 0.5).reshape_as(
weights
)
32 changes: 30 additions & 2 deletions tests/implicitron/test_ray_point_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
# LICENSE file in the root directory of this source tree.

import unittest
from itertools import product

import torch
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner

from pytorch3d.implicitron.models.renderer.ray_point_refiner import (
apply_blurpool_on_weights,
RayPointRefiner,
)
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
from tests.common_testing import TestCaseMixin

Expand All @@ -17,11 +22,12 @@ def test_simple(self):
length = 15
n_pts_per_ray = 10

for add_input_samples in [False, True]:
for add_input_samples, use_blurpool in product([False, True], [False, True]):
ray_point_refiner = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=False,
add_input_samples=add_input_samples,
blurpool_weights=use_blurpool,
)
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
bundle = ImplicitronRayBundle(
Expand Down Expand Up @@ -50,6 +56,7 @@ def test_simple(self):
n_pts_per_ray=n_pts_per_ray,
random_sampling=True,
add_input_samples=add_input_samples,
blurpool_weights=use_blurpool,
)
refined_random = ray_point_refiner_random(bundle, weights)
lengths_random = refined_random.lengths
Expand All @@ -62,3 +69,24 @@ def test_simple(self):
self.assertTrue(
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)

def test_apply_blurpool_on_weights(self):
weights = torch.tensor(
[
[0.5, 0.6, 0.7],
[0.5, 0.3, 0.9],
]
)
expected_weights = 0.5 * torch.tensor(
[
[0.5 + 0.6, 0.6 + 0.7, 0.7 + 0.7],
[0.5 + 0.5, 0.5 + 0.9, 0.9 + 0.9],
]
)
out_weights = apply_blurpool_on_weights(weights)
self.assertTrue(torch.allclose(out_weights, expected_weights))

def test_shapes_apply_blurpool_on_weights(self):
weights = torch.randn((5, 4, 3, 2, 1))
out_weights = apply_blurpool_on_weights(weights)
self.assertEqual((5, 4, 3, 2, 1), out_weights.shape)

0 comments on commit 5910d81

Please sign in to comment.