From 5910d81b7bec2e8d328a8d4e1435e1041c9921a7 Mon Sep 17 00:00:00 2001 From: Emilien Garreau Date: Thu, 6 Jul 2023 02:20:53 -0700 Subject: [PATCH] Add blurpool following MIPNerf paper. Summary: Add blurpool has defined in [MIP-NeRF](https://arxiv.org/abs/2103.13415). It has been added has an option for RayPointRefiner. Reviewed By: shapovalov Differential Revision: D46356189 fbshipit-source-id: ad841bad86d2b591a68e1cb885d4f781cf26c111 --- .../implicitron_trainer/tests/experiment.yaml | 4 ++ .../models/renderer/multipass_ea.py | 11 ++++ .../models/renderer/ray_point_refiner.py | 55 ++++++++++++++++++- tests/implicitron/test_ray_point_refiner.py | 32 ++++++++++- 4 files changed, 97 insertions(+), 5 deletions(-) diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index d9f1284b7..e0394f220 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -249,6 +249,8 @@ model_factory_ImplicitronModelFactory_args: append_coarse_samples_to_fine: true density_noise_std_train: 0.0 return_weights: false + blurpool_weights: false + sample_pdf_eps: 1.0e-05 raymarcher_CumsumRaymarcher_args: surface_thickness: 1 bg_color: @@ -679,6 +681,8 @@ model_factory_ImplicitronModelFactory_args: append_coarse_samples_to_fine: true density_noise_std_train: 0.0 return_weights: false + blurpool_weights: false + sample_pdf_eps: 1.0e-05 raymarcher_CumsumRaymarcher_args: surface_thickness: 1 bg_color: diff --git a/pytorch3d/implicitron/models/renderer/multipass_ea.py b/pytorch3d/implicitron/models/renderer/multipass_ea.py index 18ee8f5b6..03511c295 100644 --- a/pytorch3d/implicitron/models/renderer/multipass_ea.py +++ b/pytorch3d/implicitron/models/renderer/multipass_ea.py @@ -65,6 +65,9 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13 opacity field. return_weights: Enables returning the rendering weights of the EA raymarcher. Setting to `True` can lead to a prohibitivelly large memory consumption. + blurpool_weights: Use blurpool defined in [3], on the input weights of + each implicit_function except the first (implicit_functions[0]). + sample_pdf_eps: Padding applied to the weights (alpha in equation 18 of [3]). raymarcher_class_type: The type of self.raymarcher corresponding to a child of `RaymarcherBase` in the registry. raymarcher: The raymarcher object used to convert per-point features @@ -75,6 +78,8 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13 Fields for View Synthesis." ECCV 2020. [2] Lombardi, Stephen, et al. "Neural Volumes: Learning Dynamic Renderable Volumes from Images." SIGGRAPH 2019. + [3] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation + for Anti-Aliasing Neural Radiance Fields." ICCV 2021. """ @@ -88,6 +93,8 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13 append_coarse_samples_to_fine: bool = True density_noise_std_train: float = 0.0 return_weights: bool = False + blurpool_weights: bool = False + sample_pdf_eps: float = 1e-5 def __post_init__(self): self._refiners = { @@ -95,11 +102,15 @@ def __post_init__(self): n_pts_per_ray=self.n_pts_per_ray_fine_training, random_sampling=self.stratified_sampling_coarse_training, add_input_samples=self.append_coarse_samples_to_fine, + blurpool_weights=self.blurpool_weights, + sample_pdf_eps=self.sample_pdf_eps, ), EvaluationMode.EVALUATION: RayPointRefiner( n_pts_per_ray=self.n_pts_per_ray_fine_evaluation, random_sampling=self.stratified_sampling_coarse_evaluation, add_input_samples=self.append_coarse_samples_to_fine, + blurpool_weights=self.blurpool_weights, + sample_pdf_eps=self.sample_pdf_eps, ), } run_auto_creation(self) diff --git a/pytorch3d/implicitron/models/renderer/ray_point_refiner.py b/pytorch3d/implicitron/models/renderer/ray_point_refiner.py index 22db11f4b..ffdfe2a9e 100644 --- a/pytorch3d/implicitron/models/renderer/ray_point_refiner.py +++ b/pytorch3d/implicitron/models/renderer/ray_point_refiner.py @@ -32,16 +32,27 @@ class RayPointRefiner(Configurable, torch.nn.Module): sampling from that distribution. add_input_samples: Concatenates and returns the sampled values together with the input samples. + blurpool_weights: Use blurpool defined in [1], on the input weights. + sample_pdf_eps: A constant preventing division by zero in case empty bins + are present. + + References: + [1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation + for Anti-Aliasing Neural Radiance Fields." ICCV 2021. """ n_pts_per_ray: int random_sampling: bool add_input_samples: bool = True + blurpool_weights: bool = False + sample_pdf_eps: float = 1e-5 def forward( self, input_ray_bundle: ImplicitronRayBundle, ray_weights: torch.Tensor, + blurpool_weights: bool = False, + sample_pdf_padding: float = 1e-5, **kwargs, ) -> ImplicitronRayBundle: """ @@ -49,28 +60,38 @@ def forward( input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the source rays for sampling of the probability distribution. ray_weights: A tensor of shape - `(..., input_ray_bundle.legths.shape[-1])` with non-negative + `(..., input_ray_bundle.lengths.shape[-1])` with non-negative elements defining the probability distribution to sample ray points from. + blurpool_weights: Use blurpool defined in [1], on the input weights. + sample_pdf_padding: A constant preventing division by zero in case empty bins + are present. Returns: ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray points together with `n_pts_per_ray` additionally sampled points per ray. For each ray, the lengths are sorted. + + References: + [1] Jonathan T. Barron, et al. "Mip-NeRF: A Multiscale Representation + for Anti-Aliasing Neural Radiance Fields." ICCV 2021. + """ z_vals = input_ray_bundle.lengths with torch.no_grad(): + if self.blurpool_weights: + ray_weights = apply_blurpool_on_weights(ray_weights) + z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5) z_samples = sample_pdf( z_vals_mid.view(-1, z_vals_mid.shape[-1]), ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1], self.n_pts_per_ray, det=not self.random_sampling, + eps=self.sample_pdf_eps, ).view(*z_vals.shape[:-1], self.n_pts_per_ray) - if self.add_input_samples: - # Add the new samples to the input ones. z_vals = torch.cat((z_vals, z_samples), dim=-1) else: z_vals = z_samples @@ -80,3 +101,31 @@ def forward( new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle)) new_bundle.lengths = z_vals return new_bundle + + +def apply_blurpool_on_weights(weights) -> torch.Tensor: + """ + Filter weights with a 2-tap max filters followed by a 2-tap blur filter, + which produces a wide and smooth upper envelope on the weights. + + Args: + weights: Tensor of shape `(..., dim)` + + Returns: + blured_weights: Tensor of shape `(..., dim)` + """ + weights_pad = torch.concatenate( + [ + weights[..., :1], + weights, + weights[..., -1:], + ], + dim=-1, + ) + + weights_max = torch.nn.functional.max_pool1d( + weights_pad.flatten(end_dim=-2), 2, stride=1 + ) + return torch.lerp(weights_max[..., :-1], weights_max[..., 1:], 0.5).reshape_as( + weights + ) diff --git a/tests/implicitron/test_ray_point_refiner.py b/tests/implicitron/test_ray_point_refiner.py index 9373edc22..c4e7b2208 100644 --- a/tests/implicitron/test_ray_point_refiner.py +++ b/tests/implicitron/test_ray_point_refiner.py @@ -5,9 +5,14 @@ # LICENSE file in the root directory of this source tree. import unittest +from itertools import product import torch -from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner + +from pytorch3d.implicitron.models.renderer.ray_point_refiner import ( + apply_blurpool_on_weights, + RayPointRefiner, +) from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle from tests.common_testing import TestCaseMixin @@ -17,11 +22,12 @@ def test_simple(self): length = 15 n_pts_per_ray = 10 - for add_input_samples in [False, True]: + for add_input_samples, use_blurpool in product([False, True], [False, True]): ray_point_refiner = RayPointRefiner( n_pts_per_ray=n_pts_per_ray, random_sampling=False, add_input_samples=add_input_samples, + blurpool_weights=use_blurpool, ) lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length) bundle = ImplicitronRayBundle( @@ -50,6 +56,7 @@ def test_simple(self): n_pts_per_ray=n_pts_per_ray, random_sampling=True, add_input_samples=add_input_samples, + blurpool_weights=use_blurpool, ) refined_random = ray_point_refiner_random(bundle, weights) lengths_random = refined_random.lengths @@ -62,3 +69,24 @@ def test_simple(self): self.assertTrue( (lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all() ) + + def test_apply_blurpool_on_weights(self): + weights = torch.tensor( + [ + [0.5, 0.6, 0.7], + [0.5, 0.3, 0.9], + ] + ) + expected_weights = 0.5 * torch.tensor( + [ + [0.5 + 0.6, 0.6 + 0.7, 0.7 + 0.7], + [0.5 + 0.5, 0.5 + 0.9, 0.9 + 0.9], + ] + ) + out_weights = apply_blurpool_on_weights(weights) + self.assertTrue(torch.allclose(out_weights, expected_weights)) + + def test_shapes_apply_blurpool_on_weights(self): + weights = torch.randn((5, 4, 3, 2, 1)) + out_weights = apply_blurpool_on_weights(weights) + self.assertEqual((5, 4, 3, 2, 1), out_weights.shape)