diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index f992db4d8..681742de0 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import math +import os import warnings from typing import List, Optional, Union @@ -636,7 +637,10 @@ def __init__( msg = "R must have shape (3, 3) or (N, 3, 3); got %s" raise ValueError(msg % repr(R.shape)) R = R.to(device=device_, dtype=dtype) - _check_valid_rotation_matrix(R, tol=orthogonal_tol) + if os.environ.get("PYTORCH3D_CHECK_ROTATION_MATRICES", "0") == "1": + # Note: aten::all_close in the check is computationally slow, so we + # only run the check when PYTORCH3D_CHECK_ROTATION_MATRICES is on. + _check_valid_rotation_matrix(R, tol=orthogonal_tol) N = R.shape[0] mat = torch.eye(4, dtype=dtype, device=device_) mat = mat.view(1, 4, 4).repeat(N, 1, 1) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3d8ebd628..bfd67febe 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -4,9 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import math +import os import unittest +from unittest import mock import torch from pytorch3d.transforms import random_rotations @@ -191,7 +192,25 @@ def test_translate(self): self.assertTrue(torch.allclose(points_out, points_out_expected)) self.assertTrue(torch.allclose(normals_out, normals_out_expected)) - def test_rotate(self): + @mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "1"}, clear=True) + def test_rotate_check_rot_valid_on(self): + R = so3_exp_map(torch.randn((1, 3))) + t = Transform3d().rotate(R) + points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view( + 1, 3, 3 + ) + normals = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]] + ).view(1, 3, 3) + points_out = t.transform_points(points) + normals_out = t.transform_normals(normals) + points_out_expected = torch.bmm(points, R) + normals_out_expected = torch.bmm(normals, R) + self.assertTrue(torch.allclose(points_out, points_out_expected)) + self.assertTrue(torch.allclose(normals_out, normals_out_expected)) + + @mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "0"}, clear=True) + def test_rotate_check_rot_valid_off(self): R = so3_exp_map(torch.randn((1, 3))) t = Transform3d().rotate(R) points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(