Skip to content

Commit

Permalink
Support reading uv and uv map for ply format if texture_uv exists in …
Browse files Browse the repository at this point in the history
…ply file (#1100)

Summary:
When the ply format looks as follows:
  ```
comment TextureFile ***.png
element vertex 892
property double x
property double y
property double z
property double nx
property double ny
property double nz
property double texture_u
property double texture_v
```
`MeshPlyFormat` class will read uv from the ply file and read the uv map as commented as TextureFile.

Pull Request resolved: #1100

Reviewed By: MichaelRamamonjisoa

Differential Revision: D50885176

Pulled By: bottler

fbshipit-source-id: be75b1ec9a17a1ed87dbcf846a9072ea967aec37
  • Loading branch information
YangHai-1218 authored and facebook-github-bot committed Nov 14, 2023
1 parent f4f2209 commit 55638f3
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 6 deletions.
59 changes: 54 additions & 5 deletions pytorch3d/io/ply_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
meshes and point clouds as PLY files.
"""
import itertools
import os
import struct
import sys
import warnings
Expand All @@ -21,8 +22,14 @@
import numpy as np
import torch
from iopath.common.file_io import PathManager
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file, PathOrStr
from pytorch3d.renderer import TexturesVertex
from pytorch3d.io.utils import (
_check_faces_indices,
_make_tensor,
_open_file,
_read_image,
PathOrStr,
)
from pytorch3d.renderer import TexturesUV, TexturesVertex
from pytorch3d.structures import Meshes, Pointclouds

from .pluggable_formats import (
Expand Down Expand Up @@ -804,6 +811,7 @@ class _VertsColumnIndices:
color_idxs: Optional[List[int]]
color_scale: float
normal_idxs: Optional[List[int]]
texture_uv_idxs: Optional[List[int]]


def _get_verts_column_indices(
Expand All @@ -827,6 +835,8 @@ def _get_verts_column_indices(
property uchar red
property uchar green
property uchar blue
property double texture_u
property double texture_v
then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5])
Expand All @@ -839,6 +849,7 @@ def _get_verts_column_indices(
point_idxs: List[Optional[int]] = [None, None, None]
color_idxs: List[Optional[int]] = [None, None, None]
normal_idxs: List[Optional[int]] = [None, None, None]
texture_uv_idxs: List[Optional[int]] = [None, None]
for i, prop in enumerate(vertex_head.properties):
if prop.list_size_type is not None:
raise ValueError("Invalid vertices in file: did not expect list.")
Expand All @@ -851,6 +862,9 @@ def _get_verts_column_indices(
for j, name in enumerate(["nx", "ny", "nz"]):
if prop.name == name:
normal_idxs[j] = i
for j, name in enumerate(["texture_u", "texture_v"]):
if prop.name == name:
texture_uv_idxs[j] = i
if None in point_idxs:
raise ValueError("Invalid vertices in file.")
color_scale = 1.0
Expand All @@ -864,6 +878,7 @@ def _get_verts_column_indices(
color_idxs=None if None in color_idxs else color_idxs,
color_scale=color_scale,
normal_idxs=None if None in normal_idxs else normal_idxs,
texture_uv_idxs=None if None in texture_uv_idxs else texture_uv_idxs,
)


Expand All @@ -880,6 +895,7 @@ class _VertsData:
verts: torch.Tensor
verts_colors: Optional[torch.Tensor] = None
verts_normals: Optional[torch.Tensor] = None
verts_texture_uvs: Optional[torch.Tensor] = None


def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
Expand Down Expand Up @@ -922,6 +938,7 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:

vertex_colors = None
vertex_normals = None
vertex_texture_uvs = None

if len(vertex) == 1:
# This is the case where the whole vertex element has one type,
Expand All @@ -935,6 +952,10 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
vertex_normals = torch.tensor(
vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32
)
if column_idxs.texture_uv_idxs is not None:
vertex_texture_uvs = torch.tensor(
vertex[0][:, column_idxs.texture_uv_idxs], dtype=torch.float32
)
else:
# The vertex element is heterogeneous. It was read as several arrays,
# part by part, where a part is a set of properties with the same type.
Expand Down Expand Up @@ -973,11 +994,19 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
for axis in range(3):
partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]]
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]

if column_idxs.texture_uv_idxs is not None:
vertex_texture_uvs = torch.empty(
size=(vertex_head.count, 2),
dtype=torch.float32,
)
for axis in range(2):
partnum, col = prop_to_partnum_col[column_idxs.texture_uv_idxs[axis]]
vertex_texture_uvs.numpy()[:, axis] = vertex[partnum][:, col]
return _VertsData(
verts=verts,
verts_colors=vertex_colors,
verts_normals=vertex_normals,
verts_texture_uvs=vertex_texture_uvs,
)


Expand All @@ -998,6 +1027,7 @@ class _PlyData:
faces: Optional[torch.Tensor]
verts_colors: Optional[torch.Tensor]
verts_normals: Optional[torch.Tensor]
verts_texture_uvs: Optional[torch.Tensor]


def _load_ply(f, *, path_manager: PathManager) -> _PlyData:
Expand Down Expand Up @@ -1358,8 +1388,27 @@ def read(
faces = torch.zeros(0, 3, dtype=torch.int64)

texture = None
if include_textures and data.verts_colors is not None:
texture = TexturesVertex([data.verts_colors.to(device)])
if include_textures:
if data.verts_colors is not None:
texture = TexturesVertex([data.verts_colors.to(device)])
elif data.verts_texture_uvs is not None:
texture_file_path = None
for comment in data.header.comments:
if "TextureFile" in comment:
given_texture_file = comment.split(" ")[-1]
texture_file_path = os.path.join(
os.path.dirname(str(path)), given_texture_file
)
if texture_file_path is not None:
texture_map = _read_image(
texture_file_path, path_manager, format="RGB"
)
texture_map = torch.tensor(texture_map, dtype=torch.float32) / 255.0
texture = TexturesUV(
[texture_map.to(device)],
[faces.to(device)],
[data.verts_texture_uvs.to(device)],
)

verts_normals = None
if data.verts_normals is not None:
Expand Down
28 changes: 28 additions & 0 deletions tests/data/uvs.ply
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
ply
format ascii 1.0
comment made by Greg Turk
comment this file is a cube
comment TextureFile test_nd_sphere.png
element vertex 8
property float x
property float y
property float z
property float texture_u
property float texture_v
element face 6
property list uchar int vertex_index
end_header
0 0 0 0 0
0 0 1 0.2 0.3
0 1 1 0.2 0.3
0 1 0 0.2 0.3
1 0 0 0.2 0.3
1 0 1 0.2 0.3
1 1 1 0.2 0.3
1 1 0 0.4 0.5
4 0 1 2 3
4 7 6 5 4
4 0 4 5 1
4 1 5 6 2
4 2 6 7 3
4 3 7 4 0
16 changes: 15 additions & 1 deletion tests/test_io_ply.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.utils import torus

from .common_testing import TestCaseMixin
from .common_testing import get_tests_dir, TestCaseMixin


global_path_manager = PathManager()
DATA_DIR = get_tests_dir() / "data"


def _load_ply_raw(stream):
Expand Down Expand Up @@ -778,6 +779,19 @@ def test_load_simple_binary(self):
data["minus_ones"], [-1, 255, -1, 65535, -1, 4294967295]
)

def test_load_uvs(self):
io = IO()
mesh = io.load_mesh(DATA_DIR / "uvs.ply")
self.assertEqual(mesh.textures.verts_uvs_padded().shape, (1, 8, 2))
self.assertClose(
mesh.textures.verts_uvs_padded()[0],
torch.tensor([[0, 0]] + [[0.2, 0.3]] * 6 + [[0.4, 0.5]]),
)
self.assertEqual(
mesh.textures.faces_uvs_padded().shape, mesh.faces_padded().shape
)
self.assertEqual(mesh.textures.maps_padded().shape, (1, 512, 512, 3))

def test_bad_ply_syntax(self):
"""Some syntactically bad ply files."""
lines = [
Expand Down

0 comments on commit 55638f3

Please sign in to comment.