Skip to content

Commit

Permalink
TexturesVertex._num_verts_per_mesh deep copy (#623)
Browse files Browse the repository at this point in the history
Summary:
When a list of Meshes is `join_batched()`, the `num_verts_per_mesh` in the list would be unexpectedly modified.

Also some cleanup around `_num_verts_per_mesh`.

Pull Request resolved: #623

Test Plan: A modification to an existing test checks this.

Reviewed By: nikhilaravi

Differential Revision: D27682104

Pulled By: bottler

fbshipit-source-id: 9d00913dfb4869bd6c7d3f5cc9156b7b6f1aecc9
  • Loading branch information
JudyYe authored and facebook-github-bot committed Apr 20, 2021
1 parent 8660db9 commit eb04a48
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
23 changes: 6 additions & 17 deletions pytorch3d/renderer/mesh/textures.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,8 +842,7 @@ def verts_uvs_list(self) -> List[torch.Tensor]:
else:
# The number of vertices in the mesh and in verts_uvs can differ
# e.g. if a vertex is shared between 3 faces, it can
# have up to 3 different uv coordinates. Therefore we cannot
# convert directly from padded to list using _num_verts_per_mesh
# have up to 3 different uv coordinates.
self._verts_uvs_list = list(self._verts_uvs_padded.unbind(0))
return self._verts_uvs_list

Expand Down Expand Up @@ -1283,25 +1282,15 @@ def clone(self):
tex = self.__class__(self.verts_features_padded().clone())
if self._verts_features_list is not None:
tex._verts_features_list = [f.clone() for f in self._verts_features_list]
num_verts = (
self._num_verts_per_mesh.clone()
if torch.is_tensor(self._num_verts_per_mesh)
else self._num_verts_per_mesh
)
tex._num_verts_per_mesh = num_verts
tex._num_verts_per_mesh = self._num_verts_per_mesh.copy()
tex.valid = self.valid.clone()
return tex

def detach(self):
tex = self.__class__(self.verts_features_padded().detach())
if self._verts_features_list is not None:
tex._verts_features_list = [f.detach() for f in self._verts_features_list]
num_verts = (
self._num_verts_per_mesh.detach()
if torch.is_tensor(self._num_verts_per_mesh)
else self._num_verts_per_mesh
)
tex._num_verts_per_mesh = num_verts
tex._num_verts_per_mesh = self._num_verts_per_mesh.copy()
tex.valid = self.valid.detach()
return tex

Expand Down Expand Up @@ -1414,13 +1403,13 @@ def join_batch(self, textures: List["TexturesVertex"]) -> "TexturesVertex":

verts_features_list = []
verts_features_list += self.verts_features_list()
num_faces_per_mesh = self._num_verts_per_mesh
num_verts_per_mesh = self._num_verts_per_mesh.copy()
for tex in textures:
verts_features_list += tex.verts_features_list()
num_faces_per_mesh += tex._num_verts_per_mesh
num_verts_per_mesh += tex._num_verts_per_mesh

new_tex = self.__class__(verts_features=verts_features_list)
new_tex._num_verts_per_mesh = num_faces_per_mesh
new_tex._num_verts_per_mesh = num_verts_per_mesh
return new_tex

def join_scene(self) -> "TexturesVertex":
Expand Down
5 changes: 5 additions & 0 deletions tests/test_io_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,11 @@ def check_item(x, y):
mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=rgb_tex)
mesh_rgb3 = join_meshes_as_batch([mesh_rgb, mesh_rgb, mesh_rgb])
check_triple(mesh_rgb, mesh_rgb3)
nums_rgb = mesh_rgb.textures._num_verts_per_mesh
nums_rgb3 = mesh_rgb3.textures._num_verts_per_mesh
self.assertEqual(type(nums_rgb), list)
self.assertEqual(type(nums_rgb3), list)
self.assertListEqual(nums_rgb * 3, nums_rgb3)

# meshes with texture atlas, join into a batch.
device = "cuda:0"
Expand Down

0 comments on commit eb04a48

Please sign in to comment.