Skip to content

Commit

Permalink
Add OpenCV camera conversion; fix bug for camera unified PyTorch3D in…
Browse files Browse the repository at this point in the history
…terface.

Summary: This commit adds a new camera conversion function for OpenCV style parameters to Pulsar parameters to the library. Using this function it addresses a bug reported here: https://fb.workplace.com/groups/629644647557365/posts/1079637302558095, by using the PyTorch3D->OpenCV->Pulsar chain instead of the original direct conversion function. Both conversions are well-tested and an additional test for the full chain has been added, resulting in a more reliable solution requiring less code.

Reviewed By: patricklabatut

Differential Revision: D29322106

fbshipit-source-id: 13df13c2e48f628f75d9f44f19ff7f1646fb7ebd
  • Loading branch information
classner authored and facebook-github-bot committed Jul 10, 2021
1 parent fef5bcd commit 75432a0
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 32 deletions.
52 changes: 29 additions & 23 deletions pytorch3d/renderer/points/pulsar/unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
import torch.nn as nn

from ....transforms import matrix_to_rotation_6d
from ....utils import pulsar_from_cameras_projection
from ...cameras import (
FoVOrthographicCameras,
FoVPerspectiveCameras,
Expand Down Expand Up @@ -102,7 +102,7 @@ def __init__(
height=height,
max_num_balls=max_num_spheres,
orthogonal_projection=orthogonal_projection,
right_handed_system=True,
right_handed_system=False,
n_channels=n_channels,
**kwargs,
)
Expand Down Expand Up @@ -359,24 +359,28 @@ def _extract_intrinsics( # noqa: C901
def _extract_extrinsics(
self, kwargs, cloud_idx
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Extract the extrinsic information from the kwargs for a specific point cloud.
Instead of implementing a direct translation from the PyTorch3D to the Pulsar
camera model, we chain the two conversions of PyTorch3D->OpenCV and
OpenCV->Pulsar for better maintainability (PyTorch3D->OpenCV is maintained and
tested by the core PyTorch3D team, whereas OpenCV->Pulsar is maintained and
tested by the Pulsar team).
"""
# Shorthand:
cameras = self.rasterizer.cameras
R = kwargs.get("R", cameras.R)[cloud_idx]
T = kwargs.get("T", cameras.T)[cloud_idx]
norm_mat = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
dtype=torch.float32,
device=R.device,
tmp_cams = PerspectiveCameras(
R=R.unsqueeze(0), T=T.unsqueeze(0), device=R.device
)
cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...]).permute((0, 2, 1))
norm_mat = torch.tensor(
[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
dtype=torch.float32,
device=R.device,
size_tensor = torch.tensor(
[[self.renderer._renderer.height, self.renderer._renderer.width]]
)
cam_rot = torch.matmul(norm_mat, cam_rot)
cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None]))
cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot))
pulsar_cam = pulsar_from_cameras_projection(tmp_cams, size_tensor)
cam_pos = pulsar_cam[0, :3]
cam_rot = pulsar_cam[0, 3:9]
return cam_pos, cam_rot

def _get_vert_rad(
Expand Down Expand Up @@ -547,15 +551,17 @@ def forward(self, point_clouds, **kwargs) -> torch.Tensor:
otherargs["bg_col"] = bg_col
# Go!
images.append(
self.renderer(
vert_pos=vert_pos,
vert_col=vert_col,
vert_rad=vert_rad,
cam_params=cam_params,
gamma=gamma,
max_depth=zfar,
min_depth=znear,
**otherargs,
torch.flipud(
self.renderer(
vert_pos=vert_pos,
vert_col=vert_col,
vert_rad=vert_rad,
cam_params=cam_params,
gamma=gamma,
max_depth=zfar,
min_depth=znear,
**otherargs,
)
)
)
return torch.stack(images, dim=0)
2 changes: 2 additions & 0 deletions pytorch3d/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .camera_conversions import (
cameras_from_opencv_projection,
opencv_from_cameras_projection,
pulsar_from_opencv_projection,
pulsar_from_cameras_projection,
)
from .ico_sphere import ico_sphere
from .torus import torus
Expand Down
171 changes: 162 additions & 9 deletions pytorch3d/utils/camera_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Tuple

import torch

from ..renderer import PerspectiveCameras
from ..transforms import so3_exp_map, so3_log_map
from ..transforms import matrix_to_rotation_6d


LOGGER = logging.getLogger(__name__)


def cameras_from_opencv_projection(
Expand Down Expand Up @@ -54,7 +58,6 @@ def cameras_from_opencv_projection(
Returns:
cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention.
"""

focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
principal_point = camera_matrix[:, :2, 2]

Expand All @@ -68,7 +71,7 @@ def cameras_from_opencv_projection(
# For R, T we flip x, y axes (opencv screen space has an opposite
# orientation of screen axes).
# We also transpose R (opencv multiplies points from the opposite=left side).
R_pytorch3d = R.permute(0, 2, 1)
R_pytorch3d = R.clone().permute(0, 2, 1)
T_pytorch3d = tvec.clone()
R_pytorch3d[:, :, :2] *= -1
T_pytorch3d[:, :2] *= -1
Expand Down Expand Up @@ -103,20 +106,22 @@ def opencv_from_cameras_projection(
cameras: A batch of `N` cameras in the PyTorch3D convention.
image_size: A tensor of shape `(N, 2)` containing the sizes of the images
(height, width) attached to each camera.
return_as_rotmat (bool): If set to True, return the full 3x3 rotation
matrices. Otherwise, return an axis-angle vector (default).
Returns:
R: A batch of rotation matrices of shape `(N, 3, 3)`.
tvec: A batch of translation vectors of shape `(N, 3)`.
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
"""
R_pytorch3d = cameras.R
T_pytorch3d = cameras.T
R_pytorch3d = cameras.R.clone() # pyre-ignore
T_pytorch3d = cameras.T.clone() # pyre-ignore
focal_pytorch3d = cameras.focal_length
p0_pytorch3d = cameras.principal_point
T_pytorch3d[:, :2] *= -1 # pyre-ignore
R_pytorch3d[:, :, :2] *= -1 # pyre-ignore
tvec = T_pytorch3d.clone() # pyre-ignore
R = R_pytorch3d.permute(0, 2, 1) # pyre-ignore
T_pytorch3d[:, :2] *= -1
R_pytorch3d[:, :, :2] *= -1
tvec = T_pytorch3d
R = R_pytorch3d.permute(0, 2, 1)

# Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,))
Expand All @@ -130,3 +135,151 @@ def opencv_from_cameras_projection(
camera_matrix[:, 0, 0] = focal_length[:, 0]
camera_matrix[:, 1, 1] = focal_length[:, 1]
return R, tvec, camera_matrix


def pulsar_from_opencv_projection(
R: torch.Tensor,
tvec: torch.Tensor,
camera_matrix: torch.Tensor,
image_size: torch.Tensor,
znear: float = 0.1,
) -> torch.Tensor:
"""
Convert OpenCV style camera parameters to Pulsar style camera parameters.
Note:
* Pulsar does NOT support different focal lengths for x and y.
For conversion, we use the average of fx and fy.
* The Pulsar renderer MUST use a left-handed coordinate system for this
mapping to work.
* The resulting image will be vertically flipped - which has to be
addressed AFTER rendering by the user.
* The parameters `R, tvec, camera_matrix` correspond to the outputs
of `cv2.decomposeProjectionMatrix`.
Args:
R: A batch of rotation matrices of shape `(N, 3, 3)`.
tvec: A batch of translation vectors of shape `(N, 3)`.
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
image_size: A tensor of shape `(N, 2)` containing the sizes of the images
(height, width) attached to each camera.
znear (float): The near clipping value to use for Pulsar.
Returns:
cameras_pulsar: A batch of `N` Pulsar camera vectors in the Pulsar
convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
c_x, c_y).
"""
assert len(camera_matrix.size()) == 3, "This function requires batched inputs!"
assert len(R.size()) == 3, "This function requires batched inputs!"
assert len(tvec.size()) in (2, 3), "This function reuqires batched inputs!"

# Validate parameters.
image_size_wh = image_size.to(R).flip(dims=(1,))
assert torch.all(
image_size_wh > 0
), "height and width must be positive but min is: %s" % (
str(image_size_wh.min().item())
)
assert (
camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3
), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % (
camera_matrix.size(1),
camera_matrix.size(2),
)
assert (
R.size(1) == 3 and R.size(2) == 3
), "Incorrect R shape: expected 3x3 but got %dx%d" % (
R.size(1),
R.size(2),
)
if len(tvec.size()) == 2:
tvec = tvec.unsqueeze(2)
assert (
tvec.size(1) == 3 and tvec.size(2) == 1
), "Incorrect tvec shape: expected 3x1 but got %dx%d" % (
tvec.size(1),
tvec.size(2),
)
# Check batch size.
batch_size = camera_matrix.size(0)
assert R.size(0) == batch_size, "Expected R to have batch size %d. Has size %d." % (
batch_size,
R.size(0),
)
assert (
tvec.size(0) == batch_size
), "Expected tvec to have batch size %d. Has size %d." % (
batch_size,
tvec.size(0),
)
# Check image sizes.
image_w = image_size_wh[0, 0]
image_h = image_size_wh[0, 1]
assert torch.all(
image_size_wh[:, 0] == image_w
), "All images in a batch must have the same width!"
assert torch.all(
image_size_wh[:, 1] == image_h
), "All images in a batch must have the same height!"
# Focal length.
fx = camera_matrix[:, 0, 0].unsqueeze(1)
fy = camera_matrix[:, 1, 1].unsqueeze(1)
# Check that we introduce less than 1% error by averaging the focal lengths.
fx_y = fx / fy
if torch.any(fx_y > 1.01) or torch.any(fx_y < 0.99):
LOGGER.warning(
"Pulsar only supports a single focal lengths. For converting OpenCV "
"focal lengths, we average them for x and y directions. "
"The focal lengths for x and y you provided differ by more than 1%, "
"which means this could introduce a noticeable error."
)
f = (fx + fy) / 2
# Normalize f into normalized device coordinates.
focal_length_px = f / image_w
# Transfer into focal_length and sensor_width.
focal_length = torch.tensor([znear - 1e-5], dtype=torch.float32, device=R.device)
focal_length = focal_length[None, :].repeat(batch_size, 1)
sensor_width = focal_length / focal_length_px
# Principal point.
cx = camera_matrix[:, 0, 2].unsqueeze(1)
cy = camera_matrix[:, 1, 2].unsqueeze(1)
# Transfer principal point offset into centered offset.
cx = -(cx - image_w / 2)
cy = cy - image_h / 2
# Concatenate to final vector.
param = torch.cat([focal_length, sensor_width, cx, cy], dim=1)
R_trans = R.permute(0, 2, 1)
cam_pos = -torch.bmm(R_trans, tvec).squeeze(2)
cam_rot = matrix_to_rotation_6d(R_trans)
cam_params = torch.cat([cam_pos, cam_rot, param], dim=1)
return cam_params


def pulsar_from_cameras_projection(
cameras: PerspectiveCameras,
image_size: torch.Tensor,
) -> torch.Tensor:
"""
Convert PyTorch3D `PerspectiveCameras` to Pulsar style camera parameters.
Note:
* Pulsar does NOT support different focal lengths for x and y.
For conversion, we use the average of fx and fy.
* The Pulsar renderer MUST use a left-handed coordinate system for this
mapping to work.
* The resulting image will be vertically flipped - which has to be
addressed AFTER rendering by the user.
Args:
cameras: A batch of `N` cameras in the PyTorch3D convention.
image_size: A tensor of shape `(N, 2)` containing the sizes of the images
(height, width) attached to each camera.
Returns:
cameras_pulsar: A batch of `N` Pulsar camera vectors in the Pulsar
convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
c_x, c_y).
"""
opencv_R, opencv_T, opencv_K = opencv_from_cameras_projection(cameras, image_size)
return pulsar_from_opencv_projection(opencv_R, opencv_T, opencv_K, image_size)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 75432a0

Please sign in to comment.