Skip to content

Commit

Permalink
parameter specific optimizer configurations
Browse files Browse the repository at this point in the history
parsing of regex does not work correctly yet
  • Loading branch information
AdrianKs committed Oct 8, 2020
1 parent d539199 commit d275419
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 3 deletions.
14 changes: 14 additions & 0 deletions kge/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,20 @@ train:
optimizer_args:
+++: +++

# Override specific optimizer options for parameters matched with regex expressions.
# Allows for example to define a separate learning rate for all relation parameters
# Keys here are regex expressions. Sub-keys are the override optimizer args
# Additional regex expressions can be added
# example keys to set:
# .*_entity_embedder\..*
# lr: 0.1
# .*_relation_embedder\..*
# lr: 0.2
optimizer_args_override:
#.*_entity_embedder\..*:
# lr: 0.1
+++: +++

# Learning rate scheduler to use. Any scheduler from torch.optim.lr_scheduler
# can be used (e.g., ReduceLROnPlateau). When left empty, no LR scheduler is
# used.
Expand Down
49 changes: 46 additions & 3 deletions kge/util/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from kge import Config, Configurable
import torch.optim
from torch.optim.lr_scheduler import _LRScheduler
import re
from operator import or_
from functools import reduce


class KgeOptimizer:
Expand All @@ -12,8 +15,8 @@ def create(config, model):
try:
optimizer = getattr(torch.optim, config.get("train.optimizer"))
return optimizer(
[p for p in model.parameters() if p.requires_grad],
**config.get("train.optimizer_args"),
KgeOptimizer._get_parameter_specific_options(config, model),
**config.get("train.optimizer_args"), # default optimizer options
)
except AttributeError:
# perhaps TODO: try class with specified name -> extensibility
Expand All @@ -22,6 +25,47 @@ def create(config, model):
f"Please specify an optimizer provided in torch.optim"
)

@staticmethod
def _get_parameter_specific_options(config, model):
named_parameters = dict(model.named_parameters())
override_parameters = config.get("train.optimizer_args_override")
parameter_names_per_search = dict()
# filter named parameters by regex string
for regex_string in override_parameters.keys():
search_pattern = re.compile(regex_string)
filtered_named_parameters = set(
filter(search_pattern.match, named_parameters.keys())
)
parameter_names_per_search[regex_string] = filtered_named_parameters
# check if something was matched by multiple strings
parameter_values = list(parameter_names_per_search.values())
for i, (regex_string, param) in enumerate(parameter_names_per_search.items()):
for j in range(i + 1, len(parameter_names_per_search)):
intersection = set.intersection(param, parameter_values[j])
if len(intersection) > 0:
raise ValueError(
f"The parameters {intersection}, were matched by the override "
f"key {regex_string} and {list(parameter_names_per_search.keys())[j]}"
)
# now we need to create a list like [{params: [parameters], options},..]
for regex_string, params in parameter_names_per_search.items():
override_parameters[regex_string]["params"] = [
named_parameters[param] for param in params
]
resulting_parameters = list(override_parameters.values())
# we still need the unmatched parameters...
default_parameter_names = set.difference(
set(named_parameters.keys()),
reduce(or_, list(parameter_names_per_search.values())),
)
resulting_parameters.extend(
[
{"params": named_parameters[default_parameter_name]}
for default_parameter_name in default_parameter_names
]
)
return resulting_parameters


class KgeLRScheduler(Configurable):
""" Wraps torch learning rate (LR) schedulers """
Expand Down Expand Up @@ -62,7 +106,6 @@ def __init__(self, config: Config, optimizer):
).format(name, args, e)
)


def step(self, metric=None):
if self._lr_scheduler is None:
return
Expand Down

0 comments on commit d275419

Please sign in to comment.