Skip to content

Commit

Permalink
do not tie vae codebook to dall-e image embedding by default, and if …
Browse files Browse the repository at this point in the history
…tying, make sure to detach
  • Loading branch information
lucidrains committed Feb 10, 2021
1 parent 40f4119 commit 30a13eb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
12 changes: 10 additions & 2 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ def __init__(
ff_dropout = 0,
sparse_attn = False,
noncausal_attn_len = 0,
ignore_index = -100
ignore_index = -100,
tie_codebook_image_emb = False
):
super().__init__()
assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'
Expand Down Expand Up @@ -285,9 +286,12 @@ def __init__(
self.noncausal_attn_len = noncausal_attn_len

self.vae = vae
self.tie_codebook_image_emb = tie_codebook_image_emb
if exists(self.vae):
self.vae = vae
self.image_emb = vae.codebook

if tie_codebook_image_emb:
self.image_emb = vae.codebook

self.transformer = Transformer(
dim = dim,
Expand Down Expand Up @@ -394,6 +398,10 @@ def forward(

image_len = image.shape[1]
image_emb = self.image_emb(image)

if self.tie_codebook_image_emb:
image_emb.detach_()

image_emb += self.image_pos_emb(image_emb)

tokens = torch.cat((tokens, image_emb), dim = 1)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.53',
version = '0.0.54',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 30a13eb

Please sign in to comment.