From e1a847ce8286c0aa62bf501ddb7fb16cb3e86a69 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 10 Mar 2021 07:02:56 -0800 Subject: [PATCH] make sure generate script works with openai pretrained vae --- generate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generate.py b/generate.py index 7e2150bd..33bc5d5b 100644 --- a/generate.py +++ b/generate.py @@ -15,7 +15,7 @@ # dalle related classes and utils -from dalle_pytorch import DiscreteVAE, DALLE +from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, DALLE from dalle_pytorch.simple_tokenizer import tokenize, tokenizer, VOCAB_SIZE # argument parsing @@ -52,7 +52,7 @@ load_obj = torch.load(str(dalle_path)) dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights') -vae = DiscreteVAE(**vae_params) +vae = DiscreteVAE(**vae_params) if vae_params is not None else OpenAIDiscreteVAE() dalle = DALLE(vae = vae, **dalle_params).cuda()