From 7be49bf46fdce44ec41fc22d29e99d614c1988d6 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Mon, 7 Nov 2022 06:41:40 -0800 Subject: [PATCH] allow dots in param_groups Summary: Allow a module's param_group member to specify overrides to the param groups of its members or their members. Also logging for param group assignments. This allows defining `params.basis_matrix` in the param_groups of a voxel_grid. Reviewed By: shapovalov Differential Revision: D41080667 fbshipit-source-id: 49f3b0e5b36e496f78701db0699cbb8a7e20c51e --- .../impl/optimizer_factory.py | 55 ++++++++++++------- .../tests/test_optimizer_factory.py | 25 ++++++++- 2 files changed, 57 insertions(+), 23 deletions(-) diff --git a/projects/implicitron_trainer/impl/optimizer_factory.py b/projects/implicitron_trainer/impl/optimizer_factory.py index 8cd75884a..aafe075bc 100644 --- a/projects/implicitron_trainer/impl/optimizer_factory.py +++ b/projects/implicitron_trainer/impl/optimizer_factory.py @@ -258,15 +258,16 @@ def _get_param_groups( at the module level. Possible keys, including the "self" key do not have to be defined. By default all parameters have the learning rate defined in the optimizer. This can be overridden by setting the parameter group in `param_groups` - member of a specific module, it can be overridden at the: - - module level with “self” key, all the parameters and child - module's parameters will inherit it - - member level, which is the same as if the `param_groups` in that - member has key=“self” and value equal to that parameter group. + member of a specific module. Values are a parameter group name. The keys + specify what parameters will be affected as follows: + - “self”: All the parameters of the module and its child modules + - name of a parameter: A parameter with that name. + - name of a module member: All the parameters of the module and its + child modules. This is useful if members do not have `param_groups`, for example torch.nn.Linear. - - parameter level, only parameter with the same name as the key - will have it. + - .: recursive. Same as if + was used in param_groups of that submodule/member. Args: module: module from which to extract the parameters and their parameter @@ -277,7 +278,18 @@ def _get_param_groups( param_groups = defaultdict(list) - def traverse(module, default_group): + def traverse(module, default_group: str, mapping: Dict[str, str]) -> None: + """ + Visitor for module to assign its parameters to the relevant member of + param_groups. + + Args: + module: the module being visited in a depth-first search + default_group: the param group to assign parameters to unless + otherwise overriden. + mapping: known mappings of parameters to groups for this module, + destructively modified by this function. + """ # If key self is defined in param_groups then chenge the default param # group for all parameters and children in the module. if hasattr(module, "param_groups") and "self" in module.param_groups: @@ -286,25 +298,26 @@ def traverse(module, default_group): # Collect all the parameters that are directly inside the `module`, # they will be in the default param group if they don't have # defined group. + if hasattr(module, "param_groups"): + mapping.update(module.param_groups) + for name, param in module.named_parameters(recurse=False): if param.requires_grad: - if hasattr(module, "param_groups") and name in module.param_groups: - param_groups[module.param_groups[name]].append(param) - else: - param_groups[default_group].append(param) + group_name = mapping.get(name, default_group) + logger.info(f"Assigning {name} to param_group {group_name}") + param_groups[group_name].append(param) # If children have defined default param group then use it else pass # own default. for child_name, child in module.named_children(): - if ( - hasattr(module, "param_groups") - and child_name in module.param_groups - ): - traverse(child, module.param_groups[child_name]) - else: - traverse(child, default_group) - - traverse(module, "default") + mapping_to_add = { + name[len(child_name) + 1 :]: group + for name, group in mapping.items() + if name.startswith(child_name + ".") + } + traverse(child, mapping.get(child_name, default_group), mapping_to_add) + + traverse(module, "default", {}) return param_groups def _get_group_learning_rate(self, group_name: str) -> float: diff --git a/projects/implicitron_trainer/tests/test_optimizer_factory.py b/projects/implicitron_trainer/tests/test_optimizer_factory.py index 23cc5aad0..ef7517fe7 100644 --- a/projects/implicitron_trainer/tests/test_optimizer_factory.py +++ b/projects/implicitron_trainer/tests/test_optimizer_factory.py @@ -4,13 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging import os import unittest import torch from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args -from ..impl.optimizer_factory import ImplicitronOptimizerFactory +from ..impl.optimizer_factory import ( + ImplicitronOptimizerFactory, + logger as factory_logger, +) internal = os.environ.get("FB_TEST", False) @@ -23,9 +27,17 @@ def setUp(self) -> None: def _get_param_groups(self, model): default_cfg = get_default_args(ImplicitronOptimizerFactory) factory = ImplicitronOptimizerFactory(default_cfg) - return factory._get_param_groups(model) + oldlevel = factory_logger.level + factory_logger.setLevel(logging.ERROR) + out = factory._get_param_groups(model) + factory_logger.setLevel(oldlevel) + return out def _assert_allin(self, a, param_groups, key): + """ + Asserts that all the parameters in a are in the group + named by key. + """ with self.subTest(f"Testing key {key}"): b = param_groups[key] for el in a: @@ -83,6 +95,15 @@ def test_no_param_groups_defined(self): param_groups = self._get_param_groups(root) self._assert_allin([pa, pb, pc], param_groups, "default") + def test_double_dotted(self): + pa, pb = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(2)] + na = Node(params=[pa, pb]) + nb = Node(children=[na]) + root = Node(children=[nb], param_groups={"m0.m0.p0": "X", "m0.m0": "Y"}) + param_groups = self._get_param_groups(root) + self._assert_allin([pa], param_groups, "X") + self._assert_allin([pb], param_groups, "Y") + def test_tree_param_groups_defined(self): """ Test generic tree assignment.