diff --git a/kge/config-default.yaml b/kge/config-default.yaml index 352bd98b1..813fc36e3 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -216,17 +216,6 @@ train: optimizer_args: +++: +++ - # Optimizer used for training relations. - # When left empty don't use separate optimizer for relations - relation_optimizer: "" - - # Additional arguments for the relation optimizer. Arbitrary key-value pairs can be - # added here and will be passed along to the relation optimizer. E.g., use entry lr:0.1 - # to set the learning rate to 0.1. - # only used of relation_optimizer is not empty - relation_optimizer_args: - +++: +++ - # Learning rate scheduler to use. Any scheduler from torch.optim.lr_scheduler # can be used (e.g., ReduceLROnPlateau). When left empty, no LR scheduler is # used. diff --git a/kge/job/train.py b/kge/job/train.py index 3fbb0682c..74a71283b 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -65,9 +65,8 @@ def __init__( self.model: KgeModel = KgeModel.create(config, dataset) else: self.model: KgeModel = model - self.optimizer, self.relation_optimizer = KgeOptimizer.create(config, self.model) + self.optimizer = KgeOptimizer.create(config, self.model) self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) - self.relation_kge_lr_scheduler = KgeLRScheduler(config, self.relation_optimizer) self.loss = KgeLoss.create(config) self.abort_on_nan: bool = config.get("train.abort_on_nan") self.batch_size: int = config.get("train.batch_size") @@ -196,10 +195,8 @@ def _run(self) -> None: # metric-based scheduler step self.kge_lr_scheduler.step(trace_entry[metric_name]) - self.relation_kge_lr_scheduler.step(trace_entry[metric_name]) else: self.kge_lr_scheduler.step() - self.relation_kge_lr_scheduler.step() # create checkpoint and delete old one, if necessary self.save(self.config.checkpoint_file(self.epoch)) @@ -253,11 +250,8 @@ def save_to(self, checkpoint: Dict) -> Dict: "model": self.model.save(), "optimizer_state_dict": self.optimizer.state_dict(), "lr_scheduler_state_dict": self.kge_lr_scheduler.state_dict(), - "relation_lr_scheduler_state_dict": self.relation_kge_lr_scheduler.state_dict(), "job_id": self.job_id, } - if self.relation_optimizer is not None: - train_checkpoint["relation_optimizer_state_dict"] = self.relation_optimizer.state_dict() train_checkpoint = self.config.save_to(train_checkpoint) checkpoint.update(train_checkpoint) return checkpoint @@ -266,13 +260,9 @@ def _load(self, checkpoint: Dict) -> str: if checkpoint["type"] != "train": raise ValueError("Training can only be continued on trained checkpoints") self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - if self.relation_optimizer is not None and "relation_optimizer_state_dict" in checkpoint: - self.relation_optimizer.load_state_dict(checkpoint["relation_optimizer_state_dict"]) if "lr_scheduler_state_dict" in checkpoint: # new format self.kge_lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) - if "relation_lr_scheduler_state_dict" in checkpoint: - self.relation_kge_lr_scheduler.load_state_dict(checkpoint["relation_lr_scheduler_state_dict"]) self.epoch = checkpoint["epoch"] self.valid_trace = checkpoint["valid_trace"] self.model.train() @@ -301,10 +291,6 @@ def run_epoch(self) -> Dict[str, Any]: size=self.num_examples, lr=[group["lr"] for group in self.optimizer.param_groups], ) - if self.relation_optimizer is not None: - self.current_trace["epoch"]["relation_lr"] = [ - group["lr"] for group in self.relation_optimizer.param_groups - ] # run pre-epoch hooks (may modify trace) for f in self.pre_epoch_hooks: @@ -332,10 +318,6 @@ def run_epoch(self) -> Dict[str, Any]: "batches": len(self.loader), "lr": [group["lr"] for group in self.optimizer.param_groups], } - if self.relation_optimizer is not None: - self.current_trace["batch"]["relation_lr"] = [ - group["lr"] for group in self.relation_optimizer.param_groups - ] # run the pre-batch hooks (may update the trace) for f in self.pre_batch_hooks: @@ -439,8 +421,6 @@ def run_epoch(self) -> Dict[str, Any]: # update parameters batch_optimizer_time = -time.time() self.optimizer.step() - if self.relation_optimizer is not None: - self.relation_optimizer.step() batch_optimizer_time += time.time() # update batch trace with the results diff --git a/kge/util/optimizer.py b/kge/util/optimizer.py index dc421a9c8..5893edd41 100644 --- a/kge/util/optimizer.py +++ b/kge/util/optimizer.py @@ -11,25 +11,10 @@ def create(config, model): """ Factory method for optimizer creation """ try: optimizer = getattr(torch.optim, config.get("train.optimizer")) - relation_parameters = set() - relation_optimizer = None - if config.get("train.relation_optimizer"): - relation_optimizer = getattr( - torch.optim, config.get("train.relation_optimizer") - ) - relation_parameters = set( - p for p in model.get_p_embedder().parameters() if p.requires_grad - ) - relation_optimizer = relation_optimizer( - relation_parameters, **config.get("train.relation_optimizer_args") - ) - parameters = [ - p - for p in model.parameters() - if p.requires_grad and p not in relation_parameters - ] - optimizer = optimizer(parameters, **config.get("train.optimizer_args")) - return optimizer, relation_optimizer + return optimizer( + [p for p in model.parameters() if p.requires_grad], + **config.get("train.optimizer_args"), + ) except AttributeError: # perhaps TODO: try class with specified name -> extensibility raise ValueError( @@ -46,7 +31,7 @@ def __init__(self, config: Config, optimizer): name = config.get("train.lr_scheduler") args = config.get("train.lr_scheduler_args") self._lr_scheduler: _LRScheduler = None - if name != "" and optimizer is not None: + if name != "": # check for consistency of metric-based scheduler self._metric_based = name in ["ReduceLROnPlateau"] if self._metric_based: @@ -77,6 +62,7 @@ def __init__(self, config: Config, optimizer): ).format(name, args, e) ) + def step(self, metric=None): if self._lr_scheduler is None: return