From aa9bcaf04c520b5fd0aa8d6d807b8090ea43d61c Mon Sep 17 00:00:00 2001 From: David Novotny Date: Tue, 5 Jan 2021 03:37:38 -0800 Subject: [PATCH] Point clouds to volumes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Conversion from point clouds to volumes ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- ADD_POINTS_TO_VOLUMES_10_trilinear_[25, 25, 25]_1000 43219 44067 12 ADD_POINTS_TO_VOLUMES_10_trilinear_[25, 25, 25]_10000 43274 45313 12 ADD_POINTS_TO_VOLUMES_10_trilinear_[25, 25, 25]_100000 46281 47100 11 ADD_POINTS_TO_VOLUMES_10_trilinear_[101, 111, 121]_1000 51224 51912 10 ADD_POINTS_TO_VOLUMES_10_trilinear_[101, 111, 121]_10000 52092 54487 10 ADD_POINTS_TO_VOLUMES_10_trilinear_[101, 111, 121]_100000 59262 60514 9 ADD_POINTS_TO_VOLUMES_10_nearest_[25, 25, 25]_1000 15998 17237 32 ADD_POINTS_TO_VOLUMES_10_nearest_[25, 25, 25]_10000 15964 16994 32 ADD_POINTS_TO_VOLUMES_10_nearest_[25, 25, 25]_100000 16881 19286 30 ADD_POINTS_TO_VOLUMES_10_nearest_[101, 111, 121]_1000 19150 25277 27 ADD_POINTS_TO_VOLUMES_10_nearest_[101, 111, 121]_10000 18746 19999 27 ADD_POINTS_TO_VOLUMES_10_nearest_[101, 111, 121]_100000 22321 24568 23 ADD_POINTS_TO_VOLUMES_100_trilinear_[25, 25, 25]_1000 49693 50288 11 ADD_POINTS_TO_VOLUMES_100_trilinear_[25, 25, 25]_10000 51429 52449 10 ADD_POINTS_TO_VOLUMES_100_trilinear_[25, 25, 25]_100000 237076 237377 3 ADD_POINTS_TO_VOLUMES_100_trilinear_[101, 111, 121]_1000 81875 82597 7 ADD_POINTS_TO_VOLUMES_100_trilinear_[101, 111, 121]_10000 106671 107045 5 ADD_POINTS_TO_VOLUMES_100_trilinear_[101, 111, 121]_100000 483740 484607 2 ADD_POINTS_TO_VOLUMES_100_nearest_[25, 25, 25]_1000 16667 18143 31 ADD_POINTS_TO_VOLUMES_100_nearest_[25, 25, 25]_10000 17682 18922 29 ADD_POINTS_TO_VOLUMES_100_nearest_[25, 25, 25]_100000 65463 67116 8 ADD_POINTS_TO_VOLUMES_100_nearest_[101, 111, 121]_1000 48058 48826 11 ADD_POINTS_TO_VOLUMES_100_nearest_[101, 111, 121]_10000 53529 53998 10 ADD_POINTS_TO_VOLUMES_100_nearest_[101, 111, 121]_100000 123684 123901 5 -------------------------------------------------------------------------------- ``` Output with `DEBUG=True` {F338561209} Reviewed By: nikhilaravi Differential Revision: D22017500 fbshipit-source-id: ed3e8ed13940c593841d93211623dd533974012f --- pytorch3d/ops/__init__.py | 4 + pytorch3d/ops/points_to_volumes.py | 491 +++++++++++++++++++++++++++++ tests/bm_points_to_volumes.py | 24 ++ tests/test_points_to_volumes.py | 385 ++++++++++++++++++++++ 4 files changed, 904 insertions(+) create mode 100644 pytorch3d/ops/points_to_volumes.py create mode 100644 tests/bm_points_to_volumes.py create mode 100644 tests/test_points_to_volumes.py diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index 15f2f6f45..861791532 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -14,6 +14,10 @@ estimate_pointcloud_local_coord_frames, estimate_pointcloud_normals, ) +from .points_to_volumes import ( + add_pointclouds_to_volumes, + add_points_features_to_volume_densities_features, +) from .sample_points_from_meshes import sample_points_from_meshes from .subdivide_meshes import SubdivideMeshes from .utils import ( diff --git a/pytorch3d/ops/points_to_volumes.py b/pytorch3d/ops/points_to_volumes.py new file mode 100644 index 000000000..dd3df490a --- /dev/null +++ b/pytorch3d/ops/points_to_volumes.py @@ -0,0 +1,491 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import TYPE_CHECKING, Optional, Tuple + +import torch + + +if TYPE_CHECKING: + from ..structures import Pointclouds, Volumes + + +def add_pointclouds_to_volumes( + pointclouds: "Pointclouds", + initial_volumes: "Volumes", + mode: str = "trilinear", + min_weight: float = 1e-4, +) -> "Volumes": + """ + Add a batch of point clouds represented with a `Pointclouds` structure + `pointclouds` to a batch of existing volumes represented with a + `Volumes` structure `initial_volumes`. + + More specifically, the method casts a set of weighted votes (the weights are + determined based on `mode="trilinear"|"nearest"`) into the pre-initialized + `features` and `densities` fields of `initial_volumes`. + + The method returns an updated `Volumes` object that contains a copy + of `initial_volumes` with its `features` and `densities` updated with the + result of the pointcloud addition. + + Example: + ``` + # init a random point cloud + pointclouds = Pointclouds( + points=torch.randn(4, 100, 3), features=torch.rand(4, 100, 5) + ) + # init an empty volume centered around [0.5, 0.5, 0.5] in world coordinates + # with a voxel size of 1.0. + initial_volumes = Volumes( + features = torch.zeros(4, 5, 25, 25, 25), + densities = torch.zeros(4, 1, 25, 25, 25), + volume_translation = [-0.5, -0.5, -0.5], + voxel_size = 1.0, + ) + # add the pointcloud to the 'initial_volumes' buffer using + # trilinear splatting + updated_volumes = add_pointclouds_to_volumes( + pointclouds=pointclouds, + initial_volumes=initial_volumes, + mode="trilinear", + ) + ``` + + Args: + pointclouds: Batch of 3D pointclouds represented with a `Pointclouds` + structure. Note that `pointclouds.features` have to be defined. + initial_volumes: Batch of initial `Volumes` with pre-initialized 1-dimensional + densities which contain non-negative numbers corresponding to the + opaqueness of each voxel (the higher, the less transparent). + mode: The mode of the conversion of individual points into the volume. + Set either to `nearest` or `trilinear`: + `nearest`: Each 3D point is first rounded to the volumetric + lattice. Each voxel is then labeled with the average + over features that fall into the given voxel. + The gradients of nearest neighbor conversion w.r.t. the + 3D locations of the points in `pointclouds` are *not* defined. + `trilinear`: Each 3D point casts 8 weighted votes to the 8-neighborhood + of its floating point coordinate. The weights are + determined using a trilinear interpolation scheme. + Trilinear splatting is fully differentiable w.r.t. all input arguments. + min_weight: A scalar controlling the lowest possible total per-voxel + weight used to normalize the features accumulated in a voxel. + Only active for `mode==trilinear`. + + Returns: + updated_volumes: Output `Volumes` structure containing the conversion result. + """ + + if len(initial_volumes) != len(pointclouds): + raise ValueError( + "'initial_volumes' and 'pointclouds' have to have the same batch size." + ) + + # obtain the features and densities + pcl_feats = pointclouds.features_padded() + pcl_3d = pointclouds.points_padded() + + if pcl_feats is None: + raise ValueError("'pointclouds' have to have their 'features' defined.") + + # obtain the conversion mask + n_per_pcl = pointclouds.num_points_per_cloud().type_as(pcl_feats) + mask = torch.arange(n_per_pcl.max(), dtype=pcl_feats.dtype, device=pcl_feats.device) + mask = (mask[None, :] < n_per_pcl[:, None]).type_as(mask) + + # convert to the coord frame of the volume + pcl_3d_local = initial_volumes.world_to_local_coords(pcl_3d) + + features_new, densities_new = add_points_features_to_volume_densities_features( + points_3d=pcl_3d_local, + points_features=pcl_feats, + volume_features=initial_volumes.features(), + volume_densities=initial_volumes.densities(), + min_weight=min_weight, + grid_sizes=initial_volumes.get_grid_sizes(), + mask=mask, + mode=mode, + ) + + return initial_volumes.update_padded( + new_densities=densities_new, new_features=features_new + ) + + +def add_points_features_to_volume_densities_features( + points_3d: torch.Tensor, + points_features: torch.Tensor, + volume_densities: torch.Tensor, + volume_features: Optional[torch.Tensor], + mode: str = "trilinear", + min_weight: float = 1e-4, + mask: Optional[torch.Tensor] = None, + grid_sizes: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert a batch of point clouds represented with tensors of per-point + 3d coordinates and their features to a batch of volumes represented + with tensors of densities and features. + + Args: + points_3d: Batch of 3D point cloud coordinates of shape + `(minibatch, N, 3)` where N is the number of points + in each point cloud. Coordinates have to be specified in the + local volume coordinates (ranging in [-1, 1]). + points_features: Features of shape `(minibatch, N, feature_dim)` corresponding + to the points of the input point clouds `pointcloud`. + volume_densities: Batch of input feature volume densities of shape + `(minibatch, 1, D, H, W)`. Each voxel should + contain a non-negative number corresponding to its + opaqueness (the higher, the less transparent). + volume_features: Batch of input feature volumes of shape + `(minibatch, feature_dim, D, H, W)` + If set to `None`, the `volume_features` will be automatically + instantiatied with a correct size and filled with 0s. + mode: The mode of the conversion of individual points into the volume. + Set either to `nearest` or `trilinear`: + `nearest`: Each 3D point is first rounded to the volumetric + lattice. Each voxel is then labeled with the average + over features that fall into the given voxel. + The gradients of nearest neighbor rounding w.r.t. the + input point locations `points_3d` are *not* defined. + `trilinear`: Each 3D point casts 8 weighted votes to the 8-neighborhood + of its floating point coordinate. The weights are + determined using a trilinear interpolation scheme. + Trilinear splatting is fully differentiable w.r.t. all input arguments. + mask: A binary mask of shape `(minibatch, N)` determining which 3D points + are going to be converted to the resulting volume. + Set to `None` if all points are valid. + min_weight: A scalar controlling the lowest possible total per-voxel + weight used to normalize the features accumulated in a voxel. + Only active for `mode==trilinear`. + Returns: + volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)` + volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)` + containing the total amount of votes cast to each of the voxels. + """ + + # number of points in the point cloud, its dim and batch size + ba, n_points, feature_dim = points_features.shape + ba_volume, density_dim = volume_densities.shape[:2] + + if density_dim != 1: + raise ValueError("Only one-dimensional densities are allowed.") + + # init the volumetric grid sizes if uninitialized + if grid_sizes is None: + grid_sizes = torch.LongTensor(list(volume_densities.shape[2:])).to( + volume_densities + ) + + # flatten densities and features + v_shape = volume_densities.shape[2:] + volume_densities_flatten = volume_densities.view(ba, -1, 1) + n_voxels = volume_densities_flatten.shape[1] + + if volume_features is None: + # initialize features if not passed in + volume_features_flatten = volume_densities.new_zeros(ba, feature_dim, n_voxels) + else: + # otherwise just flatten + volume_features_flatten = volume_features.view(ba, feature_dim, n_voxels) + + if mode == "trilinear": # do the splatting (trilinear interp) + volume_features, volume_densities = splat_points_to_volumes( + points_3d, + points_features, + volume_densities_flatten, + volume_features_flatten, + grid_sizes, + mask=mask, + min_weight=min_weight, + ) + elif mode == "nearest": # nearest neighbor interp + volume_features, volume_densities = round_points_to_volumes( + points_3d, + points_features, + volume_densities_flatten, + volume_features_flatten, + grid_sizes, + mask=mask, + ) + else: + raise ValueError('No such interpolation mode "%s"' % mode) + + # reshape into the volume shape + volume_features = volume_features.view(ba, feature_dim, *v_shape) + volume_densities = volume_densities.view(ba, 1, *v_shape) + + return volume_features, volume_densities + + +def _check_points_to_volumes_inputs( + points_3d: torch.Tensor, + points_features: torch.Tensor, + volume_densities: torch.Tensor, + volume_features: torch.Tensor, + grid_sizes: torch.LongTensor, + mask: Optional[torch.Tensor] = None, +): + + max_grid_size = grid_sizes.max(dim=0).values + if torch.prod(max_grid_size) > volume_densities.shape[1]: + raise ValueError( + "One of the grid sizes corresponds to a larger number" + + " of elements than the number of elements in volume_densities." + ) + + _, n_voxels, density_dim = volume_densities.shape + + if density_dim != 1: + raise ValueError("Only one-dimensional densities are allowed.") + + ba, n_points, feature_dim = points_features.shape + + if volume_features.shape[1] != feature_dim: + raise ValueError( + "volume_features have a different number of channels" + + " than points_features." + ) + + if volume_features.shape[2] != n_voxels: + raise ValueError( + "volume_features have a different number of elements" + + " than volume_densities." + ) + + +def splat_points_to_volumes( + points_3d: torch.Tensor, + points_features: torch.Tensor, + volume_densities: torch.Tensor, + volume_features: torch.Tensor, + grid_sizes: torch.LongTensor, + min_weight: float = 1e-4, + mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert a batch of point clouds to a batch of volumes using trilinear + splatting into a volume. + + Args: + points_3d: Batch of 3D point cloud coordinates of shape + `(minibatch, N, 3)` where N is the number of points + in each point cloud. Coordinates have to be specified in the + local volume coordinates (ranging in [-1, 1]). + points_features: Features of shape `(minibatch, N, feature_dim)` + corresponding to the points of the input point cloud `points_3d`. + volume_features: Batch of input *flattened* feature volumes + of shape `(minibatch, feature_dim, N_voxels)` + volume_densities: Batch of input *flattened* feature volume densities + of shape `(minibatch, 1, N_voxels)`. Each voxel should + contain a non-negative number corresponding to its + opaqueness (the higher, the less transparent). + grid_sizes: `LongTensor` of shape (minibatch, 3) representing the + spatial resolutions of each of the the non-flattened `volumes` tensors. + Note that the following has to hold: + `torch.prod(grid_sizes, dim=1)==N_voxels` + mask: A binary mask of shape `(minibatch, N)` determining which 3D points + are going to be converted to the resulting volume. + Set to `None` if all points are valid. + Returns: + volume_features: Output volume of shape `(minibatch, D, N_voxels)`. + volume_densities: Occupancy volume of shape `(minibatch, 1, N_voxels)` + containing the total amount of votes cast to each of the voxels. + """ + + _check_points_to_volumes_inputs( + points_3d, + points_features, + volume_densities, + volume_features, + grid_sizes, + mask=mask, + ) + + _, n_voxels, density_dim = volume_densities.shape + ba, n_points, feature_dim = points_features.shape + + # minibatch x n_points x feature_dim -> minibatch x feature_dim x n_points + points_features = points_features.permute(0, 2, 1).contiguous() + + # XYZ = the upper-left volume index of the 8-neigborhood of every point + # grid_sizes is of the form (minibatch, depth-height-width) + grid_sizes_xyz = grid_sizes[:, [2, 1, 0]] + + # Convert from points_3d in the range [-1, 1] to + # indices in the volume grid in the range [0, grid_sizes_xyz-1] + points_3d_indices = ((points_3d + 1) * 0.5) * ( + grid_sizes_xyz[:, None].type_as(points_3d) - 1 + ) + XYZ = points_3d_indices.floor().long() + rXYZ = points_3d_indices - XYZ.type_as(points_3d) # remainder of floor + + # split into separate coordinate vectors + X, Y, Z = XYZ.split(1, dim=2) + # rX = remainder after floor = 1-"the weight of each vote into + # the X coordinate of the 8-neighborhood" + rX, rY, rZ = rXYZ.split(1, dim=2) + + # get random indices for the purpose of adding out-of-bounds values + rand_idx = X.new_zeros(X.shape).random_(0, n_voxels) + + # iterate over the x, y, z indices of the 8-neighborhood (xdiff, ydiff, zdiff) + for xdiff in (0, 1): + X_ = X + xdiff + wX = (1 - xdiff) + (2 * xdiff - 1) * rX + for ydiff in (0, 1): + Y_ = Y + ydiff + wY = (1 - ydiff) + (2 * ydiff - 1) * rY + for zdiff in (0, 1): + Z_ = Z + zdiff + wZ = (1 - zdiff) + (2 * zdiff - 1) * rZ + + # weight of each vote into the given cell of 8-neighborhood + w = wX * wY * wZ + + # valid - binary indicators of votes that fall into the volume + valid = ( + (0 <= X_) + * (X_ < grid_sizes_xyz[:, None, 0:1]) + * (0 <= Y_) + * (Y_ < grid_sizes_xyz[:, None, 1:2]) + * (0 <= Z_) + * (Z_ < grid_sizes_xyz[:, None, 2:3]) + ).long() + + # linearized indices into the volume + idx = (Z_ * grid_sizes[:, None, 1:2] + Y_) * grid_sizes[ + :, None, 2:3 + ] + X_ + + # out-of-bounds features added to a random voxel idx with weight=0. + idx_valid = idx * valid + rand_idx * (1 - valid) + w_valid = w * valid.type_as(w) + if mask is not None: + w_valid = w_valid * mask.type_as(w)[:, :, None] + + # scatter add casts the votes into the weight accumulator + # and the feature accumulator + volume_densities.scatter_add_(1, idx_valid, w_valid) + + # reshape idx_valid -> (minibatch, feature_dim, n_points) + idx_valid = idx_valid.view(ba, 1, n_points).expand_as(points_features) + w_valid = w_valid.view(ba, 1, n_points) + + # volume_features of shape (minibatch, feature_dim, n_voxels) + volume_features.scatter_add_(2, idx_valid, w_valid * points_features) + + # divide each feature by the total weight of the votes + volume_features = volume_features / volume_densities.view(ba, 1, n_voxels).clamp( + min_weight + ) + + return volume_features, volume_densities + + +def round_points_to_volumes( + points_3d: torch.Tensor, + points_features: torch.Tensor, + volume_densities: torch.Tensor, + volume_features: torch.Tensor, + grid_sizes: torch.LongTensor, + mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert a batch of point clouds to a batch of volumes using rounding to the + nearest integer coordinate of the volume. Features that fall into the same + voxel are averaged. + + Args: + points_3d: Batch of 3D point cloud coordinates of shape + `(minibatch, N, 3)` where N is the number of points + in each point cloud. Coordinates have to be specified in the + local volume coordinates (ranging in [-1, 1]). + points_features: Features of shape `(minibatch, N, feature_dim)` + corresponding to the points of the input point cloud `points_3d`. + volume_features: Batch of input *flattened* feature volumes + of shape `(minibatch, feature_dim, N_voxels)` + volume_densities: Batch of input *flattened* feature volume densities + of shape `(minibatch, 1, N_voxels)`. Each voxel should + contain a non-negative number corresponding to its + opaqueness (the higher, the less transparent). + grid_sizes: `LongTensor` of shape (minibatch, 3) representing the + spatial resolutions of each of the the non-flattened `volumes` tensors. + Note that the following has to hold: + `torch.prod(grid_sizes, dim=1)==N_voxels` + mask: A binary mask of shape `(minibatch, N)` determining which 3D points + are going to be converted to the resulting volume. + Set to `None` if all points are valid. + Returns: + volume_features: Output volume of shape `(minibatch, D, N_voxels)`. + volume_densities: Occupancy volume of shape `(minibatch, 1, N_voxels)` + containing the total amount of votes cast to each of the voxels. + """ + + _check_points_to_volumes_inputs( + points_3d, + points_features, + volume_densities, + volume_features, + grid_sizes, + mask=mask, + ) + + _, n_voxels, density_dim = volume_densities.shape + ba, n_points, feature_dim = points_features.shape + + # minibatch x n_points x feature_dim-> minibatch x feature_dim x n_points + points_features = points_features.permute(0, 2, 1).contiguous() + + # round the coordinates to nearest integer + # grid_sizes is of the form (minibatch, depth-height-width) + grid_sizes_xyz = grid_sizes[:, [2, 1, 0]] + XYZ = ((points_3d.detach() + 1) * 0.5) * ( + grid_sizes_xyz[:, None].type_as(points_3d) - 1 + ) + XYZ = torch.round(XYZ).long() + + # split into separate coordinate vectors + X, Y, Z = XYZ.split(1, dim=2) + + # get random indices for the purpose of adding out-of-bounds values + rand_idx = X.new_zeros(X.shape).random_(0, n_voxels) + + # valid - binary indicators of votes that fall into the volume + grid_sizes = grid_sizes.type_as(XYZ) + valid = ( + (0 <= X) + * (X < grid_sizes_xyz[:, None, 0:1]) + * (0 <= Y) + * (Y < grid_sizes_xyz[:, None, 1:2]) + * (0 <= Z) + * (Z < grid_sizes_xyz[:, None, 2:3]) + ).long() + + # get random indices for the purpose of adding out-of-bounds values + rand_idx = valid.new_zeros(X.shape).random_(0, n_voxels) + + # linearized indices into the volume + idx = (Z * grid_sizes[:, None, 1:2] + Y) * grid_sizes[:, None, 2:3] + X + + # out-of-bounds features added to a random voxel idx with weight=0. + idx_valid = idx * valid + rand_idx * (1 - valid) + w_valid = valid.type_as(volume_features) + + # scatter add casts the votes into the weight accumulator + # and the feature accumulator + volume_densities.scatter_add_(1, idx_valid, w_valid) + + # reshape idx_valid -> (minibatch, feature_dim, n_points) + idx_valid = idx_valid.view(ba, 1, n_points).expand_as(points_features) + w_valid = w_valid.view(ba, 1, n_points) + + # volume_features of shape (minibatch, feature_dim, n_voxels) + volume_features.scatter_add_(2, idx_valid, w_valid * points_features) + + # divide each feature by the total weight of the votes + volume_features = volume_features / volume_densities.view(ba, 1, n_voxels).clamp( + 1.0 + ) + + return volume_features, volume_densities diff --git a/tests/bm_points_to_volumes.py b/tests/bm_points_to_volumes.py new file mode 100644 index 000000000..2eca07ae1 --- /dev/null +++ b/tests/bm_points_to_volumes.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import itertools + +from fvcore.common.benchmark import benchmark +from test_points_to_volumes import TestPointsToVolumes + + +def bm_points_to_volumes() -> None: + case_grid = { + "batch_size": [10, 100], + "interp_mode": ["trilinear", "nearest"], + "volume_size": [[25, 25, 25], [101, 111, 121]], + "n_points": [1000, 10000, 100000], + } + test_cases = itertools.product(*case_grid.values()) + kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases] + + benchmark( + TestPointsToVolumes.add_points_to_volumes, + "ADD_POINTS_TO_VOLUMES", + kwargs_list, + warmup_iters=1, + ) diff --git a/tests/test_points_to_volumes.py b/tests/test_points_to_volumes.py new file mode 100644 index 000000000..70a862124 --- /dev/null +++ b/tests/test_points_to_volumes.py @@ -0,0 +1,385 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import unittest +from typing import Tuple + +import numpy as np +import torch +from common_testing import TestCaseMixin +from pytorch3d.ops import add_pointclouds_to_volumes +from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes +from pytorch3d.structures.meshes import Meshes +from pytorch3d.structures.pointclouds import Pointclouds +from pytorch3d.structures.volumes import Volumes +from pytorch3d.transforms.so3 import so3_exponential_map + + +DEBUG = False +if DEBUG: + import os + import tempfile + from PIL import Image + + +def init_cube_point_cloud( + batch_size: int = 10, n_points: int = 100000, rotate_y: bool = True +): + """ + Generate a random point cloud of `n_points` whose points of + which are sampled from faces of a 3D cube. + """ + + # create the cube mesh batch_size times + meshes = TestPointsToVolumes.init_cube_mesh(batch_size) + + # generate point clouds by sampling points from the meshes + pcl = sample_points_from_meshes(meshes, num_samples=n_points, return_normals=False) + + # colors of the cube sides + clrs = [ + [1.0, 0.0, 0.0], + [1.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + ] + + # init the color tensor "rgb" + rgb = torch.zeros_like(pcl) + + # color each side of the cube with a constant color + clri = 0 + for dim in (0, 1, 2): + for offs in (0.0, 1.0): + current_face_verts = (pcl[:, :, dim] - offs).abs() <= 1e-2 + for bi in range(batch_size): + rgb[bi, current_face_verts[bi], :] = torch.tensor(clrs[clri]).type_as( + pcl + ) + clri += 1 + + if rotate_y: + # uniformly spaced rotations around y axis + R = init_uniform_y_rotations(batch_size=batch_size) + # rotate the point clouds around y axis + pcl = torch.bmm(pcl - 0.5, R) + 0.5 + + return pcl, rgb + + +def init_volume_boundary_pointcloud( + batch_size: int, + volume_size: Tuple[int, int, int], + n_points: int, + interp_mode: str, + require_grad: bool = False, +): + """ + Initialize a point cloud that closely follows a boundary of + a volume with a given size. The volume buffer is initialized as well. + """ + + # generate a 3D point cloud sampled from sides of a [0,1] cube + xyz, rgb = init_cube_point_cloud(batch_size, n_points=n_points, rotate_y=True) + + # make volume_size tensor + volume_size_t = torch.tensor(volume_size, dtype=xyz.dtype, device=xyz.device) + + if interp_mode == "trilinear": + # make the xyz locations fall on the boundary of the + # first/last two voxels along each spatial dimension of the + # volume - this properly checks the correctness of the + # trilinear interpolation scheme + xyz = (xyz - 0.5) * ((volume_size_t - 2) / (volume_size_t - 1))[[2, 1, 0]] + 0.5 + + # rescale the cube pointcloud to overlap with the volume sides + # of the volume + rel_scale = volume_size_t / volume_size[0] + xyz = xyz * rel_scale[[2, 1, 0]][None, None] + + # enable grad accumulation for the differentiability check + xyz.requires_grad = require_grad + rgb.requires_grad = require_grad + + # create the pointclouds structure + pointclouds = Pointclouds(xyz, features=rgb) + + # set the volume translation so that the point cloud is centered + # around 0 + volume_translation = -0.5 * rel_scale[[2, 1, 0]] + + # set the voxel size to 1 / (volume_size-1) + volume_voxel_size = 1 / (volume_size[0] - 1.0) + + # instantiate the volumes + initial_volumes = Volumes( + features=xyz.new_zeros(batch_size, 3, *volume_size), + densities=xyz.new_zeros(batch_size, 1, *volume_size), + volume_translation=volume_translation, + voxel_size=volume_voxel_size, + ) + + return pointclouds, initial_volumes + + +def init_uniform_y_rotations(batch_size: int = 10): + """ + Generate a batch of `batch_size` 3x3 rotation matrices around y-axis + whose angles are uniformly distributed between 0 and 2 pi. + """ + device = torch.device("cuda:0") + axis = torch.tensor([0.0, 1.0, 0.0], device=device, dtype=torch.float32) + angles = torch.linspace(0, 2.0 * np.pi, batch_size + 1, device=device) + angles = angles[:batch_size] + log_rots = axis[None, :] * angles[:, None] + R = so3_exponential_map(log_rots) + return R + + +class TestPointsToVolumes(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + np.random.seed(42) + torch.manual_seed(42) + + @staticmethod + def add_points_to_volumes( + batch_size: int, + volume_size: Tuple[int, int, int], + n_points: int, + interp_mode: str, + ): + (pointclouds, initial_volumes) = init_volume_boundary_pointcloud( + batch_size=batch_size, + volume_size=volume_size, + n_points=n_points, + interp_mode=interp_mode, + require_grad=False, + ) + + def _add_points_to_volumes(): + add_pointclouds_to_volumes(pointclouds, initial_volumes, mode=interp_mode) + + return _add_points_to_volumes + + @staticmethod + def stack_4d_tensor_to_3d(arr): + n = arr.shape[0] + H = int(np.ceil(np.sqrt(n))) + W = int(np.ceil(n / H)) + n_add = H * W - n + arr = torch.cat((arr, torch.zeros_like(arr[:1]).repeat(n_add, 1, 1, 1))) + rows = torch.chunk(arr, chunks=W, dim=0) + arr3d = torch.cat([torch.cat(list(row), dim=2) for row in rows], dim=1) + return arr3d + + @staticmethod + def init_cube_mesh(batch_size: int = 10): + """ + Generate a batch of `batch_size` cube meshes. + """ + + device = torch.device("cuda:0") + + verts, faces = [], [] + + for _ in range(batch_size): + v = torch.tensor( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [1.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + ], + dtype=torch.float32, + device=device, + ) + verts.append(v) + faces.append( + torch.tensor( + [ + [0, 2, 1], + [0, 3, 2], + [2, 3, 4], + [2, 4, 5], + [1, 2, 5], + [1, 5, 6], + [0, 7, 4], + [0, 4, 3], + [5, 4, 7], + [5, 7, 6], + [0, 6, 7], + [0, 1, 6], + ], + dtype=torch.int64, + device=device, + ) + ) + + faces = torch.stack(faces) + verts = torch.stack(verts) + + simpleces = Meshes(verts=verts, faces=faces) + + return simpleces + + def test_from_point_cloud(self, interp_mode="trilinear"): + """ + Generates a volume from a random point cloud sampled from faces + of a 3D cube. Since each side of the cube is homogenously colored with + a different color, this should result in a volume with a + predefined homogenous color of the cells along its borders + and black interior. The test is run for both cube and non-cube shaped + volumes. + """ + + # batch_size = 4 sides of the cube + batch_size = 4 + + for volume_size in ([25, 25, 25], [30, 25, 15]): + + for interp_mode in ("trilinear", "nearest"): + + (pointclouds, initial_volumes) = init_volume_boundary_pointcloud( + volume_size=volume_size, + n_points=int(1e5), + interp_mode=interp_mode, + batch_size=batch_size, + require_grad=True, + ) + + volumes = add_pointclouds_to_volumes( + pointclouds, initial_volumes, mode=interp_mode + ) + + V_color, V_density = volumes.features(), volumes.densities() + + # expected colors of different cube sides + clr_sides = torch.tensor( + [ + [[1.0, 1.0, 1.0], [1.0, 0.0, 1.0]], + [[1.0, 0.0, 0.0], [1.0, 1.0, 0.0]], + [[1.0, 0.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]], + ], + dtype=V_color.dtype, + device=V_color.device, + ) + clr_ambient = torch.tensor( + [0.0, 0.0, 0.0], dtype=V_color.dtype, device=V_color.device + ) + clr_top_bot = torch.tensor( + [[0.0, 1.0, 0.0], [0.0, 1.0, 1.0]], + dtype=V_color.dtype, + device=V_color.device, + ) + + if DEBUG: + outdir = tempfile.gettempdir() + "/test_points_to_volumes" + os.makedirs(outdir, exist_ok=True) + + for slice_dim in (1, 2): + for vidx in range(V_color.shape[0]): + vim = V_color.detach()[vidx].split(1, dim=slice_dim) + vim = torch.stack([v.squeeze() for v in vim]) + vim = TestPointsToVolumes.stack_4d_tensor_to_3d(vim.cpu()) + im = Image.fromarray( + (vim.numpy() * 255.0) + .astype(np.uint8) + .transpose(1, 2, 0) + ) + outfile = ( + outdir + + f"/rgb_{interp_mode}" + + f"_{str(volume_size).replace(' ','')}" + + f"_{vidx:003d}_sldim{slice_dim}.png" + ) + im.save(outfile) + print("exported %s" % outfile) + + # check the density V_density + # first binarize the density + V_density_bin = (V_density > 1e-4).type_as(V_density) + d_one = V_density.new_ones(1) + d_zero = V_density.new_zeros(1) + for vidx in range(V_color.shape[0]): + # the first/last depth-wise slice has to be filled with 1.0 + self._check_volume_slice_color_density( + V_density_bin[vidx], 1, interp_mode, d_one, "first" + ) + self._check_volume_slice_color_density( + V_density_bin[vidx], 1, interp_mode, d_one, "last" + ) + # the middle depth-wise slices have to be empty + self._check_volume_slice_color_density( + V_density_bin[vidx], 1, interp_mode, d_zero, "middle" + ) + # the top/bottom slices have to be filled with 1.0 + self._check_volume_slice_color_density( + V_density_bin[vidx], 2, interp_mode, d_one, "first" + ) + self._check_volume_slice_color_density( + V_density_bin[vidx], 2, interp_mode, d_one, "last" + ) + + # check the colors + for vidx in range(V_color.shape[0]): + self._check_volume_slice_color_density( + V_color[vidx], 1, interp_mode, clr_sides[vidx][0], "first" + ) + self._check_volume_slice_color_density( + V_color[vidx], 1, interp_mode, clr_sides[vidx][1], "last" + ) + self._check_volume_slice_color_density( + V_color[vidx], 1, interp_mode, clr_ambient, "middle" + ) + self._check_volume_slice_color_density( + V_color[vidx], 2, interp_mode, clr_top_bot[0], "first" + ) + self._check_volume_slice_color_density( + V_color[vidx], 2, interp_mode, clr_top_bot[1], "last" + ) + + # check differentiability + loss = V_color.mean() + V_density.mean() + loss.backward() + rgb = pointclouds.features_padded() + xyz = pointclouds.points_padded() + for field in (xyz, rgb): + if interp_mode == "nearest" and (field is xyz): + # this does not produce grads w.r.t. xyz + self.assertIsNone(field.grad) + else: + self.assertTrue(field.grad.data.isfinite().all()) + + def _check_volume_slice_color_density( + self, V, split_dim, interp_mode, clr_gt, slice_type, border=3 + ): + # decompose the volume to individual slices along split_dim + vim = V.detach().split(1, dim=split_dim) + vim = torch.stack([v.squeeze(split_dim) for v in vim]) + + # determine which slices should be compared to clr_gt based on + # the 'slice_type' input + if slice_type == "first": + slice_dims = (0, 1) if interp_mode == "trilinear" else (0,) + elif slice_type == "last": + slice_dims = (-1, -2) if interp_mode == "trilinear" else (-1,) + elif slice_type == "middle": + internal_border = 2 if interp_mode == "trilinear" else 1 + slice_dims = torch.arange(internal_border, vim.shape[0] - internal_border) + else: + raise ValueError(slice_type) + + # compute the average error within each slice + clr_diff = ( + vim[slice_dims, :, border:-border, border:-border] + - clr_gt[None, :, None, None] + ) + clr_diff = clr_diff.abs().mean(dim=(2, 3)).view(-1) + + # check that all per-slice avg errors vanish + self.assertClose(clr_diff, torch.zeros_like(clr_diff), atol=1e-2)