Skip to content

Commit

Permalink
save_ply binary
Browse files Browse the repository at this point in the history
Summary:
Make save_ply save to binary instead of ascii. An option makes the previous functionality available. save_ply's API accepts a stream, but this is undocumented; that stream must now be a binary stream not a text stream.

Avoiding warnings about making tensors from immutable numpy arrays.

Possible performance improvement when reading binary files.

Fix reading zero-length binary lists.

Reviewed By: nikhilaravi

Differential Revision: D22333118

fbshipit-source-id: b423dfd3da46e047bead200255f47a7707306811
  • Loading branch information
bottler authored and facebook-github-bot committed Sep 21, 2020
1 parent ebe2693 commit 197f1d6
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 44 deletions.
89 changes: 62 additions & 27 deletions pytorch3d/io/ply_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import warnings
from collections import namedtuple
from io import BytesIO
from typing import Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -386,11 +387,18 @@ def _read_ply_fixed_size_element_binary(
np_type = ply_type.np_type
type_size = ply_type.size
needed_length = definition.count * len(definition.properties)
needed_bytes = needed_length * type_size
bytes_data = f.read(needed_bytes)
if len(bytes_data) != needed_bytes:
raise ValueError("Not enough data for %s." % definition.name)
data = np.frombuffer(bytes_data, dtype=np_type)
if isinstance(f, BytesIO):
# np.fromfile is faster but won't work on a BytesIO
needed_bytes = needed_length * type_size
bytes_data = bytearray(needed_bytes)
n_bytes_read = f.readinto(bytes_data)
if n_bytes_read != needed_bytes:
raise ValueError("Not enough data for %s." % definition.name)
data = np.frombuffer(bytes_data, dtype=np_type)
else:
data = np.fromfile(f, dtype=np_type, count=needed_length)
if data.shape[0] != needed_length:
raise ValueError("Not enough data for %s." % definition.name)

if (sys.byteorder == "big") != big_endian:
data = data.byteswap()
Expand Down Expand Up @@ -447,6 +455,8 @@ def _try_read_ply_constant_list_binary(
If every element has the same size, 2D numpy array corresponding to the
data. The rows are the different values. Otherwise None.
"""
if definition.count == 0:
return []
property = definition.properties[0]
endian_str = ">" if big_endian else "<"
length_format = endian_str + _PLY_TYPES[property.list_size_type].struct_char
Expand Down Expand Up @@ -689,6 +699,7 @@ def _save_ply(
verts: torch.Tensor,
faces: torch.LongTensor,
verts_normals: torch.Tensor,
ascii: bool,
decimal_places: Optional[int] = None,
) -> None:
"""
Expand All @@ -699,52 +710,75 @@ def _save_ply(
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shsape (F, 3) giving faces.
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
decimal_places: Number of decimal places for saving.
ascii: (bool) whether to use the ascii ply format.
decimal_places: Number of decimal places for saving if ascii=True.
"""
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
assert not len(verts_normals) or (
verts_normals.dim() == 2 and verts_normals.size(1) == 3
)

print("ply\nformat ascii 1.0", file=f)
print(f"element vertex {verts.shape[0]}", file=f)
print("property float x", file=f)
print("property float y", file=f)
print("property float z", file=f)
if ascii:
f.write(b"ply\nformat ascii 1.0\n")
elif sys.byteorder == "big":
f.write(b"ply\nformat binary_big_endian 1.0\n")
else:
f.write(b"ply\nformat binary_little_endian 1.0\n")
f.write(f"element vertex {verts.shape[0]}\n".encode("ascii"))
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
if verts_normals.numel() > 0:
print("property float nx", file=f)
print("property float ny", file=f)
print("property float nz", file=f)
print(f"element face {faces.shape[0]}", file=f)
print("property list uchar int vertex_index", file=f)
print("end_header", file=f)
f.write(b"property float nx\n")
f.write(b"property float ny\n")
f.write(b"property float nz\n")
f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
f.write(b"property list uchar int vertex_index\n")
f.write(b"end_header\n")

if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided")
return

if decimal_places is None:
float_str = "%f"
vert_data = torch.cat((verts, verts_normals), dim=1).detach().numpy()
if ascii:
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
np.savetxt(f, vert_data, float_str)
else:
float_str = "%" + ".%df" % decimal_places

vert_data = torch.cat((verts, verts_normals), dim=1)
np.savetxt(f, vert_data.detach().numpy(), float_str)
assert vert_data.dtype == np.float32
if isinstance(f, BytesIO):
# tofile only works with real files, but is faster than this.
f.write(vert_data.tobytes())
else:
vert_data.tofile(f)

faces_array = faces.detach().numpy()

_check_faces_indices(faces, max_index=verts.shape[0])

if len(faces_array):
np.savetxt(f, faces_array, "3 %d %d %d")
if ascii:
np.savetxt(f, faces_array, "3 %d %d %d")
else:
# rows are 13 bytes: a one-byte 3 followed by three four-byte face indices.
faces_uints = np.full((len(faces_array), 13), 3, dtype=np.uint8)
faces_uints[:, 1:] = faces_array.astype(np.uint32).view(np.uint8)
if isinstance(f, BytesIO):
f.write(faces_uints.tobytes())
else:
faces_uints.tofile(f)


def save_ply(
f,
verts: torch.Tensor,
faces: Optional[torch.LongTensor] = None,
verts_normals: Optional[torch.Tensor] = None,
ascii: bool = False,
decimal_places: Optional[int] = None,
) -> None:
"""
Expand All @@ -755,7 +789,8 @@ def save_ply(
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
decimal_places: Number of decimal places for saving.
ascii: (bool) whether to use the ascii ply format.
decimal_places: Number of decimal places for saving if ascii=True.
"""

verts_normals = (
Expand All @@ -781,5 +816,5 @@ def save_ply(
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)

with _open_file(f, "w") as f:
_save_ply(f, verts, faces, verts_normals, decimal_places)
with _open_file(f, "wb") as f:
_save_ply(f, verts, faces, verts_normals, ascii, decimal_places)
57 changes: 40 additions & 17 deletions tests/test_ply_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import struct
import unittest
from io import BytesIO, StringIO
from tempfile import TemporaryFile

import torch
from common_testing import TestCaseMixin
Expand Down Expand Up @@ -144,7 +145,7 @@ def test_save_ply_invalid_shapes(self):
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
faces = torch.LongTensor([[0, 1, 2]])
save_ply(StringIO(), verts, faces)
save_ply(BytesIO(), verts, faces)
expected_message = (
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
)
Expand All @@ -154,7 +155,7 @@ def test_save_ply_invalid_shapes(self):
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
save_ply(StringIO(), verts, faces)
save_ply(BytesIO(), verts, faces)
expected_message = (
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
)
Expand All @@ -165,14 +166,14 @@ def test_save_ply_invalid_indices(self):
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2]])
with self.assertWarnsRegex(UserWarning, message_regex):
save_ply(StringIO(), verts, faces)
save_ply(BytesIO(), verts, faces)

faces = torch.LongTensor([[-1, 0, 1]])
with self.assertWarnsRegex(UserWarning, message_regex):
save_ply(StringIO(), verts, faces)
save_ply(BytesIO(), verts, faces)

def _test_save_load(self, verts, faces):
f = StringIO()
f = BytesIO()
save_ply(f, verts, faces)
f.seek(0)
# raise Exception(f.getvalue())
Expand All @@ -193,7 +194,7 @@ def test_normals_save(self):
normals = torch.tensor(
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32
)
file = StringIO()
file = BytesIO()
save_ply(file, verts=verts, faces=faces, verts_normals=normals)
file.close()

Expand Down Expand Up @@ -237,15 +238,31 @@ def test_empty_save_load(self):

def test_simple_save(self):
verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 0]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [0, 3, 4]])
file = StringIO()
save_ply(file, verts=verts, faces=faces)
file.seek(0)
verts2, faces2 = load_ply(file)
self.assertClose(verts, verts2)
self.assertClose(faces, faces2)
for filetype in BytesIO, TemporaryFile:
lengths = {}
for ascii in [True, False]:
file = filetype()
save_ply(file, verts=verts, faces=faces, ascii=ascii)
lengths[ascii] = file.tell()

file.seek(0)
verts2, faces2 = load_ply(file)
self.assertClose(verts, verts2)
self.assertClose(faces, faces2)

file.seek(0)
if ascii:
file.read().decode("ascii")
else:
with self.assertRaises(UnicodeDecodeError):
file.read().decode("ascii")

if filetype is TemporaryFile:
file.close()
self.assertLess(lengths[False], lengths[True], "ascii should be longer")

def test_load_simple_binary(self):
for big_endian in [True, False]:
Expand Down Expand Up @@ -488,15 +505,21 @@ def test_bad_ply_syntax(self):

@staticmethod
def _bm_save_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
return lambda: save_ply(StringIO(), verts, faces, decimal_places=decimal_places)
return lambda: save_ply(
BytesIO(),
verts=verts,
faces=faces,
ascii=True,
decimal_places=decimal_places,
)

@staticmethod
def _bm_load_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
f = StringIO()
save_ply(f, verts, faces, decimal_places)
f = BytesIO()
save_ply(f, verts=verts, faces=faces, ascii=True, decimal_places=decimal_places)
s = f.getvalue()
# Recreate stream so it's unaffected by how it was created.
return lambda: load_ply(StringIO(s))
return lambda: load_ply(BytesIO(s))

@staticmethod
def bm_save_simple_ply_with_init(V: int, F: int):
Expand Down

0 comments on commit 197f1d6

Please sign in to comment.