Skip to content

Commit

Permalink
Barycentric clipping in the renderer and flat shading
Browse files Browse the repository at this point in the history
Summary:
Updates to the Renderer to enable barycentric clipping. This is important when there is blurring in the rasterization step.

Also added support for flat shading.

Reviewed By: jcjohnson

Differential Revision: D19934259

fbshipit-source-id: 036e48636cd80d28a04405d7a29fcc71a2982904
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Feb 29, 2020
1 parent f358b9b commit ff19c64
Show file tree
Hide file tree
Showing 14 changed files with 254 additions and 108 deletions.
13 changes: 7 additions & 6 deletions pytorch3d/renderer/blending.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
return torch.flip(pixel_colors, [1])


def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
def softmax_rgb_blend(
colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100
) -> torch.Tensor:
"""
RGB and alpha channel blending to return an RGBA image based on the method
proposed in [0]
Expand Down Expand Up @@ -118,13 +120,16 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
exponential function used to control the opacity of the color.
- background_color: (3) element list/tuple/torch.Tensor specifying
the RGB values for the background color.
znear: float, near clipping plane in the z direction
zfar: float, far clipping plane in the z direction
Returns:
RGBA pixel_colors: (N, H, W, 4)
[0] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for
Image-based 3D Reasoning'
"""

N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device
pix_colors = torch.ones(
Expand All @@ -140,11 +145,6 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
delta = np.exp(1e-10 / blend_params.gamma) * 1e-10
delta = torch.tensor(delta, device=device)

# Near and far clipping planes.
# TODO: add zfar/znear as input params.
zfar = 100.0
znear = 1.0

# Mask for padded pixels.
mask = fragments.pix_to_face >= 0

Expand All @@ -164,6 +164,7 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
# Weights for each face. Adjust the exponential by the max z to prevent
# overflow. zbuf shape (N, H, W, K), find max over K.
# TODO: there may still be some instability in the exponent calculation.

z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
z_inv_max = torch.max(z_inv, dim=-1).values[..., None]
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
Expand Down
11 changes: 6 additions & 5 deletions pytorch3d/renderer/mesh/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


from .texturing import ( # isort:skip
interpolate_texture_map,
interpolate_vertex_colors,
)
from .rasterize_meshes import rasterize_meshes
from .rasterizer import MeshRasterizer, RasterizationSettings
from .renderer import MeshRenderer
Expand All @@ -13,10 +18,6 @@
TexturedSoftPhongShader,
)
from .shading import gouraud_shading, phong_shading
from .texturing import ( # isort: skip
interpolate_face_attributes,
interpolate_texture_map,
interpolate_vertex_colors,
)
from .utils import interpolate_face_attributes

__all__ = [k for k in globals().keys() if not k.startswith("_")]
31 changes: 31 additions & 0 deletions pytorch3d/renderer/mesh/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import torch
import torch.nn as nn

from .rasterizer import Fragments
from .utils import _clip_barycentric_coordinates, _interpolate_zbuf

# A renderer class should be initialized with a
# function for rasterization and a function for shading.
# The rasterizer should:
Expand Down Expand Up @@ -34,6 +37,34 @@ def __init__(self, rasterizer, shader):
self.shader = shader

def forward(self, meshes_world, **kwargs) -> torch.Tensor:
"""
Render a batch of images from a batch of meshes by rasterizing and then shading.
NOTE: If the blur radius for rasterization is > 0.0, some pixels can have one or
more barycentric coordinates lying outside the range [0, 1]. For a pixel with
out of bounds barycentric coordinates with respect to a face f, clipping is required
before interpolating the texture uv coordinates and z buffer so that the colors and
depths are limited to the range for the corresponding face.
"""
fragments = self.rasterizer(meshes_world, **kwargs)
raster_settings = kwargs.get(
"raster_settings", self.rasterizer.raster_settings
)
if raster_settings.blur_radius > 0.0:
# TODO: potentially move barycentric clipping to the rasterizer
# if no downstream functions requires unclipped values.
# This will avoid unnecssary re-interpolation of the z buffer.
clipped_bary_coords = _clip_barycentric_coordinates(
fragments.bary_coords
)
clipped_zbuf = _interpolate_zbuf(
fragments.pix_to_face, clipped_bary_coords, meshes_world
)
fragments = Fragments(
bary_coords=clipped_bary_coords,
zbuf=clipped_zbuf,
dists=fragments.dists,
pix_to_face=fragments.pix_to_face,
)
images = self.shader(fragments, meshes_world, **kwargs)
return images
3 changes: 2 additions & 1 deletion pytorch3d/renderer/mesh/shader.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
colors = phong_shading(
meshes=meshes,
fragments=fragments,
Expand All @@ -278,7 +279,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras=cameras,
materials=materials,
)
images = softmax_rgb_blend(colors, fragments, self.blend_params)
images = softmax_rgb_blend(colors, fragments, blend_params)
return images


Expand Down
12 changes: 9 additions & 3 deletions pytorch3d/renderer/mesh/shading.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ def phong_shading(
vertex_normals = meshes.verts_normals_packed() # (V, 3)
faces_verts = verts[faces]
faces_normals = vertex_normals[faces]
pixel_coords = interpolate_face_attributes(fragments, faces_verts)
pixel_normals = interpolate_face_attributes(fragments, faces_normals)
pixel_coords = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_verts
)
pixel_normals = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_normals
)
ambient, diffuse, specular = _apply_lighting(
pixel_coords, pixel_normals, lights, cameras, materials
)
Expand Down Expand Up @@ -122,7 +126,9 @@ def gouraud_shading(
)
verts_colors_shaded = vertex_colors * (ambient + diffuse) + specular
face_colors = verts_colors_shaded[faces]
colors = interpolate_face_attributes(fragments, face_colors)
colors = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, face_colors
)
return colors


Expand Down
82 changes: 9 additions & 73 deletions pytorch3d/renderer/mesh/texturing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,75 +7,7 @@

from pytorch3d.structures.textures import Textures


def _clip_barycentric_coordinates(bary) -> torch.Tensor:
"""
Args:
bary: barycentric coordinates of shape (...., 3) where `...` represents
an arbitrary number of dimensions
Returns:
bary: All barycentric coordinate values clipped to the range [0, 1]
and renormalized. The output is the same shape as the input.
"""
if bary.shape[-1] != 3:
msg = "Expected barycentric coords to have last dim = 3; got %r"
raise ValueError(msg % bary.shape)
clipped = bary.clamp(min=0, max=1)
clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5)
clipped = clipped / clipped_sum
return clipped


def interpolate_face_attributes(
fragments, face_attributes: torch.Tensor, bary_clip: bool = False
) -> torch.Tensor:
"""
Interpolate arbitrary face attributes using the barycentric coordinates
for each pixel in the rasterized output.
Args:
fragments:
The outputs of rasterization. From this we use
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
- barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
face_attributes: packed attributes of shape (total_faces, 3, D),
specifying the value of the attribute for each
vertex in the face.
bary_clip: Bool to indicate if barycentric_coords should be clipped
before being used for interpolation.
Returns:
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
value of the face attribute for each pixel.
"""
pix_to_face = fragments.pix_to_face
barycentric_coords = fragments.bary_coords
F, FV, D = face_attributes.shape
if FV != 3:
raise ValueError("Faces can only have three vertices; got %r" % FV)
N, H, W, K, _ = barycentric_coords.shape
if pix_to_face.shape != (N, H, W, K):
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
raise ValueError(msg % pix_to_face.shape)
if bary_clip:
barycentric_coords = _clip_barycentric_coordinates(barycentric_coords)

# Replace empty pixels in pix_to_face with 0 in order to interpolate.
mask = pix_to_face == -1
pix_to_face = pix_to_face.clone()
pix_to_face[mask] = 0
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
pixel_vals[mask] = 0 # Replace masked values in output.
return pixel_vals
from .utils import interpolate_face_attributes


def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
Expand All @@ -97,8 +29,8 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
relative to the faces (in the packed
representation) which overlap the pixel.
meshes: Meshes representing a batch of meshes. It is expected that
meshes has a textures attribute which is an instance of the
Textures class.
meshes has a textures attribute which is an instance of the
Textures class.
Returns:
texels: tensor of shape (N, H, W, K, C) giving the interpolated
Expand All @@ -114,7 +46,9 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
texture_maps = meshes.textures.maps_padded()

# pixel_uvs: (N, H, W, K, 2)
pixel_uvs = interpolate_face_attributes(fragments, faces_verts_uvs)
pixel_uvs = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs
)

N, H_out, W_out, K = fragments.pix_to_face.shape
N, H_in, W_in, C = texture_maps.shape # 3 for RGB
Expand Down Expand Up @@ -178,5 +112,7 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
faces_packed = meshes.faces_packed()
faces_textures = vertex_textures[faces_packed] # (F, 3, C)
texels = interpolate_face_attributes(fragments, faces_textures)
texels = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_textures
)
return texels
100 changes: 100 additions & 0 deletions pytorch3d/renderer/mesh/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import torch


def _clip_barycentric_coordinates(bary) -> torch.Tensor:
"""
Args:
bary: barycentric coordinates of shape (...., 3) where `...` represents
an arbitrary number of dimensions
Returns:
bary: Barycentric coordinates clipped (i.e any values < 0 are set to 0)
and renormalized. We only clip the negative values. Values > 1 will fall
into the [0, 1] range after renormalization.
The output is the same shape as the input.
"""
if bary.shape[-1] != 3:
msg = "Expected barycentric coords to have last dim = 3; got %r"
raise ValueError(msg % bary.shape)
clipped = bary.clamp(min=0.0)
clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5)
clipped = clipped / clipped_sum
return clipped


def interpolate_face_attributes(
pix_to_face: torch.Tensor,
barycentric_coords: torch.Tensor,
face_attributes: torch.Tensor,
) -> torch.Tensor:
"""
Interpolate arbitrary face attributes using the barycentric coordinates
for each pixel in the rasterized output.
Args:
pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
face_attributes: packed attributes of shape (total_faces, 3, D),
specifying the value of the attribute for each
vertex in the face.
Returns:
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
value of the face attribute for each pixel.
"""
F, FV, D = face_attributes.shape
if FV != 3:
raise ValueError("Faces can only have three vertices; got %r" % FV)
N, H, W, K, _ = barycentric_coords.shape
if pix_to_face.shape != (N, H, W, K):
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
raise ValueError(msg % pix_to_face.shape)

# Replace empty pixels in pix_to_face with 0 in order to interpolate.
mask = pix_to_face == -1
pix_to_face = pix_to_face.clone()
pix_to_face[mask] = 0
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
pixel_vals[mask] = 0 # Replace masked values in output.
return pixel_vals


def _interpolate_zbuf(
pix_to_face: torch.Tensor, barycentric_coords: torch.Tensor, meshes
) -> torch.Tensor:
"""
A helper function to calculate the z buffer for each pixel in the
rasterized output.
Args:
pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
meshes: Meshes object representing a batch of meshes.
Returns:
zbuffer: (N, H, W, K) FloatTensor
"""
verts = meshes.verts_packed()
faces = meshes.faces_packed()
faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1)
return interpolate_face_attributes(
pix_to_face, barycentric_coords, faces_verts_z
)[
..., 0
] # (1, H, W, K)
Binary file added tests/data/test_blurry_textured_rendering.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/test_simple_sphere_light_flat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions tests/test_mesh_rendering_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import unittest
import torch

from pytorch3d.renderer.mesh.utils import _clip_barycentric_coordinates


class TestMeshRenderingUtils(unittest.TestCase):
def test_bary_clip(self):
N = 10
bary = torch.randn(size=(N, 3))
# randomly make some values negative
bary[bary < 0.3] *= -1.0
# randomly make some values be greater than 1
bary[bary > 0.8] *= 2.0
negative_mask = bary < 0.0
positive_mask = bary > 1.0
clipped = _clip_barycentric_coordinates(bary)
self.assertTrue(clipped[negative_mask].sum() == 0)
self.assertTrue(clipped[positive_mask].gt(1.0).sum() == 0)
self.assertTrue(torch.allclose(clipped.sum(dim=-1), torch.ones(N)))
Loading

0 comments on commit ff19c64

Please sign in to comment.