Skip to content

Commit

Permalink
give option to create parameter groups and provide group specific opt…
Browse files Browse the repository at this point in the history
…ions.

parameters are grouped by regex expressions.
  • Loading branch information
AdrianKs committed Oct 9, 2020
1 parent d275419 commit 87c5463
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 44 deletions.
42 changes: 22 additions & 20 deletions kge/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -208,26 +208,28 @@ train:
num_workers: 0

# Optimizer used for training.
optimizer: Adagrad # sgd, adagrad, adam

# Additional arguments for the optimizer. Arbitrary key-value pairs can be
# added here and will be passed along to the optimizer. E.g., use entry lr:0.1
# to set the learning rate to 0.1.
optimizer_args:
+++: +++

# Override specific optimizer options for parameters matched with regex expressions.
# Allows for example to define a separate learning rate for all relation parameters
# Keys here are regex expressions. Sub-keys are the override optimizer args
# Additional regex expressions can be added
# example keys to set:
# .*_entity_embedder\..*
# lr: 0.1
# .*_relation_embedder\..*
# lr: 0.2
optimizer_args_override:
#.*_entity_embedder\..*:
# lr: 0.1
optimizer:
default:
type: Adagrad # sgd, adagrad, adam

# Additional arguments for the optimizer. Arbitrary key-value pairs can be
# added here and will be passed along to the optimizer. E.g., use entry lr:0.1
# to set the learning rate to 0.1.
args:
+++: +++

# Specific optimizer options for parameters matched with regex expressions can be
# overwritten. Allows for example to define a separate learning rate for all relation
# parameters.
# Example:
# optimizer:
# relation:
# regex: .*_relation_embedder.*
# args:
# lr: 0.1
# Names of child keys of optimizer will be set as parameter group name.
# Parameters are named by their variable names and can be retrieved by:
# model.named_parameters()
+++: +++

# Learning rate scheduler to use. Any scheduler from torch.optim.lr_scheduler
Expand Down
4 changes: 4 additions & 0 deletions kge/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,10 @@ def rename_value_re(key_regex, old_value, new_value):
renamed_keys.add(key)
return renamed_keys

# 09.10.20
rename_key("train.optimizer", "train.optimizer.default.type")
rename_key("train.optimizer_args", "train.optimizer.default.args")

# 30.9.2020
if "verbose" in options:
rename_key("verbose", "console.quiet")
Expand Down
74 changes: 50 additions & 24 deletions kge/util/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ class KgeOptimizer:
def create(config, model):
""" Factory method for optimizer creation """
try:
optimizer = getattr(torch.optim, config.get("train.optimizer"))
optimizer = getattr(torch.optim, config.get("train.optimizer.default.type"))
return optimizer(
KgeOptimizer._get_parameter_specific_options(config, model),
**config.get("train.optimizer_args"), # default optimizer options
KgeOptimizer._get_parameters_and_optimizer_args(config, model),
**config.get("train.optimizer.default.args"),
)
except AttributeError:
# perhaps TODO: try class with specified name -> extensibility
Expand All @@ -26,44 +26,70 @@ def create(config, model):
)

@staticmethod
def _get_parameter_specific_options(config, model):
def _get_parameters_and_optimizer_args(config, model):
"""
Group named parameters by regex strings provided with optimizer args.
Constructs a list of dictionaries of the form:
[
{
"name": name of parameter group
"params": list of parameters to optimize
# parameter specific options as for example learning rate
...
},
...
]
"""

named_parameters = dict(model.named_parameters())
override_parameters = config.get("train.optimizer_args_override")
optimizer_settings = config.get("train.optimizer")
parameter_names_per_search = dict()
# filter named parameters by regex string
for regex_string in override_parameters.keys():
search_pattern = re.compile(regex_string)
for group_name, parameter_group in optimizer_settings.items():
if group_name == "default":
continue
search_pattern = re.compile(parameter_group["regex"])
filtered_named_parameters = set(
filter(search_pattern.match, named_parameters.keys())
)
parameter_names_per_search[regex_string] = filtered_named_parameters
parameter_names_per_search[group_name] = filtered_named_parameters

# check if something was matched by multiple strings
parameter_values = list(parameter_names_per_search.values())
for i, (regex_string, param) in enumerate(parameter_names_per_search.items()):
for i, (group_name, param) in enumerate(parameter_names_per_search.items()):
for j in range(i + 1, len(parameter_names_per_search)):
intersection = set.intersection(param, parameter_values[j])
if len(intersection) > 0:
raise ValueError(
f"The parameters {intersection}, were matched by the override "
f"key {regex_string} and {list(parameter_names_per_search.keys())[j]}"
f"The parameters {intersection}, were matched by the optimizer "
f"group {group_name} and {list(parameter_names_per_search.keys())[j]}"
)
# now we need to create a list like [{params: [parameters], options},..]
for regex_string, params in parameter_names_per_search.items():
override_parameters[regex_string]["params"] = [
resulting_parameters = []
for group_name, params in parameter_names_per_search.items():
optimizer_settings[group_name]["args"]["params"] = [
named_parameters[param] for param in params
]
resulting_parameters = list(override_parameters.values())
# we still need the unmatched parameters...
default_parameter_names = set.difference(
set(named_parameters.keys()),
reduce(or_, list(parameter_names_per_search.values())),
)
resulting_parameters.extend(
[
{"params": named_parameters[default_parameter_name]}
optimizer_settings[group_name]["args"]["name"] = group_name
resulting_parameters.append(optimizer_settings[group_name]["args"])

# add unmatched parameters to default group
if len(parameter_names_per_search) > 0:
default_parameter_names = set.difference(
set(named_parameters.keys()),
reduce(or_, list(parameter_names_per_search.values())),
)
default_parameters = [
named_parameters[default_parameter_name]
for default_parameter_name in default_parameter_names
]
)
resulting_parameters.append(
{"params": default_parameters, "name": "default"}
)
else:
# no parameters matched, add everything to default group
resulting_parameters.append(
{"params": model.parameters(), "name": "default"}
)
return resulting_parameters


Expand Down

0 comments on commit 87c5463

Please sign in to comment.