Skip to content

Commit

Permalink
make sure generate script works with openai pretrained vae
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 10, 2021
1 parent 1744c5d commit e1a847c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# dalle related classes and utils

from dalle_pytorch import DiscreteVAE, DALLE
from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, DALLE
from dalle_pytorch.simple_tokenizer import tokenize, tokenizer, VOCAB_SIZE

# argument parsing
Expand Down Expand Up @@ -52,7 +52,7 @@
load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')

vae = DiscreteVAE(**vae_params)
vae = DiscreteVAE(**vae_params) if vae_params is not None else OpenAIDiscreteVAE()

dalle = DALLE(vae = vae, **dalle_params).cuda()

Expand Down

0 comments on commit e1a847c

Please sign in to comment.