From ccf22911d4daa74af7fbf70b3373bc0fe46d6d7c Mon Sep 17 00:00:00 2001 From: Ruishen Lyu Date: Tue, 2 Apr 2024 07:50:25 -0700 Subject: [PATCH] Optimize list_to_packed to avoid for loop (#1737) Summary: For larger N and Mi value (e.g. N=154, Mi=238) I notice list_to_packed() has become a bottleneck for my application. By removing the for loop and running on GPU, i see a 10-20 x speedup. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1737 Reviewed By: MichaelRamamonjisoa Differential Revision: D54187993 Pulled By: bottler fbshipit-source-id: 16399a24cb63b48c30460c7d960abef603b115d0 --- pytorch3d/structures/utils.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/pytorch3d/structures/utils.py b/pytorch3d/structures/utils.py index c6b75d93..6c0c4f73 100644 --- a/pytorch3d/structures/utils.py +++ b/pytorch3d/structures/utils.py @@ -135,22 +135,21 @@ def list_to_packed(x: List[torch.Tensor]): - **item_packed_to_list_idx**: tensor of shape sum(Mi) containing the index of the element in the list the item belongs to. """ - N = len(x) - num_items = torch.zeros(N, dtype=torch.int64, device=x[0].device) - item_packed_first_idx = torch.zeros(N, dtype=torch.int64, device=x[0].device) - item_packed_to_list_idx = [] - cur = 0 - for i, y in enumerate(x): - num = len(y) - num_items[i] = num - item_packed_first_idx[i] = cur - item_packed_to_list_idx.append( - torch.full((num,), i, dtype=torch.int64, device=y.device) - ) - cur += num - + if not x: + raise ValueError("Input list is empty") + device = x[0].device + sizes = [xi.shape[0] for xi in x] + sizes_total = sum(sizes) + num_items = torch.tensor(sizes, dtype=torch.int64, device=device) + item_packed_first_idx = torch.zeros_like(num_items) + item_packed_first_idx[1:] = torch.cumsum(num_items[:-1], dim=0) + item_packed_to_list_idx = torch.arange( + sizes_total, dtype=torch.int64, device=device + ) + item_packed_to_list_idx = ( + torch.bucketize(item_packed_to_list_idx, item_packed_first_idx, right=True) - 1 + ) x_packed = torch.cat(x, dim=0) - item_packed_to_list_idx = torch.cat(item_packed_to_list_idx, dim=0) return x_packed, num_items, item_packed_first_idx, item_packed_to_list_idx