Skip to content

Commit

Permalink
Fix batching bug from TexturesUV packed ambiguity, other textures tidyup
Browse files Browse the repository at this point in the history
Summary:
faces_uvs_packed and verts_uvs_packed were only used in one place and the definition of the former was ambiguous. This meant that the wrong coordinates could be used for meshes other than the first in the batch. I have therefore removed both functions and build their common result inline. Added a test that a simple batch of two meshes is rendered consistently with the rendering of each alone. This test would have failed before.

I hope this fixes #283.

Some other small improvements to the textures code.

Reviewed By: nikhilaravi

Differential Revision: D23161936

fbshipit-source-id: f99b560a46f6b30262e07028b049812bc04350a7
  • Loading branch information
bottler authored and facebook-github-bot committed Aug 21, 2020
1 parent 9aaba04 commit 9a50cf8
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 65 deletions.
4 changes: 2 additions & 2 deletions pytorch3d/csrc/utils/pytorch3d_cutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
#pragma once
#include <torch/extension.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor.")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x "must be contiguous.")
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.")
#define CHECK_CONTIGUOUS_CUDA(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
2 changes: 1 addition & 1 deletion pytorch3d/datasets/shapenet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import torch
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer import (
FoVPerspectiveCameras,
HardPhongShader,
MeshRasterizer,
MeshRenderer,
FoVPerspectiveCameras,
PointLights,
RasterizationSettings,
TexturesVertex,
Expand Down
5 changes: 4 additions & 1 deletion pytorch3d/renderer/blending.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
# Mask for the background.
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)

background_color = colors.new_tensor(blend_params.background_color) # (3)
if torch.is_tensor(blend_params.background_color):
background_color = blend_params.background_color
else:
background_color = colors.new_tensor(blend_params.background_color) # (3)

# Find out how much background_color needs to be expanded to be used for masked_scatter.
num_background_pixels = is_background.sum()
Expand Down
54 changes: 13 additions & 41 deletions pytorch3d/renderer/mesh/textures.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _pad_texture_maps(
# This is also useful to have so that inside `Meshes`
# we can allow the input textures to be any texture
# type which is an instance of the base class.
class TexturesBase(object):
class TexturesBase:
def __init__(self):
self._N = 0
self.valid = None
Expand Down Expand Up @@ -262,9 +262,6 @@ class attributes for item i. Then, a new
"""
raise NotImplementedError()

def __repr__(self):
return "TexturesBase"


def Textures(
maps: Union[List, torch.Tensor, None] = None,
Expand Down Expand Up @@ -385,14 +382,6 @@ def __init__(self, atlas: Union[torch.Tensor, List, None]):
# refer to the __init__ of Meshes.
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)

# This is a hack to allow the child classes to also have the same representation
# as the parent. In meshes.py we check that the input textures have the correct
# type. However due to circular imports issues, we can't import the texture
# classes into any files in pytorch3d.structures. Instead we check
# for repr(textures) == "TexturesBase".
def __repr__(self):
return super().__repr__()

def clone(self):
tex = self.__class__(atlas=self.atlas_padded().clone())
if self._atlas_list is not None:
Expand Down Expand Up @@ -556,10 +545,7 @@ def __init__(
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3)
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each face
verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
Note: only the padded and list representation of the textures are stored
and the packed representations is computed on the fly and
not cached.
(a FloatTensor with values between 0 and 1)
"""
super().__init__()
if isinstance(faces_uvs, (list, tuple)):
Expand Down Expand Up @@ -611,9 +597,6 @@ def __init__(
"verts_uvs and faces_uvs must have the same batch dimension"
)
if not all(v.device == self.device for v in verts_uvs):
import pdb

pdb.set_trace()
raise ValueError("verts_uvs and faces_uvs must be on the same device")

# These values may be overridden when textures is
Expand Down Expand Up @@ -669,9 +652,6 @@ def __init__(

self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)

def __repr__(self):
return super().__repr__()

def clone(self):
tex = self.__class__(
self.maps_padded().clone(),
Expand Down Expand Up @@ -759,12 +739,6 @@ def faces_uvs_list(self) -> List[torch.Tensor]:
)
return self._faces_uvs_list

def faces_uvs_packed(self) -> torch.Tensor:
if self.isempty():
return torch.zeros((self._N, 3), dtype=torch.float32, device=self.device)
faces_uvs_list = self.faces_uvs_list()
return list_to_packed(faces_uvs_list)[0]

def verts_uvs_padded(self) -> torch.Tensor:
if self._verts_uvs_padded is None:
if self.isempty():
Expand All @@ -789,12 +763,6 @@ def verts_uvs_list(self) -> List[torch.Tensor]:
)
return self._verts_uvs_list

def verts_uvs_packed(self) -> torch.Tensor:
if self.isempty():
return torch.zeros((self._N, 2), dtype=torch.float32, device=self.device)
verts_uvs_list = self.verts_uvs_list()
return list_to_packed(verts_uvs_list)[0]

# Currently only the padded maps are used.
def maps_padded(self) -> torch.Tensor:
return self._maps_padded
Expand Down Expand Up @@ -850,9 +818,15 @@ def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
texels: tensor of shape (N, H, W, K, C) giving the interpolated
texture for each pixel in the rasterized image.
"""
verts_uvs = self.verts_uvs_packed()
faces_uvs = self.faces_uvs_packed()
faces_verts_uvs = verts_uvs[faces_uvs]
if self.isempty():
faces_verts_uvs = torch.zeros(
(self._N, 3, 2), dtype=torch.float32, device=self.device
)
else:
packing_list = [
i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list())
]
faces_verts_uvs = torch.cat(packing_list)
texture_maps = self.maps_padded()

# pixel_uvs: (N, H, W, K, 2)
Expand Down Expand Up @@ -890,6 +864,7 @@ def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
if texture_maps.device != pixel_uvs.device:
texture_maps = texture_maps.to(pixel_uvs.device)
texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
# texels now has shape (NK, C, H_out, W_out)
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
return texels

Expand Down Expand Up @@ -990,9 +965,6 @@ def __init__(
# refer to the __init__ of Meshes.
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)

def __repr__(self):
return super().__repr__()

def clone(self):
tex = self.__class__(self.verts_features_padded().clone())
if self._verts_features_list is not None:
Expand Down Expand Up @@ -1048,7 +1020,7 @@ def verts_features_list(self) -> List[torch.Tensor]:
if self._verts_features_list is None:
if self.isempty():
self._verts_features_list = [
torch.empty((0, 3, 0), dtype=torch.float32, device=self.device)
torch.empty((0, 3), dtype=torch.float32, device=self.device)
] * self._N
else:
self._verts_features_list = padded_to_list(
Expand Down
4 changes: 2 additions & 2 deletions pytorch3d/structures/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def __init__(self, verts=None, faces=None, textures=None):
Refer to comments above for descriptions of List and Padded representations.
"""
self.device = None
if textures is not None and not repr(textures) == "TexturesBase":
if textures is not None and not hasattr(textures, "sample_textures"):
msg = "Expected textures to be an instance of type TexturesBase; got %r"
raise ValueError(msg % repr(textures))
raise ValueError(msg % type(textures))
self.textures = textures

# Indicates whether the meshes in the list/batch have the same number
Expand Down
83 changes: 82 additions & 1 deletion tests/test_render_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
SoftSilhouetteShader,
TexturedSoftPhongShader,
)
from pytorch3d.structures.meshes import Meshes, join_mesh
from pytorch3d.structures.meshes import Meshes, join_mesh, join_meshes_as_batch
from pytorch3d.utils.ico_sphere import ico_sphere
from pytorch3d.utils.torus import torus


# If DEBUG=True, save out images generated in the tests for debugging.
Expand Down Expand Up @@ -490,6 +491,86 @@ def test_texture_map(self):

self.assertClose(rgb, image_ref, atol=0.05)

def test_batch_uvs(self):
"""Test that two random tori with TexturesUV render the same as each individually."""
torch.manual_seed(1)
device = torch.device("cuda:0")
plain_torus = torus(r=1, R=4, sides=10, rings=10, device=device)
[verts] = plain_torus.verts_list()
[faces] = plain_torus.faces_list()
nocolor = torch.zeros((100, 100), device=device)
color_gradient = torch.linspace(0, 1, steps=100, device=device)
color_gradient1 = color_gradient[None].expand_as(nocolor)
color_gradient2 = color_gradient[:, None].expand_as(nocolor)
colors1 = torch.stack([nocolor, color_gradient1, color_gradient2], dim=2)
colors2 = torch.stack([color_gradient1, color_gradient2, nocolor], dim=2)
verts_uvs1 = torch.rand(size=(verts.shape[0], 2), device=device)
verts_uvs2 = torch.rand(size=(verts.shape[0], 2), device=device)

textures1 = TexturesUV(
maps=[colors1], faces_uvs=[faces], verts_uvs=[verts_uvs1]
)
textures2 = TexturesUV(
maps=[colors2], faces_uvs=[faces], verts_uvs=[verts_uvs2]
)
mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
mesh2 = Meshes(verts=[verts], faces=[faces], textures=textures2)
mesh_both = join_meshes_as_batch([mesh1, mesh2])

R, T = look_at_view_transform(10, 10, 0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)

raster_settings = RasterizationSettings(
image_size=128, blur_radius=0.0, faces_per_pixel=1
)

# Init shader settings
lights = PointLights(device=device)
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]

blend_params = BlendParams(
sigma=1e-1,
gamma=1e-4,
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
)
# Init renderer
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=HardPhongShader(
device=device, lights=lights, cameras=cameras, blend_params=blend_params
),
)

outputs = []
for meshes in [mesh_both, mesh1, mesh2]:
outputs.append(renderer(meshes))

if DEBUG:
Image.fromarray(
(outputs[0][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
).save(DATA_DIR / "test_batch_uvs0.png")
Image.fromarray(
(outputs[1][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
).save(DATA_DIR / "test_batch_uvs1.png")
Image.fromarray(
(outputs[0][1, ..., :3].cpu().numpy() * 255).astype(np.uint8)
).save(DATA_DIR / "test_batch_uvs2.png")
Image.fromarray(
(outputs[2][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
).save(DATA_DIR / "test_batch_uvs3.png")

diff = torch.abs(outputs[0][0, ..., :3] - outputs[1][0, ..., :3])
Image.fromarray(((diff > 1e-5).cpu().numpy().astype(np.uint8) * 255)).save(
DATA_DIR / "test_batch_uvs01.png"
)
diff = torch.abs(outputs[0][1, ..., :3] - outputs[2][0, ..., :3])
Image.fromarray(((diff > 1e-5).cpu().numpy().astype(np.uint8) * 255)).save(
DATA_DIR / "test_batch_uvs23.png"
)

self.assertClose(outputs[0][0, ..., :3], outputs[1][0, ..., :3], atol=1e-5)
self.assertClose(outputs[0][1, ..., :3], outputs[2][0, ..., :3], atol=1e-5)

def test_joined_spheres(self):
"""
Test a list of Meshes can be joined as a single mesh and
Expand Down
19 changes: 2 additions & 17 deletions tests/test_texturing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def tryindex(self, index, tex, meshes, source):
basic = basic[None]

if len(basic) == 0:
self.assertEquals(len(from_texture), 0)
self.assertEquals(len(from_meshes), 0)
self.assertEqual(len(from_texture), 0)
self.assertEqual(len(from_meshes), 0)
else:
self.assertClose(basic, from_texture)
self.assertClose(basic, from_meshes)
Expand Down Expand Up @@ -608,12 +608,8 @@ def test_extend(self):
[
tex_init.faces_uvs_padded(),
new_tex.faces_uvs_padded(),
tex_init.faces_uvs_packed(),
new_tex.faces_uvs_packed(),
tex_init.verts_uvs_padded(),
new_tex.verts_uvs_padded(),
tex_init.verts_uvs_packed(),
new_tex.verts_uvs_packed(),
tex_init.maps_padded(),
new_tex.maps_padded(),
]
Expand Down Expand Up @@ -646,11 +642,9 @@ def test_padded_to_packed(self):
tex1 = tex.clone()
tex1._num_faces_per_mesh = num_faces_per_mesh
tex1._num_verts_per_mesh = num_verts_per_mesh
verts_packed = tex1.verts_uvs_packed()
verts_list = tex1.verts_uvs_list()
verts_padded = tex1.verts_uvs_padded()

faces_packed = tex1.faces_uvs_packed()
faces_list = tex1.faces_uvs_list()
faces_padded = tex1.faces_uvs_padded()

Expand All @@ -660,9 +654,7 @@ def test_padded_to_packed(self):
for f1, f2 in zip(verts_list, verts_uvs_list):
self.assertTrue((f1 == f2).all().item())

self.assertTrue(faces_packed.shape == (3 + 2, 3))
self.assertTrue(faces_padded.shape == (2, 3, 3))
self.assertTrue(verts_packed.shape == (9 + 6, 2))
self.assertTrue(verts_padded.shape == (2, 9, 2))

# Case where num_faces_per_mesh is not set and faces_verts_uvs
Expand All @@ -672,16 +664,9 @@ def test_padded_to_packed(self):
verts_uvs=verts_padded,
faces_uvs=faces_padded,
)
faces_packed = tex2.faces_uvs_packed()
faces_list = tex2.faces_uvs_list()
verts_packed = tex2.verts_uvs_packed()
verts_list = tex2.verts_uvs_list()

# Packed is just flattened padded as num_faces_per_mesh
# has not been provided.
self.assertTrue(faces_packed.shape == (3 * 2, 3))
self.assertTrue(verts_packed.shape == (9 * 2, 2))

for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)):
n = num_faces_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
Expand Down

0 comments on commit 9a50cf8

Please sign in to comment.