diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index 3f84bf036..2122845c8 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -8,9 +8,10 @@ import numpy as np import torch +import torch.nn as nn -class TensorAccessor(object): +class TensorAccessor(nn.Module): """ A helper class to be used with the __getitem__ method. This can be used for getting/setting the values for an attribute of a class at one particular @@ -82,7 +83,7 @@ def __getattr__(self, name: str): BROADCAST_TYPES = (float, int, list, tuple, torch.Tensor, np.ndarray) -class TensorProperties(object): +class TensorProperties(nn.Module): """ A mix-in class for storing tensors as properties with helper methods. """ diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index 511f95db9..4f0c2ba27 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -1042,67 +1042,3 @@ def test_simple_sphere_outside_zfar(self): ) self.assertClose(rgb, image_ref, atol=0.05) - - def test_to(self): - # Test moving all the tensors in the renderer to a new device - # to support multigpu rendering. - device1 = torch.device("cpu") - - R, T = look_at_view_transform(1500, 0.0, 0.0) - - # Init shader settings - materials = Materials(device=device1) - lights = PointLights(device=device1) - lights.location = torch.tensor([0.0, 0.0, +1000.0], device=device1)[None] - - raster_settings = RasterizationSettings( - image_size=256, blur_radius=0.0, faces_per_pixel=1 - ) - cameras = FoVPerspectiveCameras( - device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100 - ) - rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) - - blend_params = BlendParams( - 1e-4, - 1e-4, - background_color=torch.zeros(3, dtype=torch.float32, device=device1), - ) - - shader = SoftPhongShader( - lights=lights, - cameras=cameras, - materials=materials, - blend_params=blend_params, - ) - renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) - - def _check_props_on_device(renderer, device): - self.assertEqual(renderer.rasterizer.cameras.device, device) - self.assertEqual(renderer.shader.cameras.device, device) - self.assertEqual(renderer.shader.lights.device, device) - self.assertEqual(renderer.shader.lights.ambient_color.device, device) - self.assertEqual(renderer.shader.materials.device, device) - self.assertEqual(renderer.shader.materials.ambient_color.device, device) - - mesh = ico_sphere(2, device1) - verts_padded = mesh.verts_padded() - textures = TexturesVertex( - verts_features=torch.ones_like(verts_padded, device=device1) - ) - mesh.textures = textures - _check_props_on_device(renderer, device1) - - # Test rendering on cpu - output_images = renderer(mesh) - self.assertEqual(output_images.device, device1) - - # Move renderer and mesh to another device and re render - # This also tests that background_color is correctly moved to - # the new device - device2 = torch.device("cuda:0") - renderer.to(device2) - mesh = mesh.to(device2) - _check_props_on_device(renderer, device2) - output_images = renderer(mesh) - self.assertEqual(output_images.device, device2) diff --git a/tests/test_render_multigpu.py b/tests/test_render_multigpu.py new file mode 100644 index 000000000..298ddb7f9 --- /dev/null +++ b/tests/test_render_multigpu.py @@ -0,0 +1,159 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest + +import torch +import torch.nn as nn +from common_testing import TestCaseMixin, get_random_cuda_device +from pytorch3d.renderer import ( + BlendParams, + HardGouraudShader, + Materials, + MeshRasterizer, + MeshRenderer, + PointLights, + RasterizationSettings, + SoftPhongShader, + TexturesVertex, +) +from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform +from pytorch3d.structures.meshes import Meshes +from pytorch3d.utils.ico_sphere import ico_sphere + + +# Set the number of GPUS you want to test with +NUM_GPUS = 3 +GPU_LIST = list({get_random_cuda_device() for _ in range(NUM_GPUS)}) +print("GPUs: %s" % ", ".join(GPU_LIST)) + + +class TestRenderMultiGPU(TestCaseMixin, unittest.TestCase): + def _check_mesh_renderer_props_on_device(self, renderer, device): + """ + Helper function to check that all the properties of the mesh + renderer have been moved to the correct device. + """ + # Cameras + self.assertEqual(renderer.rasterizer.cameras.device, device) + self.assertEqual(renderer.rasterizer.cameras.R.device, device) + self.assertEqual(renderer.rasterizer.cameras.T.device, device) + self.assertEqual(renderer.shader.cameras.device, device) + self.assertEqual(renderer.shader.cameras.R.device, device) + self.assertEqual(renderer.shader.cameras.T.device, device) + + # Lights and Materials + self.assertEqual(renderer.shader.lights.device, device) + self.assertEqual(renderer.shader.lights.ambient_color.device, device) + self.assertEqual(renderer.shader.materials.device, device) + self.assertEqual(renderer.shader.materials.ambient_color.device, device) + + def test_mesh_renderer_to(self): + """ + Test moving all the tensors in the mesh renderer to a new device. + """ + + device1 = torch.device("cpu") + + R, T = look_at_view_transform(1500, 0.0, 0.0) + + # Init shader settings + materials = Materials(device=device1) + lights = PointLights(device=device1) + lights.location = torch.tensor([0.0, 0.0, +1000.0], device=device1)[None] + + raster_settings = RasterizationSettings( + image_size=256, blur_radius=0.0, faces_per_pixel=1 + ) + cameras = FoVPerspectiveCameras( + device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100 + ) + rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) + + blend_params = BlendParams( + 1e-4, + 1e-4, + background_color=torch.zeros(3, dtype=torch.float32, device=device1), + ) + + shader = SoftPhongShader( + lights=lights, + cameras=cameras, + materials=materials, + blend_params=blend_params, + ) + renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) + + mesh = ico_sphere(2, device1) + verts_padded = mesh.verts_padded() + textures = TexturesVertex( + verts_features=torch.ones_like(verts_padded, device=device1) + ) + mesh.textures = textures + self._check_mesh_renderer_props_on_device(renderer, device1) + + # Test rendering on cpu + output_images = renderer(mesh) + self.assertEqual(output_images.device, device1) + + # Move renderer and mesh to another device and re render + # This also tests that background_color is correctly moved to + # the new device + device2 = torch.device("cuda:0") + renderer.to(device2) + mesh = mesh.to(device2) + self._check_mesh_renderer_props_on_device(renderer, device2) + output_images = renderer(mesh) + self.assertEqual(output_images.device, device2) + + def test_render_meshes(self): + test = self + + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + mesh = ico_sphere(3) + self.register_buffer("faces", mesh.faces_padded()) + self.renderer = self.init_render() + + def init_render(self): + + cameras = FoVPerspectiveCameras() + raster_settings = RasterizationSettings( + image_size=128, blur_radius=0.0, faces_per_pixel=1 + ) + lights = PointLights( + ambient_color=((1.0, 1.0, 1.0),), + diffuse_color=((0, 0.0, 0),), + specular_color=((0.0, 0, 0),), + location=((0.0, 0.0, 1e5),), + ) + renderer = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras, raster_settings=raster_settings + ), + shader=HardGouraudShader(cameras=cameras, lights=lights), + ) + return renderer + + def forward(self, verts, texs): + batch_size = verts.size(0) + self.renderer.to(verts.device) + tex = TexturesVertex(verts_features=texs) + faces = self.faces.expand(batch_size, -1, -1).to(verts.device) + mesh = Meshes(verts, faces, tex).to(verts.device) + + test._check_mesh_renderer_props_on_device(self.renderer, verts.device) + img_render = self.renderer(mesh) + return img_render[:, :, :, :3] + + # DataParallel requires every input tensor be provided + # on the first device in its device_ids list. + verts = ico_sphere(3).verts_padded() + texs = verts.new_ones(verts.shape) + model = Model() + model = nn.DataParallel(model, device_ids=GPU_LIST) + model.to(f"cuda:{model.device_ids[0]}") + + # Test a few iterations + for _ in range(100): + model(verts, texs)