Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UNet 1d for RL model for planning + colab #105

Merged
merged 69 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
8d1a17c
re-add RL model code
natolambert Jul 19, 2022
84e94d7
match model forward api
natolambert Jul 19, 2022
f67b036
add register_to_config, pass training tests
natolambert Jul 26, 2022
e42d1c0
fix tests, update forward outputs
natolambert Oct 3, 2022
2dd514e
remove unused code, some comments
natolambert Oct 3, 2022
b4c6188
add to docs
natolambert Oct 3, 2022
c53bba9
remove extra embedding code
natolambert Oct 6, 2022
effcbdb
unify time embedding
natolambert Oct 7, 2022
7865231
remove conv1d output sequential
natolambert Oct 8, 2022
35b0a43
remove sequential from conv1dblock
natolambert Oct 8, 2022
9b1379d
style and deleting duplicated code
natolambert Oct 8, 2022
e97a610
clean files
natolambert Oct 8, 2022
8642560
remove unused variables
natolambert Oct 10, 2022
f58c915
clean variables
natolambert Oct 10, 2022
ad8376d
Merge branch 'main' into rl
natolambert Oct 10, 2022
3b08bea
add 1d resnet block structure for downsample
natolambert Oct 10, 2022
aae2a9a
rename as unet1d
natolambert Oct 10, 2022
dd872af
fix renaming
natolambert Oct 10, 2022
9b67bb7
rename files
natolambert Oct 12, 2022
db012eb
add get_block(...) api
natolambert Oct 12, 2022
4db6e0b
unify args for model1d like model2d
natolambert Oct 12, 2022
634a526
minor cleaning
natolambert Oct 12, 2022
aebf547
fix docs
natolambert Oct 12, 2022
305ecd8
improve 1d resnet blocks
natolambert Oct 12, 2022
42855b9
Merge branch 'main' into rl
natolambert Oct 12, 2022
95d3a1c
fix tests, remove permuts
natolambert Oct 12, 2022
6cbb73b
fix style
natolambert Oct 12, 2022
ffb7355
add output activation
natolambert Oct 18, 2022
a6314f6
rename flax blocks file
natolambert Oct 18, 2022
48a7414
Add Value Function and corresponding example script to Diffuser imple…
bglick13 Oct 21, 2022
3acddb5
update post merge of scripts
natolambert Oct 21, 2022
713e8f2
add mdiblock / outblock architecture
natolambert Oct 24, 2022
268ebdf
Pipeline cleanup (#947)
bglick13 Oct 24, 2022
daa05fb
Update src/diffusers/models/unet_1d_blocks.py
Oct 24, 2022
ea5f231
Update tests/test_models_unet.py
Oct 24, 2022
4f7a3a4
RL Cleanup v2 (#965)
bglick13 Oct 24, 2022
d90b8b1
fix quality in tests
natolambert Oct 24, 2022
ad8b6cf
fix quality style, split test file
natolambert Oct 24, 2022
e06a4a4
Merge branch 'main' into rl
natolambert Oct 24, 2022
99b2c81
fix checks / tests
natolambert Oct 24, 2022
de4b6e4
make timesteps closer to main
natolambert Oct 25, 2022
ef6ca1f
unify block API
natolambert Oct 25, 2022
6e3485c
Merge branch 'main' into rl
natolambert Oct 25, 2022
e6f1a83
unify forward api
natolambert Oct 25, 2022
c35a925
delete lines in examples
natolambert Oct 25, 2022
949b93a
style
natolambert Oct 25, 2022
2f6462b
examples style
natolambert Oct 25, 2022
a2dd559
all tests pass
natolambert Oct 26, 2022
39dff73
make style
natolambert Oct 26, 2022
d5eedff
make dance_diff test pass
natolambert Oct 26, 2022
faeacd5
Refactoring RL PR (#1200)
Nov 8, 2022
be25030
Merge branch 'main' into rl
natolambert Nov 8, 2022
72b7ee8
hotfix for tests
natolambert Nov 8, 2022
cf76a2d
quality
natolambert Nov 8, 2022
2290356
fix some tests
natolambert Nov 9, 2022
a061f7e
change defaults
natolambert Nov 9, 2022
0c58758
more mps test fixes
natolambert Nov 9, 2022
691ddee
unet1d defaults
natolambert Nov 9, 2022
4948ca7
do not default import experimental
natolambert Nov 9, 2022
ac88677
defaults for tests
natolambert Nov 9, 2022
ba204db
fix tests
natolambert Nov 9, 2022
915c41e
fix-copies
natolambert Nov 9, 2022
c901889
Merge branch 'main' into rl
natolambert Nov 14, 2022
becc803
fix
natolambert Nov 14, 2022
9b8e5ee
changes per Patrik's comments (#1285)
bglick13 Nov 14, 2022
3684a8c
fix renaming
natolambert Nov 14, 2022
ebdef16
skip more mps tests
natolambert Nov 14, 2022
a259aae
last test fix
natolambert Nov 14, 2022
1f7702c
Update examples/rl/README.md
Nov 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/api/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## DecoderOutput
[[autodoc]] models.vae.DecoderOutput

## TemporalUNet
[[autodoc]] TemporalUNet

## VQEncoderOutput
[[autodoc]] models.vae.VQEncoderOutput

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .models import AutoencoderKL, TemporalUNet, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
if is_torch_available():
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_rl import TemporalUNet
from .vae import AutoencoderKL, VQModel

if is_flax_available():
Expand Down
134 changes: 134 additions & 0 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,70 @@
import torch.nn.functional as F


class Upsample1D(nn.Module):
"""
An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
If 3D, then
upsampling occurs in the inner-two dimensions.
"""

def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)

x = F.interpolate(x, scale_factor=2.0, mode="nearest")

if self.use_conv:
x = self.conv(x)

return x


class Downsample1D(nn.Module):
"""
A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
If 3D, then
downsampling occurs in the inner-two dimensions.
"""

def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name

if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)

def forward(self, x):
assert x.shape[1] == self.channels
return self.conv(x)


class Upsample2D(nn.Module):
"""
An upsampling layer with an optional convolution.
Expand Down Expand Up @@ -374,6 +438,76 @@ def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x))


class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""

def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()

self.block = nn.Sequential(
natolambert marked this conversation as resolved.
Show resolved Hide resolved
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
RearrangeDim(),
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels),
RearrangeDim(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(),
)

def forward(self, x):
return self.block(x)


class RearrangeDim(nn.Module):
def __init__(self):
super().__init__()

def forward(self, tensor):
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
return tensor[:, :, None, :]
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else:
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()

self.blocks = nn.ModuleList(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big fan of using blocks here as it makes it very hard to adapt the class to future models. Sorry this is a bit annoying to do when converting the checkpoint, but could we maybe instead call the layers self.conv_in and self.conv_out? The idea here is that if in the future there are checkpoints which have an intermediate conv layer it'd be much easier to adapt this layer without breaking the previous checkpoitns

[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)

self.time_mlp = nn.Sequential(
natolambert marked this conversation as resolved.
Show resolved Hide resolved
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)

self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)

def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)


def upsample_2d(x, kernel=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter.

Expand Down
193 changes: 193 additions & 0 deletions src/diffusers/models/unet_rl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
from dataclasses import dataclass
from typing import Tuple, Union

import torch
import torch.nn as nn

from diffusers.models.resnet import Downsample1D, ResidualTemporalBlock, Upsample1D

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import get_timestep_embedding


@dataclass
class TemporalUNetOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch, horizon, obs_dimension)`):
Hidden states output. Output of last layer of model.
"""

sample: torch.FloatTensor


class SinusoidalPosEmb(nn.Module):
natolambert marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, x):
return get_timestep_embedding(x, self.dim)


class RearrangeDim(nn.Module):
def __init__(self):
super().__init__()

def forward(self, tensor):
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
return tensor[:, :, None, :]
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else:
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")


class Conv1dBlock(nn.Module):
natolambert marked this conversation as resolved.
Show resolved Hide resolved
"""
Conv1d --> GroupNorm --> Mish
"""

def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()

self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
RearrangeDim(),
nn.GroupNorm(n_groups, out_channels),
RearrangeDim(),
nn.Mish(),
)

def forward(self, x):
return self.block(x)


class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
@register_to_config
def __init__(
self,
training_horizon=128,
transition_dim=14,
cond_dim=3,
predict_epsilon=False,
clip_denoised=True,
dim=32,
dim_mults=(1, 4, 8),
):
super().__init__()

self.transition_dim = transition_dim
self.cond_dim = cond_dim
self.predict_epsilon = predict_epsilon
self.clip_denoised = clip_denoised

dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

time_dim = dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.Mish(),
nn.Linear(dim * 4, dim),
)

self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)

for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)

self.downs.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(),
]
)
)

if not is_last:
training_horizon = training_horizon // 2

mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)

for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)

self.ups.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(),
]
)
)

if not is_last:
training_horizon = training_horizon * 2

self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=5),
nn.Conv1d(dim, transition_dim, 1),
)

# def forward(self, sample, timestep):
# """
# x : [ batch x horizon x transition ] #"""
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[TemporalUNetOutput, Tuple]:
"""r
Args:
sample (`torch.FloatTensor`): (batch, horizon, obs_dimension) noisy inputs tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the shape should be (batch, horizon, obs_dimension + action_dimension)

timestep (`torch.FloatTensor` or `float` or `int): batch (batch) timesteps
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.

Returns:
[`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
sample = sample.permute(0, 2, 1)

t = self.time_mlp(timestep)
h = []

for resnet, resnet2, downsample in self.downs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's really try to mirror the design in src/diffusers/models/unet_2d.py here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try and make it work with the get_block function. The parameter re-naming script is getting quite ugly so we'll see how far I make it 🤗.

sample = resnet(sample, t)
sample = resnet2(sample, t)
h.append(sample)
sample = downsample(sample)

sample = self.mid_block1(sample, t)
sample = self.mid_block2(sample, t)

for resnet, resnet2, upsample in self.ups:
sample = torch.cat((sample, h.pop()), dim=1)
sample = resnet(sample, t)
sample = resnet2(sample, t)
sample = upsample(sample)

sample = self.final_conv(sample)

sample = sample.permute(0, 2, 1)

if not return_dict:
return (sample,)

return TemporalUNetOutput(sample=sample)
Loading