Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could you explain the code of Mask Class Token in the project? #2

Open
FYY799 opened this issue Nov 18, 2023 · 1 comment
Open

Could you explain the code of Mask Class Token in the project? #2

FYY799 opened this issue Nov 18, 2023 · 1 comment

Comments

@FYY799
Copy link

FYY799 commented Nov 18, 2023

Could you explain the code of Mask Class Token in the project?

@zh-ding
Copy link
Contributor

zh-ding commented Nov 29, 2023

There are two main parts of the Mask Class Token. First one is that we add mask class token (by repeating the class token) to CLIP

x = x.permute(1, 0, 2) # NLD -> LND
cls_embed = x[0:1]
cls_embed = cls_embed.repeat(q, 1, 1)
x = torch.cat([cls_embed, x], dim=0)

Second one is that we use the segmentation masks to serve as attention mask in the self-attention layer

def forward(self, x: torch.Tensor, masks: torch.Tensor, masks_embed: torch.Tensor = None):
l, b, d = x.shape
_, q, _, _ = masks.shape
masks = (masks.sigmoid() >= 0.5).float()
masks = F.max_pool2d(masks, self.clip_patch_size).flatten(2)
attn_mask = torch.empty((b, l, l), device=x.device, dtype=torch.bool)
attn_mask[:, :, :] = False
attn_mask[:, :, :q] = True
attn_mask[:, :q, q+1:] = masks == 0.
attn_mask = torch.repeat_interleave(attn_mask, self.n_head, dim=0)
x_res, masks_res = self.attention(self.ln_1(x), attn_mask=attn_mask, masks_embed=masks_embed)
x = x + x_res
x = x + self.mlp(self.ln_2(x))
return x, masks_res
.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants