Skip to content

Commit

Permalink
[Feat] added self-labeling training algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
LTluttmann committed Sep 12, 2024
1 parent 6334786 commit d0348fa
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 5 deletions.
1 change: 1 addition & 0 deletions configs/experiment/scheduling/am-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ model:
test_batch_size: 64
train_data_size: 2000
mini_batch_size: 512
max_grad_norm: 1

env:
stepwise_reward: True
1 change: 0 additions & 1 deletion configs/experiment/scheduling/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,3 @@ model:
lr_scheduler_kwargs:
gamma: 0.95
reward_scale: scale
max_grad_norm: 1
35 changes: 35 additions & 0 deletions configs/experiment/scheduling/episodic-ppo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# @package _global_

defaults:
- scheduling/base

logger:
wandb:
tags: ["hgnn-ppo", "${env.name}"]
name: "hgnn-ppo-${env.name}-${env.generator_params.num_jobs}j-${env.generator_params.num_machines}m"

# params from Song et al.
model:
_target_: rl4co.models.L2DModel
policy_kwargs:
embed_dim: 128
num_encoder_layers: 3
scaling_factor: ${scaling_factor}
max_grad_norm: 1
ppo_epochs: 3
het_emb: True
batch_size: 128
val_batch_size: 512
test_batch_size: 64
mini_batch_size: 512
# reward_scale: scale
optimizer_kwargs:
lr: 1e-4

trainer:
max_epochs: 10


env:
stepwise_reward: False
_torchrl_mode: False
1 change: 1 addition & 0 deletions configs/experiment/scheduling/gnn-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ model:
val_batch_size: 512
test_batch_size: 64
mini_batch_size: 512
max_grad_norm: 1


trainer:
Expand Down
1 change: 1 addition & 0 deletions configs/experiment/scheduling/hgnn-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ model:
val_batch_size: 512
test_batch_size: 64
mini_batch_size: 512
max_grad_norm: 1

env:
stepwise_reward: True
1 change: 1 addition & 0 deletions configs/experiment/scheduling/matnet-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ model:
val_batch_size: 512
test_batch_size: 64
mini_batch_size: 512
max_grad_norm: 1

env:
stepwise_reward: True
40 changes: 40 additions & 0 deletions configs/experiment/scheduling/sl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# @package _global_

defaults:
- scheduling/base

logger:
wandb:
tags: ["matnet-pomo", "${env.name}"]
name: "matnet-pomo-${env.name}-${env.generator_params.num_jobs}j-${env.generator_params.num_machines}m"

embed_dim: 256

model:
_target_: rl4co.models.SelfLabeling
policy:
_target_: rl4co.models.L2DPolicy4PPO
decoder:
_target_: rl4co.models.zoo.l2d.decoder.L2DDecoder
env_name: ${env.name}
embed_dim: ${embed_dim}
het_emb: True
feature_extractor:
_target_: rl4co.models.zoo.matnet.matnet_w_sa.Encoder
embed_dim: ${embed_dim}
num_heads: 8
num_layers: 4
normalization: "batch"
init_embedding:
_target_: rl4co.models.nn.env_embeddings.init.FJSPMatNetInitEmbedding
embed_dim: ${embed_dim}
scaling_factor: ${scaling_factor}
env_name: ${env.name}
embed_dim: ${embed_dim}
scaling_factor: ${scaling_factor}
het_emb: True
batch_size: 64
num_starts: 10
metrics:
val: ["reward", "max_reward"]
test: ${model.metrics.val}
2 changes: 1 addition & 1 deletion rl4co/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
NonAutoregressivePolicy,
)
from rl4co.models.common.transductive import TransductiveModel
from rl4co.models.rl import StepwisePPO
from rl4co.models.rl import SelfLabeling, StepwisePPO
from rl4co.models.rl.a2c.a2c import A2C
from rl4co.models.rl.common.base import RL4COLitModule
from rl4co.models.rl.ppo.ppo import PPO
Expand Down
1 change: 1 addition & 0 deletions rl4co/models/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from rl4co.models.rl.ppo.ppo import PPO
from rl4co.models.rl.ppo.stepwise_ppo import StepwisePPO
from rl4co.models.rl.reinforce.reinforce import REINFORCE
from rl4co.models.rl.self_supervised.self_labeling import SelfLabeling
160 changes: 160 additions & 0 deletions rl4co/models/rl/self_supervised/self_labeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import copy

from typing import Any, Union

import torch
import torch.nn as nn

from torch.nn import CrossEntropyLoss
from torchrl.data.replay_buffers import (
LazyMemmapStorage,
ListStorage,
SamplerWithoutReplacement,
TensorDictReplayBuffer,
)

from rl4co.envs.common.base import RL4COEnvBase
from rl4co.models.rl.common.base import RL4COLitModule
from rl4co.utils.ops import batchify, unbatchify
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


def make_replay_buffer(buffer_size, batch_size, device="cpu"):
if device == "cpu":
storage = LazyMemmapStorage(buffer_size, device="cpu")
prefetch = 3
else:
storage = ListStorage(buffer_size)
prefetch = None
return TensorDictReplayBuffer(
storage=storage,
batch_size=batch_size,
sampler=SamplerWithoutReplacement(drop_last=True),
pin_memory=False,
prefetch=prefetch,
)


class SelfLabeling(RL4COLitModule):
def __init__(
self,
env: RL4COEnvBase,
policy: nn.Module,
clip_range: float = 0.2, # epsilon of PPO
update_timestep: int = 1,
buffer_size: int = 100_000,
sl_epochs: int = 1, # inner epoch, K
batch_size: int = 256,
mini_batch_size: int = 256,
vf_lambda: float = 0.5, # lambda of Value function fitting
entropy_lambda: float = 0.01, # lambda of entropy bonus
max_grad_norm: float = 0.5, # max gradient norm
buffer_storage_device: str = "gpu",
metrics: dict = {
"train": ["loss", "surrogate_loss", "value_loss", "entropy"],
},
reward_scale: Union[str, int] = None,
num_starts: int = None,
**kwargs,
):
super().__init__(env, policy, metrics=metrics, batch_size=batch_size, **kwargs)

self.policy_old = copy.deepcopy(self.policy)
self.automatic_optimization = False # PPO uses custom optimization routine
self.rb = make_replay_buffer(buffer_size, mini_batch_size, buffer_storage_device)
self.sl_epochs = sl_epochs
self.max_grad_norm = max_grad_norm
self.update_timestep = update_timestep
self.mini_batch_size = mini_batch_size
self.num_starts = num_starts

def update(self, eval_td, device):
losses = []
# PPO inner epoch
for _ in range(self.sl_epochs):
for sub_td in self.rb:
sub_td = sub_td.to(device)

logprobs, _, _ = self.policy.evaluate(sub_td, return_selected=False)

criterion = CrossEntropyLoss(reduction="mean")
# compute total loss
loss = criterion(logprobs, sub_td["action"])

opt = self.optimizers()
opt.zero_grad()
self.manual_backward(loss)
if self.max_grad_norm is not None:
self.clip_gradients(
opt,
gradient_clip_val=self.max_grad_norm,
gradient_clip_algorithm="norm",
)

opt.step()
losses.append(loss)

# need eval for greedy decoding
out = self.policy.generate(eval_td, self.env, phase="val")
# add loss to metrics
out["loss"] = torch.stack(losses, dim=0)
return out

def shared_step(
self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None
):
orig_td = self.env.reset(batch)
device = orig_td.device
n_start = (
self.env.get_num_starts(orig_td)
if self.num_starts is None
else self.num_starts
)
next_td = batchify(orig_td.clone(), n_start)
td_stack = []

if phase == "train":
while not next_td["done"].all():

with torch.no_grad():
td = self.policy_old.act(next_td, self.env, phase="train")

# get next state
next_td = self.env.step(td)["next"]

# add tensordict with action, logprobs and reward information to buffer
td_stack.append(td)
# (bs * #samples, #steps)
td_stack = torch.stack(td_stack, dim=1)
# (bs, #samples, #steps)
td_stack_unbs = unbatchify(td_stack, n_start)
# (bs * #samples)
rewards = self.env.get_reward(next_td, None)
# (bs)
_, best_idx = unbatchify(rewards, n_start).max(dim=1)
td_best = td_stack_unbs.gather(
1, best_idx[:, None, None].expand(-1, 1, td_stack_unbs.size(2))
).squeeze(1)
# flatten so that every step is an experience TODO can we enhance this?
self.rb.extend(td_best.flatten())

# if iter mod x = 0 then update the policy (x = 1 in paper)
if batch_idx % self.update_timestep == 0:

out = self.update(orig_td, device)

# TODO check the details of this: if out["reward"].mean() > max_rew.mean():
# Copy new weights into old policy:
self.policy_old.load_state_dict(self.policy.state_dict())
# only clear the rb if we improved on the old model, otherwise the experience is still useful
self.rb.empty()

else:
out = self.policy.generate(
next_td, self.env, phase=phase # , select_best=True, multisample=True
)

metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx)
return {"loss": out.get("loss", None), **metrics}
8 changes: 5 additions & 3 deletions rl4co/models/zoo/l2d/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __init__(
self.encoder, NoEncoder
), "Define a feature extractor for decoder rather than an encoder in stepwise PPO"

def evaluate(self, td):
def evaluate(self, td, return_selected=True):
# Encoder: get encoder output and initial embeddings from initial state
hidden, _ = self.decoder.feature_extractor(td)
# pool the embeddings for the critic
Expand All @@ -220,10 +220,12 @@ def evaluate(self, td):
logits, mask = self.decoder.actor(td, *hidden)
# get logprobs and entropy over logp distribution
logprobs = process_logits(logits, mask, tanh_clipping=self.tanh_clipping)
action_logprobs = gather_by_index(logprobs, td["action"], dim=1)
dist_entropys = Categorical(logprobs.exp()).entropy()

return action_logprobs, value_pred, dist_entropys
if return_selected:
logprobs = gather_by_index(logprobs, td["action"], dim=1)

return logprobs, value_pred, dist_entropys

def act(self, td, env, phase: str = "train"):
logits, mask = self.decoder(td, hidden=None, num_starts=0)
Expand Down

0 comments on commit d0348fa

Please sign in to comment.