Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multitexture obj high precision #1574

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
611 changes: 611 additions & 0 deletions docs/tutorials/multitexture_obj_IO_and_point_sampling.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pytorch3d/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.


from .obj_io import load_obj, load_objs_as_meshes, save_obj
from .obj_io import load_obj, load_objs_as_meshes, save_obj, subset_obj
from .pluggable import IO
from .ply_io import load_ply, save_ply

Expand Down
581 changes: 486 additions & 95 deletions pytorch3d/io/obj_io.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion pytorch3d/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from ..common.datatypes import Device


PathOrStr = Union[pathlib.Path, str]


Expand Down
1 change: 1 addition & 0 deletions pytorch3d/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from .sample_farthest_points import sample_farthest_points
from .sample_points_from_meshes import sample_points_from_meshes
from .sample_points_from_obj import sample_points_from_obj
from .subdivide_meshes import SubdivideMeshes
from .utils import (
convert_pointclouds_to_tensor,
Expand Down
6 changes: 4 additions & 2 deletions pytorch3d/ops/interp_face_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,16 @@ def interpolate_face_attributes(
# On CPU use the python version
# TODO: Implement a C++ version of this function
if not pix_to_face.is_cuda:
args = (pix_to_face, barycentric_coords, face_attributes)
# accomodate high_precision inputs and force float types where needed
args = (pix_to_face, barycentric_coords.float(), face_attributes.float())
return interpolate_face_attributes_python(*args)

# Otherwise flatten and call the custom autograd function
N, H, W, K = pix_to_face.shape
pix_to_face = pix_to_face.view(-1)
barycentric_coords = barycentric_coords.view(N * H * W * K, 3)
args = (pix_to_face, barycentric_coords, face_attributes)
# accomodate high_precision inputs and force float types where needed
args = (pix_to_face, barycentric_coords.float(), face_attributes.float())
out = _InterpFaceAttrs.apply(*args)
out = out.view(N, H, W, K, -1)
return out
Expand Down
201 changes: 159 additions & 42 deletions pytorch3d/ops/sample_points_from_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals
from pytorch3d.ops.packed_to_padded import packed_to_padded
from pytorch3d.renderer.mesh.rasterizer import Fragments as MeshFragments
from pytorch3d.structures import Meshes


def sample_points_from_meshes(
meshes,
num_samples: int = 10000,
return_normals: bool = False,
return_textures: bool = False,
return_mappers: bool = False,
) -> Union[
torch.Tensor,
Tuple[torch.Tensor, torch.Tensor],
Expand All @@ -38,9 +40,9 @@ def sample_points_from_meshes(
num_samples: Integer giving the number of point samples per mesh.
return_normals: If True, return normals for the sampled points.
return_textures: If True, return textures for the sampled points.

return_mappers: If True, return a mapping of each point to its origin face.
Returns:
3-element tuple containing
4-element tuple containing

- **samples**: FloatTensor of shape (N, num_samples, 3) giving the
coordinates of sampled points for each mesh in the batch. For empty
Expand All @@ -53,8 +55,10 @@ def sample_points_from_meshes(
texture vector to each sampled point. Only returned if return_textures is True.
For empty meshes the corresponding row in the textures array will
be filled with 0.
- **mappers**: IntTensor of shape (N, num_samples) providing a point to face
mapping for each point's origin face in the sample.

Note that in a future releases, we will replace the 3-element tuple output
Note that in a future releases, we will replace the 4-element tuple output
with a `Pointclouds` datastructure, as follows

.. code-block:: python
Expand All @@ -64,6 +68,9 @@ def sample_points_from_meshes(
if meshes.isempty():
raise ValueError("Meshes are empty.")

# initialize all return values
samples, normals, textures, mappers = None, None, None, None

verts = meshes.verts_packed()
if not torch.isfinite(verts).all():
raise ValueError("Meshes contain nan or inf.")
Expand All @@ -73,12 +80,8 @@ def sample_points_from_meshes(

faces = meshes.faces_packed()
mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
num_meshes = len(meshes)
num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes.

# Initialize samples tensor with fill value 0 for empty meshes.
samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device)

# Only compute samples for non empty meshes
with torch.no_grad():
areas, _ = mesh_face_areas_normals(verts, faces) # Face areas can be zero.
Expand All @@ -91,51 +94,32 @@ def sample_points_from_meshes(
sample_face_idxs = areas_padded.multinomial(
num_samples, replacement=True
) # (N, num_samples)
sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)

# Get the vertex coordinates of the sampled faces.
face_verts = verts[faces]
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
if return_mappers:
# for each mesh, a mapping of each point to its origin face by the face index
mappers = sample_face_idxs.clone()

# Randomly generate barycentric coords.
w0, w1, w2 = _rand_barycentric_coords(
num_valid_meshes, num_samples, verts.dtype, verts.device
)
sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)

# Use the barycentric coords to get a point on each sampled face.
a = v0[sample_face_idxs] # (N, num_samples, 3)
b = v1[sample_face_idxs]
c = v2[sample_face_idxs]
samples[meshes.valid] = w0[:, :, None] * a + w1[:, :, None] * b + w2[:, :, None] * c
(samples, (v0, v1, v2), (w0, w1, w2)) = _sample_points(
meshes,
num_samples,
sample_face_idxs,
verts,
faces,
)

if return_normals:
# Initialize normals tensor with fill value 0 for empty meshes.
# Normals for the sampled points are face normals computed from
# the vertices of the face in which the sampled point lies.
normals = torch.zeros((num_meshes, num_samples, 3), device=meshes.device)
vert_normals = (v1 - v0).cross(v2 - v1, dim=1)
vert_normals = vert_normals / vert_normals.norm(dim=1, p=2, keepdim=True).clamp(
min=sys.float_info.epsilon
)
vert_normals = vert_normals[sample_face_idxs]
normals[meshes.valid] = vert_normals
normals = _sample_normals(meshes, num_samples, sample_face_idxs, v0, v1, v2)

if return_textures:
# fragment data are of shape NxHxWxK. Here H=S, W=1 & K=1.
pix_to_face = sample_face_idxs.view(len(meshes), num_samples, 1, 1) # NxSx1x1
bary = torch.stack((w0, w1, w2), dim=2).unsqueeze(2).unsqueeze(2) # NxSx1x1x3
# zbuf and dists are not used in `sample_textures` so we initialize them with dummy
dummy = torch.zeros(
(len(meshes), num_samples, 1, 1), device=meshes.device, dtype=torch.float32
) # NxSx1x1
fragments = MeshFragments(
pix_to_face=pix_to_face, zbuf=dummy, bary_coords=bary, dists=dummy
)
textures = meshes.sample_textures(fragments) # NxSx1x1xC
textures = textures[:, :, 0, 0, :] # NxSxC
textures = _sample_textures(meshes, num_samples, sample_face_idxs, w0, w1, w2)

# return
# TODO(gkioxari) consider returning a Pointclouds instance [breaking]
if return_mappers:
# return a 4-element tuple
return samples, normals, textures, mappers
if return_normals and return_textures:
# pyre-fixme[61]: `normals` may not be initialized here.
# pyre-fixme[61]: `textures` may not be initialized here.
Expand Down Expand Up @@ -173,3 +157,136 @@ def _rand_barycentric_coords(
w1 = u_sqrt * (1.0 - v)
w2 = u_sqrt * v
return w0, w1, w2


def _sample_points(
meshes: Meshes,
num_samples: int,
sample_face_idxs: torch.Tensor,
verts: torch.Tensor,
faces: torch.Tensor,
) -> Tuple[
torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""This is a helper function that re-packages the core sampling function for points.
Args:
meshes: A Meshes object to sample points from.
num_samples: Integer number of samples to generate per mesh.
num_valid_meshes: Integer value, typically the value equal to torch.sum(meshes.valid).
sample_face_idxs: An array of face indices to sample from Meshes.
verts: torch.Tensor of verts, typically meshes.verts_packed().
faces: torch.Tensor of faces, typically meshes.faces_packed().

Returns:
A 3-Tuple of sampled points array, face_verts arrays as a 3-Tuple, and
barycentric coordinate arrays as a 3-Tuple.

"""
# Initialize samples tensor with fill value 0 for empty meshes.
samples = _empty_sample(len(meshes), num_samples, verts.device, verts.dtype)

# Get the vertex coordinates of the sampled faces.
face_verts = verts[faces]
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]

# Randomly generate barycentric coords.
w0, w1, w2 = _rand_barycentric_coords(
torch.sum(meshes.valid), num_samples, verts.dtype, verts.device
)

# Use the barycentric coords to get a point on each sampled face.
a = v0[sample_face_idxs] # (N, num_samples, 3)
b = v1[sample_face_idxs]
c = v2[sample_face_idxs]
samples[meshes.valid] = w0[:, :, None] * a + w1[:, :, None] * b + w2[:, :, None] * c
return samples, (v0, v1, v2), (w0, w1, w2)


def _sample_normals(
meshes: Meshes,
num_samples: int,
sample_face_idxs: torch.Tensor,
v0: torch.Tensor,
v1: torch.Tensor,
v2: torch.Tensor,
) -> torch.Tensor:
"""This is a helper function that implements the core sampling function for point normals.

Args:
meshes: A Meshes object to sample points from.
num_samples: Integer number of samples to generate per mesh.
sample_face_idxs: An array of face indices to sample from Meshes.
v0, v1, v2: torch.Tensors of face_verts.

Returns:
a torch.Tensor of normals
"""
# Initialize normals tensor with fill value 0 for empty meshes.
# Normals for the sampled points are face normals computed from
# the vertices of the face in which the sampled point lies.
normals = torch.zeros(
(len(meshes), num_samples, 3), device=meshes.device, dtype=v0.dtype
)
vert_normals = (v1 - v0).cross(v2 - v1, dim=1)
vert_normals = vert_normals / vert_normals.norm(dim=1, p=2, keepdim=True).clamp(
min=sys.float_info.epsilon
)
vert_normals = vert_normals[sample_face_idxs]
normals[meshes.valid] = vert_normals
return normals


def _sample_textures(
meshes: Meshes,
num_samples: int,
sample_face_idxs: torch.Tensor,
w0: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> torch.Tensor:
"""This is a helper function that implements the core sampling function for point textures.

Args:
meshes: A Meshes object from which to sample textures.
num_samples: Integer value for number of texture samples.
sample_face_idxs: An array of face indices to sample from Meshes.
w0, w1, w2: Tensors giving random barycentric coordinates from _sample_points.

Returns:
A torch.Tensor of sampled textures for Meshes.
"""
# fragment data are of shape NxHxWxK. Here H=S, W=1 & K=1.
pix_to_face = sample_face_idxs.view(len(meshes), num_samples, 1, 1) # NxSx1x1
bary = torch.stack((w0, w1, w2), dim=2).unsqueeze(2).unsqueeze(2) # NxSx1x1x3
# zbuf and dists are not used in `sample_textures` so we initialize them with dummy
dummy = torch.zeros(
(len(meshes), num_samples, 1, 1), device=meshes.device, dtype=bary.dtype
) # NxSx1x1
fragments = MeshFragments(
pix_to_face=pix_to_face, zbuf=dummy, bary_coords=bary, dists=dummy
)
textures = meshes.sample_textures(fragments) # NxSx1x1xC
textures = textures[:, :, 0, 0, :] # NxSxC
return textures


def _empty_sample(
num_meshes: int, num_samples: int, device: torch.device, dtype: torch.dtype = None
) -> torch.Tensor:
"""This is a helper function that returns an empty (zeros) tensor to initialize a point sample.

Args:
num_meshes: Integer value for dim 0 of the array.
num_samples: Integer value for dim 1 of the array.
device: torch.device
dtype: Optionally specify the torch.dtype to force a specific type.

Returns:
A torch.zeros Tensor in the shape of (num_meshes x num_samples x 3)
"""
if dtype is not None:
return torch.zeros((num_meshes, num_samples, 3), device=device, dtype=dtype)
else:
return torch.zeros((num_meshes, num_samples, 3), device=device)
Loading