Skip to content

Commit

Permalink
Revert "separate relation optimizer"
Browse files Browse the repository at this point in the history
This reverts commit 7e5626d.
  • Loading branch information
AdrianKs committed Oct 8, 2020
1 parent 7e5626d commit d539199
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 52 deletions.
11 changes: 0 additions & 11 deletions kge/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -216,17 +216,6 @@ train:
optimizer_args:
+++: +++

# Optimizer used for training relations.
# When left empty don't use separate optimizer for relations
relation_optimizer: ""

# Additional arguments for the relation optimizer. Arbitrary key-value pairs can be
# added here and will be passed along to the relation optimizer. E.g., use entry lr:0.1
# to set the learning rate to 0.1.
# only used of relation_optimizer is not empty
relation_optimizer_args:
+++: +++

# Learning rate scheduler to use. Any scheduler from torch.optim.lr_scheduler
# can be used (e.g., ReduceLROnPlateau). When left empty, no LR scheduler is
# used.
Expand Down
22 changes: 1 addition & 21 deletions kge/job/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ def __init__(
self.model: KgeModel = KgeModel.create(config, dataset)
else:
self.model: KgeModel = model
self.optimizer, self.relation_optimizer = KgeOptimizer.create(config, self.model)
self.optimizer = KgeOptimizer.create(config, self.model)
self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer)
self.relation_kge_lr_scheduler = KgeLRScheduler(config, self.relation_optimizer)
self.loss = KgeLoss.create(config)
self.abort_on_nan: bool = config.get("train.abort_on_nan")
self.batch_size: int = config.get("train.batch_size")
Expand Down Expand Up @@ -196,10 +195,8 @@ def _run(self) -> None:

# metric-based scheduler step
self.kge_lr_scheduler.step(trace_entry[metric_name])
self.relation_kge_lr_scheduler.step(trace_entry[metric_name])
else:
self.kge_lr_scheduler.step()
self.relation_kge_lr_scheduler.step()

# create checkpoint and delete old one, if necessary
self.save(self.config.checkpoint_file(self.epoch))
Expand Down Expand Up @@ -253,11 +250,8 @@ def save_to(self, checkpoint: Dict) -> Dict:
"model": self.model.save(),
"optimizer_state_dict": self.optimizer.state_dict(),
"lr_scheduler_state_dict": self.kge_lr_scheduler.state_dict(),
"relation_lr_scheduler_state_dict": self.relation_kge_lr_scheduler.state_dict(),
"job_id": self.job_id,
}
if self.relation_optimizer is not None:
train_checkpoint["relation_optimizer_state_dict"] = self.relation_optimizer.state_dict()
train_checkpoint = self.config.save_to(train_checkpoint)
checkpoint.update(train_checkpoint)
return checkpoint
Expand All @@ -266,13 +260,9 @@ def _load(self, checkpoint: Dict) -> str:
if checkpoint["type"] != "train":
raise ValueError("Training can only be continued on trained checkpoints")
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if self.relation_optimizer is not None and "relation_optimizer_state_dict" in checkpoint:
self.relation_optimizer.load_state_dict(checkpoint["relation_optimizer_state_dict"])
if "lr_scheduler_state_dict" in checkpoint:
# new format
self.kge_lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
if "relation_lr_scheduler_state_dict" in checkpoint:
self.relation_kge_lr_scheduler.load_state_dict(checkpoint["relation_lr_scheduler_state_dict"])
self.epoch = checkpoint["epoch"]
self.valid_trace = checkpoint["valid_trace"]
self.model.train()
Expand Down Expand Up @@ -301,10 +291,6 @@ def run_epoch(self) -> Dict[str, Any]:
size=self.num_examples,
lr=[group["lr"] for group in self.optimizer.param_groups],
)
if self.relation_optimizer is not None:
self.current_trace["epoch"]["relation_lr"] = [
group["lr"] for group in self.relation_optimizer.param_groups
]

# run pre-epoch hooks (may modify trace)
for f in self.pre_epoch_hooks:
Expand Down Expand Up @@ -332,10 +318,6 @@ def run_epoch(self) -> Dict[str, Any]:
"batches": len(self.loader),
"lr": [group["lr"] for group in self.optimizer.param_groups],
}
if self.relation_optimizer is not None:
self.current_trace["batch"]["relation_lr"] = [
group["lr"] for group in self.relation_optimizer.param_groups
]

# run the pre-batch hooks (may update the trace)
for f in self.pre_batch_hooks:
Expand Down Expand Up @@ -439,8 +421,6 @@ def run_epoch(self) -> Dict[str, Any]:
# update parameters
batch_optimizer_time = -time.time()
self.optimizer.step()
if self.relation_optimizer is not None:
self.relation_optimizer.step()
batch_optimizer_time += time.time()

# update batch trace with the results
Expand Down
26 changes: 6 additions & 20 deletions kge/util/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,10 @@ def create(config, model):
""" Factory method for optimizer creation """
try:
optimizer = getattr(torch.optim, config.get("train.optimizer"))
relation_parameters = set()
relation_optimizer = None
if config.get("train.relation_optimizer"):
relation_optimizer = getattr(
torch.optim, config.get("train.relation_optimizer")
)
relation_parameters = set(
p for p in model.get_p_embedder().parameters() if p.requires_grad
)
relation_optimizer = relation_optimizer(
relation_parameters, **config.get("train.relation_optimizer_args")
)
parameters = [
p
for p in model.parameters()
if p.requires_grad and p not in relation_parameters
]
optimizer = optimizer(parameters, **config.get("train.optimizer_args"))
return optimizer, relation_optimizer
return optimizer(
[p for p in model.parameters() if p.requires_grad],
**config.get("train.optimizer_args"),
)
except AttributeError:
# perhaps TODO: try class with specified name -> extensibility
raise ValueError(
Expand All @@ -46,7 +31,7 @@ def __init__(self, config: Config, optimizer):
name = config.get("train.lr_scheduler")
args = config.get("train.lr_scheduler_args")
self._lr_scheduler: _LRScheduler = None
if name != "" and optimizer is not None:
if name != "":
# check for consistency of metric-based scheduler
self._metric_based = name in ["ReduceLROnPlateau"]
if self._metric_based:
Expand Down Expand Up @@ -77,6 +62,7 @@ def __init__(self, config: Config, optimizer):
).format(name, args, e)
)


def step(self, metric=None):
if self._lr_scheduler is None:
return
Expand Down

0 comments on commit d539199

Please sign in to comment.