Skip to content

Commit

Permalink
allow dots in param_groups
Browse files Browse the repository at this point in the history
Summary:
Allow a module's param_group member to specify overrides to the param groups of its members or their members.
Also logging for param group assignments.

This allows defining `params.basis_matrix` in the param_groups of a voxel_grid.

Reviewed By: shapovalov

Differential Revision: D41080667

fbshipit-source-id: 49f3b0e5b36e496f78701db0699cbb8a7e20c51e
  • Loading branch information
bottler authored and facebook-github-bot committed Nov 7, 2022
1 parent a1f2ded commit 7be49bf
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 23 deletions.
55 changes: 34 additions & 21 deletions projects/implicitron_trainer/impl/optimizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,16 @@ def _get_param_groups(
at the module level. Possible keys, including the "self" key do not have to
be defined. By default all parameters have the learning rate defined in the
optimizer. This can be overridden by setting the parameter group in `param_groups`
member of a specific module, it can be overridden at the:
- module level with “self” key, all the parameters and child
module's parameters will inherit it
- member level, which is the same as if the `param_groups` in that
member has key=“self” and value equal to that parameter group.
member of a specific module. Values are a parameter group name. The keys
specify what parameters will be affected as follows:
- “self”: All the parameters of the module and its child modules
- name of a parameter: A parameter with that name.
- name of a module member: All the parameters of the module and its
child modules.
This is useful if members do not have `param_groups`, for
example torch.nn.Linear.
- parameter level, only parameter with the same name as the key
will have it.
- <name of module member>.<something>: recursive. Same as if <something>
was used in param_groups of that submodule/member.
Args:
module: module from which to extract the parameters and their parameter
Expand All @@ -277,7 +278,18 @@ def _get_param_groups(

param_groups = defaultdict(list)

def traverse(module, default_group):
def traverse(module, default_group: str, mapping: Dict[str, str]) -> None:
"""
Visitor for module to assign its parameters to the relevant member of
param_groups.
Args:
module: the module being visited in a depth-first search
default_group: the param group to assign parameters to unless
otherwise overriden.
mapping: known mappings of parameters to groups for this module,
destructively modified by this function.
"""
# If key self is defined in param_groups then chenge the default param
# group for all parameters and children in the module.
if hasattr(module, "param_groups") and "self" in module.param_groups:
Expand All @@ -286,25 +298,26 @@ def traverse(module, default_group):
# Collect all the parameters that are directly inside the `module`,
# they will be in the default param group if they don't have
# defined group.
if hasattr(module, "param_groups"):
mapping.update(module.param_groups)

for name, param in module.named_parameters(recurse=False):
if param.requires_grad:
if hasattr(module, "param_groups") and name in module.param_groups:
param_groups[module.param_groups[name]].append(param)
else:
param_groups[default_group].append(param)
group_name = mapping.get(name, default_group)
logger.info(f"Assigning {name} to param_group {group_name}")
param_groups[group_name].append(param)

# If children have defined default param group then use it else pass
# own default.
for child_name, child in module.named_children():
if (
hasattr(module, "param_groups")
and child_name in module.param_groups
):
traverse(child, module.param_groups[child_name])
else:
traverse(child, default_group)

traverse(module, "default")
mapping_to_add = {
name[len(child_name) + 1 :]: group
for name, group in mapping.items()
if name.startswith(child_name + ".")
}
traverse(child, mapping.get(child_name, default_group), mapping_to_add)

traverse(module, "default", {})
return param_groups

def _get_group_learning_rate(self, group_name: str) -> float:
Expand Down
25 changes: 23 additions & 2 deletions projects/implicitron_trainer/tests/test_optimizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import unittest

import torch
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args

from ..impl.optimizer_factory import ImplicitronOptimizerFactory
from ..impl.optimizer_factory import (
ImplicitronOptimizerFactory,
logger as factory_logger,
)

internal = os.environ.get("FB_TEST", False)

Expand All @@ -23,9 +27,17 @@ def setUp(self) -> None:
def _get_param_groups(self, model):
default_cfg = get_default_args(ImplicitronOptimizerFactory)
factory = ImplicitronOptimizerFactory(default_cfg)
return factory._get_param_groups(model)
oldlevel = factory_logger.level
factory_logger.setLevel(logging.ERROR)
out = factory._get_param_groups(model)
factory_logger.setLevel(oldlevel)
return out

def _assert_allin(self, a, param_groups, key):
"""
Asserts that all the parameters in a are in the group
named by key.
"""
with self.subTest(f"Testing key {key}"):
b = param_groups[key]
for el in a:
Expand Down Expand Up @@ -83,6 +95,15 @@ def test_no_param_groups_defined(self):
param_groups = self._get_param_groups(root)
self._assert_allin([pa, pb, pc], param_groups, "default")

def test_double_dotted(self):
pa, pb = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(2)]
na = Node(params=[pa, pb])
nb = Node(children=[na])
root = Node(children=[nb], param_groups={"m0.m0.p0": "X", "m0.m0": "Y"})
param_groups = self._get_param_groups(root)
self._assert_allin([pa], param_groups, "X")
self._assert_allin([pb], param_groups, "Y")

def test_tree_param_groups_defined(self):
"""
Test generic tree assignment.
Expand Down

0 comments on commit 7be49bf

Please sign in to comment.