Skip to content

Commit

Permalink
Adding renderer for ShapeNetBase
Browse files Browse the repository at this point in the history
Summary: Adding a renderer to ShapeNetCore (Note that the lights are currently turned off for the test; will investigate why lighting causes instability in rendering)

Reviewed By: nikhilaravi

Differential Revision: D22102673

fbshipit-source-id: a704756a1e93b61d5a879f0e5ee14ebcb0df49d7
  • Loading branch information
megluyagao authored and facebook-github-bot committed Jul 14, 2020
1 parent 09c1762 commit 358e211
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 43 deletions.
1 change: 1 addition & 0 deletions pytorch3d/datasets/shapenet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from .shapenet_core import ShapeNetCore


Expand Down
82 changes: 41 additions & 41 deletions pytorch3d/datasets/shapenet/shapenet_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
import warnings
from os import path
from pathlib import Path
from typing import Dict

import torch
from pytorch3d.datasets.shapenet_base import ShapeNetBase
from pytorch3d.io import load_obj


SYNSET_DICT_DIR = Path(__file__).resolve().parent


class ShapeNetCore(torch.utils.data.Dataset):
class ShapeNetCore(ShapeNetBase):
"""
This class loads ShapeNetCore from a given directory into a Dataset object.
ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from
Expand All @@ -23,6 +24,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
def __init__(self, data_dir, synsets=None, version: int = 1):
"""
Store each object's synset id and models id from data_dir.
Args:
data_dir: Path to ShapeNetCore data.
synsets: List of synset categories to load from ShapeNetCore in the form of
Expand All @@ -38,6 +40,7 @@ def __init__(self, data_dir, synsets=None, version: int = 1):
version 1.
"""
super().__init__()
self.data_dir = data_dir
if version not in [1, 2]:
raise ValueError("Version number must be either 1 or 2.")
Expand All @@ -48,7 +51,7 @@ def __init__(self, data_dir, synsets=None, version: int = 1):
with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict:
self.synset_dict = json.load(read_dict)
# Inverse dicitonary mapping synset labels to corresponding offsets.
synset_inv = {label: offset for offset, label in self.synset_dict.items()}
self.synset_inv = {label: offset for offset, label in self.synset_dict.items()}

# If categories are specified, check if each category is in the form of either
# synset offset or synset label, and if the category exists in the given directory.
Expand All @@ -60,62 +63,61 @@ def __init__(self, data_dir, synsets=None, version: int = 1):
path.isdir(path.join(data_dir, synset))
):
synset_set.add(synset)
elif (synset in synset_inv.keys()) and (
(path.isdir(path.join(data_dir, synset_inv[synset])))
elif (synset in self.synset_inv.keys()) and (
(path.isdir(path.join(data_dir, self.synset_inv[synset])))
):
synset_set.add(synset_inv[synset])
synset_set.add(self.synset_inv[synset])
else:
msg = """Synset category %s either not part of ShapeNetCore dataset
or cannot be found in %s.""" % (
synset,
data_dir,
)
msg = (
"Synset category %s either not part of ShapeNetCore dataset "
"or cannot be found in %s."
) % (synset, data_dir)
warnings.warn(msg)
# If no category is given, load every category in the given directory.
# Ignore synset folders not included in the official mapping.
else:
synset_set = {
synset
for synset in os.listdir(data_dir)
if path.isdir(path.join(data_dir, synset))
and synset in self.synset_dict
}
for synset in synset_set:
if synset not in self.synset_dict.keys():
msg = """Synset category %s(%s) is part of ShapeNetCore ver.%s
but not found in %s.""" % (
synset,
self.synset_dict[synset],
version,
data_dir,
)
warnings.warn(msg)

# Check if there are any categories in the official mapping that are not loaded.
# Update self.synset_inv so that it only includes the loaded categories.
synset_not_present = set(self.synset_dict.keys()).difference(synset_set)
[self.synset_inv.pop(self.synset_dict[synset]) for synset in synset_not_present]

if len(synset_not_present) > 0:
msg = (
"The following categories are included in ShapeNetCore ver.%d's "
"official mapping but not found in the dataset location %s: %s"
""
) % (version, data_dir, ", ".join(synset_not_present))
warnings.warn(msg)

# Extract model_id of each object from directory names.
# Each grandchildren directory of data_dir contains an object, and the name
# of the directory is the object's model_id.
self.synset_ids = []
self.model_ids = []
for synset in synset_set:
for model in os.listdir(path.join(data_dir, synset)):
if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
msg = """ Object file not found in the model directory %s
under synset directory %s.""" % (
model,
synset,
)
msg = (
"Object file not found in the model directory %s "
"under synset directory %s."
) % (model, synset)
warnings.warn(msg)
else:
self.synset_ids.append(synset)
self.model_ids.append(model)
continue
self.synset_ids.append(synset)
self.model_ids.append(model)

def __len__(self):
"""
Return number of total models in shapenet core.
"""
return len(self.model_ids)

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Dict:
"""
Read a model by the given index.
Args:
idx: The idx of the model to be retrieved in the dataset.
Returns:
dictionary with following keys:
- verts: FloatTensor of shape (V, 3).
Expand All @@ -124,9 +126,7 @@ def __getitem__(self, idx):
- model_id (str): model id
- label (str): synset label.
"""
model = {}
model["synset_id"] = self.synset_ids[idx]
model["model_id"] = self.model_ids[idx]
model = self._get_item_ids(idx)
model_path = path.join(
self.data_dir, model["synset_id"], model["model_id"], self.model_dir
)
Expand Down
107 changes: 107 additions & 0 deletions pytorch3d/datasets/shapenet_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from typing import Dict

import torch
from pytorch3d.renderer import (
HardPhongShader,
MeshRasterizer,
MeshRenderer,
OpenGLPerspectiveCameras,
PointLights,
RasterizationSettings,
)
from pytorch3d.structures import Meshes, Textures


class ShapeNetBase(torch.utils.data.Dataset):
"""
'ShapeNetBase' implements a base Dataset for ShapeNet and R2N2 with helper methods.
It is not intended to be used on its own as a Dataset for a Dataloader. Both __init__
and __getitem__ need to be implemented.
"""

def __init__(self):
"""
Set up lists of synset_ids and model_ids.
"""
self.synset_ids = []
self.model_ids = []

def __len__(self):
"""
Return number of total models in the loaded dataset.
"""
return len(self.model_ids)

def __getitem__(self, idx) -> Dict:
"""
Read a model by the given index. Need to be implemented for every child class
of ShapeNetBase.
Args:
idx: The idx of the model to be retrieved in the dataset.
Returns:
dictionary containing information about the model.
"""
raise NotImplementedError(
"__getitem__ should be implemented in the child class of ShapeNetBase"
)

def _get_item_ids(self, idx) -> Dict:
"""
Read a model by the given index.
Args:
idx: The idx of the model to be retrieved in the dataset.
Returns:
dictionary with following keys:
- synset_id (str): synset id
- model_id (str): model id
"""
model = {}
model["synset_id"] = self.synset_ids[idx]
model["model_id"] = self.model_ids[idx]
return model

def render(
self, idx: int = 0, shader_type=HardPhongShader, device="cpu", **kwargs
) -> torch.Tensor:
"""
Renders a model by the given index.
Args:
idx: The index of model to be rendered in the dataset.
shader_type: select shading. Valid options include HardPhongShader (default),
SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader,
SoftSilhouetteShader.
device: torch.device on which the tensors should be located.
**kwargs: Accepts any of the kwargs that the renderer supports.
Returns:
Rendered image of shape (1, H, W, 3).
"""

model = self.__getitem__(idx)
verts, faces = model["verts"], model["faces"]
verts_rgb = torch.ones_like(verts, device=device)[None]
mesh = Meshes(
verts=[verts.to(device)],
faces=[faces.to(device)],
textures=Textures(verts_rgb=verts_rgb.to(device)),
)
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=kwargs.get("raster_settings", RasterizationSettings()),
),
shader=shader_type(
device=device,
cameras=cameras,
lights=kwargs.get("lights", PointLights()).to(device),
),
)
return renderer(mesh)
Binary file added tests/data/test_shapenet_core_render_piano.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 45 additions & 2 deletions tests/test_shapenet_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,32 @@
import random
import unittest
import warnings
from pathlib import Path

import numpy as np
import torch
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image
from pytorch3d.datasets import ShapeNetCore
from pytorch3d.renderer import (
OpenGLPerspectiveCameras,
PointLights,
RasterizationSettings,
look_at_view_transform,
)


SHAPENET_PATH = None
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"


class TestShapenetCore(TestCaseMixin, unittest.TestCase):
def test_load_shapenet_core(self):
# Setup
device = torch.device("cuda:0")

# The ShapeNet dataset is not provided in the repo.
# Download this separately and update the `shapenet_path`
Expand All @@ -31,7 +46,7 @@ def test_load_shapenet_core(self):
warnings.warn(msg)
return True

# Try load ShapeNetCore with an invalid version number and catch error.
# Try loading ShapeNetCore with an invalid version number and catch error.
with self.assertRaises(ValueError) as err:
ShapeNetCore(SHAPENET_PATH, version=3)
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
Expand Down Expand Up @@ -93,3 +108,31 @@ def test_load_shapenet_core(self):
for offset in subset_offsets
]
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))

# Render the first image in the piano category.
R, T = look_at_view_transform(1.0, 1.0, 90)
piano_dataset = ShapeNetCore(SHAPENET_PATH, synsets=["piano"])

cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device)
raster_settings = RasterizationSettings(image_size=512)
lights = PointLights(
location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
# TODO: debug the source of the discrepancy in two images when rendering on GPU.
diffuse_color=((0, 0, 0),),
specular_color=((0, 0, 0),),
device=device,
)
images = piano_dataset.render(
0,
device=device,
cameras=cameras,
raster_settings=raster_settings,
lights=lights,
)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_shapenet_core_render_piano.png"
)
image_ref = load_rgb_image("test_shapenet_core_render_piano.png", DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)

0 comments on commit 358e211

Please sign in to comment.