diff --git a/pytorch3d/implicitron/dataset/orm_types.py b/pytorch3d/implicitron/dataset/orm_types.py index 5736ab4ba..2e916021a 100644 --- a/pytorch3d/implicitron/dataset/orm_types.py +++ b/pytorch3d/implicitron/dataset/orm_types.py @@ -33,7 +33,35 @@ # these produce policies to serialize structured types to blobs -def ArrayTypeFactory(shape): +def ArrayTypeFactory(shape=None): + if shape is None: + + class VariableShapeNumpyArrayType(TypeDecorator): + impl = LargeBinary + + def process_bind_param(self, value, dialect): + if value is None: + return None + + ndim_bytes = np.int32(value.ndim).tobytes() + shape_bytes = np.array(value.shape, dtype=np.int64).tobytes() + value_bytes = value.astype(np.float32).tobytes() + return ndim_bytes + shape_bytes + value_bytes + + def process_result_value(self, value, dialect): + if value is None: + return None + + ndim = np.frombuffer(value[:4], dtype=np.int32)[0] + value_start = 4 + 8 * ndim + shape = np.frombuffer(value[4:value_start], dtype=np.int64) + assert shape.shape == (ndim,) + return np.frombuffer(value[value_start:], dtype=np.float32).reshape( + shape + ) + + return VariableShapeNumpyArrayType + class NumpyArrayType(TypeDecorator): impl = LargeBinary @@ -158,4 +186,4 @@ class SqlSequenceAnnotation(Base): mapped_column("_point_cloud_n_points", nullable=True), ) # the bigger the better - viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None) + viewpoint_quality_score: Mapped[Optional[float]] = mapped_column() diff --git a/pytorch3d/implicitron/dataset/sql_dataset.py b/pytorch3d/implicitron/dataset/sql_dataset.py index 4c9d3bb5e..2c74e56c5 100644 --- a/pytorch3d/implicitron/dataset/sql_dataset.py +++ b/pytorch3d/implicitron/dataset/sql_dataset.py @@ -142,8 +142,10 @@ def __post_init__(self) -> None: run_auto_creation(self) self.frame_data_builder.path_manager = self.path_manager - # pyre-ignore - self._sql_engine = sa.create_engine(f"sqlite:///{self.sqlite_metadata_file}") + # pyre-ignore # NOTE: sqlite-specific args (read-only mode). + self._sql_engine = sa.create_engine( + f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true" + ) sequences = self._get_filtered_sequences_if_any() diff --git a/tests/implicitron/test_orm_types.py b/tests/implicitron/test_orm_types.py index 7570b002b..e6f94c010 100644 --- a/tests/implicitron/test_orm_types.py +++ b/tests/implicitron/test_orm_types.py @@ -8,7 +8,7 @@ import numpy as np -from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory +from pytorch3d.implicitron.dataset.orm_types import ArrayTypeFactory, TupleTypeFactory class TestOrmTypes(unittest.TestCase): @@ -35,3 +35,28 @@ def test_tuple_serialization_2d(self): self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0])) # we use float32 to serialise np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6) + + def test_array_serialization_none(self): + ttype = ArrayTypeFactory((3, 3))() + output = ttype.process_bind_param(None, None) + self.assertIsNone(output) + output = ttype.process_result_value(output, None) + self.assertIsNone(output) + + def test_array_serialization(self): + for input_list in [[1, 2, 3], [[4.5, 6.7], [8.9, 10.0]]]: + input_array = np.array(input_list) + + # first, dynamic-size array + ttype = ArrayTypeFactory()() + output = ttype.process_bind_param(input_array, None) + input_hat = ttype.process_result_value(output, None) + self.assertEqual(input_hat.dtype, np.float32) + np.testing.assert_almost_equal(input_hat, input_array, decimal=6) + + # second, fixed-size array + ttype = ArrayTypeFactory(tuple(input_array.shape))() + output = ttype.process_bind_param(input_array, None) + input_hat = ttype.process_result_value(output, None) + self.assertEqual(input_hat.dtype, np.float32) + np.testing.assert_almost_equal(input_hat, input_array, decimal=6)