diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index 1b355f26a..4febd3894 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -41,7 +41,7 @@ Stats are logged and plotted to the file "train_stats.pdf" in the same directory. The stats are also saved as part of the checkpoint file. - Visualizations - Prredictions are plotted to a visdom server running at the + Predictions are plotted to a visdom server running at the port specified by the `visdom_server` and `visdom_port` keys in the config file. diff --git a/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py b/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py index 9393ac400..39b313d96 100644 --- a/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py +++ b/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py @@ -9,7 +9,7 @@ import warnings from collections import OrderedDict from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union import numpy as np import torch @@ -27,7 +27,9 @@ from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras from pytorch3d.vis.plotly_vis import plot_scene from tabulate import tabulate -from visdom import Visdom + +if TYPE_CHECKING: + from visdom import Visdom EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9] @@ -43,7 +45,7 @@ class _Visualizer: visdom_env: str = "eval_debug" - _viz: Visdom = field(init=False) + _viz: Optional["Visdom"] = field(init=False) def __post_init__(self): self._viz = vis_utils.get_visdom_connection() @@ -51,6 +53,8 @@ def __post_init__(self): def show_rgb( self, loss_value: float, metric_name: str, loss_mask_now: torch.Tensor ): + if self._viz is None: + return self._viz.images( torch.cat( ( @@ -68,7 +72,10 @@ def show_rgb( def show_depth( self, depth_loss: float, name_postfix: str, loss_mask_now: torch.Tensor ): - self._viz.images( + if self._viz is None: + return + viz = self._viz + viz.images( torch.cat( ( make_depth_image(self.depth_render, loss_mask_now), @@ -80,13 +87,13 @@ def show_depth( win="depth_abs" + name_postfix, opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}"}, ) - self._viz.images( + viz.images( loss_mask_now, env=self.visdom_env, win="depth_abs" + name_postfix + "_mask", opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"}, ) - self._viz.images( + viz.images( self.depth_mask, env=self.visdom_env, win="depth_abs" + name_postfix + "_maskd", @@ -126,7 +133,7 @@ def show_depth( pointcloud_max_points=10000, pointcloud_marker_size=1, ) - self._viz.plotlyplot( + viz.plotlyplot( plotlyplot, env=self.visdom_env, win=f"pcl{name_postfix}", diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index 228bbcec0..3c1715cfe 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -12,7 +12,7 @@ import math import warnings from dataclasses import field -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch import tqdm @@ -34,7 +34,9 @@ from pytorch3d.renderer import utils as rend_utils from pytorch3d.renderer.cameras import CamerasBase -from visdom import Visdom + +if TYPE_CHECKING: + from visdom import Visdom from .base_model import ImplicitronModelBase, ImplicitronRender from .feature_extractor import FeatureExtractorBase @@ -544,7 +546,7 @@ def _get_objective(self, preds) -> Optional[torch.Tensor]: def visualize( self, - viz: Visdom, + viz: Optional["Visdom"], visdom_env_imgs: str, preds: Dict[str, Any], prefix: str, @@ -559,7 +561,7 @@ def visualize( preds: predictions dict like returned by forward() prefix: prepended to the names of images """ - if not viz.check_connection(): + if viz is None or not viz.check_connection(): logger.info("no visdom server! -> skipping batch vis") return diff --git a/pytorch3d/implicitron/models/visualization/render_flyaround.py b/pytorch3d/implicitron/models/visualization/render_flyaround.py index 2c577757d..f3d868d5a 100644 --- a/pytorch3d/implicitron/models/visualization/render_flyaround.py +++ b/pytorch3d/implicitron/models/visualization/render_flyaround.py @@ -10,7 +10,7 @@ import math import os import random -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union import numpy as np import torch @@ -27,7 +27,9 @@ make_depth_image, ) from tqdm import tqdm -from visdom import Visdom + +if TYPE_CHECKING: + from visdom import Visdom logger = logging.getLogger(__name__) @@ -272,7 +274,7 @@ def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.T def _show_predictions( preds: List[Dict[str, Any]], sequence_name: str, - viz: Visdom, + viz: "Visdom", viz_env: str = "visualizer", predicted_keys: Sequence[str] = ( "images_render", @@ -318,7 +320,7 @@ def _show_predictions( def _generate_prediction_videos( preds: List[Dict[str, Any]], sequence_name: str, - viz: Optional[Visdom] = None, + viz: Optional["Visdom"] = None, viz_env: str = "visualizer", predicted_keys: Sequence[str] = ( "images_render", diff --git a/pytorch3d/implicitron/tools/stats.py b/pytorch3d/implicitron/tools/stats.py index 012ab54a3..49acfeae3 100644 --- a/pytorch3d/implicitron/tools/stats.py +++ b/pytorch3d/implicitron/tools/stats.py @@ -337,7 +337,7 @@ def plot_stats( novisdom = False viz = get_visdom_connection(server=visdom_server, port=visdom_port) - if not viz.check_connection(): + if viz is None or not viz.check_connection(): print("no visdom server! -> skipping visdom plots") novisdom = True diff --git a/pytorch3d/implicitron/tools/vis_utils.py b/pytorch3d/implicitron/tools/vis_utils.py index 585a15a30..be5eb9d91 100644 --- a/pytorch3d/implicitron/tools/vis_utils.py +++ b/pytorch3d/implicitron/tools/vis_utils.py @@ -5,10 +5,12 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING import torch -from visdom import Visdom + +if TYPE_CHECKING: + from visdom import Visdom logger = logging.getLogger(__name__) @@ -40,9 +42,9 @@ def get_visdom_env(visdom_env: str, exp_dir: str) -> str: def get_visdom_connection( server: str = "http://localhost", port: int = 8097, -) -> Visdom: +) -> Optional["Visdom"]: """ - Obtain a connection to a visdom server. + Obtain a connection to a visdom server if visdom is installed. Args: server: Server address. @@ -51,6 +53,15 @@ def get_visdom_connection( Returns: connection: The connection object. """ + try: + from visdom import Visdom + except ImportError: + logger.debug("Cannot load visdom") + return None + + if server == "None": + return None + global _viz_singleton if _viz_singleton is None: _viz_singleton = Visdom(server=server, port=port) @@ -58,7 +69,7 @@ def get_visdom_connection( def visualize_basics( - viz: Visdom, + viz: "Visdom", preds: Dict[str, Any], visdom_env_imgs: str, title: str = "",