-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Adding a renderer to ShapeNetCore (Note that the lights are currently turned off for the test; will investigate why lighting causes instability in rendering) Reviewed By: nikhilaravi Differential Revision: D22102673 fbshipit-source-id: a704756a1e93b61d5a879f0e5ee14ebcb0df49d7
- Loading branch information
1 parent
09c1762
commit 358e211
Showing
5 changed files
with
194 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
from .shapenet_core import ShapeNetCore | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
from typing import Dict | ||
|
||
import torch | ||
from pytorch3d.renderer import ( | ||
HardPhongShader, | ||
MeshRasterizer, | ||
MeshRenderer, | ||
OpenGLPerspectiveCameras, | ||
PointLights, | ||
RasterizationSettings, | ||
) | ||
from pytorch3d.structures import Meshes, Textures | ||
|
||
|
||
class ShapeNetBase(torch.utils.data.Dataset): | ||
""" | ||
'ShapeNetBase' implements a base Dataset for ShapeNet and R2N2 with helper methods. | ||
It is not intended to be used on its own as a Dataset for a Dataloader. Both __init__ | ||
and __getitem__ need to be implemented. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Set up lists of synset_ids and model_ids. | ||
""" | ||
self.synset_ids = [] | ||
self.model_ids = [] | ||
|
||
def __len__(self): | ||
""" | ||
Return number of total models in the loaded dataset. | ||
""" | ||
return len(self.model_ids) | ||
|
||
def __getitem__(self, idx) -> Dict: | ||
""" | ||
Read a model by the given index. Need to be implemented for every child class | ||
of ShapeNetBase. | ||
Args: | ||
idx: The idx of the model to be retrieved in the dataset. | ||
Returns: | ||
dictionary containing information about the model. | ||
""" | ||
raise NotImplementedError( | ||
"__getitem__ should be implemented in the child class of ShapeNetBase" | ||
) | ||
|
||
def _get_item_ids(self, idx) -> Dict: | ||
""" | ||
Read a model by the given index. | ||
Args: | ||
idx: The idx of the model to be retrieved in the dataset. | ||
Returns: | ||
dictionary with following keys: | ||
- synset_id (str): synset id | ||
- model_id (str): model id | ||
""" | ||
model = {} | ||
model["synset_id"] = self.synset_ids[idx] | ||
model["model_id"] = self.model_ids[idx] | ||
return model | ||
|
||
def render( | ||
self, idx: int = 0, shader_type=HardPhongShader, device="cpu", **kwargs | ||
) -> torch.Tensor: | ||
""" | ||
Renders a model by the given index. | ||
Args: | ||
idx: The index of model to be rendered in the dataset. | ||
shader_type: select shading. Valid options include HardPhongShader (default), | ||
SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader, | ||
SoftSilhouetteShader. | ||
device: torch.device on which the tensors should be located. | ||
**kwargs: Accepts any of the kwargs that the renderer supports. | ||
Returns: | ||
Rendered image of shape (1, H, W, 3). | ||
""" | ||
|
||
model = self.__getitem__(idx) | ||
verts, faces = model["verts"], model["faces"] | ||
verts_rgb = torch.ones_like(verts, device=device)[None] | ||
mesh = Meshes( | ||
verts=[verts.to(device)], | ||
faces=[faces.to(device)], | ||
textures=Textures(verts_rgb=verts_rgb.to(device)), | ||
) | ||
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device) | ||
renderer = MeshRenderer( | ||
rasterizer=MeshRasterizer( | ||
cameras=cameras, | ||
raster_settings=kwargs.get("raster_settings", RasterizationSettings()), | ||
), | ||
shader=shader_type( | ||
device=device, | ||
cameras=cameras, | ||
lights=kwargs.get("lights", PointLights()).to(device), | ||
), | ||
) | ||
return renderer(mesh) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters