Skip to content

Commit

Permalink
Merge pull request facebookresearch#71 from themachinefan/loading_att…
Browse files Browse the repository at this point in the history
…n_weights_fix

Loading attn weights fix
  • Loading branch information
soniajoseph authored Feb 9, 2024
2 parents 61afe9e + ba4240d commit 9e4a4b4
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 208 deletions.
2 changes: 1 addition & 1 deletion src/vit_prisma/configs/HookedViTConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class HookedViTConfig:
max_grad_norm = 1.0

# Saving related
parent_dir: str = "/Users/praneets/Downloads/working_dir"
parent_dir: str = ""
save_dir: str = 'Checkpoints'
save_checkpoints: bool = True
save_cp_frequency: int = 5
Expand Down
253 changes: 110 additions & 143 deletions src/vit_prisma/models/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
"""
Attention code from timm model
"""

import torch.nn as nn
import torch

Expand Down Expand Up @@ -37,76 +33,70 @@ def __init__(

self.cfg = cfg

self.qkv = nn.Linear(self.cfg.d_model, self.cfg.d_model * 3, bias=True)
self.proj = nn.Linear(self.cfg.d_model, self.cfg.d_model)

# # Initialize parameters
# self.W_Q = nn.Parameter(
# torch.empty(
# self.cfg.n_heads,
# self.cfg.d_model,
# self.cfg.d_head,
# dtype = self.cfg.dtype
# )
# )
# self.W_K = nn.Parameter(
# torch.empty(
# self.cfg.n_heads,
# self.cfg.d_model,
# self.cfg.d_head,
# dtype = self.cfg.dtype
# )
# )
# self.W_V = nn.Parameter(
# torch.empty(
# self.cfg.n_heads,
# self.cfg.d_model,
# self.cfg.d_head,
# dtype = self.cfg.dtype
# )
# )
# self.W_O = nn.Parameter(
# torch.empty(
# self.cfg.n_heads,
# self.cfg.d_head,
# self.cfg.d_model,
# dtype = self.cfg.dtype
# )
# )

# # Initialize biases
# self.b_Q = nn.Parameter(
# torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
# )
# self.b_K = nn.Parameter(
# torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
# )
# self.b_V = nn.Parameter(
# torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
# )
# self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
# Initialize parameters
self.W_Q = nn.Parameter(
torch.empty(
self.cfg.n_heads,
self.cfg.d_model,
self.cfg.d_head,
dtype = self.cfg.dtype
)
)
self.W_K = nn.Parameter(
torch.empty(
self.cfg.n_heads,
self.cfg.d_model,
self.cfg.d_head,
dtype = self.cfg.dtype
)
)
self.W_V = nn.Parameter(
torch.empty(
self.cfg.n_heads,
self.cfg.d_model,
self.cfg.d_head,
dtype = self.cfg.dtype
)
)
self.W_O = nn.Parameter(
torch.empty(
self.cfg.n_heads,
self.cfg.d_head,
self.cfg.d_model,
dtype = self.cfg.dtype
)
)

# Initialize biases
self.b_Q = nn.Parameter(
torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
)
self.b_K = nn.Parameter(
torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
)
self.b_V = nn.Parameter(
torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
)
self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))


# Add hook points
self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]
# self.hook_result = HookPoint() # [batch, pos, head_index, d_model]

self.hook_qkv = HookPoint() # [batch, pos, 3, head_index, d_head]
self.hook_result = HookPoint() # [batch, pos, head_index, d_model]

self.layer_id = layer_id

# # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability?
# if self.cfg.use_attn_scale:
# self.attn_scale = np.sqrt(self.cfg.d_head)
# else:
# self.attn_scale = 1.0

self.scale = self.cfg.d_head ** -0.5
# Note to Sonia: check this.
# attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability?
if self.cfg.use_attn_scale:
self.attn_scale = np.sqrt(self.cfg.d_head)
else:
self.attn_scale = 1.0

@property
def OV(self) -> FactoredMatrix:
Expand Down Expand Up @@ -135,80 +125,62 @@ def QK(self) -> FactoredMatrix:

def forward(
self,
x: Union[
query_input: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
],
key_input: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
],
value_input: Union[
Float[torch.Tensor, "batch pos d_model"],
# Float[torch.Tensor, "batch pos head_index d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
]
) -> Float[torch.Tensor, "batch pos d_model"]:

# Calculate QKV
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.cfg.n_heads, self.cfg.d_head).permute(2, 0, 3, 1, 4)
q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)

qkv = self.hook_qkv(qkv)

q, k, v = qkv.unbind(0)

q = self.hook_q(q)
k = self.hook_k(k)
v = self.hook_v(v)

# q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)

# Attention Scores
q = q * self.scale
attn_scores = q @ k.transpose(-2, -1)
attn_scores = self.calculate_attn_scores(q, k)
attn_scores = self.hook_attn_scores(attn_scores)

# Attention Pattern
attn_pattern = attn_scores.softmax(dim=-1)
attn_pattern = torch.where(torch.isnan(attn_pattern), torch.zeros_like(attn_pattern), attn_pattern)
attn_pattern = self.hook_pattern(attn_pattern)
attn_pattern = attn_pattern.to(self.cfg.dtype)

# Value matrix
z = attn_pattern @ v
z = self.hook_z(z)

# Output projection layer
z = z.transpose(1,2).reshape(B,N,C)
output = self.proj(z)
# output = self.hook_result(output)

return output

# z = self.calculate_z_scores(v, pattern)

# if not self.cfg.use_attn_result:
# out = (
# (
# einsum(
# "batch pos head_index d_head, \
# head_index d_head d_model -> \
# batch pos d_model",
# z,
# self.W_O,
# )
# )
# + self.b_O
# )
# else:
# # Explicitly calculate the attention result so it can be accessed by a hook.
# # Off by default to not eat through GPU memory.
# result = self.hook_result(
# einsum(
# "batch pos head_index d_head, \
# head_index d_head d_model -> \
# batch pos head_index d_model",
# z,
# self.W_O,
# )
# )
# out = (
# einops.reduce(result, "batch pos head_index d_model -> batch position d_model", "sum")
# + self.b_O
# )
# return out
pattern = F.softmax(attn_scores, dim=-1) # where do I do normalization?
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern)

pattern = pattern.to(self.cfg.dtype)
z = self.calculate_z_scores(v, pattern)

if not self.cfg.use_attn_result:
out = (
(
einsum(
"batch pos head_index d_head, \
head_index d_head d_model -> \
batch pos d_model",
z,
self.W_O,
)
)
+ self.b_O
)
else:
# Explicitly calculate the attention result so it can be accessed by a hook.
# Off by default to not eat through GPU memory.
result = self.hook_result(
einsum(
"batch pos head_index d_head, \
head_index d_head d_model -> \
batch pos head_index d_model",
z,
self.W_O,
)
)
out = (
einops.reduce(result, "batch pos head_index d_model -> batch position d_model", "sum")
+ self.b_O
)
return out

def calculate_qkv_matrices(
self,
Expand Down Expand Up @@ -240,15 +212,16 @@ def calculate_qkv_matrices(
else:
qkv_einops_string = "batch pos d_model"

q = einsum(

q = self.hook_q(
einsum(
f"{qkv_einops_string}, head_index d_model d_head \
-> batch pos head_index d_head",
query_input,
self.W_Q,
) + self.b_Q
# [batch, pos, head_index, d_head]
q = self.hook_q(q.transpose(-1,-2)) # [batch, pos, d_head, head_index]; to match timm

)
+ self.b_Q
) # [batch, pos, head_index, d_head]
k = self.hook_k(
einsum(
f"{qkv_einops_string}, head_index d_model d_head \
Expand All @@ -258,8 +231,6 @@ def calculate_qkv_matrices(
)
+ self.b_K
) # [batch, pos, head_index, d_head]
k = self.hook_k(k.transpose(-1,-2)) # [batch, pos, d_head, head_index]; to match timm

v = self.hook_v(
einsum(
f"{qkv_einops_string}, head_index d_model d_head \
Expand All @@ -269,7 +240,6 @@ def calculate_qkv_matrices(
)
+ self.b_V
) # [batch, pos, head_index, d_head]
v = self.hook_v(v.transpose(-1,-2))
return q, k, v

def calculate_attn_scores(
Expand All @@ -282,22 +252,19 @@ def calculate_attn_scores(
Returns a tensor of shape [batch, head_index, query_pos, key_pos]
"""
q = q * self.scale
attn_scores = einsum(
"batch query_pos d_head head_index, batch key_pos d_head -> batch head_index query_pos key_pos",
"batch query_pos head_index d_head, batch key_pos head_index d_head -> batch head_index query_pos key_pos",
q,
k,
)

attn_scores = attn_scores / self.attn_scale
return attn_scores

def calculate_z_scores(
self,
v: Float[torch.Tensor, "batch key_pos head_index d_head"],
pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:


z = self.hook_z(
einsum(
"batch key_pos head_index d_head, \
Expand Down
Loading

0 comments on commit 9e4a4b4

Please sign in to comment.