diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index 8d470a80d..5c2976b00 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -470,7 +470,7 @@ def atlas_packed(self) -> torch.Tensor: def extend(self, N: int) -> "TexturesAtlas": new_props = self._extend(N, ["atlas_padded", "_num_faces_per_mesh"]) - new_tex = TexturesAtlas(atlas=new_props["atlas_padded"]) + new_tex = self.__class__(atlas=new_props["atlas_padded"]) new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"] return new_tex @@ -865,7 +865,7 @@ def extend(self, N: int) -> "TexturesUV": "_num_faces_per_mesh", ], ) - new_tex = TexturesUV( + new_tex = self.__class__( maps=new_props["maps_padded"], faces_uvs=new_props["faces_uvs_padded"], verts_uvs=new_props["verts_uvs_padded"], @@ -1339,7 +1339,7 @@ def verts_features_packed(self) -> torch.Tensor: def extend(self, N: int) -> "TexturesVertex": new_props = self._extend(N, ["verts_features_padded", "_num_verts_per_mesh"]) - new_tex = TexturesVertex(verts_features=new_props["verts_features_padded"]) + new_tex = self.__class__(verts_features=new_props["verts_features_padded"]) new_tex._num_verts_per_mesh = new_props["_num_verts_per_mesh"] return new_tex