Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UNet 1d for RL model for planning + colab #105

Merged
merged 69 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from 68 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
8d1a17c
re-add RL model code
natolambert Jul 19, 2022
84e94d7
match model forward api
natolambert Jul 19, 2022
f67b036
add register_to_config, pass training tests
natolambert Jul 26, 2022
e42d1c0
fix tests, update forward outputs
natolambert Oct 3, 2022
2dd514e
remove unused code, some comments
natolambert Oct 3, 2022
b4c6188
add to docs
natolambert Oct 3, 2022
c53bba9
remove extra embedding code
natolambert Oct 6, 2022
effcbdb
unify time embedding
natolambert Oct 7, 2022
7865231
remove conv1d output sequential
natolambert Oct 8, 2022
35b0a43
remove sequential from conv1dblock
natolambert Oct 8, 2022
9b1379d
style and deleting duplicated code
natolambert Oct 8, 2022
e97a610
clean files
natolambert Oct 8, 2022
8642560
remove unused variables
natolambert Oct 10, 2022
f58c915
clean variables
natolambert Oct 10, 2022
ad8376d
Merge branch 'main' into rl
natolambert Oct 10, 2022
3b08bea
add 1d resnet block structure for downsample
natolambert Oct 10, 2022
aae2a9a
rename as unet1d
natolambert Oct 10, 2022
dd872af
fix renaming
natolambert Oct 10, 2022
9b67bb7
rename files
natolambert Oct 12, 2022
db012eb
add get_block(...) api
natolambert Oct 12, 2022
4db6e0b
unify args for model1d like model2d
natolambert Oct 12, 2022
634a526
minor cleaning
natolambert Oct 12, 2022
aebf547
fix docs
natolambert Oct 12, 2022
305ecd8
improve 1d resnet blocks
natolambert Oct 12, 2022
42855b9
Merge branch 'main' into rl
natolambert Oct 12, 2022
95d3a1c
fix tests, remove permuts
natolambert Oct 12, 2022
6cbb73b
fix style
natolambert Oct 12, 2022
ffb7355
add output activation
natolambert Oct 18, 2022
a6314f6
rename flax blocks file
natolambert Oct 18, 2022
48a7414
Add Value Function and corresponding example script to Diffuser imple…
bglick13 Oct 21, 2022
3acddb5
update post merge of scripts
natolambert Oct 21, 2022
713e8f2
add mdiblock / outblock architecture
natolambert Oct 24, 2022
268ebdf
Pipeline cleanup (#947)
bglick13 Oct 24, 2022
daa05fb
Update src/diffusers/models/unet_1d_blocks.py
Oct 24, 2022
ea5f231
Update tests/test_models_unet.py
Oct 24, 2022
4f7a3a4
RL Cleanup v2 (#965)
bglick13 Oct 24, 2022
d90b8b1
fix quality in tests
natolambert Oct 24, 2022
ad8b6cf
fix quality style, split test file
natolambert Oct 24, 2022
e06a4a4
Merge branch 'main' into rl
natolambert Oct 24, 2022
99b2c81
fix checks / tests
natolambert Oct 24, 2022
de4b6e4
make timesteps closer to main
natolambert Oct 25, 2022
ef6ca1f
unify block API
natolambert Oct 25, 2022
6e3485c
Merge branch 'main' into rl
natolambert Oct 25, 2022
e6f1a83
unify forward api
natolambert Oct 25, 2022
c35a925
delete lines in examples
natolambert Oct 25, 2022
949b93a
style
natolambert Oct 25, 2022
2f6462b
examples style
natolambert Oct 25, 2022
a2dd559
all tests pass
natolambert Oct 26, 2022
39dff73
make style
natolambert Oct 26, 2022
d5eedff
make dance_diff test pass
natolambert Oct 26, 2022
faeacd5
Refactoring RL PR (#1200)
Nov 8, 2022
be25030
Merge branch 'main' into rl
natolambert Nov 8, 2022
72b7ee8
hotfix for tests
natolambert Nov 8, 2022
cf76a2d
quality
natolambert Nov 8, 2022
2290356
fix some tests
natolambert Nov 9, 2022
a061f7e
change defaults
natolambert Nov 9, 2022
0c58758
more mps test fixes
natolambert Nov 9, 2022
691ddee
unet1d defaults
natolambert Nov 9, 2022
4948ca7
do not default import experimental
natolambert Nov 9, 2022
ac88677
defaults for tests
natolambert Nov 9, 2022
ba204db
fix tests
natolambert Nov 9, 2022
915c41e
fix-copies
natolambert Nov 9, 2022
c901889
Merge branch 'main' into rl
natolambert Nov 14, 2022
becc803
fix
natolambert Nov 14, 2022
9b8e5ee
changes per Patrik's comments (#1285)
bglick13 Nov 14, 2022
3684a8c
fix renaming
natolambert Nov 14, 2022
ebdef16
skip more mps tests
natolambert Nov 14, 2022
a259aae
last test fix
natolambert Nov 14, 2022
1f7702c
Update examples/rl/README.md
Nov 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,6 @@ tags
*.lock

# DS_Store (MacOS)
.DS_Store
.DS_Store
# RL pipelines may produce mp4 outputs
*.mp4
9 changes: 6 additions & 3 deletions docs/source/api/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput

## UNet1DModel
[[autodoc]] UNet1DModel

## UNet2DModel
[[autodoc]] UNet2DModel

## UNet1DOutput
[[autodoc]] models.unet_1d.UNet1DOutput

## UNet1DModel
[[autodoc]] UNet1DModel

## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput

Expand Down
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
| [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)

| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool!


## Community

Expand Down
19 changes: 19 additions & 0 deletions examples/rl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Overview

These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers.
There are four scripts,
1. `run_diffuser_locomotion.py` to sample actions and run them in the environment,
2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model.

You will need some RL specific requirements to run the examples:

```
pip install -f https://download.pytorch.org/whl/torch_stable.html \
free-mujoco-py \
einops \
gym \
natolambert marked this conversation as resolved.
Show resolved Hide resolved
protobuf==3.20.1 \
git+https://github.com/rail-berkeley/d4rl.git \
mediapy \
Pillow==9.0.0
```
57 changes: 57 additions & 0 deletions examples/rl/run_diffuser_gen_trajectories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import d4rl # noqa
import gym
import tqdm
from diffusers.experimental import ValueGuidedRLPipeline
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect :-)



config = dict(
n_samples=64,
horizon=32,
num_inference_steps=20,
n_guide_steps=0,
scale_grad_by_std=True,
scale=0.1,
eta=0.0,
t_grad_cutoff=2,
device="cpu",
)


if __name__ == "__main__":
env_name = "hopper-medium-v2"
env = gym.make(env_name)

pipeline = ValueGuidedRLPipeline.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32",
env=env,
)

env.seed(0)
obs = env.reset()
total_reward = 0
total_score = 0
T = 1000
rollout = [obs.copy()]
try:
for t in tqdm.tqdm(range(T)):
# Call the policy
denorm_actions = pipeline(obs, planning_horizon=32)

# execute action in environment
next_observation, reward, terminal, _ = env.step(denorm_actions)
score = env.get_normalized_score(total_reward)
# update return
total_reward += reward
total_score += score
print(
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
f" {total_score}"
)
# save observations for rendering
rollout.append(next_observation.copy())

obs = next_observation
except KeyboardInterrupt:
pass

print(f"Total reward: {total_reward}")
57 changes: 57 additions & 0 deletions examples/rl/run_diffuser_locomotion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import d4rl # noqa
import gym
import tqdm
from diffusers.experimental import ValueGuidedRLPipeline


config = dict(
n_samples=64,
horizon=32,
num_inference_steps=20,
n_guide_steps=2,
scale_grad_by_std=True,
scale=0.1,
eta=0.0,
t_grad_cutoff=2,
device="cpu",
)


if __name__ == "__main__":
env_name = "hopper-medium-v2"
env = gym.make(env_name)

pipeline = ValueGuidedRLPipeline.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32",
env=env,
)

env.seed(0)
obs = env.reset()
total_reward = 0
total_score = 0
T = 1000
rollout = [obs.copy()]
try:
for t in tqdm.tqdm(range(T)):
# call the policy
denorm_actions = pipeline(obs, planning_horizon=32)

# execute action in environment
next_observation, reward, terminal, _ = env.step(denorm_actions)
score = env.get_normalized_score(total_reward)
# update return
total_reward += reward
total_score += score
print(
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
f" {total_score}"
)
# save observations for rendering
rollout.append(next_observation.copy())

obs = next_observation
except KeyboardInterrupt:
pass

print(f"Total reward: {total_reward}")
100 changes: 100 additions & 0 deletions scripts/convert_models_diffuser_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import json
import os

import torch

from diffusers import UNet1DModel


os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)

os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)


def unet(hor):
if hor == 128:
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
block_out_channels = (32, 128, 256)
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")

elif hor == 32:
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
block_out_channels = (32, 64, 128, 256)
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
state_dict = model.state_dict()
config = dict(
down_block_types=down_block_types,
block_out_channels=block_out_channels,
up_block_types=up_block_types,
layers_per_block=1,
use_timestep_embedding=True,
out_block_type="OutConv1DBlock",
norm_num_groups=8,
downsample_each_block=False,
in_channels=14,
out_channels=14,
extra_in_channels=0,
time_embedding_type="positional",
flip_sin_to_cos=False,
freq_shift=1,
sample_size=65536,
mid_block_type="MidResTemporalBlock1D",
act_fn="mish",
)
hf_value_function = UNet1DModel(**config)
print(f"length of state dict: {len(state_dict.keys())}")
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
for k, v in mapping.items():
state_dict[v] = state_dict.pop(k)
hf_value_function.load_state_dict(state_dict)

torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
json.dump(config, f)


def value_function():
config = dict(
in_channels=14,
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
up_block_types=(),
out_block_type="ValueFunction",
mid_block_type="ValueFunctionMidBlock1D",
block_out_channels=(32, 64, 128, 256),
layers_per_block=1,
downsample_each_block=True,
sample_size=65536,
out_channels=14,
extra_in_channels=0,
time_embedding_type="positional",
use_timestep_embedding=True,
flip_sin_to_cos=False,
freq_shift=1,
norm_num_groups=8,
act_fn="mish",
)

model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
state_dict = model
hf_value_function = UNet1DModel(**config)
print(f"length of state dict: {len(state_dict.keys())}")
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")

mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))
for k, v in mapping.items():
state_dict[v] = state_dict.pop(k)

hf_value_function.load_state_dict(state_dict)

torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
json.dump(config, f)


if __name__ == "__main__":
unet(32)
# unet(128)
value_function()
5 changes: 5 additions & 0 deletions src/diffusers/experimental/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# 🧨 Diffusers Experimental

We are adding experimental code to support novel applications and usages of the Diffusers library.
Currently, the following experiments are supported:
* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
1 change: 1 addition & 0 deletions src/diffusers/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rl import ValueGuidedRLPipeline
1 change: 1 addition & 0 deletions src/diffusers/experimental/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .value_guided_sampling import ValueGuidedRLPipeline
Loading