From 61afe9e25756a4145c22a74edac2ff69ed92095d Mon Sep 17 00:00:00 2001 From: soniajoseph Date: Thu, 8 Feb 2024 17:59:09 -0500 Subject: [PATCH] refactor attn for real this time back to timm code --- src/vit_prisma/models/layers/attention.py | 236 +++++---- .../models/layers/transformer_block.py | 34 +- .../prisma/loading_from_pretrained.py | 55 ++- tests/test_notebooks/timm_comparison.ipynb | 463 +++++++++++------- 4 files changed, 463 insertions(+), 325 deletions(-) diff --git a/src/vit_prisma/models/layers/attention.py b/src/vit_prisma/models/layers/attention.py index bda513d3..6a5beb89 100644 --- a/src/vit_prisma/models/layers/attention.py +++ b/src/vit_prisma/models/layers/attention.py @@ -1,3 +1,7 @@ +""" +Attention code from timm model +""" + import torch.nn as nn import torch @@ -33,51 +37,54 @@ def __init__( self.cfg = cfg - # Initialize parameters - self.W_Q = nn.Parameter( - torch.empty( - self.cfg.n_heads, - self.cfg.d_model, - self.cfg.d_head, - dtype = self.cfg.dtype - ) - ) - self.W_K = nn.Parameter( - torch.empty( - self.cfg.n_heads, - self.cfg.d_model, - self.cfg.d_head, - dtype = self.cfg.dtype - ) - ) - self.W_V = nn.Parameter( - torch.empty( - self.cfg.n_heads, - self.cfg.d_model, - self.cfg.d_head, - dtype = self.cfg.dtype - ) - ) - self.W_O = nn.Parameter( - torch.empty( - self.cfg.n_heads, - self.cfg.d_head, - self.cfg.d_model, - dtype = self.cfg.dtype - ) - ) - - # Initialize biases - self.b_Q = nn.Parameter( - torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) - ) - self.b_K = nn.Parameter( - torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) - ) - self.b_V = nn.Parameter( - torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) - ) - self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) + self.qkv = nn.Linear(self.cfg.d_model, self.cfg.d_model * 3, bias=True) + self.proj = nn.Linear(self.cfg.d_model, self.cfg.d_model) + + # # Initialize parameters + # self.W_Q = nn.Parameter( + # torch.empty( + # self.cfg.n_heads, + # self.cfg.d_model, + # self.cfg.d_head, + # dtype = self.cfg.dtype + # ) + # ) + # self.W_K = nn.Parameter( + # torch.empty( + # self.cfg.n_heads, + # self.cfg.d_model, + # self.cfg.d_head, + # dtype = self.cfg.dtype + # ) + # ) + # self.W_V = nn.Parameter( + # torch.empty( + # self.cfg.n_heads, + # self.cfg.d_model, + # self.cfg.d_head, + # dtype = self.cfg.dtype + # ) + # ) + # self.W_O = nn.Parameter( + # torch.empty( + # self.cfg.n_heads, + # self.cfg.d_head, + # self.cfg.d_model, + # dtype = self.cfg.dtype + # ) + # ) + + # # Initialize biases + # self.b_Q = nn.Parameter( + # torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) + # ) + # self.b_K = nn.Parameter( + # torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) + # ) + # self.b_V = nn.Parameter( + # torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) + # ) + # self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) # Add hook points @@ -87,7 +94,9 @@ def __init__( self.hook_z = HookPoint() # [batch, pos, head_index, d_head] self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] - self.hook_result = HookPoint() # [batch, pos, head_index, d_model] + # self.hook_result = HookPoint() # [batch, pos, head_index, d_model] + + self.hook_qkv = HookPoint() # [batch, pos, 3, head_index, d_head] self.layer_id = layer_id @@ -126,62 +135,80 @@ def QK(self) -> FactoredMatrix: def forward( self, - query_input: Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], - ], - key_input: Union[ + x: Union[ Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], - ], - value_input: Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], + # Float[torch.Tensor, "batch pos head_index d_model"], ] ) -> Float[torch.Tensor, "batch pos d_model"]: - q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) + # Calculate QKV + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.cfg.n_heads, self.cfg.d_head).permute(2, 0, 3, 1, 4) + + qkv = self.hook_qkv(qkv) + + q, k, v = qkv.unbind(0) + + q = self.hook_q(q) + k = self.hook_k(k) + v = self.hook_v(v) + + # q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) - attn_scores = self.calculate_attn_scores(q, k) + # Attention Scores + q = q * self.scale + attn_scores = q @ k.transpose(-2, -1) attn_scores = self.hook_attn_scores(attn_scores) - pattern = F.softmax(attn_scores, dim=-1) # where do I do normalization? - pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) - pattern = self.hook_pattern(pattern) - - pattern = pattern.to(self.cfg.dtype) - z = self.calculate_z_scores(v, pattern) - - if not self.cfg.use_attn_result: - out = ( - ( - einsum( - "batch pos head_index d_head, \ - head_index d_head d_model -> \ - batch pos d_model", - z, - self.W_O, - ) - ) - + self.b_O - ) - else: - # Explicitly calculate the attention result so it can be accessed by a hook. - # Off by default to not eat through GPU memory. - result = self.hook_result( - einsum( - "batch pos head_index d_head, \ - head_index d_head d_model -> \ - batch pos head_index d_model", - z, - self.W_O, - ) - ) - out = ( - einops.reduce(result, "batch pos head_index d_model -> batch position d_model", "sum") - + self.b_O - ) - return out + # Attention Pattern + attn_pattern = attn_scores.softmax(dim=-1) + attn_pattern = torch.where(torch.isnan(attn_pattern), torch.zeros_like(attn_pattern), attn_pattern) + attn_pattern = self.hook_pattern(attn_pattern) + attn_pattern = attn_pattern.to(self.cfg.dtype) + + # Value matrix + z = attn_pattern @ v + z = self.hook_z(z) + + # Output projection layer + z = z.transpose(1,2).reshape(B,N,C) + output = self.proj(z) + # output = self.hook_result(output) + + return output + + # z = self.calculate_z_scores(v, pattern) + + # if not self.cfg.use_attn_result: + # out = ( + # ( + # einsum( + # "batch pos head_index d_head, \ + # head_index d_head d_model -> \ + # batch pos d_model", + # z, + # self.W_O, + # ) + # ) + # + self.b_O + # ) + # else: + # # Explicitly calculate the attention result so it can be accessed by a hook. + # # Off by default to not eat through GPU memory. + # result = self.hook_result( + # einsum( + # "batch pos head_index d_head, \ + # head_index d_head d_model -> \ + # batch pos head_index d_model", + # z, + # self.W_O, + # ) + # ) + # out = ( + # einops.reduce(result, "batch pos head_index d_model -> batch position d_model", "sum") + # + self.b_O + # ) + # return out def calculate_qkv_matrices( self, @@ -213,16 +240,15 @@ def calculate_qkv_matrices( else: qkv_einops_string = "batch pos d_model" - - q = self.hook_q( - einsum( + q = einsum( f"{qkv_einops_string}, head_index d_model d_head \ -> batch pos head_index d_head", query_input, self.W_Q, - ) - + self.b_Q - ) # [batch, pos, head_index, d_head] + ) + self.b_Q + # [batch, pos, head_index, d_head] + q = self.hook_q(q.transpose(-1,-2)) # [batch, pos, d_head, head_index]; to match timm + k = self.hook_k( einsum( f"{qkv_einops_string}, head_index d_model d_head \ @@ -232,6 +258,8 @@ def calculate_qkv_matrices( ) + self.b_K ) # [batch, pos, head_index, d_head] + k = self.hook_k(k.transpose(-1,-2)) # [batch, pos, d_head, head_index]; to match timm + v = self.hook_v( einsum( f"{qkv_einops_string}, head_index d_model d_head \ @@ -241,6 +269,7 @@ def calculate_qkv_matrices( ) + self.b_V ) # [batch, pos, head_index, d_head] + v = self.hook_v(v.transpose(-1,-2)) return q, k, v def calculate_attn_scores( @@ -255,10 +284,11 @@ def calculate_attn_scores( """ q = q * self.scale attn_scores = einsum( - "batch query_pos head_index d_head, batch key_pos head_index d_head -> batch head_index query_pos key_pos", + "batch query_pos d_head head_index, batch key_pos d_head -> batch head_index query_pos key_pos", q, k, ) + return attn_scores def calculate_z_scores( @@ -266,6 +296,8 @@ def calculate_z_scores( v: Float[torch.Tensor, "batch key_pos head_index d_head"], pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: + + z = self.hook_z( einsum( "batch key_pos head_index d_head, \ diff --git a/src/vit_prisma/models/layers/transformer_block.py b/src/vit_prisma/models/layers/transformer_block.py index 8dae8b82..72cace19 100644 --- a/src/vit_prisma/models/layers/transformer_block.py +++ b/src/vit_prisma/models/layers/transformer_block.py @@ -85,35 +85,33 @@ def forward( resid_pre = self.hook_resid_pre(resid_pre) - if self.cfg.use_attn_in or self.cfg.use_split_qkv_input: - # We're adding a head dimension - attn_in = add_head_dimension(resid_pre, self.cfg.n_heads, clone_tensor=False) - else: - attn_in = resid_pre + # if self.cfg.use_attn_in or self.cfg.use_split_qkv_input: + # # We're adding a head dimension + # attn_in = add_head_dimension(resid_pre, self.cfg.n_heads, clone_tensor=False) + # else: + attn_in = resid_pre if self.cfg.use_attn_in: attn_in = self.hook_attn_in(attn_in.clone()) - - if self.cfg.use_split_qkv_input: - query_input = self.hook_q_input(attn_in.clone()) - key_input = self.hook_k_input(attn_in.clone()) - value_input = self.hook_v_input(attn_in.clone()) - else: - query_input = attn_in - key_input = attn_in - value_input = attn_in + # if self.cfg.use_split_qkv_input: + # query_input = self.hook_q_input(attn_in.clone()) + # key_input = self.hook_k_input(attn_in.clone()) + # value_input = self.hook_v_input(attn_in.clone()) + # else: + # query_input = attn_in + # key_input = attn_in + # value_input = attn_in # attn_out = self.attn( # query_input = query_input, # key_input = key_input, # value_input = value_input, # ) - + + # Update: moving to timm setup for accuracy; add function to split qkv later. attn_out = self.attn( - query_input = self.ln1(query_input), - key_input = self.ln1(key_input), - value_input = self.ln1(value_input), + self.ln1(attn_in) ) attn_out = self.hook_attn_out( diff --git a/src/vit_prisma/prisma/loading_from_pretrained.py b/src/vit_prisma/prisma/loading_from_pretrained.py index 1738e9c3..ac401e52 100644 --- a/src/vit_prisma/prisma/loading_from_pretrained.py +++ b/src/vit_prisma/prisma/loading_from_pretrained.py @@ -46,30 +46,37 @@ def convert_timm_weights( new_state_dict[f"{layer_key}.ln2.b"] = old_state_dict[f"{layer_key}.norm2.bias"] W = old_state_dict[f"{layer_key}.attn.qkv.weight"] - W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) - W_Q = einops.rearrange(W_Q, "(i h) m->h m i", h=cfg.n_heads) - W_K = einops.rearrange(W_K, "(i h) m->h m i", h=cfg.n_heads) - W_V = einops.rearrange(W_V, "(i h) m->h m i", h=cfg.n_heads) - new_state_dict[f"{layer_key}.attn.W_Q"] = W_Q - new_state_dict[f"{layer_key}.attn.W_K"] = W_K - new_state_dict[f"{layer_key}.attn.W_V"] = W_V - - W_O = old_state_dict[f"{layer_key}.attn.proj.weight"] - W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) - new_state_dict[f"{layer_key}.attn.W_O"] = W_O - - attn_bias = old_state_dict[f"{layer_key}.attn.qkv.bias"] - b_Q, b_K, b_V = torch.tensor_split(attn_bias, 3, dim=0) - b_Q = einops.rearrange(b_Q, "(i h) -> h i", h=cfg.n_heads) - b_K = einops.rearrange(b_K, "(i h) -> h i", h=cfg.n_heads) - b_V = einops.rearrange(b_V, "(i h) -> h i", h=cfg.n_heads) - new_state_dict[f"{layer_key}.attn.b_Q"] = b_Q - new_state_dict[f"{layer_key}.attn.b_K"] = b_K - new_state_dict[f"{layer_key}.attn.b_V"] = b_V - - b_O = old_state_dict[f"{layer_key}.attn.proj.bias"] - b_O = einops.rearrange(b_O, "m -> m") - new_state_dict[f"{layer_key}.attn.b_O"] = b_O + new_state_dict[f"{layer_key}.attn.qkv.weight"] = old_state_dict[f"{layer_key}.attn.qkv.weight"] + new_state_dict[f"{layer_key}.attn.qkv.bias"] = old_state_dict[f"{layer_key}.attn.qkv.bias"] + + new_state_dict[f"{layer_key}.attn.proj.weight"] = old_state_dict[f"{layer_key}.attn.proj.weight"] + new_state_dict[f"{layer_key}.attn.proj.bias"] = old_state_dict[f"{layer_key}.attn.proj.bias"] + + + # W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) + # W_Q = einops.rearrange(W_Q, "(i h) m->h m i", h=cfg.n_heads) + # W_K = einops.rearrange(W_K, "(i h) m->h m i", h=cfg.n_heads) + # W_V = einops.rearrange(W_V, "(i h) m->h m i", h=cfg.n_heads) + # new_state_dict[f"{layer_key}.attn.W_Q"] = W_Q + # new_state_dict[f"{layer_key}.attn.W_K"] = W_K + # new_state_dict[f"{layer_key}.attn.W_V"] = W_V + + # W_O = old_state_dict[f"{layer_key}.attn.proj.weight"] + # W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) + # new_state_dict[f"{layer_key}.attn.W_O"] = W_O + + # attn_bias = old_state_dict[f"{layer_key}.attn.qkv.bias"] + # b_Q, b_K, b_V = torch.tensor_split(attn_bias, 3, dim=0) + # b_Q = einops.rearrange(b_Q, "(i h) -> h i", h=cfg.n_heads) + # b_K = einops.rearrange(b_K, "(i h) -> h i", h=cfg.n_heads) + # b_V = einops.rearrange(b_V, "(i h) -> h i", h=cfg.n_heads) + # new_state_dict[f"{layer_key}.attn.b_Q"] = b_Q + # new_state_dict[f"{layer_key}.attn.b_K"] = b_K + # new_state_dict[f"{layer_key}.attn.b_V"] = b_V + + # b_O = old_state_dict[f"{layer_key}.attn.proj.bias"] + # b_O = einops.rearrange(b_O, "m -> m") + # new_state_dict[f"{layer_key}.attn.b_O"] = b_O new_state_dict[f"{layer_key}.mlp.b_in"] = old_state_dict[f"{layer_key}.mlp.fc1.bias"] new_state_dict[f"{layer_key}.mlp.b_out"] = old_state_dict[f"{layer_key}.mlp.fc2.bias"] diff --git a/tests/test_notebooks/timm_comparison.ipynb b/tests/test_notebooks/timm_comparison.ipynb index ffef9cff..48847cc6 100644 --- a/tests/test_notebooks/timm_comparison.ipynb +++ b/tests/test_notebooks/timm_comparison.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -39,21 +39,42 @@ " fold_ln=False, \n", " fold_value_biases=False,\n", " use_attn_scale=False,\n", - " use_split_qkv_input=True,\n", + " use_attn_in=True,\n", ")\n", "timm_model = timm.create_model('vit_base_patch16_224', pretrained=True)\n" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Files already downloaded and verified\n" + "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/170498071 [00:00 batch pos n_heads d_model\",\n", + "# n_heads=12,\n", + "# )\n", + "\n", + "# print(repeated_tensor[0].shape)\n", + "print(cache['blocks.0.hook_attn_in'].shape)\n", + "\n", + "assert torch.allclose(activations[0], cache['blocks.0.hook_attn_in'][0], atol=1), \"Activations differ more than the allowed tolerance\"" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 197, 768])\n", + "torch.Size([1, 197, 768])\n" + ] + } + ], + "source": [ "\n", + "import einops \n", "\n", "activations = []\n", "def hook_fn(module, input, output):\n", @@ -393,6 +482,8 @@ "timm_output = timm_model(image)\n", "hook_handle.remove()\n", "\n", + "print(activations[0].shape)\n", + "print(cache['blocks.0.ln1.hook_normalized'].shape)\n", "\n", "# Assert equal to the first layer\n", "assert torch.allclose(activations[0], cache['blocks.0.ln1.hook_normalized'][0], atol=1e-6), \"Activations differ more than the allowed tolerance\"" @@ -402,69 +493,103 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Attention" + "### Attention\n", + "\n", + "**Weights**" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ - "# let's compare qkv weights\n", - "QKV = timm_model.blocks[0].attn.qkv.weight\n", - "W_Q, W_K, W_V = torch.tensor_split(QKV, 3, dim=0)\n", - "t_Q = einops.rearrange(W_Q, \"(i h) m->h m i\", h=12)\n", - "t_K = einops.rearrange(W_K, \"(i h) m->h m i\", h=12)\n", - "t_V = einops.rearrange(W_V, \"(i h) m->h m i\", h=12)\n", + "# # let's compare qkv weights\n", + "# QKV = timm_model.blocks[0].attn.qkv.weight\n", + "# W_Q, W_K, W_V = torch.tensor_split(QKV, 3, dim=0)\n", + "# t_Q = einops.rearrange(W_Q, \"(i h) m->h m i\", h=12)\n", + "# t_K = einops.rearrange(W_K, \"(i h) m->h m i\", h=12)\n", + "# t_V = einops.rearrange(W_V, \"(i h) m->h m i\", h=12)\n", "\n", - "p_Q = prisma_model.blocks[0].attn.W_Q\n", - "p_K = prisma_model.blocks[0].attn.W_K\n", - "p_V = prisma_model.blocks[0].attn.W_V\n", + "# p_Q = prisma_model.blocks[0].attn.W_Q\n", + "# p_K = prisma_model.blocks[0].attn.W_K\n", + "# p_V = prisma_model.blocks[0].attn.W_V\n", "\n", - "assert torch.allclose(p_Q, t_Q, atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", - "assert torch.allclose(p_K, t_K, atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", - "assert torch.allclose(p_V, t_V, atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", + "# assert torch.allclose(p_Q, t_Q, atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", + "# assert torch.allclose(p_K, t_K, atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", + "# assert torch.allclose(p_V, t_V, atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", "\n", - "# qkv bias\n", - "bias_QKV = timm_model.blocks[0].attn.qkv.bias\n", + "# # qkv bias\n", + "# bias_QKV = timm_model.blocks[0].attn.qkv.bias\n", "\n", - "b_Q, b_K, b_V = torch.tensor_split(bias_QKV, 3, dim=0)\n", + "# b_Q, b_K, b_V = torch.tensor_split(bias_QKV, 3, dim=0)\n", "\n", - "bt_Q = einops.rearrange(b_Q, \"(i h) -> h i\", h=12)\n", - "bt_K = einops.rearrange(b_K, \"(i h) -> h i\", h=12)\n", - "bt_V = einops.rearrange(b_V, \"(i h) -> h i\", h=12)\n", + "# bt_Q = einops.rearrange(b_Q, \"(i h) -> h i\", h=12)\n", + "# bt_K = einops.rearrange(b_K, \"(i h) -> h i\", h=12)\n", + "# bt_V = einops.rearrange(b_V, \"(i h) -> h i\", h=12)\n", "\n", - "bp_Q = prisma_model.blocks[0].attn.b_Q\n", - "bp_K = prisma_model.blocks[0].attn.b_K\n", - "bp_V = prisma_model.blocks[0].attn.b_V\n", + "# bp_Q = prisma_model.blocks[0].attn.b_Q\n", + "# bp_K = prisma_model.blocks[0].attn.b_K\n", + "# bp_V = prisma_model.blocks[0].attn.b_V\n", "\n", - "assert torch.allclose(bp_Q, bt_Q, atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", - "assert torch.allclose(bp_K, bt_K, atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", - "assert torch.allclose(bp_V, bt_V, atol=1e-6), \"Activations differ more than the allowed tolerance\"" + "# assert torch.allclose(bp_Q, bt_Q, atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", + "# assert torch.allclose(bp_K, bt_K, atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", + "# assert torch.allclose(bp_V, bt_V, atol=1e-6), \"Activations differ more than the allowed tolerance\"" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "def print_matrix_corner(matrix, rows=1, cols=1):\n", + " \"\"\"\n", + " Prints the top-left corner of a matrix (tensor) up to the specified number of rows and columns.\n", + "\n", + " Parameters:\n", + " - matrix (torch.Tensor): The matrix (tensor) from which to print the corner.\n", + " - rows (int): The number of rows to include in the printed corner. Default is 5.\n", + " - cols (int): The number of columns to include in the printed corner. Default is 5.\n", + " \"\"\"\n", + " # Ensure the matrix is a PyTorch tensor\n", + " if not isinstance(matrix, torch.Tensor):\n", + " print(\"The input is not a PyTorch tensor.\")\n", + " return\n", + "\n", + " # Get the size of the matrix\n", + " num_rows, num_cols = matrix.shape[:2]\n", + "\n", + " # Adjust rows and cols if the matrix is smaller than specified dimensions\n", + " rows_to_print = min(rows, num_rows)\n", + " cols_to_print = min(cols, num_cols)\n", + "\n", + " # Slice the matrix to get the top-left corner\n", + " corner = matrix[:rows_to_print, :cols_to_print]\n", + "\n", + " print(f\"Top-left corner ({rows_to_print}x{cols_to_print}):\\n{corner}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**QKV matrix**" + ] + }, + { + "cell_type": "code", + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([12, 197, 64])\n" - ] - }, - { - "ename": "AssertionError", - "evalue": "Activations differ more than the allowed tolerance", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[22], line 13\u001b[0m\n\u001b[1;32m 8\u001b[0m hook_handle\u001b[38;5;241m.\u001b[39mremove()\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28mprint\u001b[39m(activations[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mshape)\n\u001b[0;32m---> 13\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mallclose(activations[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;241m0\u001b[39m], cache[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mblocks.0.attn.hook_q\u001b[39m\u001b[38;5;124m'\u001b[39m][\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m0\u001b[39m,\u001b[38;5;241m2\u001b[39m), atol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-2\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mActivations differ more than the allowed tolerance\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", - "\u001b[0;31mAssertionError\u001b[0m: Activations differ more than the allowed tolerance" + "timm output torch.Size([1, 197, 2304])\n", + "timm shape torch.Size([3, 1, 12, 197, 64])\n", + "prisma shape torch.Size([3, 1, 12, 197, 64])\n", + "prisma q shape torch.Size([1, 12, 197, 64])\n" ] } ], @@ -474,14 +599,31 @@ "def hook_fn(module, input, output):\n", " activations.append(output)\n", "\n", - "hook_handle = timm_model.blocks[0].attn.q_norm.register_forward_hook(hook_fn)\n", + "hook_handle = timm_model.blocks[0].attn.qkv.register_forward_hook(hook_fn)\n", + "\n", "timm_output = timm_model(image)\n", "hook_handle.remove()\n", "\n", - "print(activations[0][0].shape)\n", + "print(\"timm output\", activations[0].shape)\n", + "qkv = activations[0].reshape(-1, 197, 3, 12, 64).permute(2, 0, 3, 1, 4)\n", + "q, k, v = qkv.unbind(0)\n", "\n", + "print(\"timm shape\", qkv.shape)\n", + "print(\"prisma shape\", cache['blocks.0.attn.hook_qkv'].shape)\n", "\n", - "assert torch.allclose(activations[0][0], cache['blocks.0.attn.hook_q'][0].permute(1,0,2), atol=1e-2), \"Activations differ more than the allowed tolerance\"" + "print(\"prisma q shape\", cache['blocks.0.attn.hook_q'].shape)\n", + "\n", + "assert torch.allclose(qkv, cache['blocks.0.attn.hook_qkv'], atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", + "assert torch.allclose(q, cache['blocks.0.attn.hook_q'], atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", + "assert torch.allclose(k, cache['blocks.0.attn.hook_k'], atol=1e-6), \"Activations differ more than the allowed tolerance\"\n", + "assert torch.allclose(v, cache['blocks.0.attn.hook_v'], atol=1e-6), \"Activations differ more than the allowed tolerance\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Attention Scores**" ] }, { @@ -493,58 +635,56 @@ "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([1, 12, 197, 64])\n", - "torch.Size([12, 197, 64])\n" + "timm attn scores torch.Size([1, 12, 197, 197])\n", + "prisma attn scores torch.Size([1, 12, 197, 197])\n" ] } ], "source": [ - " cache['blocks.0.attn.hook_k'][0].shape\n", - "print(activations[0].shape)\n", - "print(cache['blocks.0.attn.hook_k'][0].permute(1,0,2).shape)" + "scaled_q = q * 64 ** -0.5\n", + "timm_attn_scores = scaled_q @ k.transpose(-2,-1)\n", + "\n", + "print(\"timm attn scores\", timm_attn_scores.shape)\n", + "print(\"prisma attn scores\", cache['blocks.0.attn.hook_attn_scores'].shape)\n", + "\n", + "assert torch.allclose(timm_attn_scores, cache['blocks.0.attn.hook_attn_scores'], atol=1e-4), \"Activations differ more than the allowed tolerance\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Attention pattern**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.float32" - ] - }, - "execution_count": 82, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "activations[0].dtype" + "timm_attn_pattern = timm_attn_scores.softmax(dim=-1) \n", + "\n", + "assert torch.allclose(timm_attn_pattern, cache['blocks.0.attn.hook_pattern'], atol=1e-4), \"Activations differ more than the allowed tolerance\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Attention Output**" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([12, 197, 64])\n" - ] - }, - { - "ename": "RuntimeError", - "evalue": "The size of tensor a (768) must match the size of tensor b (64) at non-singleton dimension 2", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[19], line 15\u001b[0m\n\u001b[1;32m 11\u001b[0m q, k, v \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39msplit(activations[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m768\u001b[39m, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 13\u001b[0m hook_handle\u001b[38;5;241m.\u001b[39mremove()\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mallclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcache\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mblocks.0.attn.hook_k\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpermute\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43matol\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-6\u001b[39;49m\u001b[43m)\u001b[49m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mActivations differ more than the allowed tolerance\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", - "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (768) must match the size of tensor b (64) at non-singleton dimension 2" + "torch.Size([1, 197, 768])\n" ] } ], @@ -554,67 +694,28 @@ "def hook_fn(module, input, output):\n", " activations.append(output)\n", "\n", - "hook_handle = timm_model.blocks[0].attn.qkv.register_forward_hook(hook_fn)\n", + "hook_handle = timm_model.blocks[0].attn.proj.register_forward_hook(hook_fn)\n", "timm_output = timm_model(image)\n", + "hook_handle.remove()\n", "\n", - "print(cache['blocks.0.attn.hook_k'][0].permute(1,0,2).shape)\n", - "\n", - "q, k, v = torch.split(activations[0], 768, dim=-1)\n", + "print(activations[0].shape)\n", + "# print(cache['blocks.0.attn.hook_attn_out'].shape)\n", "\n", - "hook_handle.remove()\n", "\n", - "assert torch.allclose(k, cache['blocks.0.attn.hook_k'][0].permute(1,0,2), atol=1e-6), \"Activations differ more than the allowed tolerance\"" + "assert torch.allclose(cache['blocks.0.attn.hook_result'], activations[0], atol=1e-3), \"Activations differ more than the allowed tolerance\"\n" ] }, { - "cell_type": "code", - "execution_count": 20, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "ename": "AssertionError", - "evalue": "Activations differ more than the allowed tolerance", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[20], line 10\u001b[0m\n\u001b[1;32m 7\u001b[0m timm_output \u001b[38;5;241m=\u001b[39m timm_model(image)\n\u001b[1;32m 8\u001b[0m hook_handle\u001b[38;5;241m.\u001b[39mremove()\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mallclose(activations[\u001b[38;5;241m0\u001b[39m], cache[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mblocks.0.hook_attn_out\u001b[39m\u001b[38;5;124m'\u001b[39m][\u001b[38;5;241m0\u001b[39m], atol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-6\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mActivations differ more than the allowed tolerance\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", - "\u001b[0;31mAssertionError\u001b[0m: Activations differ more than the allowed tolerance" - ] - } - ], "source": [ - "# First layer\n", - "activations = []\n", - "def hook_fn(module, input, output):\n", - " activations.append(output)\n", - "\n", - "hook_handle = timm_model.blocks[0].attn.register_forward_hook(hook_fn)\n", - "timm_output = timm_model(image)\n", - "hook_handle.remove()\n", - "\n", - "assert torch.allclose(activations[0], cache['blocks.0.hook_attn_out'][0], atol=1e-6), \"Activations differ more than the allowed tolerance\"" + "## MLP" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "prisma_model.cfg.use_split_qkv_input" - ] + "source": [] }, { "cell_type": "code",