Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to visualize attention map #1

Open
piantic opened this issue Dec 4, 2020 · 5 comments
Open

How to visualize attention map #1

piantic opened this issue Dec 4, 2020 · 5 comments

Comments

@piantic
Copy link

piantic commented Dec 4, 2020

Hi,

I want to visualize attention map.
I found https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb

In this repo, I did not found vis option for attention map.
(If any, please let me know and I'd appreciate it.)

So, I decided to add this to model.py.
like this:

# In VisionTransformer
def forward(self, x):
    feat, attn_weights = self.extract_features(x)

    # classifier
    logits = self.classifier(feat[:, 0])
    return logits, attn_weights
# In Encoder
def forward(self, x):
    attn_weights = []
    out = self.pos_embedding(x)

    for layer in self.encoder_layers:
        out, weights = layer(out)
        attn_weights.append(weights)

    out = self.norm(out)
    return out, attn_weights
# In SelfAttention
def forward(self, x):
    b, n, _ = x.shape

    q = self.query(x, dims=([2], [0]))
    k = self.key(x, dims=([2], [0]))
    v = self.value(x, dims=([2], [0]))

    q = q.permute(0, 2, 1, 3)
    k = k.permute(0, 2, 1, 3)
    v = v.permute(0, 2, 1, 3)

    attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
    attn_weights = F.softmax(attn_weights, dim=-1)
    out = torch.matmul(attn_weights, v)
    out = out.permute(0, 2, 1, 3)

    out = self.out(out, dims=([2, 3], [0, 1]))

    return out, attn_weights

And I got the result.

image

But I don't know that it is right or not.
Because the result of attention map above link is quite different for me.
(I used pretrained weights in here).

image

I am not sure if my results are correct.
I would be happy if I could hear the answer.

Thanks.

@tczhangzhi
Copy link
Owner

Looks good to me but one thing you should pay attention to is that vit-model-1 is finetuned on the cassava-leaf-disease-classification task. You may expect to visualize an image from that dataset. It is quite different from object classification and focuses on the low-level texture of the input leaf. To visualize the attention map of a dog, you can utilize pre-trained models here.

Anyway, it is a good first try. I'm still hesitating about the operation of extracting the "attention Map" since I don't want it to affect the inference process, that is, to modify the forward function. Maybe later I will check some best practices about hooks. If u r willing to, u can make a PR of your implement.

@piantic
Copy link
Author

piantic commented Dec 5, 2020

Thanks for answer.
I used your recommended pre-trained models.

Here is result for a dog.

image

Original attention map in repo for https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb is below:

image

It seems to be something, but I'm not sure.
What do you think about this part?

If you think this part is okay,
It seems that a simple flag-vis can minimize the influence of inference.

If vis is False,
Model is working original forward function.

@cifkao
Copy link

cifkao commented Mar 3, 2022

For people still looking for a solution, my package NoPdb allows capturing attention weights from pretty much any Transformer implementation without any modifications to the code. See a Colab notebook showing how to do this for ViT (a different implementation).

In this case, it would be something like:

with nopdb.capture_calls(SelfAttention.forward) as calls:
    logits = model(x)

calls[0].locals["attn_weights"]  # attention weights of the first layer

@Suryanshg
Copy link

Hi, when I try to implement the changes by @piantic, this is the error I am getting:

Traceback (most recent call last):
File "C:\Users\Surya\Desktop\Automatic-Pain-Estimation-MQP\scripts\Visualize_Attention_Map.py", line 96, in
result_img = get_attention_map(viz_image)
File "C:\Users\Surya\Desktop\Automatic-Pain-Estimation-MQP\scripts\Visualize_Attention_Map.py", line 25, in get_attention_map
logits, att_mat = model(x.unsqueeze(0))
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\vision_transformer_pytorch\model.py", line 269, in forward
feat, attn_weights = self.extract_features(x)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\vision_transformer_pytorch\model.py", line 265, in extract_features
feat = self.transformer(emb)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\vision_transformer_pytorch\model.py", line 177, in forward
out, weights = layer(out)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\vision_transformer_pytorch\model.py", line 139, in forward
out = self.dropout(out)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\dropout.py", line 58, in forward
return F.dropout(input, self.p, self.training, self.inplace)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\functional.py", line 1169, in dropout
return VF.dropout(input, p, training) if inplace else _VF.dropout(input, p, training)
TypeError: dropout(): argument 'input' (position 1) must be Tensor, not tuple

Is there anything else I need to do? I feel that there might be some change that needs to be made in the EncoderBlock part of the model.py file

@piantic
Copy link
Author

piantic commented Apr 6, 2022

Hi, @Suryanshg.

This is my example notebook for visualizing attention map using this github.
https://www.kaggle.com/code/piantic/vision-transformer-vit-visualize-attention-map/notebook

And you can see visualized version of ViT in below link.
https://www.kaggle.com/datasets/piantic/visiontransformerpytorch121

I hope this helps you.
Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants