-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Implementation of ball query from PointNet++. This function is similar to KNN (find the neighbors in p2 for all points in p1). These are the key differences: - It will return the **first** K neighbors within a specified radius as opposed to the **closest** K neighbors. - As all the points in p2 do not need to be considered to find the closest K, the algorithm is much faster than KNN when p2 has a large number of points. - The neighbors are not sorted - Due to the radius threshold it is not guaranteed that there will be K neighbors even if there are more than K points in p2. - The padding value for `idx` is -1 instead of 0. # Note: - Some of the code is very similar to KNN so it could be possible to modify the KNN forward kernels to support ball query. - Some users might want to use kNN with ball query - for this we could provide a wrapper function around the current `knn_points` which enables applying the radius threshold afterwards as an alternative. This could be called `ball_query_knn`. Reviewed By: jcjohnson Differential Revision: D30261362 fbshipit-source-id: 66b6a7e0114beff7164daf7eba21546ff41ec450
- Loading branch information
1 parent
e5c58a8
commit 103da63
Showing
10 changed files
with
709 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
#include <math.h> | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
#include "utils/pytorch3d_cutils.h" | ||
|
||
// A chunk of work is blocksize-many points of P1. | ||
// The number of potential chunks to do is N*(1+(P1-1)/blocksize) | ||
// call (1+(P1-1)/blocksize) chunks_per_cloud | ||
// These chunks are divided among the gridSize-many blocks. | ||
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . | ||
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from | ||
// blocksize*(i%chunks_per_cloud). | ||
|
||
template <typename scalar_t> | ||
__global__ void BallQueryKernel( | ||
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p1, | ||
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p2, | ||
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> | ||
lengths1, | ||
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> | ||
lengths2, | ||
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs, | ||
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists, | ||
const int64_t K, | ||
const float radius2) { | ||
const int64_t N = p1.size(0); | ||
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x); | ||
const int64_t chunks_to_do = N * chunks_per_cloud; | ||
const int D = p1.size(2); | ||
|
||
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { | ||
const int64_t n = chunk / chunks_per_cloud; // batch_index | ||
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); | ||
int64_t i = start_point + threadIdx.x; | ||
|
||
// Check if point is valid in heterogeneous tensor | ||
if (i >= lengths1[n]) { | ||
continue; | ||
} | ||
|
||
// Iterate over points in p2 until desired count is reached or | ||
// all points have been considered | ||
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) { | ||
// Calculate the distance between the points | ||
scalar_t dist2 = 0.0; | ||
for (int d = 0; d < D; ++d) { | ||
scalar_t diff = p1[n][i][d] - p2[n][j][d]; | ||
dist2 += (diff * diff); | ||
} | ||
|
||
if (dist2 < radius2) { | ||
// If the point is within the radius | ||
// Set the value of the index to the point index | ||
idxs[n][i][count] = j; | ||
dists[n][i][count] = dist2; | ||
|
||
// increment the number of selected samples for the point i | ||
++count; | ||
} | ||
} | ||
} | ||
} | ||
|
||
std::tuple<at::Tensor, at::Tensor> BallQueryCuda( | ||
const at::Tensor& p1, // (N, P1, 3) | ||
const at::Tensor& p2, // (N, P2, 3) | ||
const at::Tensor& lengths1, // (N,) | ||
const at::Tensor& lengths2, // (N,) | ||
int K, | ||
float radius) { | ||
// Check inputs are on the same device | ||
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, | ||
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}; | ||
at::CheckedFrom c = "BallQueryCuda"; | ||
at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t}); | ||
at::checkAllSameType(c, {p1_t, p2_t}); | ||
|
||
// Set the device for the kernel launch based on the device of p1 | ||
at::cuda::CUDAGuard device_guard(p1.device()); | ||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
TORCH_CHECK( | ||
p2.size(2) == p1.size(2), "Point sets must have the same last dimension"); | ||
|
||
const int N = p1.size(0); | ||
const int P1 = p1.size(1); | ||
const int64_t K_64 = K; | ||
const float radius2 = radius * radius; | ||
|
||
// Output tensor with indices of neighbors for each point in p1 | ||
auto long_dtype = lengths1.options().dtype(at::kLong); | ||
auto idxs = at::full({N, P1, K}, -1, long_dtype); | ||
auto dists = at::zeros({N, P1, K}, p1.options()); | ||
|
||
if (idxs.numel() == 0) { | ||
AT_CUDA_CHECK(cudaGetLastError()); | ||
return std::make_tuple(idxs, dists); | ||
} | ||
|
||
const size_t blocks = 256; | ||
const size_t threads = 256; | ||
|
||
AT_DISPATCH_FLOATING_TYPES( | ||
p1.scalar_type(), "ball_query_kernel_cuda", ([&] { | ||
BallQueryKernel<<<blocks, threads, 0, stream>>>( | ||
p1.packed_accessor64<float, 3, at::RestrictPtrTraits>(), | ||
p2.packed_accessor64<float, 3, at::RestrictPtrTraits>(), | ||
lengths1.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(), | ||
lengths2.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(), | ||
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(), | ||
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(), | ||
K_64, | ||
radius2); | ||
})); | ||
|
||
AT_CUDA_CHECK(cudaGetLastError()); | ||
|
||
return std::make_tuple(idxs, dists); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
#include <torch/extension.h> | ||
#include <tuple> | ||
#include "utils/pytorch3d_cutils.h" | ||
|
||
// Compute indices of K neighbors in pointcloud p2 to points | ||
// in pointcloud p1 which fall within a specified radius | ||
// | ||
// Args: | ||
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each | ||
// containing P1 points of dimension D. | ||
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each | ||
// containing P2 points of dimension D. | ||
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. | ||
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. | ||
// K: Integer giving the upper bound on the number of samples to take | ||
// within the radius | ||
// radius: the radius around each point within which the neighbors need to be | ||
// located | ||
// | ||
// Returns: | ||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where | ||
// p1_neighbor_idx[n, i, k] = j means that the kth | ||
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. | ||
// This is padded with -1s both where a cloud in p2 has fewer than | ||
// S points and where a cloud in p1 has fewer than P1 points and | ||
// also if there are fewer than K points which satisfy the radius | ||
// threshold. | ||
// | ||
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared | ||
// distance from each point p1[n, p, :] to its K neighbors | ||
// p2[n, p1_neighbor_idx[n, p, k], :]. | ||
|
||
// CPU implementation | ||
std::tuple<at::Tensor, at::Tensor> BallQueryCpu( | ||
const at::Tensor& p1, | ||
const at::Tensor& p2, | ||
const at::Tensor& lengths1, | ||
const at::Tensor& lengths2, | ||
const int K, | ||
const float radius); | ||
|
||
// CUDA implementation | ||
std::tuple<at::Tensor, at::Tensor> BallQueryCuda( | ||
const at::Tensor& p1, | ||
const at::Tensor& p2, | ||
const at::Tensor& lengths1, | ||
const at::Tensor& lengths2, | ||
const int K, | ||
const float radius); | ||
|
||
// Implementation which is exposed | ||
// Note: the backward pass reuses the KNearestNeighborBackward kernel | ||
inline std::tuple<at::Tensor, at::Tensor> BallQuery( | ||
const at::Tensor& p1, | ||
const at::Tensor& p2, | ||
const at::Tensor& lengths1, | ||
const at::Tensor& lengths2, | ||
int K, | ||
float radius) { | ||
if (p1.is_cuda() || p2.is_cuda()) { | ||
#ifdef WITH_CUDA | ||
CHECK_CUDA(p1); | ||
CHECK_CUDA(p2); | ||
return BallQueryCuda( | ||
p1.contiguous(), | ||
p2.contiguous(), | ||
lengths1.contiguous(), | ||
lengths2.contiguous(), | ||
K, | ||
radius); | ||
#else | ||
AT_ERROR("Not compiled with GPU support."); | ||
#endif | ||
} | ||
return BallQueryCpu( | ||
p1.contiguous(), | ||
p2.contiguous(), | ||
lengths1.contiguous(), | ||
lengths2.contiguous(), | ||
K, | ||
radius); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <torch/extension.h> | ||
#include <queue> | ||
#include <tuple> | ||
|
||
std::tuple<at::Tensor, at::Tensor> BallQueryCpu( | ||
const at::Tensor& p1, | ||
const at::Tensor& p2, | ||
const at::Tensor& lengths1, | ||
const at::Tensor& lengths2, | ||
int K, | ||
float radius) { | ||
const int N = p1.size(0); | ||
const int P1 = p1.size(1); | ||
const int D = p1.size(2); | ||
|
||
auto long_opts = lengths1.options().dtype(torch::kInt64); | ||
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts); | ||
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options()); | ||
const float radius2 = radius * radius; | ||
|
||
auto p1_a = p1.accessor<float, 3>(); | ||
auto p2_a = p2.accessor<float, 3>(); | ||
auto lengths1_a = lengths1.accessor<int64_t, 1>(); | ||
auto lengths2_a = lengths2.accessor<int64_t, 1>(); | ||
auto idxs_a = idxs.accessor<int64_t, 3>(); | ||
auto dists_a = dists.accessor<float, 3>(); | ||
|
||
for (int n = 0; n < N; ++n) { | ||
const int64_t length1 = lengths1_a[n]; | ||
const int64_t length2 = lengths2_a[n]; | ||
for (int64_t i = 0; i < length1; ++i) { | ||
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) { | ||
float dist2 = 0; | ||
for (int d = 0; d < D; ++d) { | ||
float diff = p1_a[n][i][d] - p2_a[n][j][d]; | ||
dist2 += diff * diff; | ||
} | ||
if (dist2 < radius2) { | ||
dists_a[n][i][count] = dist2; | ||
idxs_a[n][i][count] = j; | ||
++count; | ||
} | ||
} | ||
} | ||
} | ||
return std::make_tuple(idxs, dists); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.