Skip to content

Commit

Permalink
offer ability to generate more than one text
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 15, 2021
1 parent dbbbcfd commit ce0c892
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,14 @@ $ python generate.py --dalle_path ./dalle.pt --text 'fireflies in a field under

You should see your images saved as `./outputs/{your prompt}/{image number}.jpg`

To generate multiple images, just pass in your text with '|' character as a separator.

ex.

```python
$ python generate.py --dalle_path ./dalle.pt --text 'a dog chewing a bone|a cat chasing mice|a frog eating a fly'
```

### Distributed Training

#### DeepSpeed
Expand Down
29 changes: 16 additions & 13 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,27 @@ def exists(val):

image_size = vae.image_size

text = tokenizer.tokenize([args.text], dalle.text_seq_len).cuda()
texts = args.text.split('|')

text = repeat(text, '() n -> b n', b = args.num_images)
for text in tqdm(texts):
text = tokenizer.tokenize([args.text], dalle.text_seq_len).cuda()

outputs = []
text = repeat(text, '() n -> b n', b = args.num_images)

for text_chunk in tqdm(text.split(args.batch_size), desc = 'generating images'):
output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
outputs.append(output)
outputs = []

outputs = torch.cat(outputs)
for text_chunk in tqdm(text.split(args.batch_size), desc = f'generating images for - {text}'):
output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
outputs.append(output)

# save all images
outputs = torch.cat(outputs)

outputs_dir = Path(args.outputs_dir) / args.text.replace(' ', '_')
outputs_dir.mkdir(parents = True, exist_ok = True)
# save all images

for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
save_image(image, outputs_dir / f'{i}.jpg', normalize=True)
outputs_dir = Path(args.outputs_dir) / args.text.replace(' ', '_')
outputs_dir.mkdir(parents = True, exist_ok = True)

print(f'created {args.num_images} images at "{str(outputs_dir)}"')
for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
save_image(image, outputs_dir / f'{i}.jpg', normalize=True)

print(f'created {args.num_images} images at "{str(outputs_dir)}"')

0 comments on commit ce0c892

Please sign in to comment.