diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index 620380318..f992db4d8 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -13,6 +13,7 @@ from ..common.datatypes import Device, get_device, make_device from ..common.workaround import _safe_det_3x3 from .rotation_conversions import _axis_angle_rotation +from .se3 import se3_log_map class Transform3d: @@ -130,13 +131,13 @@ class Transform3d: [Tx, Ty, Tz, 1], ] - To apply the transformation to points which are row vectors, the M matrix - can be pre multiplied by the points: + To apply the transformation to points, which are row vectors, the latter are + converted to homogeneous (4D) coordinates and right-multiplied by the M matrix: .. code-block:: python points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point - transformed_points = points * M + [transformed_points, 1] ∝ [points, 1] @ M """ @@ -218,9 +219,10 @@ def compose(self, *others: "Transform3d") -> "Transform3d": def get_matrix(self) -> torch.Tensor: """ - Return a matrix which is the result of composing this transform - with others stored in self.transforms. Where necessary transforms - are broadcast against each other. + Returns a 4×4 matrix corresponding to each transform in the batch. + + If the transform was composed from others, the matrix for the composite + transform will be returned. For example, if self.transforms contains transforms t1, t2, and t3, and given a set of points x, the following should be true: @@ -230,8 +232,11 @@ def get_matrix(self) -> torch.Tensor: y2 = t3.transform(t2.transform(t1.transform(x))) y1.get_matrix() == y2.get_matrix() + Where necessary, those transforms are broadcast against each other. + Returns: - A transformation matrix representing the composed inputs. + A (N, 4, 4) batch of transformation matrices representing + the stored transforms. See the class documentation for the conventions. """ composed_matrix = self._matrix.clone() if len(self._transforms) > 0: @@ -240,6 +245,49 @@ def get_matrix(self) -> torch.Tensor: composed_matrix = _broadcast_bmm(composed_matrix, other_matrix) return composed_matrix + def get_se3_log(self, eps: float = 1e-4, cos_bound: float = 1e-4) -> torch.Tensor: + """ + Returns a 6D SE(3) log vector corresponding to each transform in the batch. + + In the SE(3) logarithmic representation SE(3) matrices are + represented as 6-dimensional vectors `[log_translation | log_rotation]`, + i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`. + + The conversion from the 4x4 SE(3) matrix `transform` to the + 6D representation `log_transform = [log_translation | log_rotation]` + is done as follows: + ``` + log_transform = log(transform.get_matrix()) + log_translation = log_transform[3, :3] + log_rotation = inv_hat(log_transform[:3, :3]) + ``` + where `log` is the matrix logarithm + and `inv_hat` is the inverse of the Hat operator [2]. + + See the docstring for `se3.se3_log_map` and [1], Sec 9.4.2. for more + detailed description. + + Args: + eps: A threshold for clipping the squared norm of the rotation logarithm + to avoid division by zero in the singular case. + cos_bound: Clamps the cosine of the rotation angle to + [-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs. + The non-finite outputs can be caused by passing small rotation angles + to the `acos` function in `so3_rotation_angle` of `so3_log_map`. + + Returns: + A (N, 6) tensor, rows of which represent the individual transforms + stored in the object as SE(3) logarithms. + + Raises: + ValueError if the stored transform is not Euclidean (e.g. R is not a rotation + matrix or the last column has non-zeros in the first three places). + + [1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf + [2] https://en.wikipedia.org/wiki/Hat_operator + """ + return se3_log_map(self.get_matrix(), eps, cos_bound) + def _get_matrix_inverse(self) -> torch.Tensor: """ Return the inverse of self._matrix. diff --git a/tests/test_transforms.py b/tests/test_transforms.py index a10d12a38..3d8ebd628 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -10,6 +10,7 @@ import torch from pytorch3d.transforms import random_rotations +from pytorch3d.transforms.se3 import se3_log_map from pytorch3d.transforms.so3 import so3_exp_map from pytorch3d.transforms.transform3d import ( Rotate, @@ -161,6 +162,16 @@ def test_init_with_custom_matrix_errors(self): matrix = torch.randn(*bad_shape).float() self.assertRaises(ValueError, Transform3d, matrix=matrix) + def test_get_se3(self): + N = 16 + random_rotations(N) + tr = Translate(torch.rand((N, 3))) + R = Rotate(random_rotations(N)) + transform = Transform3d().compose(R, tr) + se3_log = transform.get_se3_log() + gt_se3_log = se3_log_map(transform.get_matrix()) + self.assertClose(se3_log, gt_se3_log) + def test_translate(self): t = Transform3d().translate(1, 2, 3) points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(