Skip to content

Commit

Permalink
pytorch TORCH_CHECK_ARG version compatibility
Browse files Browse the repository at this point in the history
Summary: Restore compatibility with old C++ after recent torch change. #995

Reviewed By: patricklabatut

Differential Revision: D33093174

fbshipit-source-id: 841202fb875d601db265e93dcf9cfa4249d02b25
  • Loading branch information
bottler authored and facebook-github-bot committed Dec 15, 2021
1 parent 9eec430 commit 069c9fd
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
4 changes: 3 additions & 1 deletion pytorch3d/csrc/pulsar/cuda/commands.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ __device__ static float atomicMin(float* address, float val) {
#define IABS(a) abs(a)

// Checks.
#define ARGCHECK TORCH_CHECK_ARG
// like TORCH_CHECK_ARG in PyTorch > 1.10
#define ARGCHECK(cond, argN, ...) \
TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__)

// Math.
#define NORM3DF(x, y, z) norm3df(x, y, z)
Expand Down
4 changes: 3 additions & 1 deletion pytorch3d/csrc/pulsar/host/commands.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ INLINE void ATOMICADD_F3(T* address, T val) {
#define IABS(a) abs(a)

// Checks.
#define ARGCHECK TORCH_CHECK_ARG
// like TORCH_CHECK_ARG in PyTorch > 1.10
#define ARGCHECK(cond, argN, ...) \
TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__)

// Math.
#define NORM3DF(x, y, z) sqrtf(x* x + y * y + z * z)
Expand Down
6 changes: 6 additions & 0 deletions pytorch3d/csrc/pulsar/pytorch/renderer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
#include <c10/cuda/CUDAGuard.h>
#endif

#ifndef TORCH_CHECK_ARG
// torch <= 1.10
#define TORCH_CHECK_ARG(cond, argN, ...) \
TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__)
#endif

namespace PRE = ::pulsar::Renderer;

namespace pulsar {
Expand Down

0 comments on commit 069c9fd

Please sign in to comment.