Skip to content

Commit

Permalink
add existing mesh formats to pluggable
Browse files Browse the repository at this point in the history
Summary: We already have code for obj and ply formats. Here we actually make it available in `IO.load_mesh` and `IO.save_mesh`.

Reviewed By: theschnitz, nikhilaravi

Differential Revision: D25400650

fbshipit-source-id: f26d6d7fc46c48634a948eea4d255afad13b807b
  • Loading branch information
bottler authored and facebook-github-bot committed Jan 7, 2021
1 parent b183dcb commit 89532a8
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 65 deletions.
57 changes: 56 additions & 1 deletion pytorch3d/io/obj_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import os
import warnings
from collections import namedtuple
from typing import List, Optional
from pathlib import Path
from typing import List, Optional, Union

import numpy as np
import torch
Expand All @@ -15,6 +16,8 @@
from pytorch3d.renderer import TexturesAtlas, TexturesUV
from pytorch3d.structures import Meshes, join_meshes_as_batch

from .pluggable_formats import MeshFormatInterpreter, endswith


# Faces & Aux type returned from load_obj function.
_Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
Expand Down Expand Up @@ -286,6 +289,58 @@ def load_objs_as_meshes(
return join_meshes_as_batch(mesh_list)


class MeshObjFormat(MeshFormatInterpreter):
def __init__(self):
self.known_suffixes = (".obj",)

def read(
self,
path: Union[str, Path],
include_textures: bool,
device,
path_manager: PathManager,
create_texture_atlas: bool = False,
texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat",
**kwargs,
) -> Optional[Meshes]:
if not endswith(path, self.known_suffixes):
return None
mesh = load_objs_as_meshes(
files=[path],
device=device,
load_textures=include_textures,
create_texture_atlas=create_texture_atlas,
texture_atlas_size=texture_atlas_size,
texture_wrap=texture_wrap,
path_manager=path_manager,
)
return mesh

def save(
self,
data: Meshes,
path: Union[str, Path],
path_manager: PathManager,
binary: Optional[bool],
decimal_places: Optional[int] = None,
**kwargs,
) -> bool:
if not endswith(path, self.known_suffixes):
return False

verts = data.verts_list()[0]
faces = data.faces_list()[0]
save_obj(
f=path,
verts=verts,
faces=faces,
decimal_places=decimal_places,
path_manager=path_manager,
)
return True


def _parse_face(
line,
tokens,
Expand Down
6 changes: 4 additions & 2 deletions pytorch3d/io/pluggable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from iopath.common.file_io import PathManager
from pytorch3d.structures import Meshes, Pointclouds

from .obj_io import MeshObjFormat
from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter
from .ply_io import MeshPlyFormat


"""
Expand Down Expand Up @@ -70,8 +72,8 @@ def __init__(
self.register_default_formats()

def register_default_formats(self) -> None:
# This will be populated in later diffs
pass
self.register_meshes_format(MeshObjFormat())
self.register_meshes_format(MeshPlyFormat())

def register_meshes_format(self, interpreter: MeshFormatInterpreter) -> None:
"""
Expand Down
54 changes: 51 additions & 3 deletions pytorch3d/io/ply_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
import warnings
from collections import namedtuple
from io import BytesIO
from typing import Optional, Tuple
from pathlib import Path
from typing import Optional, Tuple, Union

import numpy as np
import torch
from iopath.common.file_io import PathManager
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.structures import Meshes

from .pluggable_formats import MeshFormatInterpreter, endswith


_PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type")
Expand Down Expand Up @@ -679,8 +683,7 @@ def load_ply(f, path_manager: Optional[PathManager] = None):
# but we don't need to enforce this.

if not len(face):
# pyre-fixme[28]: Unexpected keyword argument `size`.
faces = torch.zeros(size=(0, 3), dtype=torch.int64)
faces = torch.zeros((0, 3), dtype=torch.int64)
elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements
if face.shape[1] < 3:
raise ValueError("Faces must have at least 3 vertices.")
Expand Down Expand Up @@ -831,3 +834,48 @@ def save_ply(
path_manager = PathManager()
with _open_file(f, path_manager, "wb") as f:
_save_ply(f, verts, faces, verts_normals, ascii, decimal_places)


class MeshPlyFormat(MeshFormatInterpreter):
def __init__(self):
self.known_suffixes = (".ply",)

def read(
self,
path: Union[str, Path],
include_textures: bool,
device,
path_manager: PathManager,
**kwargs,
) -> Optional[Meshes]:
if not endswith(path, self.known_suffixes):
return None

verts, faces = load_ply(f=path, path_manager=path_manager)
mesh = Meshes(verts=[verts.to(device)], faces=[faces.to(device)])
return mesh

def save(
self,
data: Meshes,
path: Union[str, Path],
path_manager: PathManager,
binary: Optional[bool],
decimal_places: Optional[int] = None,
**kwargs,
) -> bool:
if not endswith(path, self.known_suffixes):
return False

# TODO: normals are not saved. We only want to save them if they already exist.
verts = data.verts_list()[0]
faces = data.faces_list()[0]
save_ply(
f=path,
verts=verts,
faces=faces,
ascii=binary is False,
decimal_places=decimal_places,
path_manager=path_manager,
)
return True
71 changes: 68 additions & 3 deletions tests/test_obj_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import warnings
from io import StringIO
from pathlib import Path
from tempfile import NamedTemporaryFile

import torch
from common_testing import TestCaseMixin
from iopath.common.file_io import PathManager
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj
from pytorch3d.io.mtl_io import (
_bilinear_interpolation_grid_sample,
_bilinear_interpolation_vectorized,
Expand Down Expand Up @@ -145,6 +146,70 @@ def test_load_obj_complex(self):
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)

def test_load_obj_complex_pluggable(self):
"""
This won't work on Windows due to the behavior of NamedTemporaryFile
"""
obj_file = "\n".join(
[
"# this is a comment", # Comments should be ignored.
"v 0.1 0.2 0.3",
"v 0.2 0.3 0.4",
"v 0.3 0.4 0.5",
"v 0.4 0.5 0.6",
"vn 0.000000 0.000000 -1.000000",
"vn -1.000000 -0.000000 -0.000000",
"vn -0.000000 -0.000000 1.000000", # Normals should not be ignored.
"v 0.5 0.6 0.7",
"vt 0.749279 0.501284 0.0", # Some files add 0.0 - ignore this.
"vt 0.999110 0.501077",
"vt 0.999455 0.750380",
"f 1 2 3",
"f 1 2 4 3 5", # Polygons should be split into triangles
"f 2/1/2 3/1/2 4/2/2", # Texture/normals are loaded correctly.
"f -1 -2 1", # Negative indexing counts from the end.
]
)
io = IO()
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
mesh = io.load_mesh(f.name)
mesh_from_path = io.load_mesh(Path(f.name))

with NamedTemporaryFile(mode="w", suffix=".ply") as f:
f.write(obj_file)
f.flush()
with self.assertRaisesRegex(ValueError, "Invalid file header."):
io.load_mesh(f.name)

expected_verts = torch.tensor(
[
[0.1, 0.2, 0.3],
[0.2, 0.3, 0.4],
[0.3, 0.4, 0.5],
[0.4, 0.5, 0.6],
[0.5, 0.6, 0.7],
],
dtype=torch.float32,
)
expected_faces = torch.tensor(
[
[0, 1, 2], # First face
[0, 1, 3], # Second face (polygon)
[0, 3, 2], # Second face (polygon)
[0, 2, 4], # Second face (polygon)
[1, 2, 3], # Third face (normals / texture)
[4, 3, 0], # Fourth face (negative indices)
],
dtype=torch.int64,
)
self.assertClose(mesh.verts_padded(), expected_verts[None])
self.assertClose(mesh.faces_padded(), expected_faces[None])
self.assertClose(mesh_from_path.verts_padded(), expected_verts[None])
self.assertClose(mesh_from_path.faces_padded(), expected_faces[None])
self.assertIsNone(mesh.textures)

def test_load_obj_normals_only(self):
obj_file = "\n".join(
[
Expand Down Expand Up @@ -588,8 +653,8 @@ def test_load_obj_mlt_no_image(self):
expected_atlas = torch.tensor([0.5, 0.0, 0.0], dtype=torch.float32)
expected_atlas = expected_atlas[None, None, None, :].expand(2, R, R, -1)
self.assertTrue(torch.allclose(aux.texture_atlas, expected_atlas))
self.assertEquals(len(aux.material_colors.keys()), 1)
self.assertEquals(list(aux.material_colors.keys()), ["material_1"])
self.assertEqual(len(aux.material_colors.keys()), 1)
self.assertEqual(list(aux.material_colors.keys()), ["material_1"])

def test_load_obj_missing_texture(self):
DATA_DIR = Path(__file__).resolve().parent / "data"
Expand Down
Loading

0 comments on commit 89532a8

Please sign in to comment.