-
Notifications
You must be signed in to change notification settings - Fork 644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
General questions to the algorithmic understanding #21
Comments
@CDitzel Hey Carsten! Glad you are reading the repository and double checking my work! So I think the best way to make clear is to work through a concrete example. Let's say each text token is just a character, and let's pretend numbers are visual tokens. say we are working with a multi modal sequence [c] [a] [t] [4] [1] [2] [3] what we do during training is first append an EOS [c] [a] [t] [4] [1] [2] [3] [eos] then, for autoregressive, we break this sequence from the range [a] [t] [4] [1] [2] [3] [eos]
<bottom tokens predict top>
[c] [a] [t] [4] [1] [2] [3] you can see, since the last text token predicts the start of the first visual token, that is why in the logits mask, you see that it is off by one
[T] [T] [I] [I] [I] [I] [I]
[a] [t] [4] [1] [2] [3] [eos]
<bottom tokens predict top>
[c] [a] [t] [4] [1] [2] [3] As for your second question, we are always trying to predict the next token. Even without the masking, the attention net will eventually figure it out. This just makes it a bit cleaner, and makes sure during sampling we don't hit the wrong token. |
The EOS token is strictly not necessary if you assume fixed length (256 + 1024), but I just wanted to give the last token something to predict that makes sense |
@CDitzel it is also a possibility that DALL-E was trained with full attention on the text tokens, in which case, I may simplify this in the future so that text tokens are not included in the logit space at all :) |
@CDitzel So after I wrote up all that, I realized that perhaps my implementation was needlessly complex I decided to switch over from using an EOS to having a BOS (assuming 0 as padding and as BOS) #22 This also gets rid of some off-by-one confusion (but also introduces some tensor slicing here and there) let me know if that makes more sense! |
thank you for your effort and time Phil. I will have a look tomorrow and get back to you. Again so many thanks! |
Hi Phil, so after carefully going through your explanations and the changes you recently merged, things start to become cleaner I guess. I still hope you dont mind me asking a couple of further questions.
These again are a lot of question, but it is very rare to find knowledgeable people that also helpful, so I decided to take my chance. Thank you so much in advance! |
Correct!
Yup, we are assuming a fixed length generation, so once we hit the last image token, we stop trying to predict the next one
Yup correct
So in this specific setup, because text tokens precede image tokens, the image tokens can attend to all the text tokens (but not the other way around). However, one can imagine a future system where you don't have such restrictions and just mix all tokens from all modalities together in any order
I believe the codebook is pretrained only in the VAE, and the DALL-E trains its own embeddings for the visual tokens
The VAE pre-training should encourage the encoder to discretize the image to unique codebook entries across the latent feature map.
Yup, it's exactly like GPT, but for text and image tokens. Nothing complicated
So this is a mistake on my part, I'll remove this restriction. I had thought perhaps there was a way to share codebook embeddings between DALL-E and VAE, but I don't think that would work
No problem! I'm learning as I go as well, so these questions are helpful for me to think out loud |
thank you once again for your answers and the possibility to discuss matters here.
mh but right now, imho the codebook is also adjusted during DALL-E training...
I cannot follow. The attention of the transformer receives the input which is text and image tokens, concatenated along the token dimension. But since the mask during training has TRUE set everywhere, this is full fletched attention from every token to every other, isnt it? maybe I didnt understand properly. Another thing which I dont get is, why there are two BOS tokens prepended. Once in the generate_images function DALLE-pytorch/dalle_pytorch/dalle_pytorch.py Line 340 in 6a50564
and then again here DALLE-pytorch/dalle_pytorch/dalle_pytorch.py Line 379 in 6a50564
because of this DALLE-pytorch/dalle_pytorch/dalle_pytorch.py Line 400 in 6a50564
doesnt this cause the first image token to be predicted not by the last but by the second to the last text token? |
yea, someone else actually brought that up. I don't believe so, because if you read the iGPT paper, they clustered the pixel space into 512 values and then simply retrained on those 512 values as unique embeddings, and it still worked. however, I have a branch in this repository named 'end-to-end' that contains what you are describing and you are free to try it out
so the attention is only from future to past because the causal flag is turned on https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/transformer.py#L86 https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py#L294
that's a bug on my part, fixed in the latest commit! 🙏 |
I am wondering, if instead of text one has another image modality, say for example the left image of a pair of stereo cameras where the right image has been used to train the VAE, how would one go about using this in DALL-E? According to the discussion section of this repo, the camera image has to be tokenized. I am contemplating whether it makes more sense to use another VAE for the second stream of images and rely on its resulting codebook indices or if it is more reasonable to use e.g. a ViT prior to token concatenation and feeding into the main transformer of DALL-E? Maybe even a simple trainable ViT Embedding layer within the forward pass of DALL-E before the concatenation process suffices? I am just spitballing here and would be grateful for yours or anyone else's take on this |
It seems like dalle-pytorch has used BPE based on individual letters, rather than using GPT-3 or BART encoder. This issue helped out a lot! Thanks 🤗 |
Been trying to get a grasp of the DALLE code recently. However, there are a couple of things, I cant quite wrap my head around and since the paper is not published yet, I was wondering, if we can maybe clarify them here.
So there is the VAE training which basically features the codebook in the bottleneck and is trained a priori.
Next, Dalle receives text and image pairs, embeds them and adds positional encodings individually to both modalities.
However, the image data is not embedded like e.g. in ViT but by downsampling it via the Encoder of the VAE (without accumulating gradients), argmax search within the feature dimension across the downsampled image patches and finally indexing into the previously trained codebook.
The resulting representations of both modalities are then concatenated along the token dimension. And while every word of the text input is one token, the height and width of the VAE-encoded image yields the number of image tokens.
The combined embedding is then passed into a single transformer which calculates self-attention not only intra-modal but also across both modalities if I am not mistaken.
A masking of the form
mask = torch.ones_like(text).bool()
results in unmasked attention calculation, right?
A final Mlp maps the transformer output to all potential token possibilities (both text and image).
Then I dont understand the masking
shouldnt there be one more row concerned with the text input and one less row for the image input?
For the following config with 3 text input tokens
the mask looks like this
shouldt it be?
The purpose of the masking is so that image tokens dont contribute to the predictions of text and vice versa.
The code proceeds by constructing labels from the text integer tokens and the VAE image embedding pixels by using the codebook indices.
But what is it we are actually trying to predict with this classification task here?
It is a 2d CrossEntropyLoss where for each token (either text or image) we are trying to predict ... exactly what?
Some I am missing the intuition here I guess...
And then, why is the label vector neglecting the very first label entry but using the EOS enty?
Maybe someone can help me (and others) in understanding better whats going on here. Thank you in advance
The text was updated successfully, but these errors were encountered: