diff --git a/docs/tutorials/fit_textured_volume.ipynb b/docs/tutorials/fit_textured_volume.ipynb new file mode 100644 index 000000000..a82efb607 --- /dev/null +++ b/docs/tutorials/fit_textured_volume.ipynb @@ -0,0 +1,456 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fit a volume via raymarching\n", + "\n", + "This tutorial shows how to fit a volume given a set of views of a scene using differentiable volumetric rendering.\n", + "\n", + "More specificially, this tutorial will explain how to:\n", + "1. Create a differentiable volumetric renderer.\n", + "2. Create a Volumetric model (including how to use the `Volumes` class).\n", + "3. Fit the volume based on the images using the differentiable volumetric renderer. \n", + "4. Visualize the predicted volume." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 0. Install and Import modules\n", + "If `torch` and `pytorch3d` are not installed, run the following cell:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install torch\n", + "# import sys\n", + "# import torch\n", + "# if torch.__version__=='1.6.0+cu101' and sys.platform.startswith('linux'):\n", + "# !pip install pytorch3d\n", + "# else:\n", + "# !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import time\n", + "import json\n", + "import glob\n", + "import torch\n", + "import math\n", + "from tqdm.notebook import tqdm\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from PIL import Image\n", + "from IPython import display\n", + "\n", + "# Data structures and functions for rendering\n", + "from pytorch3d.structures import Volumes\n", + "from pytorch3d.renderer import (\n", + " FoVPerspectiveCameras, \n", + " VolumeRenderer,\n", + " NDCGridRaysampler,\n", + " EmissionAbsorptionRaymarcher\n", + ")\n", + "from pytorch3d.transforms import so3_exponential_map\n", + "\n", + "# add path for demo utils functions \n", + "sys.path.append(os.path.abspath(''))\n", + "from utils.plot_image_grid import image_grid\n", + "from utils.generate_cow_renders import generate_cow_renders\n", + "\n", + "# obtain the utilized device\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda:0\")\n", + " torch.cuda.set_device(device)\n", + "else:\n", + " device = torch.device(\"cpu\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Generate images of the scene and masks\n", + "\n", + "The following cell generates our training data.\n", + "It renders the cow mesh from the `fit_textured_mesh.ipynb` tutorial from several viewpoints and returns:\n", + "1. A batch of image and silhouette tensors that are produced by the cow mesh renderer.\n", + "2. A set of cameras corresponding to each render.\n", + "\n", + "Note: For the purpose of this tutorial, which aims at explaining the details of volumetric rendering, we do not explain how the mesh rendering, implemented in the `generate_cow_renders` function, works. Please refer to `fit_textured_mesh.ipynb` for a detailed explanation of mesh rendering." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target_cameras, target_images, target_silhouettes = generate_cow_renders(num_views=40)\n", + "print(f'Generated {len(target_images)} images/silhouettes/cameras.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Initialize the volumetric renderer\n", + "\n", + "The following initializes a volumetric renderer that emits a ray from each pixel of a target image and samples a set of uniformly-spaced points along the ray. At each ray-point, the corresponding density and color value is obtained by querying the corresponding location in the volumetric model of the scene (the model is described & instantiated in a later cell).\n", + "\n", + "The renderer is composed of a *raymarcher* and a *raysampler*.\n", + "- The *raysampler* is responsible for emiting rays from image pixels and sampling the points along them. Here, we use the `NDCGridRaysampler` which follows the standard PyTorch3D coordinate grid convention (+X from right to left; +Y from bottom to top; +Z away from the user).\n", + "- The *raymarcher* takes the densities and colors sampled along each ray and renders each ray into a color and an opacity value of the ray's source pixel. Here we use the `EmissionAbsorptionRaymarcher` which implements the standard Emission-Absorption raymarching algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# render_size describes the size of both sides of the \n", + "# rendered images in pixels. We set this to the same size\n", + "# as the target images. I.e. we render at the same\n", + "# size as the ground truth images.\n", + "render_size = target_images.shape[1]\n", + "\n", + "# Our rendered scene is centered around (0,0,0) \n", + "# and is enclosed inside a bounding box\n", + "# whose side is roughly equal to 3.0 (world units).\n", + "volume_extent_world = 3.0\n", + "\n", + "# 1) Instantiate the raysampler.\n", + "# Here, NDCGridRaysampler generates a rectangular image\n", + "# grid of rays whose coordinates follow the pytorch3d\n", + "# coordinate conventions.\n", + "# Since we use a volume of size 128^3, we sample n_pts_per_ray=150,\n", + "# which roughly corresponds to a one ray-point per voxel.\n", + "# We futher set the min_depth=0.1 since there is no surface within\n", + "# 0.1 units of any camera plane.\n", + "raysampler = NDCGridRaysampler(\n", + " image_width=render_size,\n", + " image_height=render_size,\n", + " n_pts_per_ray=150,\n", + " min_depth=0.1,\n", + " max_depth=volume_extent_world,\n", + ")\n", + "\n", + "\n", + "# 2) Instantiate the raymarcher.\n", + "# Here, we use the standard EmissionAbsorptionRaymarcher \n", + "# which marches along each ray in order to render\n", + "# each ray into a single 3D color vector \n", + "# and an opacity scalar.\n", + "raymarcher = EmissionAbsorptionRaymarcher()\n", + "\n", + "# Finally, instantiate the volumetric render\n", + "# with the raysampler and raymarcher objects.\n", + "renderer = VolumeRenderer(\n", + " raysampler=raysampler, raymarcher=raymarcher,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Initialize the volumetric model\n", + "\n", + "Next we instantiate a volumetric model of the scene. This quantizes the 3D space to cubical voxels, where each voxel is described with a 3D vector representing the voxel's RGB color and a density scalar which describes the opacity of the voxel (ranging between [0-1], the higher the more opaque).\n", + "\n", + "In order to ensure the range of densities and colors is between [0-1], we represent both volume colors and densities in the logarithmic space. During the forward function of the model, the log-space values are passed through the sigmoid function to bring the log-space values to the correct range.\n", + "\n", + "Additionally, `VolumeModel` contains the renderer object. This object stays unaltered throughout the optimization.\n", + "\n", + "In this cell we also define the `huber` loss function which computes the discrepancy between the rendered colors and masks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class VolumeModel(torch.nn.Module):\n", + " def __init__(self, renderer, volume_size=[64] * 3, voxel_size=0.1):\n", + " super().__init__()\n", + " # After evaluating torch.sigmoid(self.log_colors), we get \n", + " # densities close to zero.\n", + " self.log_densities = torch.nn.Parameter(-4.0 * torch.ones(1, *volume_size))\n", + " # After evaluating torch.sigmoid(self.log_colors), we get \n", + " # a neutral gray color everywhere.\n", + " self.log_colors = torch.nn.Parameter(torch.zeros(3, *volume_size))\n", + " self._voxel_size = voxel_size\n", + " # Store the renderer module as well.\n", + " self._renderer = renderer\n", + " \n", + " def forward(self, cameras):\n", + " batch_size = cameras.R.shape[0]\n", + "\n", + " # Convert the log-space values to the densities/colors\n", + " densities = torch.sigmoid(self.log_densities)\n", + " colors = torch.sigmoid(self.log_colors)\n", + " \n", + " # Instantiate the Volumes object, making sure\n", + " # the densities and colors are correctly\n", + " # expanded batch_size-times.\n", + " volumes = Volumes(\n", + " densities = densities[None].expand(\n", + " batch_size, *self.log_densities.shape),\n", + " features = colors[None].expand(\n", + " batch_size, *self.log_colors.shape),\n", + " voxel_size=self._voxel_size,\n", + " )\n", + " \n", + " # Given cameras and volumes, run the renderer\n", + " # and return only the first output value \n", + " # (the 2nd output is a representation of the sampled\n", + " # rays which can be omitted for our purpose).\n", + " return self._renderer(cameras=cameras, volumes=volumes)[0]\n", + " \n", + "# A helper function for evaluating the smooth L1 (huber) loss\n", + "# between the rendered silhouettes and colors.\n", + "def huber(x, y, scaling=0.1):\n", + " diff_sq = (x - y) ** 2\n", + " loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling)\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Fit the volume\n", + "\n", + "Here we carry out the volume fitting with differentiable rendering.\n", + "\n", + "In order to fit the volume, we render it from the viewpoints of the `target_cameras`\n", + "and compare the resulting renders with the observed `target_images` and `target_silhouettes`.\n", + "\n", + "The comparison is done by evaluating the mean huber (smooth-l1) error between corresponding\n", + "pairs of `target_images`/`rendered_images` and `target_silhouettes`/`rendered_silhouettes`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# First move all relevant variables to the correct device.\n", + "target_cameras = target_cameras.to(device)\n", + "target_images = target_images.to(device)\n", + "target_silhouettes = target_silhouettes.to(device)\n", + "\n", + "# Instantiate the volumetric model.\n", + "# We use a cubical volume with the size of \n", + "# one side = 128. The size of each voxel of the volume \n", + "# is set to volume_extent_world / volume_size s.t. the\n", + "# volume represents the space enclosed in a 3D bounding box\n", + "# centered at (0, 0, 0) with the size of each side equal to 3.\n", + "volume_size = 128\n", + "volume_model = VolumeModel(\n", + " renderer,\n", + " volume_size=[volume_size] * 3, \n", + " voxel_size = volume_extent_world / volume_size,\n", + ").to(device)\n", + "\n", + "# Instantiate the Adam optimizer. We set its master learning rate to 0.1.\n", + "lr = 0.1\n", + "optimizer = torch.optim.Adam(volume_model.parameters(), lr=lr)\n", + "\n", + "# We do 300 Adam iterations and sample 10 random images in each minibatch.\n", + "batch_size = 10\n", + "n_iter = 300\n", + "for iteration in range(n_iter):\n", + "\n", + " # In case we reached the last 75% of iterations,\n", + " # decrease the learning rate of the optimizer 10-fold.\n", + " if iteration == round(n_iter * 0.75):\n", + " print('Decreasing LR 10-fold ...')\n", + " optimizer = torch.optim.Adam(\n", + " volume_model.parameters(), lr=lr * 0.1\n", + " )\n", + " \n", + " # Zero the optimizer gradient.\n", + " optimizer.zero_grad()\n", + " \n", + " # Sample random batch indices.\n", + " batch_idx = torch.randperm(len(target_cameras))[:batch_size]\n", + " \n", + " # Sample the minibatch of cameras.\n", + " batch_cameras = FoVPerspectiveCameras(\n", + " R = target_cameras.R[batch_idx], \n", + " T = target_cameras.T[batch_idx], \n", + " znear = target_cameras.znear[batch_idx],\n", + " zfar = target_cameras.zfar[batch_idx],\n", + " aspect_ratio = target_cameras.aspect_ratio[batch_idx],\n", + " fov = target_cameras.fov[batch_idx],\n", + " device = device,\n", + " )\n", + " \n", + " # Evaluate the volumetric model.\n", + " rendered_images, rendered_silhouettes = volume_model(\n", + " batch_cameras\n", + " ).split([3, 1], dim=-1)\n", + " \n", + " # Compute the silhoutte error as the mean huber\n", + " # loss between the predicted masks and the\n", + " # target silhouettes.\n", + " sil_err = huber(\n", + " rendered_silhouettes[..., 0], target_silhouettes[batch_idx],\n", + " ).abs().mean()\n", + "\n", + " # Compute the color error as the mean huber\n", + " # loss between the rendered colors and the\n", + " # target ground truth images.\n", + " color_err = huber(\n", + " rendered_images, target_images[batch_idx],\n", + " ).abs().mean()\n", + " \n", + " # The optimization loss is a simple\n", + " # sum of the color and silhouette errors.\n", + " loss = color_err + sil_err \n", + " \n", + " # Print the current values of the losses.\n", + " if iteration % 10 == 0:\n", + " print(\n", + " f'Iteration {iteration:05d}:'\n", + " + f' color_err = {float(color_err):1.2e}'\n", + " + f' mask_err = {float(sil_err):1.2e}'\n", + " )\n", + " \n", + " # Take the optimization step.\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " # Visualize the renders every 40 iterations.\n", + " if iteration % 40 == 0:\n", + " # Visualize only a single randomly selected element of the batch.\n", + " im_show_idx = int(torch.randint(low=0, high=batch_size, size=(1,)))\n", + " fig, ax = plt.subplots(2, 2, figsize=(10, 10))\n", + " ax = ax.ravel()\n", + " clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy()\n", + " ax[0].imshow(clamp_and_detach(rendered_images[im_show_idx]))\n", + " ax[1].imshow(clamp_and_detach(target_images[batch_idx[im_show_idx], ..., :3]))\n", + " ax[2].imshow(clamp_and_detach(rendered_silhouettes[im_show_idx, ..., 0]))\n", + " ax[3].imshow(clamp_and_detach(target_silhouettes[batch_idx[im_show_idx]]))\n", + " for ax_, title_ in zip(\n", + " ax, \n", + " (\"rendered image\", \"target image\", \"rendered silhouette\", \"target silhouette\")\n", + " ):\n", + " ax_.grid(\"off\")\n", + " ax_.axis(\"off\")\n", + " ax_.set_title(title_)\n", + " fig.canvas.draw(); fig.show()\n", + " display.clear_output(wait=True)\n", + " display.display(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Visualizing the optimized volume\n", + "\n", + "Finally, we visualize the optimized volume by rendering from multiple viewpoints that rotate around the volume's y-axis." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_rotating_volume(volume_model, n_frames = 50):\n", + " logRs = torch.zeros(n_frames, 3, device=device)\n", + " logRs[:, 1] = torch.linspace(0.0, 2.0 * 3.14, n_frames, device=device)\n", + " Rs = so3_exponential_map(logRs)\n", + " Ts = torch.zeros(n_frames, 3, device=device)\n", + " Ts[:, 2] = 2.7\n", + " frames = []\n", + " print('Generating rotating volume ...')\n", + " for R, T in zip(tqdm(Rs), Ts):\n", + " camera = FoVPerspectiveCameras(\n", + " R=R[None], \n", + " T=T[None], \n", + " znear = target_cameras.znear[0],\n", + " zfar = target_cameras.zfar[0],\n", + " aspect_ratio = target_cameras.aspect_ratio[0],\n", + " fov = target_cameras.fov[0],\n", + " device=device,\n", + " )\n", + " frames.append(volume_model(camera)[..., :3].clamp(0.0, 1.0))\n", + " return torch.cat(frames)\n", + " \n", + "with torch.no_grad():\n", + " rotating_volume_frames = generate_rotating_volume(volume_model, n_frames=7*4)\n", + "\n", + "image_grid(rotating_volume_frames.clamp(0., 1.).cpu().numpy(), rows=4, cols=7, rgb=True, fill=True)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Conclusion\n", + "\n", + "In this tutorial, we have shown how to optimize a 3D volumetric representation of a scene such that the renders of the volume from known viewpoints match the observed images for each viewpoint. The rendering was carried out using the PyTorch3D's volumetric renderer composed of an `NDCGridRaysampler` and an `EmissionAbsorptionRaymarcher`." + ] + } + ], + "metadata": { + "bento_stylesheets": { + "bento/extensions/flow/main.css": true, + "bento/extensions/kernel_selector/main.css": true, + "bento/extensions/kernel_ui/main.css": true, + "bento/extensions/new_kernel/main.css": true, + "bento/extensions/system_usage/main.css": true, + "bento/extensions/theme/main.css": true + }, + "kernelspec": { + "display_name": "pytorch3d_etc (local)", + "language": "python", + "name": "pytorch3d_etc_local" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5+" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/tutorials/utils/generate_cow_renders.py b/docs/tutorials/utils/generate_cow_renders.py new file mode 100644 index 000000000..fc080e461 --- /dev/null +++ b/docs/tutorials/utils/generate_cow_renders.py @@ -0,0 +1,162 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import os + +import numpy as np +import torch + +# Util function for loading meshes +from pytorch3d.io import load_objs_as_meshes +from pytorch3d.renderer import ( + BlendParams, + FoVPerspectiveCameras, + MeshRasterizer, + MeshRenderer, + PointLights, + RasterizationSettings, + SoftPhongShader, + SoftSilhouetteShader, + look_at_view_transform, +) + +# create the default data directory +current_dir = os.path.dirname(os.path.realpath(__file__)) +DATA_DIR = os.path.join(current_dir, "..", "data", "cow_mesh") + + +def generate_cow_renders(num_views: int = 40, data_dir: str = DATA_DIR): + """ + This function generates `num_views` renders of a cow mesh. + The renders are generated from viewpoints sampled at uniformly distributed + azimuth intervals. The elevation is kept constant so that the camera's + vertical position coincides with the equator. + + For a more detailed explanation of this code, please refer to the + docs/tutorials/fit_textured_mesh.ipynb notebook. + + Args: + num_views: The number of generated renders. + data_dir: The folder that contains the cow mesh files. If the cow mesh + files do not exist in the folder, this function will automatically + download them. + + Returns: + cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the + images are rendered. + images: A tensor of shape `(num_views, height, width, 3)` containing + the rendered images. + silhouettes: A tensor of shape `(num_views, height, width)` containing + the rendered silhouettes. + """ + + # set the paths + + # download the cow mesh if not done before + cow_mesh_files = [ + os.path.join(data_dir, fl) for fl in ("cow.obj", "cow.mtl", "cow_texture.png") + ] + if any(not os.path.isfile(f) for f in cow_mesh_files): + os.makedirs(data_dir, exis_ok=True) + os.system( + f"wget -P {data_dir} " + + "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj" + ) + os.system( + f"wget -P {data_dir} " + + "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl" + ) + os.system( + f"wget -P {data_dir} " + + "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png" + ) + + # Setup + if torch.cuda.is_available(): + device = torch.device("cuda:0") + torch.cuda.set_device(device) + else: + device = torch.device("cpu") + + # Load obj file + obj_filename = os.path.join(data_dir, "cow.obj") + mesh = load_objs_as_meshes([obj_filename], device=device) + + # We scale normalize and center the target mesh to fit in a sphere of radius 1 + # centered at (0,0,0). (scale, center) will be used to bring the predicted mesh + # to its original center and scale. Note that normalizing the target mesh, + # speeds up the optimization but is not necessary! + verts = mesh.verts_packed() + N = verts.shape[0] + center = verts.mean(0) + scale = max((verts - center).abs().max(0)[0]) + mesh.offset_verts_(-(center.expand(N, 3))) + mesh.scale_verts_((1.0 / float(scale))) + + # Get a batch of viewing angles. + elev = torch.linspace(0, 0, num_views) # keep constant + azim = torch.linspace(-180, 180, num_views) + + # Place a point light in front of the object. As mentioned above, the front of + # the cow is facing the -z direction. + lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]]) + + # Initialize an OpenGL perspective camera that represents a batch of different + # viewing angles. All the cameras helper methods support mixed type inputs and + # broadcasting. So we can view the camera from the a distance of dist=2.7, and + # then specify elevation and azimuth angles for each viewpoint as tensors. + R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) + + # Define the settings for rasterization and shading. Here we set the output + # image to be of size 128X128. As we are rendering images for visualization + # purposes only we will set faces_per_pixel=1 and blur_radius=0.0. Refer to + # rasterize_meshes.py for explanations of these parameters. We also leave + # bin_size and max_faces_per_bin to their default values of None, which sets + # their values using huristics and ensures that the faster coarse-to-fine + # rasterization method is used. Refer to docs/notes/renderer.md for an + # explanation of the difference between naive and coarse-to-fine rasterization. + raster_settings = RasterizationSettings( + image_size=128, blur_radius=0.0, faces_per_pixel=1 + ) + + # Create a phong renderer by composing a rasterizer and a shader. The textured + # phong shader will interpolate the texture uv coordinates for each vertex, + # sample from a texture image and apply the Phong lighting model + blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0)) + renderer = MeshRenderer( + rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), + shader=SoftPhongShader( + device=device, cameras=cameras, lights=lights, blend_params=blend_params + ), + ) + + # Create a batch of meshes by repeating the cow mesh and associated textures. + # Meshes has a useful `extend` method which allows us do this very easily. + # This also extends the textures. + meshes = mesh.extend(num_views) + + # Render the cow mesh from each viewing angle + target_images = renderer(meshes, cameras=cameras, lights=lights) + + # Rasterization settings for silhouette rendering + sigma = 1e-4 + raster_settings_silhouette = RasterizationSettings( + image_size=128, blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma, faces_per_pixel=50 + ) + + # Silhouette renderer + renderer_silhouette = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras, raster_settings=raster_settings_silhouette + ), + shader=SoftSilhouetteShader(), + ) + + # Render silhouette images. The 3rd channel of the rendering output is + # the alpha/silhouette channel + silhouette_images = renderer_silhouette(meshes, cameras=cameras, lights=lights) + + # binary silhouettes + silhouette_binary = (silhouette_images[..., 3] > 1e-4).float() + + return cameras, target_images[..., :3], silhouette_binary