Skip to content

Commit

Permalink
Add check for verts and faces being on same device and also checks fo…
Browse files Browse the repository at this point in the history
…r pointclouds/features/normals being on the same device (#384)

Summary: Pull Request resolved: #384

Test Plan: `test_meshes` and `test_points`

Reviewed By: gkioxari

Differential Revision: D24730524

Pulled By: nikhilaravi

fbshipit-source-id: acbd35be5d9f1b13b4d56f3db14f6e8c2c0f7596
  • Loading branch information
Randl authored and facebook-github-bot committed Dec 15, 2020
1 parent 1934046 commit 569e522
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 2 deletions.
12 changes: 11 additions & 1 deletion pytorch3d/structures/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ def __init__(self, verts=None, faces=None, textures=None):
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
if self._N > 0:
self.device = self._verts_list[0].device
if not (
all(v.device == self.device for v in verts)
and all(f.device == self.device for f in faces)
):
raise ValueError(
"All Verts and Faces tensors should be on same device."
)
self._num_verts_per_mesh = torch.tensor(
[len(v) for v in self._verts_list], device=self.device
)
Expand All @@ -341,7 +348,6 @@ def __init__(self, verts=None, faces=None, textures=None):
dtype=torch.bool,
device=self.device,
)

if (len(self._num_verts_per_mesh.unique()) == 1) and (
len(self._num_faces_per_mesh.unique()) == 1
):
Expand All @@ -355,6 +361,10 @@ def __init__(self, verts=None, faces=None, textures=None):
self._N = self._verts_padded.shape[0]
self._V = self._verts_padded.shape[1]

if verts.device != faces.device:
msg = "Verts and Faces tensors should be on same device. \n Got {} and {}."
raise ValueError(msg.format(verts.device, faces.device))

self.device = self._verts_padded.device
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
if self._N > 0:
Expand Down
12 changes: 11 additions & 1 deletion pytorch3d/structures/pointclouds.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,13 @@ def __init__(self, points, normals=None, features=None):
self._num_points_per_cloud = []

if self._N > 0:
self.device = self._points_list[0].device
for p in self._points_list:
if len(p) > 0 and (p.dim() != 2 or p.shape[1] != 3):
raise ValueError("Clouds in list must be of shape Px3 or empty")
if p.device != self.device:
raise ValueError("All points must be on the same device")

self.device = self._points_list[0].device
num_points_per_cloud = torch.tensor(
[len(p) for p in self._points_list], device=self.device
)
Expand Down Expand Up @@ -261,6 +263,10 @@ def _parse_auxiliary_input(self, aux_input):
raise ValueError(
"A cloud has mismatched numbers of points and inputs"
)
if d.device != self.device:
raise ValueError(
"All auxillary inputs must be on the same device as the points."
)
if p > 0:
if d.dim() != 2:
raise ValueError(
Expand All @@ -283,6 +289,10 @@ def _parse_auxiliary_input(self, aux_input):
"Inputs tensor must have the right maximum \
number of points in each cloud."
)
if aux_input.device != self.device:
raise ValueError(
"All auxillary inputs must be on the same device as the points."
)
aux_input_C = aux_input.shape[2]
return None, aux_input, aux_input_C
else:
Expand Down
24 changes: 24 additions & 0 deletions tests/test_meshes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import random
import unittest

import numpy as np
Expand Down Expand Up @@ -162,6 +163,29 @@ def test_simple(self):
torch.tensor([0, 3, 8], dtype=torch.int64),
)

def test_init_error(self):
# Check if correct errors are raised when verts/faces are on
# different devices

mesh = TestMeshes.init_mesh(10, 10, 100)
verts_list = mesh.verts_list() # all tensors on cpu
verts_list = [
v.to("cuda:0") if random.uniform(0, 1) > 0.5 else v for v in verts_list
]
faces_list = mesh.faces_list()

with self.assertRaises(ValueError) as cm:
Meshes(verts=verts_list, faces=faces_list)
self.assertTrue("same device" in cm.msg)

verts_padded = mesh.verts_padded() # on cpu
verts_padded = verts_padded.to("cuda:0")
faces_padded = mesh.faces_padded()

with self.assertRaises(ValueError) as cm:
Meshes(verts=verts_padded, faces=faces_padded)
self.assertTrue("same device" in cm.msg)

def test_simple_random_meshes(self):

# Define the test mesh object either as a list or tensor of faces/verts.
Expand Down
39 changes: 39 additions & 0 deletions tests/test_pointclouds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import random
import unittest

import numpy as np
Expand Down Expand Up @@ -126,6 +127,44 @@ def test_simple(self):
torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]),
)

def test_init_error(self):
# Check if correct errors are raised when verts/faces are on
# different devices

clouds = self.init_cloud(10, 100, 5)
points_list = clouds.points_list() # all tensors on cuda:0
points_list = [
p.to("cpu") if random.uniform(0, 1) > 0.5 else p for p in points_list
]
features_list = clouds.features_list()
normals_list = clouds.normals_list()

with self.assertRaises(ValueError) as cm:
Pointclouds(
points=points_list, features=features_list, normals=normals_list
)
self.assertTrue("same device" in cm.msg)

points_list = clouds.points_list()
features_list = [
f.to("cpu") if random.uniform(0, 1) > 0.2 else f for f in features_list
]
with self.assertRaises(ValueError) as cm:
Pointclouds(
points=points_list, features=features_list, normals=normals_list
)
self.assertTrue("same device" in cm.msg)

points_padded = clouds.points_padded() # on cuda:0
features_padded = clouds.features_padded().to("cpu")
normals_padded = clouds.normals_padded()

with self.assertRaises(ValueError) as cm:
Pointclouds(
points=points_padded, features=features_padded, normals=normals_padded
)
self.assertTrue("same device" in cm.msg)

def test_all_constructions(self):
public_getters = [
"points_list",
Expand Down

0 comments on commit 569e522

Please sign in to comment.