From 3eb4233844a8a3c6441e91ebe22a4354da8f5fae Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Mon, 24 Jan 2022 10:51:03 -0800 Subject: [PATCH] New raysamplers Summary: New MultinomialRaysampler succeeds GridRaysampler bringing masking and subsampling. Correspondingly, NDCMultinomialRaysampler succeeds NDCGridRaysampler. Reviewed By: nikhilaravi, shapovalov Differential Revision: D33256897 fbshipit-source-id: cd80ec6f35b110d1d20a75c62f4e889ba8fa5d45 --- projects/nerf/nerf/implicit_function.py | 2 +- pytorch3d/renderer/__init__.py | 2 + pytorch3d/renderer/implicit/__init__.py | 9 +- pytorch3d/renderer/implicit/raysampling.py | 341 ++++++++++++++++++--- tests/benchmarks/bm_raysampling.py | 10 +- tests/common_testing.py | 11 + tests/test_raysampling.py | 96 ++++-- 7 files changed, 411 insertions(+), 60 deletions(-) diff --git a/projects/nerf/nerf/implicit_function.py b/projects/nerf/nerf/implicit_function.py index d84986ff5..7a1ad60f1 100644 --- a/projects/nerf/nerf/implicit_function.py +++ b/projects/nerf/nerf/implicit_function.py @@ -7,7 +7,7 @@ from typing import Tuple import torch -from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points, HarmonicEmbedding +from pytorch3d.renderer import HarmonicEmbedding, RayBundle, ray_bundle_to_ray_points from .linear_with_repeat import LinearWithRepeat diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index 0b3c7d444..ef82733d2 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -32,7 +32,9 @@ HarmonicEmbedding, ImplicitRenderer, MonteCarloRaysampler, + MultinomialRaysampler, NDCGridRaysampler, + NDCMultinomialRaysampler, RayBundle, VolumeRenderer, VolumeSampler, diff --git a/pytorch3d/renderer/implicit/__init__.py b/pytorch3d/renderer/implicit/__init__.py index dd9f5827b..f3ec515e4 100644 --- a/pytorch3d/renderer/implicit/__init__.py +++ b/pytorch3d/renderer/implicit/__init__.py @@ -6,7 +6,13 @@ from .harmonic_embedding import HarmonicEmbedding from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher -from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler +from .raysampling import ( + GridRaysampler, + MonteCarloRaysampler, + MultinomialRaysampler, + NDCGridRaysampler, + NDCMultinomialRaysampler, +) from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler from .utils import ( RayBundle, @@ -14,4 +20,5 @@ ray_bundle_variables_to_ray_points, ) + __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index ef7591671..c4e4139d2 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -4,22 +4,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import torch +import warnings +from typing import Optional -from ..cameras import CamerasBase -from .utils import RayBundle +import torch +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.renderer.implicit.utils import RayBundle +from torch.nn import functional as F """ This file defines three raysampling techniques: - - GridRaysampler which can be used to sample rays from pixels of an image grid - - NDCGridRaysampler which can be used to sample rays from pixels of an image grid, + - MultinomialRaysampler which can be used to sample rays from pixels of an image grid + - NDCMultinomialRaysampler which can be used to sample rays from pixels of an image grid, which follows the pytorch3d convention for image grid coordinates - - MonteCarloRaysampler which randomly selects image pixels and emits rays from them + - MonteCarloRaysampler which randomly selects real-valued locations in the image plane + and emits rays from them """ -class GridRaysampler(torch.nn.Module): +class MultinomialRaysampler(torch.nn.Module): """ Samples a fixed number of points along rays which are regularly distributed in a batch of rectangular image grids. Points along each ray @@ -44,19 +48,20 @@ class GridRaysampler(torch.nn.Module): < --- image_width --- > ``` - In order to generate ray points, `GridRaysampler` takes each 3D point of + In order to generate ray points, `MultinomialRaysampler` takes each 3D point of the grid (with coordinates `[x, y, depth]`) and unprojects it with `cameras.unproject_points([x, y, depth])`, where `cameras` are an additional input to the `forward` function. Note that this is a generic implementation that can support any image grid coordinate convention. For a raysampler which follows the PyTorch3D - coordinate conventions please refer to `NDCGridRaysampler`. - As such, `NDCGridRaysampler` is a special case of `GridRaysampler`. + coordinate conventions please refer to `NDCMultinomialRaysampler`. + As such, `NDCMultinomialRaysampler` is a special case of `MultinomialRaysampler`. """ def __init__( self, + *, min_x: float, max_x: float, min_y: float, @@ -66,6 +71,9 @@ def __init__( n_pts_per_ray: int, min_depth: float, max_depth: float, + n_rays_per_image: Optional[int] = None, + unit_directions: bool = False, + stratified_sampling: bool = False, ) -> None: """ Args: @@ -78,11 +86,18 @@ def __init__( n_pts_per_ray: The number of points sampled along each ray. min_depth: The minimum depth of a ray-point. max_depth: The maximum depth of a ray-point. + n_rays_per_image: If given, this amount of rays are sampled from the grid. + unit_directions: whether to normalize direction vectors in ray bundle. + stratified_sampling: if set, performs stratified random sampling + along the ray; otherwise takes ray points at deterministic offsets. """ super().__init__() self._n_pts_per_ray = n_pts_per_ray self._min_depth = min_depth self._max_depth = max_depth + self._n_rays_per_image = n_rays_per_image + self._unit_directions = unit_directions + self._stratified_sampling = stratified_sampling # get the initial grid of image xy coords _xy_grid = torch.stack( @@ -96,69 +111,127 @@ def __init__( ), dim=-1, ) + self.register_buffer("_xy_grid", _xy_grid, persistent=False) - def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle: + def forward( + self, + cameras: CamerasBase, + *, + mask: Optional[torch.Tensor] = None, + min_depth: Optional[float] = None, + max_depth: Optional[float] = None, + n_rays_per_image: Optional[int] = None, + n_pts_per_ray: Optional[int] = None, + stratified_sampling: bool = False, + **kwargs, + ) -> RayBundle: """ Args: cameras: A batch of `batch_size` cameras from which the rays are emitted. + mask: if given, the rays are sampled from the mask. Should be of size + (batch_size, image_height, image_width). + min_depth: The minimum depth of a ray-point. + max_depth: The maximum depth of a ray-point. + n_rays_per_image: If given, this amount of rays are sampled from the grid. + n_pts_per_ray: The number of points sampled along each ray. + stratified_sampling: if set, performs stratified sampling in n_pts_per_ray + bins for each ray; otherwise takes n_pts_per_ray deterministic points + on each ray with uniform offsets. Returns: A named tuple RayBundle with the following fields: origins: A tensor of shape - `(batch_size, image_height, image_width, 3)` + `(batch_size, s1, s2, 3)` denoting the locations of ray origins in the world coordinates. directions: A tensor of shape - `(batch_size, image_height, image_width, 3)` + `(batch_size, s1, s2, 3)` denoting the directions of each ray in the world coordinates. lengths: A tensor of shape - `(batch_size, image_height, image_width, n_pts_per_ray)` + `(batch_size, s1, s2, n_pts_per_ray)` containing the z-coordinate (=depth) of each ray in world units. xys: A tensor of shape - `(batch_size, image_height, image_width, 2)` - containing the 2D image coordinates of each ray. + `(batch_size, s1, s2, 2)` + containing the 2D image coordinates of each ray or, + if mask is given, `(batch_size, n, 1, 2)` + Here `s1, s2` refer to spatial dimensions. Unless the mask is + given, they equal `(image_height, image_width)`, otherwise `(n, 1)`, + where `n` is `n_rays_per_image` if provided, otherwise the minimum + cardinality of the mask in the batch. """ - batch_size = cameras.R.shape[0] - device = cameras.device # expand the (H, W, 2) grid batch_size-times to (B, H, W, 2) - xy_grid = self._xy_grid.to(device)[None].expand( - batch_size, *self._xy_grid.shape + xy_grid = self._xy_grid.to(device).expand(batch_size, -1, -1, -1) + + num_rays = n_rays_per_image or self._n_rays_per_image + if mask is not None and num_rays is None: + # if num rays not given, sample according to the smallest mask + num_rays = num_rays or mask.sum(dim=(1, 2)).min().int().item() + + if num_rays is not None: + if mask is not None: + assert mask.shape == xy_grid.shape[:3] + weights = mask.reshape(batch_size, -1) + else: + # it is probably more efficient to use torch.randperm + # for uniform weights but it is unlikely given that randperm + # is not batched and does not support partial permutation + _, width, height, _ = xy_grid.shape + weights = xy_grid.new_ones(batch_size, width * height) + rays_idx = _safe_multinomial(weights, num_rays)[..., None].expand(-1, -1, 2) + + xy_grid = torch.gather(xy_grid.reshape(batch_size, -1, 2), 1, rays_idx)[ + :, :, None + ] + + min_depth = min_depth if min_depth is not None else self._min_depth + max_depth = max_depth if max_depth is not None else self._max_depth + n_pts_per_ray = ( + n_pts_per_ray if n_pts_per_ray is not None else self._n_pts_per_ray + ) + stratified_sampling = ( + stratified_sampling + if stratified_sampling is not None + else self._stratified_sampling ) return _xy_to_ray_bundle( - cameras, xy_grid, self._min_depth, self._max_depth, self._n_pts_per_ray + cameras, + xy_grid, + min_depth, + max_depth, + n_pts_per_ray, + self._unit_directions, + stratified_sampling, ) -class NDCGridRaysampler(GridRaysampler): +class NDCMultinomialRaysampler(MultinomialRaysampler): """ Samples a fixed number of points along rays which are regularly distributed in a batch of rectangular image grids. Points along each ray have uniformly-spaced z-coordinates between a predefined minimum and maximum depth. - `NDCGridRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds` + `NDCMultinomialRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds` renderers. I.e. the pixel coordinates are in [-1, 1]x[-u, u] or [-u, u]x[-1, 1] where u > 1 is the aspect ratio of the image. + + For the description of arguments, see the documentation to MultinomialRaysampler. """ def __init__( self, + *, image_width: int, image_height: int, n_pts_per_ray: int, min_depth: float, max_depth: float, + n_rays_per_image: Optional[int] = None, + unit_directions: bool = False, + stratified_sampling: bool = False, ) -> None: - """ - Args: - image_width: The horizontal size of the image grid. - image_height: The vertical size of the image grid. - n_pts_per_ray: The number of points sampled along each ray. - min_depth: The minimum depth of a ray-point. - max_depth: The maximum depth of a ray-point. - """ if image_width >= image_height: range_x = image_width / image_height range_y = 1.0 @@ -178,6 +251,9 @@ def __init__( n_pts_per_ray=n_pts_per_ray, min_depth=min_depth, max_depth=max_depth, + n_rays_per_image=n_rays_per_image, + unit_directions=unit_directions, + stratified_sampling=stratified_sampling, ) @@ -187,6 +263,9 @@ class MonteCarloRaysampler(torch.nn.Module): For each pixel, a fixed number of points is sampled along its ray at uniformly-spaced z-coordinates such that the z-coordinates range between a predefined minimum and maximum depth. + + For practical purposes, this is similar to MultinomialRaysampler without a mask, + however sampling at real-valued locations bypassing replacement checks may be faster. """ def __init__( @@ -199,6 +278,9 @@ def __init__( n_pts_per_ray: int, min_depth: float, max_depth: float, + *, + unit_directions: bool = False, + stratified_sampling: bool = False, ) -> None: """ Args: @@ -210,6 +292,10 @@ def __init__( n_pts_per_ray: The number of points sampled along each ray. min_depth: The minimum depth of each ray-point. max_depth: The maximum depth of each ray-point. + unit_directions: whether to normalize direction vectors in ray bundle. + stratified_sampling: if set, performs stratified sampling in n_pts_per_ray + bins for each ray; otherwise takes n_pts_per_ray deterministic points + on each ray with uniform offsets. """ super().__init__() self._min_x = min_x @@ -220,11 +306,18 @@ def __init__( self._n_pts_per_ray = n_pts_per_ray self._min_depth = min_depth self._max_depth = max_depth + self._unit_directions = unit_directions + self._stratified_sampling = stratified_sampling - def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle: + def forward( + self, cameras: CamerasBase, *, stratified_sampling: bool = False, **kwargs + ) -> RayBundle: """ Args: cameras: A batch of `batch_size` cameras from which the rays are emitted. + stratified_sampling: if set, performs stratified sampling in n_pts_per_ray + bins for each ray; otherwise takes n_pts_per_ray deterministic points + on each ray with uniform offsets. Returns: A named tuple RayBundle with the following fields: origins: A tensor of shape @@ -264,17 +357,141 @@ def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle: dim=2, ) + stratified_sampling = ( + stratified_sampling + if stratified_sampling is not None + else self._stratified_sampling + ) + return _xy_to_ray_bundle( - cameras, rays_xy, self._min_depth, self._max_depth, self._n_pts_per_ray + cameras, + rays_xy, + self._min_depth, + self._max_depth, + self._n_pts_per_ray, + self._unit_directions, + stratified_sampling, ) +# Settings for backwards compatibility +def GridRaysampler( + min_x: float, + max_x: float, + min_y: float, + max_y: float, + image_width: int, + image_height: int, + n_pts_per_ray: int, + min_depth: float, + max_depth: float, +) -> "MultinomialRaysampler": + """ + GridRaysampler has been DEPRECATED. Use MultinomialRaysampler instead. + Preserving GridRaysampler for backward compatibility. + """ + + warnings.warn( + """GridRaysampler is deprecated, + Use MultinomialRaysampler instead. + GridRaysampler will be removed in future releases.""", + PendingDeprecationWarning, + ) + + return MultinomialRaysampler( + min_x=min_x, + max_x=max_x, + min_y=min_y, + max_y=max_y, + image_width=image_width, + image_height=image_height, + n_pts_per_ray=n_pts_per_ray, + min_depth=min_depth, + max_depth=max_depth, + ) + + +# Settings for backwards compatibility +def NDCGridRaysampler( + image_width: int, + image_height: int, + n_pts_per_ray: int, + min_depth: float, + max_depth: float, +) -> "NDCMultinomialRaysampler": + """ + NDCGridRaysampler has been DEPRECATED. Use NDCMultinomialRaysampler instead. + Preserving NDCGridRaysampler for backward compatibility. + """ + + warnings.warn( + """NDCGridRaysampler is deprecated, + Use NDCMultinomialRaysampler instead. + NDCGridRaysampler will be removed in future releases.""", + PendingDeprecationWarning, + ) + + return NDCMultinomialRaysampler( + image_width=image_width, + image_height=image_height, + n_pts_per_ray=n_pts_per_ray, + min_depth=min_depth, + max_depth=max_depth, + ) + + +def _safe_multinomial(input: torch.Tensor, num_samples: int) -> torch.Tensor: + """ + Wrapper around torch.multinomial that attempts sampling without replacement + when possible, otherwise resorts to sampling with replacement. + + Args: + input: tensor of shape [B, n] containing non-negative values; + rows are interpreted as unnormalized event probabilities + in categorical distributions. + num_samples: number of samples to take. + + Returns: + LongTensor of shape [B, num_samples] containing + values from {0, ..., n - 1} where the elements [i, :] of row i make + (1) if there are num_samples or more non-zero values in input[i], + a random subset of the indices of those values, with + probabilities proportional to the values in input[i, :]. + + (2) if not, a random sample with replacement of the indices of + those values, with probabilities proportional to them. + This sample might not contain all the indices of the + non-zero values. + Behavior undetermined if there are no non-zero values in a whole row + or if there are negative values. + """ + try: + res = torch.multinomial(input, num_samples, replacement=False) + except RuntimeError: + # this is probably rare, so we don't mind sampling twice + res = torch.multinomial(input, num_samples, replacement=True) + no_repl = (input > 0.0).sum(dim=-1) >= num_samples + res[no_repl] = torch.multinomial(input[no_repl], num_samples, replacement=False) + return res + + # in some versions of Pytorch, zero probabilty samples can be drawn without an error + # due to this bug: https://github.com/pytorch/pytorch/issues/50034. Handle this case: + repl = (input > 0.0).sum(dim=-1) < num_samples + # pyre-fixme[16]: Undefined attribute `torch.ByteTensor` has no attribute `any`. + if repl.any(): + res[repl] = torch.multinomial(input[repl], num_samples, replacement=True) + + return res + + def _xy_to_ray_bundle( cameras: CamerasBase, xy_grid: torch.Tensor, min_depth: float, max_depth: float, n_pts_per_ray: int, + unit_directions: bool, + stratified_sampling: bool = False, ) -> RayBundle: """ Extends the `xy_grid` input of shape `(batch_size, ..., 2)` to rays. @@ -283,16 +500,36 @@ def _xy_to_ray_bundle( The extended grid is then unprojected with `cameras` to yield ray origins, directions and depths. + + Args: + cameras: cameras object representing a batch of cameras. + xy_grid: torch.tensor grid of image xy coords. + min_depth: The minimum depth of each ray-point. + max_depth: The maximum depth of each ray-point. + n_pts_per_ray: The number of points sampled along each ray. + unit_directions: whether to normalize direction vectors in ray bundle. + stratified_sampling: if set, performs stratified sampling in n_pts_per_ray + bins for each ray; otherwise takes n_pts_per_ray deterministic points + on each ray with uniform offsets. """ batch_size = xy_grid.shape[0] spatial_size = xy_grid.shape[1:-1] n_rays_per_image = spatial_size.numel() # pyre-ignore # ray z-coords - depths = torch.linspace( - min_depth, max_depth, n_pts_per_ray, dtype=xy_grid.dtype, device=xy_grid.device - ) - rays_zs = depths[None, None].expand(batch_size, n_rays_per_image, n_pts_per_ray) + rays_zs = xy_grid.new_empty((0,)) + if n_pts_per_ray > 0: + depths = torch.linspace( + min_depth, + max_depth, + n_pts_per_ray, + dtype=xy_grid.dtype, + device=xy_grid.device, + ) + rays_zs = depths[None, None].expand(batch_size, n_rays_per_image, n_pts_per_ray) + + if stratified_sampling: + rays_zs = _jiggle_within_stratas(rays_zs) # make two sets of points at a constant depth=1 and 2 to_unproject = torch.cat( @@ -320,6 +557,8 @@ def _xy_to_ray_bundle( # directions are the differences between the two planes of points rays_directions_world = rays_plane_2_world - rays_plane_1_world + if unit_directions: + rays_directions_world = F.normalize(rays_directions_world, dim=-1) # origins are given by subtracting the ray directions from the first plane rays_origins_world = rays_plane_1_world - rays_directions_world @@ -330,3 +569,31 @@ def _xy_to_ray_bundle( rays_zs.view(batch_size, *spatial_size, n_pts_per_ray), xy_grid, ) + + +def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor: + """ + Performs sampling of 1 point per bin given the bin centers. + + More specifically, it replaces each point's value `z` + with a sample from a uniform random distribution on + `[z - delta_−, z + delta_+]`, where `delta_−` is half of the difference + between `z` and the previous point, and `delta_+` is half of the difference + between the next point and `z`. For the first and last items, the + corresponding boundary deltas are assumed zero. + + Args: + `bin_centers`: The input points of size (..., N); the result is broadcast + along all but the last dimension (the rows). Each row should be + sorted in ascending order. + + Returns: + a tensor of size (..., N) with the locations jiggled within stratas/bins. + """ + # Get intervals between bin centers. + mids = 0.5 * (bin_centers[..., 1:] + bin_centers[..., :-1]) + upper = torch.cat((mids, bin_centers[..., -1:]), dim=-1) + lower = torch.cat((bin_centers[..., :1], mids), dim=-1) + # Samples in those intervals. + jiggled = lower + (upper - lower) * torch.rand_like(lower) + return jiggled diff --git a/tests/benchmarks/bm_raysampling.py b/tests/benchmarks/bm_raysampling.py index 180f6d0e9..99df8150f 100644 --- a/tests/benchmarks/bm_raysampling.py +++ b/tests/benchmarks/bm_raysampling.py @@ -10,9 +10,9 @@ from pytorch3d.renderer import ( FoVOrthographicCameras, FoVPerspectiveCameras, - GridRaysampler, MonteCarloRaysampler, - NDCGridRaysampler, + MultinomialRaysampler, + NDCMultinomialRaysampler, OrthographicCameras, PerspectiveCameras, ) @@ -21,7 +21,11 @@ def bm_raysampling() -> None: case_grid = { - "raysampler_type": [GridRaysampler, NDCGridRaysampler, MonteCarloRaysampler], + "raysampler_type": [ + MultinomialRaysampler, + NDCMultinomialRaysampler, + MonteCarloRaysampler, + ], "camera_type": [ PerspectiveCameras, OrthographicCameras, diff --git a/tests/common_testing.py b/tests/common_testing.py index c4f364bee..0e62795a9 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -6,6 +6,7 @@ import os import unittest +from numbers import Real from pathlib import Path from typing import Callable, Optional, Union @@ -190,3 +191,13 @@ def assertClose( if msg is not None: self.fail(f"{msg} {err}") self.fail(err) + + def assertConstant(self, input: TensorOrArray, value: Real) -> None: + """ + Asserts input is entirely filled with value. + + Args: + input: tensor or array + """ + self.assertEqual(input.min(), value) + self.assertEqual(input.max(), value) diff --git a/tests/test_raysampling.py b/tests/test_raysampling.py index a280bb4ac..cb99ed314 100644 --- a/tests/test_raysampling.py +++ b/tests/test_raysampling.py @@ -5,17 +5,27 @@ # LICENSE file in the root directory of this source tree. import unittest +from typing import Callable import torch from common_testing import TestCaseMixin from pytorch3d.ops import eyes -from pytorch3d.renderer import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler +from pytorch3d.renderer import ( + MonteCarloRaysampler, + MultinomialRaysampler, + NDCGridRaysampler, + NDCMultinomialRaysampler, +) from pytorch3d.renderer.cameras import ( FoVOrthographicCameras, FoVPerspectiveCameras, OrthographicCameras, PerspectiveCameras, ) +from pytorch3d.renderer.implicit.raysampling import ( + _jiggle_within_stratas, + _safe_multinomial, +) from pytorch3d.renderer.implicit.utils import ( ray_bundle_to_ray_points, ray_bundle_variables_to_ray_points, @@ -93,14 +103,16 @@ def setUp(self) -> None: @staticmethod def raysampler( - raysampler_type=GridRaysampler, - camera_type=PerspectiveCameras, - n_pts_per_ray=10, - batch_size=1, - image_width=10, - image_height=20, - ): - + raysampler_type, + camera_type, + n_pts_per_ray: int, + batch_size: int, + image_width: int, + image_height: int, + ) -> Callable[[], None]: + """ + Used for benchmarks. + """ device = torch.device("cuda") # init raysamplers @@ -120,7 +132,7 @@ def raysampler( # init a batch of random cameras cameras = init_random_cameras(camera_type, batch_size, random_z=True).to(device) - def run_raysampler(): + def run_raysampler() -> None: raysampler(cameras=cameras) torch.cuda.synchronize() @@ -128,7 +140,7 @@ def run_raysampler(): @staticmethod def init_raysampler( - raysampler_type=GridRaysampler, + raysampler_type, min_x=-1.0, max_x=1.0, min_y=-1.0, @@ -149,7 +161,7 @@ def init_raysampler( "max_depth": max_depth, } - if issubclass(raysampler_type, GridRaysampler): + if issubclass(raysampler_type, MultinomialRaysampler): raysampler_params.update( {"image_width": image_width, "image_height": image_height} ) @@ -158,7 +170,7 @@ def init_raysampler( else: raise ValueError(str(raysampler_type)) - if issubclass(raysampler_type, NDCGridRaysampler): + if issubclass(raysampler_type, NDCMultinomialRaysampler): # NDCGridRaysampler does not use min/max_x/y for k in ("min_x", "max_x", "min_y", "max_y"): del raysampler_params[k] @@ -191,8 +203,8 @@ def test_raysamplers( for raysampler_type in ( MonteCarloRaysampler, - GridRaysampler, - NDCGridRaysampler, + MultinomialRaysampler, + NDCMultinomialRaysampler, ): raysampler = TestRaysampling.init_raysampler( @@ -208,7 +220,7 @@ def test_raysamplers( n_pts_per_ray=n_pts_per_ray, ) - if issubclass(raysampler_type, NDCGridRaysampler): + if issubclass(raysampler_type, NDCMultinomialRaysampler): # adjust the gt bounds for NDCGridRaysampler if image_width >= image_height: range_x = image_width / image_height @@ -297,7 +309,7 @@ def _check_raysampler_output_shapes( Checks the shapes of raysampler outputs. """ - if isinstance(raysampler, GridRaysampler): + if isinstance(raysampler, MultinomialRaysampler): spatial_size = [image_height, image_width] elif isinstance(raysampler, MonteCarloRaysampler): spatial_size = [image_height * image_width] @@ -386,7 +398,7 @@ def _check_raysampler_ray_points( # check that projected world points' xy coordinates # range correctly between [minx/y, max/y] - if isinstance(raysampler, GridRaysampler): + if isinstance(raysampler, MultinomialRaysampler): # get the expected coordinates along each grid axis ys, xs = [ torch.linspace( @@ -518,3 +530,51 @@ def test_load_state_different_resolution(self): ) state = module1.state_dict() module2.load_state_dict(state) + + def test_jiggle(self): + # random data which is in ascending order along the last dimension + scale = 180 + data = scale * torch.cumsum(torch.rand(8, 3, 4, 20), dim=-1) + + out = _jiggle_within_stratas(data) + self.assertTupleEqual(out.shape, data.shape) + + # Check `out` is in ascending order + self.assertGreater(torch.diff(out, dim=-1).min(), 0) + + self.assertConstant(out[..., :-1] < data[..., 1:], True) + self.assertConstant(data[..., :-1] < out[..., 1:], True) + + jiggles = out - data + # jiggles is random between -scale/2 and scale/2 + self.assertLess(jiggles.min(), -0.4 * scale) + self.assertGreater(jiggles.min(), -0.5 * scale) + self.assertGreater(jiggles.max(), 0.4 * scale) + self.assertLess(jiggles.max(), 0.5 * scale) + + def test_safe_multinomial(self): + mask = [ + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 1, 0], + ] + tmask = torch.tensor(mask, dtype=torch.float32) + + for _ in range(5): + random_scalar = torch.rand(1) + samples = _safe_multinomial(tmask * random_scalar, 3) + self.assertTupleEqual(samples.shape, (4, 3)) + + # samples[0] is exactly determined + self.assertConstant(samples[0], 0) + + self.assertGreaterEqual(samples[1].min(), 0) + self.assertLessEqual(samples[1].max(), 1) + + # samples[2] is exactly determined + self.assertSetEqual(set(samples[2].tolist()), {0, 1, 2}) + + # samples[3] has enough sources, so must contain 3 distinct values. + self.assertLessEqual(samples[3].max(), 3) + self.assertEqual(len(set(samples[3].tolist())), 3)