Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code for DeiT #11

Open
yueatsprograms opened this issue Sep 22, 2023 · 1 comment
Open

Code for DeiT #11

yueatsprograms opened this issue Sep 22, 2023 · 1 comment

Comments

@yueatsprograms
Copy link

Thank you for publishing your code. I saw the DeiT ablation in your paper. Is there a chance you could also provide code to reproduce that? If you'd prefer to contact me in private, my email is [email protected]

Thanks again!

@VictorLlu
Copy link
Collaborator

You can just use the following block to replace pyramid block strcuture

self.block = nn.ModuleList([SoftmaxFreeTransformerBlock(
            dim=embed_dim, num_heads=num_head, drop_path=dpr[i], ratio=sr_ratios[0], conv_size=9,
            max_iter=newton_max_iter, kernel_method=kernel_method)
            for i in range(depths)])

and use the average pooling before classification head
or, you can

Q_landmarks = Q[:, :, 1:, :].reshape(b * nhead, H * W, 
                        headdim).reshape(b * nhead, 
                                               H, W, headdim).permute(0, 3, 1, 2)
Q_landmarks = self.Qlandmark_op(Q_landmarks)
Q_landmarks = Q_landmarks.flatten(2).transpose(1, 2).reshape(b, nhead, -1, headdim)
Q_landmarks = self.Qnorm_act(Q_landmarks)
K_landmarks = Q_landmarks

kernel_1_ = self.kernel_function(Q, K_landmarks.transpose(-1, -2).contiguous())
kernel_1_ = torch.exp(-kernel_1_/2)

kernel_2_ = self.kernel_function(Q_landmarks, K_landmarks.transpose(-1, -2).contiguous())
kernel_2_ = torch.exp(-kernel_2_/2)

kernel_3_ = kernel_1_.transpose(-1, -2)

X = torch.matmul(torch.matmul(kernel_1_, newton_inverse_kernel(kernel_2_, self.max_iter)), torch.matmul(kernel_3_, V))

if self.use_conv:
            V_ = V[:,:,1:,:]
            cls_token = V[:,:,0,:].unsqueeze(2)
            V_ = V_.reshape(b, nhead, h, w, headdim)
            V_ = V_.permute(0, 4, 1, 2, 3).reshape(b*headdim, nhead, h, w)
            out = self.conv(V_).reshape(b, headdim, nhead, h, w).flatten(3).permute(0, 2, 3, 1)
            out = torch.cat([cls_token, out], dim=2)
            X += out

in forward of SoftmaxFreeAttentionKernel class to handle cls token.
The training strategy is exactly the same as SOFT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants