Skip to content

Commit

Permalink
Add UNet 1d for RL model for planning + colab (#105)
Browse files Browse the repository at this point in the history
* re-add RL model code

* match model forward api

* add register_to_config, pass training tests

* fix tests, update forward outputs

* remove unused code, some comments

* add to docs

* remove extra embedding code

* unify time embedding

* remove conv1d output sequential

* remove sequential from conv1dblock

* style and deleting duplicated code

* clean files

* remove unused variables

* clean variables

* add 1d resnet block structure for downsample

* rename as unet1d

* fix renaming

* rename files

* add get_block(...) api

* unify args for model1d like model2d

* minor cleaning

* fix docs

* improve 1d resnet blocks

* fix tests, remove permuts

* fix style

* add output activation

* rename flax blocks file

* Add Value Function and corresponding example script to Diffuser implementation (#884)

* valuefunction code

* start example scripts

* missing imports

* bug fixes and placeholder example script

* add value function scheduler

* load value function from hub and get best actions in example

* very close to working example

* larger batch size for planning

* more tests

* merge unet1d changes

* wandb for debugging, use newer models

* success!

* turns out we just need more diffusion steps

* run on modal

* merge and code cleanup

* use same api for rl model

* fix variance type

* wrong normalization function

* add tests

* style

* style and quality

* edits based on comments

* style and quality

* remove unused var

* hack unet1d into a value function

* add pipeline

* fix arg order

* add pipeline to core library

* community pipeline

* fix couple shape bugs

* style

* Apply suggestions from code review

Co-authored-by: Nathan Lambert <nathan@huggingface.co>

* update post merge of scripts

* add mdiblock / outblock architecture

* Pipeline cleanup (#947)

* valuefunction code

* start example scripts

* missing imports

* bug fixes and placeholder example script

* add value function scheduler

* load value function from hub and get best actions in example

* very close to working example

* larger batch size for planning

* more tests

* merge unet1d changes

* wandb for debugging, use newer models

* success!

* turns out we just need more diffusion steps

* run on modal

* merge and code cleanup

* use same api for rl model

* fix variance type

* wrong normalization function

* add tests

* style

* style and quality

* edits based on comments

* style and quality

* remove unused var

* hack unet1d into a value function

* add pipeline

* fix arg order

* add pipeline to core library

* community pipeline

* fix couple shape bugs

* style

* Apply suggestions from code review

* clean up comments

* convert older script to using pipeline and add readme

* rename scripts

* style, update tests

* delete unet rl model file

* remove imports in src

Co-authored-by: Nathan Lambert <nathan@huggingface.co>

* Update src/diffusers/models/unet_1d_blocks.py

* Update tests/test_models_unet.py

* RL Cleanup v2 (#965)

* valuefunction code

* start example scripts

* missing imports

* bug fixes and placeholder example script

* add value function scheduler

* load value function from hub and get best actions in example

* very close to working example

* larger batch size for planning

* more tests

* merge unet1d changes

* wandb for debugging, use newer models

* success!

* turns out we just need more diffusion steps

* run on modal

* merge and code cleanup

* use same api for rl model

* fix variance type

* wrong normalization function

* add tests

* style

* style and quality

* edits based on comments

* style and quality

* remove unused var

* hack unet1d into a value function

* add pipeline

* fix arg order

* add pipeline to core library

* community pipeline

* fix couple shape bugs

* style

* Apply suggestions from code review

* clean up comments

* convert older script to using pipeline and add readme

* rename scripts

* style, update tests

* delete unet rl model file

* remove imports in src

* add specific vf block and update tests

* style

* Update tests/test_models_unet.py

Co-authored-by: Nathan Lambert <nathan@huggingface.co>

* fix quality in tests

* fix quality style, split test file

* fix checks / tests

* make timesteps closer to main

* unify block API

* unify forward api

* delete lines in examples

* style

* examples style

* all tests pass

* make style

* make dance_diff test pass

* Refactoring RL PR (#1200)

* init file changes

* add import utils

* finish cleaning files, imports

* remove import flags

* clean examples

* fix imports, tests for merge

* update readmes

* hotfix for tests

* quality

* fix some tests

* change defaults

* more mps test fixes

* unet1d defaults

* do not default import experimental

* defaults for tests

* fix tests

* fix-copies

* fix

* changes per Patrik's comments (#1285)

* changes per Patrik's comments

* update conversion script

* fix renaming

* skip more mps tests

* last test fix

* Update examples/rl/README.md

Co-authored-by: Ben Glickenhaus <benglickenhaus@gmail.com>
  • Loading branch information
Nathan Lambert and bglick13 authored Nov 14, 2022
1 parent a8d0977 commit 7c5fef8
Show file tree
Hide file tree
Showing 18 changed files with 1,176 additions and 65 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,6 @@ tags
*.lock

# DS_Store (MacOS)
.DS_Store
.DS_Store
# RL pipelines may produce mp4 outputs
*.mp4
9 changes: 6 additions & 3 deletions docs/source/api/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput

## UNet1DModel
[[autodoc]] UNet1DModel

## UNet2DModel
[[autodoc]] UNet2DModel

## UNet1DOutput
[[autodoc]] models.unet_1d.UNet1DOutput

## UNet1DModel
[[autodoc]] UNet1DModel

## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput

Expand Down
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
| [**Text-to-Image fine-tuning**](./text_to_image) |||
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
| [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)

| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon.

## Community

Expand Down
19 changes: 19 additions & 0 deletions examples/rl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Overview

These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers.
There are four scripts,
1. `run_diffuser_locomotion.py` to sample actions and run them in the environment,
2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model.

You will need some RL specific requirements to run the examples:

```
pip install -f https://download.pytorch.org/whl/torch_stable.html \
free-mujoco-py \
einops \
gym==0.24.1 \
protobuf==3.20.1 \
git+https://github.com/rail-berkeley/d4rl.git \
mediapy \
Pillow==9.0.0
```
57 changes: 57 additions & 0 deletions examples/rl/run_diffuser_gen_trajectories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import d4rl # noqa
import gym
import tqdm
from diffusers.experimental import ValueGuidedRLPipeline


config = dict(
n_samples=64,
horizon=32,
num_inference_steps=20,
n_guide_steps=0,
scale_grad_by_std=True,
scale=0.1,
eta=0.0,
t_grad_cutoff=2,
device="cpu",
)


if __name__ == "__main__":
env_name = "hopper-medium-v2"
env = gym.make(env_name)

pipeline = ValueGuidedRLPipeline.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32",
env=env,
)

env.seed(0)
obs = env.reset()
total_reward = 0
total_score = 0
T = 1000
rollout = [obs.copy()]
try:
for t in tqdm.tqdm(range(T)):
# Call the policy
denorm_actions = pipeline(obs, planning_horizon=32)

# execute action in environment
next_observation, reward, terminal, _ = env.step(denorm_actions)
score = env.get_normalized_score(total_reward)
# update return
total_reward += reward
total_score += score
print(
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
f" {total_score}"
)
# save observations for rendering
rollout.append(next_observation.copy())

obs = next_observation
except KeyboardInterrupt:
pass

print(f"Total reward: {total_reward}")
57 changes: 57 additions & 0 deletions examples/rl/run_diffuser_locomotion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import d4rl # noqa
import gym
import tqdm
from diffusers.experimental import ValueGuidedRLPipeline


config = dict(
n_samples=64,
horizon=32,
num_inference_steps=20,
n_guide_steps=2,
scale_grad_by_std=True,
scale=0.1,
eta=0.0,
t_grad_cutoff=2,
device="cpu",
)


if __name__ == "__main__":
env_name = "hopper-medium-v2"
env = gym.make(env_name)

pipeline = ValueGuidedRLPipeline.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32",
env=env,
)

env.seed(0)
obs = env.reset()
total_reward = 0
total_score = 0
T = 1000
rollout = [obs.copy()]
try:
for t in tqdm.tqdm(range(T)):
# call the policy
denorm_actions = pipeline(obs, planning_horizon=32)

# execute action in environment
next_observation, reward, terminal, _ = env.step(denorm_actions)
score = env.get_normalized_score(total_reward)
# update return
total_reward += reward
total_score += score
print(
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
f" {total_score}"
)
# save observations for rendering
rollout.append(next_observation.copy())

obs = next_observation
except KeyboardInterrupt:
pass

print(f"Total reward: {total_reward}")
100 changes: 100 additions & 0 deletions scripts/convert_models_diffuser_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import json
import os

import torch

from diffusers import UNet1DModel


os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)

os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)


def unet(hor):
if hor == 128:
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
block_out_channels = (32, 128, 256)
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")

elif hor == 32:
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
block_out_channels = (32, 64, 128, 256)
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
state_dict = model.state_dict()
config = dict(
down_block_types=down_block_types,
block_out_channels=block_out_channels,
up_block_types=up_block_types,
layers_per_block=1,
use_timestep_embedding=True,
out_block_type="OutConv1DBlock",
norm_num_groups=8,
downsample_each_block=False,
in_channels=14,
out_channels=14,
extra_in_channels=0,
time_embedding_type="positional",
flip_sin_to_cos=False,
freq_shift=1,
sample_size=65536,
mid_block_type="MidResTemporalBlock1D",
act_fn="mish",
)
hf_value_function = UNet1DModel(**config)
print(f"length of state dict: {len(state_dict.keys())}")
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
for k, v in mapping.items():
state_dict[v] = state_dict.pop(k)
hf_value_function.load_state_dict(state_dict)

torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
json.dump(config, f)


def value_function():
config = dict(
in_channels=14,
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
up_block_types=(),
out_block_type="ValueFunction",
mid_block_type="ValueFunctionMidBlock1D",
block_out_channels=(32, 64, 128, 256),
layers_per_block=1,
downsample_each_block=True,
sample_size=65536,
out_channels=14,
extra_in_channels=0,
time_embedding_type="positional",
use_timestep_embedding=True,
flip_sin_to_cos=False,
freq_shift=1,
norm_num_groups=8,
act_fn="mish",
)

model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
state_dict = model
hf_value_function = UNet1DModel(**config)
print(f"length of state dict: {len(state_dict.keys())}")
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")

mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))
for k, v in mapping.items():
state_dict[v] = state_dict.pop(k)

hf_value_function.load_state_dict(state_dict)

torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
json.dump(config, f)


if __name__ == "__main__":
unet(32)
# unet(128)
value_function()
5 changes: 5 additions & 0 deletions src/diffusers/experimental/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# 🧨 Diffusers Experimental

We are adding experimental code to support novel applications and usages of the Diffusers library.
Currently, the following experiments are supported:
* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
1 change: 1 addition & 0 deletions src/diffusers/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rl import ValueGuidedRLPipeline
1 change: 1 addition & 0 deletions src/diffusers/experimental/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .value_guided_sampling import ValueGuidedRLPipeline
Loading

0 comments on commit 7c5fef8

Please sign in to comment.