diff --git a/pytorch3d/utils/__init__.py b/pytorch3d/utils/__init__.py index b5b3f46ab..90e7fe8a3 100644 --- a/pytorch3d/utils/__init__.py +++ b/pytorch3d/utils/__init__.py @@ -4,7 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .camera_conversions import cameras_from_opencv_projection +from .camera_conversions import ( + cameras_from_opencv_projection, + opencv_from_cameras_projection, +) from .ico_sphere import ico_sphere from .torus import torus diff --git a/pytorch3d/utils/camera_conversions.py b/pytorch3d/utils/camera_conversions.py index 077d74b23..30477d86f 100644 --- a/pytorch3d/utils/camera_conversions.py +++ b/pytorch3d/utils/camera_conversions.py @@ -4,10 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Tuple + import torch from ..renderer import PerspectiveCameras -from ..transforms import so3_exponential_map +from ..transforms import so3_exponential_map, so3_log_map def cameras_from_opencv_projection( @@ -35,7 +37,7 @@ def cameras_from_opencv_projection( followed by the homogenization of `x_screen_opencv`. Note: - The parameters `rvec, tvec, camera_matrix` correspond e.g. to the inputs + The parameters `rvec, tvec, camera_matrix` correspond, e.g., to the inputs of `cv2.projectPoints`, or to the ouputs of `cv2.calibrateCamera`. Args: @@ -74,3 +76,51 @@ def cameras_from_opencv_projection( focal_length=focal_pytorch3d, principal_point=p0_pytorch3d, ) + + +def opencv_from_cameras_projection( + cameras: PerspectiveCameras, + image_size: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Converts a batch of `PerspectiveCameras` into OpenCV-convention + axis-angle rotation vectors `rvec`, translation vectors `tvec`, and the camera + calibration matrices `camera_matrix`. This operation is exactly the inverse + of `cameras_from_opencv_projection`. + + Note: + The parameters `rvec, tvec, camera_matrix` correspond, e.g., to the inputs + of `cv2.projectPoints`, or to the ouputs of `cv2.calibrateCamera`. + + Args: + cameras: A batch of `N` cameras in the PyTorch3D convention. + image_size: A tensor of shape `(N, 2)` containing the sizes of the images + (height, width) attached to each camera. + + Returns: + rvec: A batch of axis-angle rotation vectors of shape `(N, 3)`. + tvec: A batch of translation vectors of shape `(N, 3)`. + camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`. + """ + R_pytorch3d = cameras.R + T_pytorch3d = cameras.T + focal_pytorch3d = cameras.focal_length + p0_pytorch3d = cameras.principal_point + T_pytorch3d[:, :2] *= -1 # pyre-ignore + R_pytorch3d[:, :, :2] *= -1 # pyre-ignore + tvec = T_pytorch3d.clone() # pyre-ignore + R = R_pytorch3d.permute(0, 2, 1) # pyre-ignore + + # Retype the image_size correctly and flip to width, height. + image_size_wh = image_size.to(R).flip(dims=(1,)) + + principal_point = (-p0_pytorch3d + 1.0) * (0.5 * image_size_wh) # pyre-ignore + focal_length = focal_pytorch3d * (0.5 * image_size_wh) + + camera_matrix = torch.zeros_like(R) + camera_matrix[:, :2, 2] = principal_point + camera_matrix[:, 2, 2] = 1.0 + camera_matrix[:, 0, 0] = focal_length[:, 0] + camera_matrix[:, 1, 1] = focal_length[:, 1] + rvec = so3_log_map(R) + return rvec, tvec, camera_matrix diff --git a/tests/test_camera_conversions.py b/tests/test_camera_conversions.py index 851b93bf3..7c44a9d48 100644 --- a/tests/test_camera_conversions.py +++ b/tests/test_camera_conversions.py @@ -15,6 +15,7 @@ from pytorch3d.transforms import so3_exponential_map, so3_log_map from pytorch3d.utils import ( cameras_from_opencv_projection, + opencv_from_cameras_projection, ) DATA_DIR = get_tests_dir() / "data" @@ -151,3 +152,11 @@ def test_opencv_conversion(self): self.assertClose( pts_proj_opencv_in_pytorch3d_screen, pts_proj_pytorch3d, atol=1e-5 ) + + # Check the inverse. + rvec_i, tvec_i, camera_matrix_i = opencv_from_cameras_projection( + cameras_opencv_to_pytorch3d, image_size + ) + self.assertClose(rvec, rvec_i) + self.assertClose(tvec, tvec_i) + self.assertClose(camera_matrix, camera_matrix_i)