diff --git a/README.md b/README.md index 29bb63de..0d1ae428 100644 --- a/README.md +++ b/README.md @@ -386,6 +386,14 @@ $ python generate.py --dalle_path ./dalle.pt --text 'fireflies in a field under You should see your images saved as `./outputs/{your prompt}/{image number}.jpg` +To generate multiple images, just pass in your text with '|' character as a separator. + +ex. + +```python +$ python generate.py --dalle_path ./dalle.pt --text 'a dog chewing a bone|a cat chasing mice|a frog eating a fly' +``` + ### Distributed Training #### DeepSpeed diff --git a/generate.py b/generate.py index cb699dba..13793d8e 100644 --- a/generate.py +++ b/generate.py @@ -88,24 +88,27 @@ def exists(val): image_size = vae.image_size -text = tokenizer.tokenize([args.text], dalle.text_seq_len).cuda() +texts = args.text.split('|') -text = repeat(text, '() n -> b n', b = args.num_images) +for text in tqdm(texts): + text = tokenizer.tokenize([args.text], dalle.text_seq_len).cuda() -outputs = [] + text = repeat(text, '() n -> b n', b = args.num_images) -for text_chunk in tqdm(text.split(args.batch_size), desc = 'generating images'): - output = dalle.generate_images(text_chunk, filter_thres = args.top_k) - outputs.append(output) + outputs = [] -outputs = torch.cat(outputs) + for text_chunk in tqdm(text.split(args.batch_size), desc = f'generating images for - {text}'): + output = dalle.generate_images(text_chunk, filter_thres = args.top_k) + outputs.append(output) -# save all images + outputs = torch.cat(outputs) -outputs_dir = Path(args.outputs_dir) / args.text.replace(' ', '_') -outputs_dir.mkdir(parents = True, exist_ok = True) + # save all images -for i, image in tqdm(enumerate(outputs), desc = 'saving images'): - save_image(image, outputs_dir / f'{i}.jpg', normalize=True) + outputs_dir = Path(args.outputs_dir) / args.text.replace(' ', '_') + outputs_dir.mkdir(parents = True, exist_ok = True) -print(f'created {args.num_images} images at "{str(outputs_dir)}"') + for i, image in tqdm(enumerate(outputs), desc = 'saving images'): + save_image(image, outputs_dir / f'{i}.jpg', normalize=True) + + print(f'created {args.num_images} images at "{str(outputs_dir)}"')