Skip to content

Commit

Permalink
mesh_normal_consistency speedup
Browse files Browse the repository at this point in the history
Summary: One step in finding all the pairs of vertices which share faces is a simple calculation but annoying to parallelize. It was implemented in pure Python. We move it to C++. We still pull the data to the CPU and put the answer back on the device.

Reviewed By: nikhilaravi, gkioxari

Differential Revision: D26073475

fbshipit-source-id: ffbf4e2c347a511ab5084bceff600465812b6a52
  • Loading branch information
bottler authored and facebook-github-bot committed Feb 11, 2021
1 parent 5ac2f42 commit 4bfe715
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 28 deletions.
3 changes: 3 additions & 0 deletions pytorch3d/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "gather_scatter/gather_scatter.h"
#include "interp_face_attrs/interp_face_attrs.h"
#include "knn/knn.h"
#include "mesh_normal_consistency/mesh_normal_consistency.h"
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
#include "point_mesh/point_mesh_cuda.h"
#include "rasterize_meshes/rasterize_meshes.h"
Expand All @@ -31,6 +32,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif
m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("knn_points_backward", &KNearestNeighborBackward);
m.def(
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
m.def("gather_scatter", &GatherScatter);
m.def("rasterize_points", &RasterizePoints);
m.def("rasterize_points_backward", &RasterizePointsBackward);
Expand Down
24 changes: 24 additions & 0 deletions pytorch3d/csrc/mesh_normal_consistency/mesh_normal_consistency.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

#pragma once
#include <torch/extension.h>
#include "utils/pytorch3d_cutils.h"

// For mesh_normal_consistency, find pairs of vertices opposite the same edge.
//
// Args:
// edge_num: int64 Tensor of shape (E,) giving the number of vertices
// corresponding to each edge.
//
// Returns:
// pairs: int64 Tensor of shape (N,2)

at::Tensor MeshNormalConsistencyFindVerticesCpu(const at::Tensor& edge_num);

// Exposed implementation.
at::Tensor MeshNormalConsistencyFindVertices(const at::Tensor& edge_num) {
if (edge_num.is_cuda()) {
AT_ERROR("This function needs a CPU tensor.");
}
return MeshNormalConsistencyFindVerticesCpu(edge_num);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

#include <ATen/ATen.h>
#include <utility>
#include <vector>

at::Tensor MeshNormalConsistencyFindVerticesCpu(const at::Tensor& edge_num) {
// We take a LongTensor of shape (E,) giving the number of things intersecting
// each edge. The things are taken to be numbered in order.
// (In fact, the "things" are opposite vertices to edges, renumbered).
// We return a tensor of shape (?, 2) where for every pair of things which
// intersect the same edge there is a row of their numbers in the output.

// Example possible inputs and outputs (order of output is not specified):
// [1,0,1,1,0] => [[]]
// [3] => [[0,1], [0,2], [1,2]]
// [0,3] => [[0,1], [0,2], [1,2]]
// [1,3] => [[1,2], [1,3], [2,3]]
//[1,0,2,1,0,2] => [[1,2], [4,5]]

const auto num_edges = edge_num.size(0);
auto edges_a = edge_num.accessor<int64_t, 1>();

int64_t vert_idx = 0;
std::vector<std::pair<int64_t, int64_t>> pairs;
for (int64_t i_edge = 0; i_edge < num_edges; ++i_edge) {
int64_t e = edges_a[i_edge];
for (int64_t j = 0; j < e; ++j) {
for (int64_t i = 0; i < j; ++i) {
pairs.emplace_back(vert_idx + i, vert_idx + j);
}
}
vert_idx += e;
}

// Convert from std::vector by copying over the items to a new empty torch
// tensor.
auto pairs_tensor = at::empty({(int64_t)pairs.size(), 2}, edge_num.options());
auto pairs_a = pairs_tensor.accessor<int64_t, 2>();
for (int64_t i_pair = 0; i_pair < pairs.size(); ++i_pair) {
auto accessor = pairs_a[i_pair];
accessor[0] = pairs[i_pair].first;
accessor[1] = pairs[i_pair].second;
}

return pairs_tensor;
}
38 changes: 10 additions & 28 deletions pytorch3d/loss/mesh_normal_consistency.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


from itertools import islice

import torch

# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
from pytorch3d import _C


def mesh_normal_consistency(meshes):
r"""
Expand Down Expand Up @@ -71,9 +71,9 @@ def mesh_normal_consistency(meshes):
F = faces_packed.shape[0] # sum(F_n)

# We don't want gradients for the following operation. The goal is to
# find for each edge e all the vertices associated with e. In the example above,
# the vertices associated with e are (v0, v1, a, b), i.e. points on e (=v0, v1)
# and points connected on faces to e (=a, b).
# find for each edge e all the vertices associated with e. In the example
# above, the vertices associated with e are (a, b), i.e. the points connected
# on faces to e.
with torch.no_grad():
edge_idx = face_to_edge.reshape(F * 3) # (3 * F,) indexes into edges
vert_idx = (
Expand All @@ -95,23 +95,10 @@ def mesh_normal_consistency(meshes):
# the number of vertices which are associated with each edge.
# There can be a different number for each edge.
edge_num = edge_idx.bincount(minlength=E)
# Create pairs of vertices associated to e. We generate a list of lists:
# each list has the indices of the vertices which are opposite to one edge.
# The length of the list for each edge will vary.
vert_edge_pair_idx = split_list(
list(range(edge_idx.shape[0])), edge_num.tolist()
)
# For each list find all combinations of pairs in the list. This represents
# all pairs of vertices which are opposite to the same edge.
vert_edge_pair_idx = [
[e[i], e[j]]
for e in vert_edge_pair_idx
for i in range(len(e) - 1)
for j in range(1, len(e))
if i < j
]
vert_edge_pair_idx = torch.tensor(
vert_edge_pair_idx, device=meshes.device, dtype=torch.int64

# This calculates all pairs of vertices which are opposite to the same edge.
vert_edge_pair_idx = _C.mesh_normal_consistency_find_verts(edge_num.cpu()).to(
edge_num.device
)

if vert_edge_pair_idx.shape[0] == 0:
Expand Down Expand Up @@ -141,8 +128,3 @@ def mesh_normal_consistency(meshes):

loss = loss * weights
return loss.sum() / N


def split_list(input, length_to_split):
inputt = iter(input)
return [list(islice(inputt, elem)) for elem in length_to_split]

0 comments on commit 4bfe715

Please sign in to comment.