Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trimming option in ICP #1866

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
29 changes: 27 additions & 2 deletions pytorch3d/ops/points_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def iterative_closest_point(
X: Union[torch.Tensor, "Pointclouds"],
Y: Union[torch.Tensor, "Pointclouds"],
init_transform: Optional[SimilarityTransform] = None,
trim_fraction: Union[float, torch.Tensor] = 0.,
max_iterations: int = 100,
relative_rmse_thr: float = 1e-6,
estimate_scale: bool = False,
Expand Down Expand Up @@ -67,6 +68,11 @@ def iterative_closest_point(
shape `(minibatch, d, d)`, `T` is a batch of translations
of shape `(minibatch, d)` and `s` is a batch of scaling factors
of shape `(minibatch,)`.
**trim_fraction**: A float or 1d `Tensor` of shape `(minibatch,)` in [0, 1]
specifying the ratio of outliers in each point cloud. If float, assume
the same outliers ratio for all point clouds in the batch. Outliers will
be detected by taking the `trim_fraction * num_points_X` highest values of
`s[i] X[i] R[i] + T[i] = Y[NN[i]]`.
**max_iterations**: The maximum number of ICP iterations.
**relative_rmse_thr**: A threshold on the relative root mean squared error
used to terminate the algorithm.
Expand Down Expand Up @@ -152,6 +158,17 @@ def iterative_closest_point(
T = Xt.new_zeros((b, dim))
s = Xt.new_ones(b)

# initialize trim fraction parameter
if isinstance(trim_fraction, float):
trim_fraction = torch.as_tensor(trim_fraction)
trim_fraction = trim_fraction.to(Xt.device) # type: ignore
if trim_fraction.ndim == 0:
trim_fraction = trim_fraction.repeat(b)
trim = trim_fraction.min() > 0.0

# initial mask: no trim considered, only padding
mask = mask_X.bool().clone()

prev_rmse = None
rmse = None
iteration = -1
Expand All @@ -170,7 +187,7 @@ def iterative_closest_point(
R, T, s = corresponding_points_alignment(
Xt_init,
Xt_nn_points,
weights=mask_X,
weights=mask,
estimate_scale=estimate_scale,
allow_reflection=allow_reflection,
)
Expand All @@ -184,7 +201,15 @@ def iterative_closest_point(
# compute the root mean squared error
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2)
rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0]

# trimming: select `1 - trim_fraction` lowest distances.
if trim:
diff_thresholds = Xt_sq_diff[mask_X.bool()].quantile(1 - trim_fraction)
mask_trim = Xt_sq_diff < diff_thresholds[:, None]
# final mask is (trim_mask AND pad_mask)
mask = torch.logical_and(mask_trim, mask_X)

rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask).sqrt()[:, 0, 0]

# compute the relative rmse
if prev_rmse is None:
Expand Down