diff --git a/pytorch3d/csrc/marching_cubes/marching_cubes.cu b/pytorch3d/csrc/marching_cubes/marching_cubes.cu index e1b4e733..44d50934 100644 --- a/pytorch3d/csrc/marching_cubes/marching_cubes.cu +++ b/pytorch3d/csrc/marching_cubes/marching_cubes.cu @@ -223,7 +223,7 @@ __global__ void CompactVoxelsKernel( compactedVoxelArray, const at::PackedTensorAccessor32 voxelOccupied, - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 voxelOccupiedScan, uint numVoxels) { uint id = blockIdx.x * blockDim.x + threadIdx.x; @@ -255,7 +255,8 @@ __global__ void GenerateFacesKernel( at::PackedTensorAccessor ids, at::PackedTensorAccessor32 compactedVoxelArray, - at::PackedTensorAccessor32 numVertsScanned, + at::PackedTensorAccessor32 + numVertsScanned, const uint activeVoxels, const at::PackedTensorAccessor32 vol, const at::PackedTensorAccessor32 faceTable, @@ -471,7 +472,7 @@ std::tuple MarchingCubesCuda( auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)}); // number of active voxels - int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item(); + int64_t activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item(); const int device_id = vol.device().index(); auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id); @@ -492,7 +493,8 @@ std::tuple MarchingCubesCuda( CompactVoxelsKernel<<>>( d_compVoxelArray.packed_accessor32(), d_voxelOccupied.packed_accessor32(), - d_voxelOccupiedScan_.packed_accessor32(), + d_voxelOccupiedScan_ + .packed_accessor32(), numVoxels); AT_CUDA_CHECK(cudaGetLastError()); cudaDeviceSynchronize(); @@ -502,7 +504,7 @@ std::tuple MarchingCubesCuda( auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)}); // total number of vertices - int totalVerts = d_voxelVertsScan[numVoxels].cpu().item(); + int64_t totalVerts = d_voxelVertsScan[numVoxels].cpu().item(); // Execute "GenerateFacesKernel" kernel // This runs only on the occupied voxels. @@ -522,7 +524,7 @@ std::tuple MarchingCubesCuda( faces.packed_accessor(), ids.packed_accessor(), d_compVoxelArray.packed_accessor32(), - d_voxelVertsScan_.packed_accessor32(), + d_voxelVertsScan_.packed_accessor32(), activeVoxels, vol.packed_accessor32(), faceTable.packed_accessor32(),