-
Notifications
You must be signed in to change notification settings - Fork 644
/
dalle_pytorch.py
437 lines (341 loc) · 14.1 KB
/
dalle_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
from math import log2, sqrt
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
from axial_positional_embedding import AxialPositionalEmbedding
from dalle_pytorch.transformer import Transformer
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def always(val):
def inner(*args, **kwargs):
return val
return inner
def is_empty(t):
return t.nelement() == 0
def masked_mean(t, mask, dim = 1):
t = t.masked_fill(~mask[:, :, None], 0.)
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# sampling helpers
def top_k(logits, thres = 0.5):
num_logits = logits.shape[-1]
k = max(int((1 - thres) * num_logits), 1)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# discrete vae class
class ResBlock(nn.Module):
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan, 3, padding = 1),
nn.ReLU(),
nn.Conv2d(chan, chan, 3, padding = 1),
nn.ReLU(),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class DiscreteVAE(nn.Module):
def __init__(
self,
image_size = 256,
num_tokens = 512,
codebook_dim = 512,
num_layers = 3,
num_resnet_blocks = 0,
hidden_dim = 64,
channels = 3,
temperature = 0.9,
straight_through = False
):
super().__init__()
assert log2(image_size).is_integer(), 'image size must be a power of 2'
assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
has_resblocks = num_resnet_blocks > 0
self.image_size = image_size
self.num_tokens = num_tokens
self.num_layers = num_layers
self.temperature = temperature
self.straight_through = straight_through
self.codebook = nn.Embedding(num_tokens, codebook_dim)
hdim = hidden_dim
enc_chans = [hidden_dim] * num_layers
dec_chans = list(reversed(enc_chans))
enc_chans = [channels, *enc_chans]
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
dec_chans = [dec_init_chan, *dec_chans]
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
enc_layers = []
dec_layers = []
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
for _ in range(num_resnet_blocks):
dec_layers.insert(0, ResBlock(dec_chans[1]))
enc_layers.append(ResBlock(enc_chans[-1]))
if num_resnet_blocks > 0:
dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))
enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))
self.encoder = nn.Sequential(*enc_layers)
self.decoder = nn.Sequential(*dec_layers)
@torch.no_grad()
def get_codebook_indices(self, images):
logits = self.forward(images, return_logits = True)
codebook_indices = logits.argmax(dim = 1).flatten(1)
return codebook_indices
def decode(
self,
img_seq
):
image_embeds = self.codebook(img_seq)
b, n, d = image_embeds.shape
h = w = int(sqrt(n))
image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
images = self.decoder(image_embeds)
return images
def forward(
self,
img,
return_recon_loss = False,
return_logits = False
):
logits = self.encoder(img)
if return_logits:
return logits # return logits for getting hard image indices for DALL-E training
soft_one_hot = F.gumbel_softmax(logits, tau = self.temperature, dim = 1, hard = self.straight_through)
sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight)
out = self.decoder(sampled)
if not return_recon_loss:
return out
loss = F.mse_loss(img, out)
return loss
# main classes
class CLIP(nn.Module):
def __init__(
self,
*,
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 10000,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
num_visual_tokens = 512,
visual_enc_depth = 6,
visual_heads = 8,
visual_image_size = 256,
visual_patch_size = 32,
channels = 3
):
super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
self.text_transformer = Transformer(causal = False, seq_len = text_seq_len, dim = dim_text, depth = text_enc_depth, heads = text_heads)
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)
assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (visual_image_size // visual_patch_size) ** 2
patch_dim = channels * visual_patch_size ** 2
self.visual_patch_size = visual_patch_size
self.to_visual_embedding = nn.Linear(patch_dim, dim_image)
self.visual_pos_emb = nn.Embedding(num_patches, dim_image)
self.visual_transformer = Transformer(causal = False, seq_len = num_patches, dim = dim_image, depth = visual_enc_depth, heads = visual_heads)
self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)
self.temperature = nn.Parameter(torch.tensor(1.))
def forward(
self,
text,
image,
text_mask = None,
return_loss = False
):
b, device, p = text.shape[0], text.device, self.visual_patch_size
text_emb = self.text_emb(text)
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device))
image_patches = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
image_emb = self.to_visual_embedding(image_patches)
image_emb += self.visual_pos_emb(torch.arange(image_emb.shape[1], device = device))
enc_text = self.text_transformer(text_emb, mask = text_mask)
enc_image = self.visual_transformer(image_emb)
if exists(text_mask):
text_latents = masked_mean(enc_text, text_mask, dim = 1)
else:
text_latents = enc_text.mean(dim = 1)
image_latents = enc_image.mean(dim = 1)
text_latents = self.to_text_latent(text_latents)
image_latents = self.to_visual_latent(image_latents)
text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents))
temp = self.temperature.exp()
if not return_loss:
sim = einsum('n d, n d -> n', text_latents, image_latents) * temp
return sim
sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp
labels = torch.arange(b, device = device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
return loss
# main DALL-E class
class DALLE(nn.Module):
def __init__(
self,
*,
dim,
vae,
num_text_tokens = 10000,
text_seq_len = 256,
depth,
heads = 8,
dim_head = 64,
reversible = False,
attn_dropout = 0.,
ff_dropout = 0,
sparse_attn = False,
noncausal_attn_len = 0,
ignore_index = -100
):
super().__init__()
assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'
image_size = vae.image_size
num_image_tokens = vae.num_tokens
image_seq_len = (vae.image_size // (2 ** vae.num_layers)) ** 2
self.text_emb = nn.Embedding(num_text_tokens, dim)
self.image_emb = nn.Embedding(num_image_tokens, dim)
self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) # +1 for <bos>
self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_size, image_size))
self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss
self.num_image_tokens = num_image_tokens
self.text_seq_len = text_seq_len
self.image_seq_len = image_seq_len
seq_len = text_seq_len + image_seq_len
total_tokens = num_text_tokens + num_image_tokens
self.total_tokens = total_tokens
self.total_seq_len = seq_len
self.noncausal_attn_len = noncausal_attn_len
self.vae = vae
if exists(self.vae):
self.vae = vae
self.image_emb = vae.codebook
self.transformer = Transformer(
dim = dim,
causal = True,
seq_len = seq_len,
depth = depth,
heads = heads,
dim_head = dim_head,
reversible = reversible,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
noncausal_attn_len = (noncausal_attn_len + 1),
sparse_attn = sparse_attn,
sparse_attn_global_indices = range(text_seq_len)
)
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, self.total_tokens),
)
seq_range = torch.arange(seq_len)
logits_range = torch.arange(total_tokens)
seq_range = rearrange(seq_range, 'n -> () n ()')
logits_range = rearrange(logits_range, 'd -> () () d')
logits_mask = (
((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
)
self.register_buffer('logits_mask', logits_mask)
self.ignore_index = ignore_index
@torch.no_grad()
@eval_decorator
def generate_images(
self,
text,
*,
clip = None,
mask = None,
filter_thres = 0.5,
temperature = 1.
):
vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
total_len = text_seq_len + image_seq_len
out = text
for cur_len in range(text.shape[1], total_len):
is_image = cur_len >= text_seq_len
text, image = out[:, :text_seq_len], out[:, text_seq_len:]
logits = self(text, image, mask = mask)[:, -1, :]
filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim = -1)
sample = torch.multinomial(probs, 1)
sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
out = torch.cat((out, sample), dim=-1)
if out.shape[1] <= text_seq_len:
mask = F.pad(mask, (0, 1), value = True)
text_seq = out[:, :text_seq_len]
img_seq = out[:, -image_seq_len:]
images = vae.decode(img_seq)
if exists(clip):
scores = clip(text_seq, images, return_loss = False)
return images, scores
return images
def forward(
self,
text,
image = None,
mask = None,
return_loss = False
):
device, ignore_index, total_seq_len = text.device, self.ignore_index, self.total_seq_len
text = F.pad(text, (1, 0), value = 0) # use padding as <bos>
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
tokens = self.text_emb(text)
tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device))
seq_len = tokens.shape[1]
if exists(image) and not is_empty(image):
is_raw_image = len(image.shape) == 4
if is_raw_image:
image = self.vae.get_codebook_indices(image)
image_len = image.shape[1]
image_emb = self.image_emb(image)
image_emb += self.image_pos_emb(image_emb)
tokens = torch.cat((tokens, image_emb), dim = 1)
seq_len += image_len
if exists(mask):
mask = F.pad(mask, (0, image_emb.shape[1]), value = True)
# when training, if the length exceeds the total text + image length
# remove the last token, since it needs not to be trained
if tokens.shape[1] > total_seq_len:
seq_len -= 1
tokens = tokens[:, :-1]
if exists(mask):
mask = mask[:, :-1]
out = self.transformer(tokens, mask = mask)
logits = self.to_logits(out)
# mask logits to make sure text predicts text (except last token), and image predicts image
logits_mask = self.logits_mask[:, :seq_len]
max_neg_value = -torch.finfo(logits.dtype).max
logits.masked_fill_(logits_mask, max_neg_value)
if not return_loss:
return logits
assert exists(image), 'when training, image must be supplied'
noncausal_attn_len = self.noncausal_attn_len
offsetted_image = image + self.num_text_tokens
labels = torch.cat((text[:, 1:], offsetted_image), dim = 1)
if noncausal_attn_len > 0:
seq_range = torch.arange(seq_len, device = device)
noncausal_attn_mask = seq_range < noncausal_attn_len
noncausal_attn_mask = rearrange(noncausal_attn_mask, 'n -> () n')
labels.masked_fill_(noncausal_attn_mask, ignore_index)
loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels)
return loss