diff --git a/pytorch3d/datasets/shapenet/__init__.py b/pytorch3d/datasets/shapenet/__init__.py index dd0bc8630..44469dab2 100644 --- a/pytorch3d/datasets/shapenet/__init__.py +++ b/pytorch3d/datasets/shapenet/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + from .shapenet_core import ShapeNetCore diff --git a/pytorch3d/datasets/shapenet/shapenet_core.py b/pytorch3d/datasets/shapenet/shapenet_core.py index 25b6bca01..e28ae7973 100644 --- a/pytorch3d/datasets/shapenet/shapenet_core.py +++ b/pytorch3d/datasets/shapenet/shapenet_core.py @@ -5,15 +5,16 @@ import warnings from os import path from pathlib import Path +from typing import Dict -import torch +from pytorch3d.datasets.shapenet_base import ShapeNetBase from pytorch3d.io import load_obj SYNSET_DICT_DIR = Path(__file__).resolve().parent -class ShapeNetCore(torch.utils.data.Dataset): +class ShapeNetCore(ShapeNetBase): """ This class loads ShapeNetCore from a given directory into a Dataset object. ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from @@ -23,6 +24,7 @@ class ShapeNetCore(torch.utils.data.Dataset): def __init__(self, data_dir, synsets=None, version: int = 1): """ Store each object's synset id and models id from data_dir. + Args: data_dir: Path to ShapeNetCore data. synsets: List of synset categories to load from ShapeNetCore in the form of @@ -38,6 +40,7 @@ def __init__(self, data_dir, synsets=None, version: int = 1): version 1. """ + super().__init__() self.data_dir = data_dir if version not in [1, 2]: raise ValueError("Version number must be either 1 or 2.") @@ -48,7 +51,7 @@ def __init__(self, data_dir, synsets=None, version: int = 1): with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict: self.synset_dict = json.load(read_dict) # Inverse dicitonary mapping synset labels to corresponding offsets. - synset_inv = {label: offset for offset, label in self.synset_dict.items()} + self.synset_inv = {label: offset for offset, label in self.synset_dict.items()} # If categories are specified, check if each category is in the form of either # synset offset or synset label, and if the category exists in the given directory. @@ -60,62 +63,61 @@ def __init__(self, data_dir, synsets=None, version: int = 1): path.isdir(path.join(data_dir, synset)) ): synset_set.add(synset) - elif (synset in synset_inv.keys()) and ( - (path.isdir(path.join(data_dir, synset_inv[synset]))) + elif (synset in self.synset_inv.keys()) and ( + (path.isdir(path.join(data_dir, self.synset_inv[synset]))) ): - synset_set.add(synset_inv[synset]) + synset_set.add(self.synset_inv[synset]) else: - msg = """Synset category %s either not part of ShapeNetCore dataset - or cannot be found in %s.""" % ( - synset, - data_dir, - ) + msg = ( + "Synset category %s either not part of ShapeNetCore dataset " + "or cannot be found in %s." + ) % (synset, data_dir) warnings.warn(msg) # If no category is given, load every category in the given directory. + # Ignore synset folders not included in the official mapping. else: synset_set = { synset for synset in os.listdir(data_dir) if path.isdir(path.join(data_dir, synset)) + and synset in self.synset_dict } - for synset in synset_set: - if synset not in self.synset_dict.keys(): - msg = """Synset category %s(%s) is part of ShapeNetCore ver.%s - but not found in %s.""" % ( - synset, - self.synset_dict[synset], - version, - data_dir, - ) - warnings.warn(msg) + + # Check if there are any categories in the official mapping that are not loaded. + # Update self.synset_inv so that it only includes the loaded categories. + synset_not_present = set(self.synset_dict.keys()).difference(synset_set) + [self.synset_inv.pop(self.synset_dict[synset]) for synset in synset_not_present] + + if len(synset_not_present) > 0: + msg = ( + "The following categories are included in ShapeNetCore ver.%d's " + "official mapping but not found in the dataset location %s: %s" + "" + ) % (version, data_dir, ", ".join(synset_not_present)) + warnings.warn(msg) # Extract model_id of each object from directory names. # Each grandchildren directory of data_dir contains an object, and the name # of the directory is the object's model_id. - self.synset_ids = [] - self.model_ids = [] for synset in synset_set: for model in os.listdir(path.join(data_dir, synset)): if not path.exists(path.join(data_dir, synset, model, self.model_dir)): - msg = """ Object file not found in the model directory %s - under synset directory %s.""" % ( - model, - synset, - ) + msg = ( + "Object file not found in the model directory %s " + "under synset directory %s." + ) % (model, synset) warnings.warn(msg) - else: - self.synset_ids.append(synset) - self.model_ids.append(model) + continue + self.synset_ids.append(synset) + self.model_ids.append(model) - def __len__(self): - """ - Return number of total models in shapenet core. - """ - return len(self.model_ids) - - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Dict: """ Read a model by the given index. + + Args: + idx: The idx of the model to be retrieved in the dataset. + Returns: dictionary with following keys: - verts: FloatTensor of shape (V, 3). @@ -124,9 +126,7 @@ def __getitem__(self, idx): - model_id (str): model id - label (str): synset label. """ - model = {} - model["synset_id"] = self.synset_ids[idx] - model["model_id"] = self.model_ids[idx] + model = self._get_item_ids(idx) model_path = path.join( self.data_dir, model["synset_id"], model["model_id"], self.model_dir ) diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py new file mode 100644 index 000000000..f76546cef --- /dev/null +++ b/pytorch3d/datasets/shapenet_base.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import Dict + +import torch +from pytorch3d.renderer import ( + HardPhongShader, + MeshRasterizer, + MeshRenderer, + OpenGLPerspectiveCameras, + PointLights, + RasterizationSettings, +) +from pytorch3d.structures import Meshes, Textures + + +class ShapeNetBase(torch.utils.data.Dataset): + """ + 'ShapeNetBase' implements a base Dataset for ShapeNet and R2N2 with helper methods. + It is not intended to be used on its own as a Dataset for a Dataloader. Both __init__ + and __getitem__ need to be implemented. + """ + + def __init__(self): + """ + Set up lists of synset_ids and model_ids. + """ + self.synset_ids = [] + self.model_ids = [] + + def __len__(self): + """ + Return number of total models in the loaded dataset. + """ + return len(self.model_ids) + + def __getitem__(self, idx) -> Dict: + """ + Read a model by the given index. Need to be implemented for every child class + of ShapeNetBase. + + Args: + idx: The idx of the model to be retrieved in the dataset. + + Returns: + dictionary containing information about the model. + """ + raise NotImplementedError( + "__getitem__ should be implemented in the child class of ShapeNetBase" + ) + + def _get_item_ids(self, idx) -> Dict: + """ + Read a model by the given index. + + Args: + idx: The idx of the model to be retrieved in the dataset. + + Returns: + dictionary with following keys: + - synset_id (str): synset id + - model_id (str): model id + """ + model = {} + model["synset_id"] = self.synset_ids[idx] + model["model_id"] = self.model_ids[idx] + return model + + def render( + self, idx: int = 0, shader_type=HardPhongShader, device="cpu", **kwargs + ) -> torch.Tensor: + """ + Renders a model by the given index. + + Args: + idx: The index of model to be rendered in the dataset. + shader_type: select shading. Valid options include HardPhongShader (default), + SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader, + SoftSilhouetteShader. + device: torch.device on which the tensors should be located. + **kwargs: Accepts any of the kwargs that the renderer supports. + + Returns: + Rendered image of shape (1, H, W, 3). + """ + + model = self.__getitem__(idx) + verts, faces = model["verts"], model["faces"] + verts_rgb = torch.ones_like(verts, device=device)[None] + mesh = Meshes( + verts=[verts.to(device)], + faces=[faces.to(device)], + textures=Textures(verts_rgb=verts_rgb.to(device)), + ) + cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device) + renderer = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=kwargs.get("raster_settings", RasterizationSettings()), + ), + shader=shader_type( + device=device, + cameras=cameras, + lights=kwargs.get("lights", PointLights()).to(device), + ), + ) + return renderer(mesh) diff --git a/tests/data/test_shapenet_core_render_piano.png b/tests/data/test_shapenet_core_render_piano.png new file mode 100644 index 000000000..fc7524c89 Binary files /dev/null and b/tests/data/test_shapenet_core_render_piano.png differ diff --git a/tests/test_shapenet_core.py b/tests/test_shapenet_core.py index ff623f783..db92c83a7 100644 --- a/tests/test_shapenet_core.py +++ b/tests/test_shapenet_core.py @@ -6,17 +6,32 @@ import random import unittest import warnings +from pathlib import Path +import numpy as np import torch -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, load_rgb_image +from PIL import Image from pytorch3d.datasets import ShapeNetCore +from pytorch3d.renderer import ( + OpenGLPerspectiveCameras, + PointLights, + RasterizationSettings, + look_at_view_transform, +) SHAPENET_PATH = None +# If DEBUG=True, save out images generated in the tests for debugging. +# All saved images have prefix DEBUG_ +DEBUG = False +DATA_DIR = Path(__file__).resolve().parent / "data" class TestShapenetCore(TestCaseMixin, unittest.TestCase): def test_load_shapenet_core(self): + # Setup + device = torch.device("cuda:0") # The ShapeNet dataset is not provided in the repo. # Download this separately and update the `shapenet_path` @@ -31,7 +46,7 @@ def test_load_shapenet_core(self): warnings.warn(msg) return True - # Try load ShapeNetCore with an invalid version number and catch error. + # Try loading ShapeNetCore with an invalid version number and catch error. with self.assertRaises(ValueError) as err: ShapeNetCore(SHAPENET_PATH, version=3) self.assertTrue("Version number must be either 1 or 2." in str(err.exception)) @@ -93,3 +108,31 @@ def test_load_shapenet_core(self): for offset in subset_offsets ] self.assertEqual(len(shapenet_subset), sum(subset_model_nums)) + + # Render the first image in the piano category. + R, T = look_at_view_transform(1.0, 1.0, 90) + piano_dataset = ShapeNetCore(SHAPENET_PATH, synsets=["piano"]) + + cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device) + raster_settings = RasterizationSettings(image_size=512) + lights = PointLights( + location=torch.tensor([0.0, 1.0, -2.0], device=device)[None], + # TODO: debug the source of the discrepancy in two images when rendering on GPU. + diffuse_color=((0, 0, 0),), + specular_color=((0, 0, 0),), + device=device, + ) + images = piano_dataset.render( + 0, + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + rgb = images[0, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / "DEBUG_shapenet_core_render_piano.png" + ) + image_ref = load_rgb_image("test_shapenet_core_render_piano.png", DATA_DIR) + self.assertClose(rgb, image_ref, atol=0.05)