From fe5bdb2fb501199e2be915b8f65859a970bbfd60 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Tue, 18 Oct 2022 15:58:18 -0700 Subject: [PATCH] different learning rate for different parts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Adds the ability to have different learning rates for different parts of the model. The trainable parts of the implicitron have a new member param_groups: dictionary where keys are names of individual parameters, or module’s members and values are the parameter group where the parameter/member will be sorted to. "self" key is used to denote the parameter group at the module level. Possible keys, including the "self" key do not have to be defined. By default all parameters are put into "default" parameter group and have the learning rate defined in the optimizer, it can be overriden at the: - module level with “self” key, all the parameters and child module s parameters will be put to that parameter group - member level, which is the same as if the `param_groups` in that member has key=“self” and value equal to that parameter group. This is useful if members do not have `param_groups`, for example torch.nn.Linear. - parameter level, parameter with the same name as the key will be put to that parameter group. And in the optimizer factory, parameters and their learning rates are recursively gathered. Reviewed By: shapovalov Differential Revision: D40145802 fbshipit-source-id: 631c02b8d79ee1c0eb4c31e6e42dbd3d2882078a --- .../impl/optimizer_factory.py | 96 ++++++++++- .../implicitron_trainer/tests/experiment.yaml | 1 + .../tests/test_optimizer_factory.py | 162 ++++++++++++++++++ .../implicit_function/decoding_functions.py | 22 ++- .../models/implicit_function/voxel_grid.py | 16 ++ tests/implicitron/test_voxel_grids.py | 1 - 6 files changed, 293 insertions(+), 5 deletions(-) create mode 100644 projects/implicitron_trainer/tests/test_optimizer_factory.py diff --git a/projects/implicitron_trainer/impl/optimizer_factory.py b/projects/implicitron_trainer/impl/optimizer_factory.py index 9e4a52275..8cd75884a 100644 --- a/projects/implicitron_trainer/impl/optimizer_factory.py +++ b/projects/implicitron_trainer/impl/optimizer_factory.py @@ -7,7 +7,9 @@ import inspect import logging import os -from typing import Any, Dict, Optional, Tuple +from collections import defaultdict +from dataclasses import field +from typing import Any, Dict, List, Optional, Tuple import torch.optim @@ -64,6 +66,12 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase): weight_decay: The optimizer weight_decay (L2 penalty on model weights). foreach: Whether to use new "foreach" implementation of optimizer where available (e.g. requires PyTorch 1.12.0 for Adam) + group_learning_rates: Parameters or modules can be assigned to parameter + groups. This dictionary has names of those parameter groups as keys + and learning rates as values. All parameter group names have to be + defined in this dictionary. Parameters which do not have predefined + parameter group are put into "default" parameter group which has + `lr` as its learning rate. """ betas: Tuple[float, ...] = (0.9, 0.999) @@ -78,6 +86,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase): linear_exponential_lr_milestone: int = 200 linear_exponential_start_gamma: float = 0.1 foreach: Optional[bool] = True + group_learning_rates: Dict[str, float] = field(default_factory=lambda: {}) def __post_init__(self): run_auto_creation(self) @@ -115,8 +124,10 @@ def __call__( # pyre-ignore[29] p_groups = model._get_param_groups(self.lr, wd=self.weight_decay) else: - allprm = [prm for prm in model.parameters() if prm.requires_grad] - p_groups = [{"params": allprm, "lr": self.lr}] + p_groups = [ + {"params": params, "lr": self._get_group_learning_rate(group)} + for group, params in self._get_param_groups(model).items() + ] # Intialize the optimizer optimizer_kwargs: Dict[str, Any] = { @@ -233,3 +244,82 @@ def _get_optimizer_state( else: raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.") return optimizer_state + + def _get_param_groups( + self, module: torch.nn.Module + ) -> Dict[str, List[torch.nn.Parameter]]: + """ + Recursively visits all the modules inside the `module` and sorts all the + parameters in parameter groups. + + Uses `param_groups` dictionary member, where keys are names of individual + parameters or module members and values are the names of the parameter groups + for those parameters or members. "self" key is used to denote the parameter groups + at the module level. Possible keys, including the "self" key do not have to + be defined. By default all parameters have the learning rate defined in the + optimizer. This can be overridden by setting the parameter group in `param_groups` + member of a specific module, it can be overridden at the: + - module level with “self” key, all the parameters and child + module's parameters will inherit it + - member level, which is the same as if the `param_groups` in that + member has key=“self” and value equal to that parameter group. + This is useful if members do not have `param_groups`, for + example torch.nn.Linear. + - parameter level, only parameter with the same name as the key + will have it. + + Args: + module: module from which to extract the parameters and their parameter + groups + Returns: + dictionary with parameter groups as keys and lists of parameters as values + """ + + param_groups = defaultdict(list) + + def traverse(module, default_group): + # If key self is defined in param_groups then chenge the default param + # group for all parameters and children in the module. + if hasattr(module, "param_groups") and "self" in module.param_groups: + default_group = module.param_groups["self"] + + # Collect all the parameters that are directly inside the `module`, + # they will be in the default param group if they don't have + # defined group. + for name, param in module.named_parameters(recurse=False): + if param.requires_grad: + if hasattr(module, "param_groups") and name in module.param_groups: + param_groups[module.param_groups[name]].append(param) + else: + param_groups[default_group].append(param) + + # If children have defined default param group then use it else pass + # own default. + for child_name, child in module.named_children(): + if ( + hasattr(module, "param_groups") + and child_name in module.param_groups + ): + traverse(child, module.param_groups[child_name]) + else: + traverse(child, default_group) + + traverse(module, "default") + return param_groups + + def _get_group_learning_rate(self, group_name: str) -> float: + """ + Wraps the `group_learning_rates` dictionary providing errors and returns + `self.lr` for "default" group_name. + + Args: + group_name: a string representing the name of the group + Returns: + learning rate for a specific group + """ + if group_name == "default": + return self.lr + lr = self.group_learning_rates.get(group_name, None) + if lr is None: + raise ValueError(f"no learning rate given for group {group_name}") + return lr diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index d6b6beed8..f9e6e3296 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -409,6 +409,7 @@ optimizer_factory_ImplicitronOptimizerFactory_args: linear_exponential_lr_milestone: 200 linear_exponential_start_gamma: 0.1 foreach: true + group_learning_rates: {} training_loop_ImplicitronTrainingLoop_args: evaluator_class_type: ImplicitronEvaluator evaluator_ImplicitronEvaluator_args: diff --git a/projects/implicitron_trainer/tests/test_optimizer_factory.py b/projects/implicitron_trainer/tests/test_optimizer_factory.py new file mode 100644 index 000000000..23cc5aad0 --- /dev/null +++ b/projects/implicitron_trainer/tests/test_optimizer_factory.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest + +import torch +from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args + +from ..impl.optimizer_factory import ImplicitronOptimizerFactory + +internal = os.environ.get("FB_TEST", False) + + +class TestOptimizerFactory(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(42) + expand_args_fields(ImplicitronOptimizerFactory) + + def _get_param_groups(self, model): + default_cfg = get_default_args(ImplicitronOptimizerFactory) + factory = ImplicitronOptimizerFactory(default_cfg) + return factory._get_param_groups(model) + + def _assert_allin(self, a, param_groups, key): + with self.subTest(f"Testing key {key}"): + b = param_groups[key] + for el in a: + if el not in b: + raise ValueError( + f"Element {el}\n\n from:\n\n {a}\n\n not in:\n\n {b}\n\n." + + f" Full param groups = \n\n{param_groups}" + ) + for el in b: + if el not in a: + raise ValueError( + f"Element {el}\n\n from:\n\n {b}\n\n not in:\n\n {a}\n\n." + + f" Full param groups = \n\n{param_groups}" + ) + + def test_default_param_group_assignment(self): + pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)] + na, nb = Node(params=[pa]), Node(params=[pb]) + root = Node(children=[na, nb], params=[pc]) + param_groups = self._get_param_groups(root) + self._assert_allin([pa, pb, pc], param_groups, "default") + + def test_member_overrides_default_param_group_assignment(self): + pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)] + na, nb = Node(params=[pa]), Node(params=[pb]) + root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb"}) + param_groups = self._get_param_groups(root) + self._assert_allin([pa, pc], param_groups, "default") + self._assert_allin([pb], param_groups, "pb") + + def test_self_overrides_member_param_group_assignment(self): + pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)] + na, nb = Node(params=[pa]), Node(params=[pb], param_groups={"self": "pb_self"}) + root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"}) + param_groups = self._get_param_groups(root) + self._assert_allin([pa, pc], param_groups, "default") + self._assert_allin([pb], param_groups, "pb_self") + assert len(param_groups["pb_member"]) == 0, param_groups + + def test_param_overrides_self_param_group_assignment(self): + pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)] + na, nb = Node(params=[pa]), Node( + params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"} + ) + root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"}) + param_groups = self._get_param_groups(root) + self._assert_allin([pa, pc], param_groups, "default") + self._assert_allin([pb], param_groups, "pb_self") + assert len(param_groups["pb_member"]) == 0, param_groups + + def test_no_param_groups_defined(self): + pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)] + na, nb = Node(params=[pa]), Node(params=[pb]) + root = Node(children=[na, nb], params=[pc]) + param_groups = self._get_param_groups(root) + self._assert_allin([pa, pb, pc], param_groups, "default") + + def test_tree_param_groups_defined(self): + """ + Test generic tree assignment. + + A0 + |--------------------------- + | | | + Bb M J- + |----- |------- + | | | | + C Ddg K Ll + |-------------- + | | | | + E4 Ff G H- + + All nodes have one parameter. Character next to the capital + letter means they have added something to their `parameter_groups`: + - small letter same as capital means self is set to that letter + - small letter different then capital means that member is set + (the one that is named like that) + - number means parameter's parameter_group is set like that + - "-" means it does not have `parameter_groups` member + """ + p = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(12)] + L = Node(params=[p[11]], param_groups={"self": "l"}) + K = Node(params=[p[10]], param_groups={}) + J = Node(params=[p[9]], param_groups=None, children=[K, L]) + M = Node(params=[p[8]], param_groups={}) + + E = Node(params=[p[4]], param_groups={"p0": "4"}) + F = Node(params=[p[5]], param_groups={"self": "f"}) + G = Node(params=[p[6]], param_groups={}) + H = Node(params=[p[7]], param_groups=None) + + D = Node( + params=[p[3]], param_groups={"self": "d", "m2": "g"}, children=[E, F, G, H] + ) + C = Node(params=[p[2]], param_groups={}) + + B = Node(params=[p[1]], param_groups={"self": "b"}, children=[C, D]) + + A = Node(params=[p[0]], param_groups={"p0": "0"}, children=[B, M, J]) + + param_groups = self._get_param_groups(A) + + # if parts of the group belong to two different categories assert is repeated + # parameter level + self._assert_allin([p[0]], param_groups, "0") + self._assert_allin([p[4]], param_groups, "4") + # self level + self._assert_allin([p[5]], param_groups, "f") + self._assert_allin([p[11]], param_groups, "l") + self._assert_allin([p[2], p[1]], param_groups, "b") + self._assert_allin([p[7], p[3]], param_groups, "d") + # member level + self._assert_allin([p[6]], param_groups, "g") + # inherit level + self._assert_allin([p[7], p[3]], param_groups, "d") + self._assert_allin([p[2], p[1]], param_groups, "b") + # default level + self._assert_allin([p[8], p[9], p[10]], param_groups, "default") + + +class Node(torch.nn.Module): + def __init__(self, children=(), params=(), param_groups=None): + super().__init__() + for i, child in enumerate(children): + self.add_module("m" + str(i), child) + for i, param in enumerate(params): + setattr(self, "p" + str(i), param) + if param_groups is not None: + self.param_groups = param_groups + + def __str__(self): + return ( + "modules:\n" + str(self._modules) + "\nparameters\n" + str(self._parameters) + ) diff --git a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py index 2713ea462..fb7494a70 100644 --- a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py +++ b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py @@ -13,9 +13,10 @@ """ import logging +from dataclasses import field from enum import Enum -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple import torch @@ -42,8 +43,27 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module): """ Decoding function is a torch.nn.Module which takes the embedding of a location in space and transforms it into the required quantity (for example density and color). + + Members: + param_groups: dictionary where keys are names of individual parameters + or module members and values are the parameter group where the + parameter/member will be sorted to. "self" key is used to denote the + parameter group at the module level. Possible keys, including the "self" key + do not have to be defined. By default all parameters are put into "default" + parameter group and have the learning rate defined in the optimizer, + it can be overridden at the: + - module level with “self” key, all the parameters and child + module's parameters will be put to that parameter group + - member level, which is the same as if the `param_groups` in that + member has key=“self” and value equal to that parameter group. + This is useful if members do not have `param_groups`, for + example torch.nn.Linear. + - parameter level, parameter with the same name as the key + will be put to that parameter group. """ + param_groups: Dict[str, str] = field(default_factory=lambda: {}) + def __post_init__(self): super().__init__() diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py index b9f8c1bf7..2135c8224 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py @@ -808,6 +808,21 @@ class VoxelGridModule(Configurable, torch.nn.Module): with mean=init_mean and std=init_std. Default 0. hold_voxel_grid_as_parameters: if True components of the underlying voxel grids will be saved as parameters and therefore be trainable. Default True. + param_groups: dictionary where keys are names of individual parameters + or module members and values are the parameter group where the + parameter/member will be sorted to. "self" key is used to denote the + parameter group at the module level. Possible keys, including the "self" key + do not have to be defined. By default all parameters are put into "default" + parameter group and have the learning rate defined in the optimizer, + it can be overridden at the: + - module level with “self” key, all the parameters and child + module's parameters will be put to that parameter group + - member level, which is the same as if the `param_groups` in that + member has key=“self” and value equal to that parameter group. + This is useful if members do not have `param_groups`, for + example torch.nn.Linear. + - parameter level, parameter with the same name as the key + will be put to that parameter group. """ voxel_grid_class_type: str = "FullResolutionVoxelGrid" @@ -820,6 +835,7 @@ class VoxelGridModule(Configurable, torch.nn.Module): init_mean: float = 0 hold_voxel_grid_as_parameters: bool = True + param_groups: Dict[str, str] = field(default_factory=lambda: {}) def __post_init__(self): super().__init__() diff --git a/tests/implicitron/test_voxel_grids.py b/tests/implicitron/test_voxel_grids.py index 6fdcdb29f..2b47de7f9 100644 --- a/tests/implicitron/test_voxel_grids.py +++ b/tests/implicitron/test_voxel_grids.py @@ -19,7 +19,6 @@ from pytorch3d.implicitron.models.implicit_function.voxel_grid import ( CPFactorizedVoxelGrid, FullResolutionVoxelGrid, - FullResolutionVoxelGridValues, VMFactorizedVoxelGrid, VoxelGridModule, )