Skip to content

Commit

Permalink
skeleton of pluggable IO
Browse files Browse the repository at this point in the history
Summary: Unified interface for loading and saving meshes and pointclouds.

Reviewed By: nikhilaravi

Differential Revision: D25372968

fbshipit-source-id: 6fe57cc3704a89d81d13e959bee707b0c7b57d3b
  • Loading branch information
bottler authored and facebook-github-bot committed Jan 7, 2021
1 parent 9fc661f commit b183dcb
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 1 deletion.
24 changes: 24 additions & 0 deletions docs/notes/io.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
---
hide_title: true
sidebar_label: File IO
---

# File IO
There is a flexible interface for loading and saving point clouds and meshes from different formats.

The main usage is via the `pytorch3d.io.IO` object, and its methods
`load_mesh`, `save_mesh`, `load_point_cloud` and `save_point_cloud`.

For example, to load a mesh you might do
```
from pytorch3d.io import IO
device=torch.device("cuda:0")
mesh = IO().load_mesh("mymesh.ply", device=device)
```

and to save a pointcloud you might do
```
pcl = Pointclouds(...)
IO().save_point_cloud(pcl, "output_poincloud.obj")
```
1 change: 1 addition & 0 deletions pytorch3d/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


from .obj_io import load_obj, load_objs_as_meshes, save_obj
from .pluggable import IO
from .ply_io import load_ply, save_ply


Expand Down
208 changes: 208 additions & 0 deletions pytorch3d/io/pluggable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


from collections import deque
from pathlib import Path
from typing import Deque, Optional, Union

from iopath.common.file_io import PathManager
from pytorch3d.structures import Meshes, Pointclouds

from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter


"""
This module has the master functions for loading and saving data.
The main usage is via the IO object, and its methods
`load_mesh`, `save_mesh`, `load_pointcloud` and `save_pointcloud`.
For example, to load a mesh you might do
```
from pytorch3d.io import IO
mesh = IO().load_mesh("mymesh.obj")
```
and to save a point cloud you might do
```
pcl = Pointclouds(...)
IO().save_pointcloud(pcl, "output_poincloud.obj")
```
"""


class IO:
"""
This class is the interface to flexible loading and saving of meshes and point clouds.
In simple cases the user will just initialise an instance of this class as `IO()`
and then use its load and save functions. The arguments of the initializer are not
usually needed.
The user can add their own formats for saving and loading by passing their own objects
to the register_* functions.
Args:
include_default_formats: If False, the built-in file formats will not be available.
Then only user-registered formats can be used.
path_manager: Used to customise how paths given as strings are interpreted.
"""

def __init__(
self,
include_default_formats: bool = True,
path_manager: Optional[PathManager] = None,
):
if path_manager is None:
self.path_manager = PathManager()
else:
self.path_manager = path_manager

self.mesh_interpreters: Deque[MeshFormatInterpreter] = deque()
self.pointcloud_interpreters: Deque[PointcloudFormatInterpreter] = deque()

if include_default_formats:
self.register_default_formats()

def register_default_formats(self) -> None:
# This will be populated in later diffs
pass

def register_meshes_format(self, interpreter: MeshFormatInterpreter) -> None:
"""
Register a new interpreter for a new mesh file format.
Args:
interpreter: the new interpreter to use, which must be an instance
of a class which inherits MeshFormatInterpreter.
"""
self.mesh_interpreters.appendleft(interpreter)

def register_pointcloud_format(
self, interpreter: PointcloudFormatInterpreter
) -> None:
"""
Register a new interpreter for a new point cloud file format.
Args:
interpreter: the new interpreter to use, which must be an instance
of a class which inherits PointcloudFormatInterpreter.
"""
self.pointcloud_interpreters.appendleft(interpreter)

def load_mesh(
self,
path: Union[str, Path],
include_textures: bool = True,
device="cpu",
**kwargs,
) -> Meshes:
"""
Attempt to load a mesh from the given file, using a registered format.
Materials are not returned. If you have a .obj file with materials
you might want to load them with the load_obj function instead.
Args:
path: file to read
include_textures: whether to try to load texture information
device: device on which to leave the data.
Returns:
new Meshes object containing one mesh.
"""
for mesh_interpreter in self.mesh_interpreters:
mesh = mesh_interpreter.read(
path,
include_textures=include_textures,
path_manager=self.path_manager,
device=device,
**kwargs,
)
if mesh is not None:
return mesh

raise ValueError(f"No mesh interpreter found to read {path}.")

def save_mesh(
self,
data: Meshes,
path: Union[str, Path],
binary: Optional[bool] = None,
include_textures: bool = True,
**kwargs,
) -> None:
"""
Attempt to save a mesh to the given file, using a registered format.
Args:
data: a 1-element Meshes
path: file to write
binary: If there is a choice, whether to save in a binary format.
include_textures: If textures are present, whether to try to save
them.
"""
if len(data) != 1:
raise ValueError("Can only save a single mesh.")

for mesh_interpreter in self.mesh_interpreters:
success = mesh_interpreter.save(
data, path, path_manager=self.path_manager, binary=binary, **kwargs
)
if success:
return

raise ValueError(f"No mesh interpreter found to write to {path}.")

def load_pointcloud(
self, path: Union[str, Path], device="cpu", **kwargs
) -> Pointclouds:
"""
Attempt to load a point cloud from the given file, using a registered format.
Args:
path: file to read
device: torch.device on which to load the data.
Returns:
new Pointclouds object containing one mesh.
"""
for pointcloud_interpreter in self.pointcloud_interpreters:
pointcloud = pointcloud_interpreter.read(
path, path_manager=self.path_manager, device=device, **kwargs
)
if pointcloud is not None:
return pointcloud

raise ValueError(f"No point cloud interpreter found to read {path}.")

def save_pointcloud(
self,
data: Pointclouds,
path: Union[str, Path],
binary: Optional[bool] = None,
**kwargs,
) -> None:
"""
Attempt to save a point cloud to the given file, using a registered format.
Args:
data: a 1-element Pointclouds
path: file to write
binary: If there is a choice, whether to save in a binary format.
"""
if len(data) != 1:
raise ValueError("Can only save a single point cloud.")

for pointcloud_interpreter in self.pointcloud_interpreters:
success = pointcloud_interpreter.save(
data, path, path_manager=self.path_manager, binary=binary, **kwargs
)
if success:
return

raise ValueError(f"No point cloud interpreter found to write to {path}.")
136 changes: 136 additions & 0 deletions pytorch3d/io/pluggable_formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


from pathlib import Path
from typing import Optional, Tuple, Union

from iopath.common.file_io import PathManager
from pytorch3d.structures import Meshes, Pointclouds


"""
This module has the base classes which must be extended to define
an interpreter for loading and saving data in a particular format.
These can be registered on an IO object so that they can be used in
its load_* and save_* functions.
"""


def endswith(path, suffixes: Tuple[str, ...]) -> bool:
"""
Returns whether the path ends with one of the given suffixes.
If `path` is not actually a path, returns True. This is useful
for allowing interpreters to bypass inappropriate paths, but
always accepting streams.
"""
if isinstance(path, Path):
return path.suffix.lower() in suffixes
if isinstance(path, str):
return path.lower().endswith(suffixes)
return True


class MeshFormatInterpreter:
"""
This is a base class for an interpreter which can read or write
a mesh in a particular format.
"""

def read(
self,
path: Union[str, Path],
include_textures: bool,
device,
path_manager: PathManager,
**kwargs,
) -> Optional[Meshes]:
"""
Read the data from the specified file and return it as
a Meshes object.
Args:
path: path to load.
include_textures: whether to try to load texture information.
device: torch.device to load data on to.
path_manager: PathManager to interpret the path.
Returns:
None if self is not the appropriate object to interpret the given
path.
Otherwise, the read Meshes object.
"""
raise NotImplementedError()

def save(
self,
data: Meshes,
path: Union[str, Path],
path_manager: PathManager,
binary: Optional[bool],
**kwargs,
) -> bool:
"""
Save the given Meshes object to the given path.
Args:
data: mesh to save
path: path to save to, which may be overwritten.
path_manager: PathManager to interpret the path.
binary: If there is a choice, whether to save in a binary format.
Returns:
False: if self is not the appropriate object to write to the given path.
True: on success.
"""
raise NotImplementedError()


class PointcloudFormatInterpreter:
"""
This is a base class for an interpreter which can read or write
a point cloud in a particular format.
"""

def read(
self, path: Union[str, Path], device, path_manager: PathManager, **kwargs
) -> Optional[Pointclouds]:
"""
Read the data from the specified file and return it as
a Pointclouds object.
Args:
path: path to load.
device: torch.device to load data on to.
path_manager: PathManager to interpret the path.
Returns:
None if self is not the appropriate object to interpret the given
path.
Otherwise, the read Pointclouds object.
"""
raise NotImplementedError()

def save(
self,
data: Pointclouds,
path: Union[str, Path],
path_manager: PathManager,
binary: Optional[bool],
**kwargs,
) -> bool:
"""
Save the given Pointclouds object to the given path.
Args:
data: point cloud object to save
path: path to save to, which may be overwritten.
path_manager: PathManager to interpret the path.
binary: If there is a choice, whether to save in a binary format.
Returns:
False: if self is not the appropriate object to write to the given path.
True: on success.
"""
raise NotImplementedError()
2 changes: 1 addition & 1 deletion pytorch3d/structures/pointclouds.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def split(self, split_sizes: list):
returned.
Returns:
list[PointClouds].
list[Pointclouds].
"""
if not all(isinstance(x, int) for x in split_sizes):
raise ValueError("Value of split_sizes must be a list of integers.")
Expand Down

0 comments on commit b183dcb

Please sign in to comment.