-
Hi all! I have been trying to visualize the attention map for ViT, as shown in the paper (Fig. 6 and Fig. 13) There's some alternative implementations, made by lucidrains and jeonsworld. They made it possible by having either a new Recorder class or the Attention class returns both attention outputs and weights. I prefer to use timm library, since it allows me to create model seamlessly. Therefore, I start to code to visualize the attention maps. I created a class, references to lucidrains's recorder class by editing the necessary parameters used to fit timm's. Attached is the class. from timm.models.vision_transformer import Attention
def find_modules(nn_module, type):
return [module for module in nn_module.modules() if isinstance(module, type)]
class Recorder(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.data = None
self.recordings = []
self.hooks = []
self.hook_registered = False
self.ejected = False
def _hook(self, _, input, output):
self.recordings.append(output.clone().detach())
def _register_hook(self):
modules = find_modules(self.model.model.blocks, Attention)
for module in modules:
handle = nn.Softmax(dim=-1).register_forward_hook(self._hook)
self.hooks.append(handle)
self.hook_registered = True
def eject(self):
self.ejected = True
for hook in self.hooks:
hook.remove()
self.hooks.clear()
return self.model
def clear(self):
self.recordings.clear()
def record(self, attn):
recording = attn.clone().detach()
self.recordings.append(recording)
def forward(self, img):
assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
self.clear()
if not self.hook_registered:
self._register_hook()
pred = self.model(img)
attns = self.recordings
return pred, attns But when I run the following: r_model = Recorder(model) #model is from the timm's create_model class
r_model(x[None,:,:,:].cuda()) returns (tensor([[ 2.0708e+00, 7.3846e+00, -1.6603e+00, -1.9909e+00, -3.3746e+00,
-2.2779e+00, -3.9599e-01, 1.9346e+00, -1.4222e-01, -6.1785e-03]],
device='cuda:0', grad_fn=<AddmmBackward>),
[]) I can get the output predictions, but not the attention values. Any pointers for me to get the attention map working? Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Hey @khvmaths, did you manage to get your Recorder working ? Thanks |
Beta Was this translation helpful? Give feedback.
-
Closing this discussion, there's an issue for this by zlapp over at #292. The Gist is available here |
Beta Was this translation helpful? Give feedback.
Closing this discussion, there's an issue for this by zlapp over at #292. The Gist is available here