From bc678c0b33e19d11fe52d8ae09afea6520cc2a31 Mon Sep 17 00:00:00 2001 From: Rob Date: Thu, 8 Feb 2024 19:06:37 -0500 Subject: [PATCH 1/3] fix attention weight loading --- src/vit_prisma/configs/HookedViTConfig.py | 2 +- .../prisma/loading_from_pretrained.py | 61 +++++++------------ 2 files changed, 24 insertions(+), 39 deletions(-) diff --git a/src/vit_prisma/configs/HookedViTConfig.py b/src/vit_prisma/configs/HookedViTConfig.py index d32f9f6f..4dfe3fed 100644 --- a/src/vit_prisma/configs/HookedViTConfig.py +++ b/src/vit_prisma/configs/HookedViTConfig.py @@ -97,7 +97,7 @@ class HookedViTConfig: max_grad_norm = 1.0 # Saving related - parent_dir: str = "/Users/praneets/Downloads/working_dir" + parent_dir: str = "" save_dir: str = 'Checkpoints' save_checkpoints: bool = True save_cp_frequency: int = 5 diff --git a/src/vit_prisma/prisma/loading_from_pretrained.py b/src/vit_prisma/prisma/loading_from_pretrained.py index ac401e52..d56ee001 100644 --- a/src/vit_prisma/prisma/loading_from_pretrained.py +++ b/src/vit_prisma/prisma/loading_from_pretrained.py @@ -19,7 +19,7 @@ import einops -def convert_timm_weights( +def convert_timm_weigthts( old_state_dict, cfg: HookedViTConfig, ): @@ -29,10 +29,7 @@ def convert_timm_weights( new_state_dict = {} new_state_dict["cls_token"] = old_state_dict["cls_token"] - new_state_dict["pos_embed.W_pos"] = old_state_dict["pos_embed"] - pos_embed_W_pos = old_state_dict["pos_embed"] - pos_embed_W_pos = pos_embed_W_pos.squeeze(0) - new_state_dict["pos_embed.W_pos"] = pos_embed_W_pos + new_state_dict["pos_embed.W_pos"] = old_state_dict["pos_embed"].squeeze(0) new_state_dict["embed.proj.weight"] = old_state_dict["patch_embed.proj.weight"] new_state_dict["embed.proj.bias"] = old_state_dict["patch_embed.proj.bias"] new_state_dict["ln_final.w"] = old_state_dict["norm.weight"] @@ -46,37 +43,25 @@ def convert_timm_weights( new_state_dict[f"{layer_key}.ln2.b"] = old_state_dict[f"{layer_key}.norm2.bias"] W = old_state_dict[f"{layer_key}.attn.qkv.weight"] - new_state_dict[f"{layer_key}.attn.qkv.weight"] = old_state_dict[f"{layer_key}.attn.qkv.weight"] - new_state_dict[f"{layer_key}.attn.qkv.bias"] = old_state_dict[f"{layer_key}.attn.qkv.bias"] - - new_state_dict[f"{layer_key}.attn.proj.weight"] = old_state_dict[f"{layer_key}.attn.proj.weight"] - new_state_dict[f"{layer_key}.attn.proj.bias"] = old_state_dict[f"{layer_key}.attn.proj.bias"] - - - # W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) - # W_Q = einops.rearrange(W_Q, "(i h) m->h m i", h=cfg.n_heads) - # W_K = einops.rearrange(W_K, "(i h) m->h m i", h=cfg.n_heads) - # W_V = einops.rearrange(W_V, "(i h) m->h m i", h=cfg.n_heads) - # new_state_dict[f"{layer_key}.attn.W_Q"] = W_Q - # new_state_dict[f"{layer_key}.attn.W_K"] = W_K - # new_state_dict[f"{layer_key}.attn.W_V"] = W_V - - # W_O = old_state_dict[f"{layer_key}.attn.proj.weight"] - # W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) - # new_state_dict[f"{layer_key}.attn.W_O"] = W_O - - # attn_bias = old_state_dict[f"{layer_key}.attn.qkv.bias"] - # b_Q, b_K, b_V = torch.tensor_split(attn_bias, 3, dim=0) - # b_Q = einops.rearrange(b_Q, "(i h) -> h i", h=cfg.n_heads) - # b_K = einops.rearrange(b_K, "(i h) -> h i", h=cfg.n_heads) - # b_V = einops.rearrange(b_V, "(i h) -> h i", h=cfg.n_heads) - # new_state_dict[f"{layer_key}.attn.b_Q"] = b_Q - # new_state_dict[f"{layer_key}.attn.b_K"] = b_K - # new_state_dict[f"{layer_key}.attn.b_V"] = b_V - - # b_O = old_state_dict[f"{layer_key}.attn.proj.bias"] - # b_O = einops.rearrange(b_O, "m -> m") - # new_state_dict[f"{layer_key}.attn.b_O"] = b_O + W_reshape = einops.rearrange( W, "(three h dh) d ->three h d dh" , three=3, h=cfg.n_heads, d=cfg.d_model, dh=cfg.d_head) + W_Q, W_K, W_V = torch.unbind(W_reshape, dim=0) + new_state_dict[f"{layer_key}.attn.W_Q"] = W_Q + new_state_dict[f"{layer_key}.attn.W_K"] = W_K + new_state_dict[f"{layer_key}.attn.W_V"] = W_V + + W_O = old_state_dict[f"{layer_key}.attn.proj.weight"] + W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) + new_state_dict[f"{layer_key}.attn.W_O"] = W_O + + attn_bias = old_state_dict[f"{layer_key}.attn.qkv.bias"] + attn_bias_reshape = einops.rearrange(attn_bias, "(three h dh) -> three h dh", three=3, h=cfg.n_heads, dh=cfg.d_head) + b_Q, b_K, b_V = torch.unbind(attn_bias_reshape, dim=0) + new_state_dict[f"{layer_key}.attn.b_Q"] = b_Q + new_state_dict[f"{layer_key}.attn.b_K"] = b_K + new_state_dict[f"{layer_key}.attn.b_V"] = b_V + + b_O = old_state_dict[f"{layer_key}.attn.proj.bias"] + new_state_dict[f"{layer_key}.attn.b_O"] = b_O new_state_dict[f"{layer_key}.mlp.b_in"] = old_state_dict[f"{layer_key}.mlp.fc1.bias"] new_state_dict[f"{layer_key}.mlp.b_out"] = old_state_dict[f"{layer_key}.mlp.fc2.bias"] @@ -139,7 +124,7 @@ def get_pretrained_state_dict( param.requires_grad = False # state_dict = None # Conversion of state dict to HookedTransformer format - state_dict = convert_timm_weights(hf_model.state_dict(), cfg) + state_dict = convert_timm_weigthts(hf_model.state_dict(), cfg) return state_dict @@ -204,4 +189,4 @@ def convert_pretrained_model_config(model: str, is_timm: bool = True) -> HookedV 'n_params' : sum(p.numel() for p in model.parameters() if p.requires_grad) if is_timm else None, } - return HookedViTConfig.from_dict(pretrained_config) # Does this entirely override config or add to it? + return HookedViTConfig.from_dict(pretrained_config) From 275bdf20e6ca985d938fcd71392a7b34a61e3d02 Mon Sep 17 00:00:00 2001 From: Rob Date: Thu, 8 Feb 2024 19:40:33 -0500 Subject: [PATCH 2/3] Fixed loading of attention block, revert changes to attention and transformer_block, added test --- src/vit_prisma/models/layers/attention.py | 253 ++++++++---------- .../models/layers/transformer_block.py | 43 ++- .../prisma/loading_from_pretrained.py | 17 +- tests/test_loading_timm.py | 26 ++ 4 files changed, 168 insertions(+), 171 deletions(-) create mode 100644 tests/test_loading_timm.py diff --git a/src/vit_prisma/models/layers/attention.py b/src/vit_prisma/models/layers/attention.py index 6a5beb89..23451543 100644 --- a/src/vit_prisma/models/layers/attention.py +++ b/src/vit_prisma/models/layers/attention.py @@ -1,7 +1,3 @@ -""" -Attention code from timm model -""" - import torch.nn as nn import torch @@ -37,76 +33,70 @@ def __init__( self.cfg = cfg - self.qkv = nn.Linear(self.cfg.d_model, self.cfg.d_model * 3, bias=True) - self.proj = nn.Linear(self.cfg.d_model, self.cfg.d_model) - - # # Initialize parameters - # self.W_Q = nn.Parameter( - # torch.empty( - # self.cfg.n_heads, - # self.cfg.d_model, - # self.cfg.d_head, - # dtype = self.cfg.dtype - # ) - # ) - # self.W_K = nn.Parameter( - # torch.empty( - # self.cfg.n_heads, - # self.cfg.d_model, - # self.cfg.d_head, - # dtype = self.cfg.dtype - # ) - # ) - # self.W_V = nn.Parameter( - # torch.empty( - # self.cfg.n_heads, - # self.cfg.d_model, - # self.cfg.d_head, - # dtype = self.cfg.dtype - # ) - # ) - # self.W_O = nn.Parameter( - # torch.empty( - # self.cfg.n_heads, - # self.cfg.d_head, - # self.cfg.d_model, - # dtype = self.cfg.dtype - # ) - # ) - - # # Initialize biases - # self.b_Q = nn.Parameter( - # torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) - # ) - # self.b_K = nn.Parameter( - # torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) - # ) - # self.b_V = nn.Parameter( - # torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) - # ) - # self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) + # Initialize parameters + self.W_Q = nn.Parameter( + torch.empty( + self.cfg.n_heads, + self.cfg.d_model, + self.cfg.d_head, + dtype = self.cfg.dtype + ) + ) + self.W_K = nn.Parameter( + torch.empty( + self.cfg.n_heads, + self.cfg.d_model, + self.cfg.d_head, + dtype = self.cfg.dtype + ) + ) + self.W_V = nn.Parameter( + torch.empty( + self.cfg.n_heads, + self.cfg.d_model, + self.cfg.d_head, + dtype = self.cfg.dtype + ) + ) + self.W_O = nn.Parameter( + torch.empty( + self.cfg.n_heads, + self.cfg.d_head, + self.cfg.d_model, + dtype = self.cfg.dtype + ) + ) + + # Initialize biases + self.b_Q = nn.Parameter( + torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) + ) + self.b_K = nn.Parameter( + torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) + ) + self.b_V = nn.Parameter( + torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) + ) + self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) # Add hook points - self.hook_q = HookPoint() # [batch, pos, head_index, d_head] self.hook_k = HookPoint() # [batch, pos, head_index, d_head] + self.hook_q = HookPoint() # [batch, pos, head_index, d_head] self.hook_v = HookPoint() # [batch, pos, head_index, d_head] self.hook_z = HookPoint() # [batch, pos, head_index, d_head] self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] - # self.hook_result = HookPoint() # [batch, pos, head_index, d_model] - - self.hook_qkv = HookPoint() # [batch, pos, 3, head_index, d_head] + self.hook_result = HookPoint() # [batch, pos, head_index, d_model] self.layer_id = layer_id - # # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? - # if self.cfg.use_attn_scale: - # self.attn_scale = np.sqrt(self.cfg.d_head) - # else: - # self.attn_scale = 1.0 - - self.scale = self.cfg.d_head ** -0.5 + # Note to Sonia: check this. + # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? + if self.cfg.use_attn_scale: + self.attn_scale = np.sqrt(self.cfg.d_head) + else: + self.attn_scale = 1.0 @property def OV(self) -> FactoredMatrix: @@ -135,80 +125,62 @@ def QK(self) -> FactoredMatrix: def forward( self, - x: Union[ + query_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + key_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + value_input: Union[ Float[torch.Tensor, "batch pos d_model"], - # Float[torch.Tensor, "batch pos head_index d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], ] ) -> Float[torch.Tensor, "batch pos d_model"]: - # Calculate QKV - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.cfg.n_heads, self.cfg.d_head).permute(2, 0, 3, 1, 4) + q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) - qkv = self.hook_qkv(qkv) - - q, k, v = qkv.unbind(0) - - q = self.hook_q(q) - k = self.hook_k(k) - v = self.hook_v(v) - - # q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) - - # Attention Scores - q = q * self.scale - attn_scores = q @ k.transpose(-2, -1) + attn_scores = self.calculate_attn_scores(q, k) attn_scores = self.hook_attn_scores(attn_scores) - # Attention Pattern - attn_pattern = attn_scores.softmax(dim=-1) - attn_pattern = torch.where(torch.isnan(attn_pattern), torch.zeros_like(attn_pattern), attn_pattern) - attn_pattern = self.hook_pattern(attn_pattern) - attn_pattern = attn_pattern.to(self.cfg.dtype) - - # Value matrix - z = attn_pattern @ v - z = self.hook_z(z) - - # Output projection layer - z = z.transpose(1,2).reshape(B,N,C) - output = self.proj(z) - # output = self.hook_result(output) - - return output - - # z = self.calculate_z_scores(v, pattern) - - # if not self.cfg.use_attn_result: - # out = ( - # ( - # einsum( - # "batch pos head_index d_head, \ - # head_index d_head d_model -> \ - # batch pos d_model", - # z, - # self.W_O, - # ) - # ) - # + self.b_O - # ) - # else: - # # Explicitly calculate the attention result so it can be accessed by a hook. - # # Off by default to not eat through GPU memory. - # result = self.hook_result( - # einsum( - # "batch pos head_index d_head, \ - # head_index d_head d_model -> \ - # batch pos head_index d_model", - # z, - # self.W_O, - # ) - # ) - # out = ( - # einops.reduce(result, "batch pos head_index d_model -> batch position d_model", "sum") - # + self.b_O - # ) - # return out + pattern = F.softmax(attn_scores, dim=-1) # where do I do normalization? + pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) + pattern = self.hook_pattern(pattern) + + pattern = pattern.to(self.cfg.dtype) + z = self.calculate_z_scores(v, pattern) + + if not self.cfg.use_attn_result: + out = ( + ( + einsum( + "batch pos head_index d_head, \ + head_index d_head d_model -> \ + batch pos d_model", + z, + self.W_O, + ) + ) + + self.b_O + ) + else: + # Explicitly calculate the attention result so it can be accessed by a hook. + # Off by default to not eat through GPU memory. + result = self.hook_result( + einsum( + "batch pos head_index d_head, \ + head_index d_head d_model -> \ + batch pos head_index d_model", + z, + self.W_O, + ) + ) + out = ( + einops.reduce(result, "batch pos head_index d_model -> batch position d_model", "sum") + + self.b_O + ) + return out def calculate_qkv_matrices( self, @@ -240,15 +212,16 @@ def calculate_qkv_matrices( else: qkv_einops_string = "batch pos d_model" - q = einsum( + + q = self.hook_q( + einsum( f"{qkv_einops_string}, head_index d_model d_head \ -> batch pos head_index d_head", query_input, self.W_Q, - ) + self.b_Q - # [batch, pos, head_index, d_head] - q = self.hook_q(q.transpose(-1,-2)) # [batch, pos, d_head, head_index]; to match timm - + ) + + self.b_Q + ) # [batch, pos, head_index, d_head] k = self.hook_k( einsum( f"{qkv_einops_string}, head_index d_model d_head \ @@ -258,8 +231,6 @@ def calculate_qkv_matrices( ) + self.b_K ) # [batch, pos, head_index, d_head] - k = self.hook_k(k.transpose(-1,-2)) # [batch, pos, d_head, head_index]; to match timm - v = self.hook_v( einsum( f"{qkv_einops_string}, head_index d_model d_head \ @@ -269,7 +240,6 @@ def calculate_qkv_matrices( ) + self.b_V ) # [batch, pos, head_index, d_head] - v = self.hook_v(v.transpose(-1,-2)) return q, k, v def calculate_attn_scores( @@ -282,13 +252,12 @@ def calculate_attn_scores( Returns a tensor of shape [batch, head_index, query_pos, key_pos] """ - q = q * self.scale attn_scores = einsum( - "batch query_pos d_head head_index, batch key_pos d_head -> batch head_index query_pos key_pos", + "batch query_pos head_index d_head, batch key_pos head_index d_head -> batch head_index query_pos key_pos", q, k, ) - + attn_scores = attn_scores / self.attn_scale return attn_scores def calculate_z_scores( @@ -296,8 +265,6 @@ def calculate_z_scores( v: Float[torch.Tensor, "batch key_pos head_index d_head"], pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: - - z = self.hook_z( einsum( "batch key_pos head_index d_head, \ diff --git a/src/vit_prisma/models/layers/transformer_block.py b/src/vit_prisma/models/layers/transformer_block.py index 72cace19..70a4db70 100644 --- a/src/vit_prisma/models/layers/transformer_block.py +++ b/src/vit_prisma/models/layers/transformer_block.py @@ -29,7 +29,7 @@ def add_head_dimension( if clone_tensor: return repeated_tensor.clone() else: - return repeated_tensor + return repeated_tensor class TransformerBlock(nn.Module): """ @@ -85,35 +85,32 @@ def forward( resid_pre = self.hook_resid_pre(resid_pre) - # if self.cfg.use_attn_in or self.cfg.use_split_qkv_input: - # # We're adding a head dimension - # attn_in = add_head_dimension(resid_pre, self.cfg.n_heads, clone_tensor=False) - # else: - attn_in = resid_pre + if self.cfg.use_attn_in or self.cfg.use_split_qkv_input: + # We're adding a head dimension + attn_in = add_head_dimension(resid_pre, self.cfg.n_heads, clone_tensor=False) + else: + attn_in = resid_pre if self.cfg.use_attn_in: attn_in = self.hook_attn_in(attn_in.clone()) - # if self.cfg.use_split_qkv_input: - # query_input = self.hook_q_input(attn_in.clone()) - # key_input = self.hook_k_input(attn_in.clone()) - # value_input = self.hook_v_input(attn_in.clone()) - # else: - # query_input = attn_in - # key_input = attn_in - # value_input = attn_in + if self.cfg.use_split_qkv_input: + query_input = self.hook_q_input(attn_in.clone()) + key_input = self.hook_k_input(attn_in.clone()) + value_input = self.hook_v_input(attn_in.clone()) + else: + query_input = attn_in + key_input = attn_in + value_input = attn_in - # attn_out = self.attn( - # query_input = query_input, - # key_input = key_input, - # value_input = value_input, - # ) - - # Update: moving to timm setup for accuracy; add function to split qkv later. attn_out = self.attn( - self.ln1(attn_in) + query_input = self.ln1(query_input), + key_input = self.ln1(key_input), + value_input = self.ln1(value_input), ) - + + # Take hook fn + attn_out = self.hook_attn_out( attn_out ) diff --git a/src/vit_prisma/prisma/loading_from_pretrained.py b/src/vit_prisma/prisma/loading_from_pretrained.py index d56ee001..2bc857fb 100644 --- a/src/vit_prisma/prisma/loading_from_pretrained.py +++ b/src/vit_prisma/prisma/loading_from_pretrained.py @@ -162,13 +162,13 @@ def fill_missing_keys(model, state_dict): state_dict[key] = default_state_dict[key] return state_dict -def convert_pretrained_model_config(model: str, is_timm: bool = True) -> HookedViTConfig: +def convert_pretrained_model_config(model_name: str, is_timm: bool = True) -> HookedViTConfig: if is_timm: - model = timm.create_model(model) + model = timm.create_model(model_name) hf_config = AutoConfig.from_pretrained(model.default_cfg['hf_hub_id']) else: - hf_config = AutoConfig.from_pretrained(model) + hf_config = AutoConfig.from_pretrained(model_name) pretrained_config = { 'n_layers' : hf_config.num_hidden_layers, @@ -178,8 +178,7 @@ def convert_pretrained_model_config(model: str, is_timm: bool = True) -> HookedV 'n_heads' : hf_config.num_attention_heads, 'd_mlp' : hf_config.intermediate_size, 'activation_name' : hf_config.hidden_act, - 'eps': 1e-6, # There is a bug here - # 'eps' : hf_config.layer_norm_eps, + 'eps' : hf_config.layer_norm_eps, 'original_architecture' : hf_config.architecture, 'initializer_range' : hf_config.initializer_range, 'n_channels' : hf_config.num_channels, @@ -188,5 +187,13 @@ def convert_pretrained_model_config(model: str, is_timm: bool = True) -> HookedV 'n_classes' : hf_config.num_classes, 'n_params' : sum(p.numel() for p in model.parameters() if p.requires_grad) if is_timm else None, } + + # Currently a bug getting configs, only this model works and still requires modification of eps + if is_timm and model_name == "vit_base_patch16_224": + pretrained_config.update({ + "eps": 1e-6, + "return_type": "class_logits", + }) + return HookedViTConfig.from_dict(pretrained_config) diff --git a/tests/test_loading_timm.py b/tests/test_loading_timm.py new file mode 100644 index 00000000..5803913f --- /dev/null +++ b/tests/test_loading_timm.py @@ -0,0 +1,26 @@ +import pytest +import torch +import timm +from vit_prisma.models.base_vit import HookedViT + +#currently only vit_base_patch16_224 supported (config loading issue) +def test_loading_timm(): + TOLERANCE = 1e-5 + + model_name = "vit_base_patch16_224" + batch_size = 5 + channels = 3 + height = 224 + width = 224 + device = "cpu" + + hooked_model = HookedViT.from_pretrained(model_name) + hooked_model.to(device) + timm_model = timm.create_model(model_name, pretrained=True) + timm_model.to(device) + + with torch.random.fork_rng(): + torch.manual_seed(1) + input_image = torch.rand((batch_size, channels, height, width)).to(device) + + assert torch.allclose(hooked_model(input_image), timm_model(input_image), atol=TOLERANCE), "Model output diverges!" From ba4240d3979bfbc603b5f4633f5f39df30cf279e Mon Sep 17 00:00:00 2001 From: Rob Date: Thu, 8 Feb 2024 19:47:37 -0500 Subject: [PATCH 3/3] typo + comment change --- src/vit_prisma/prisma/loading_from_pretrained.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/vit_prisma/prisma/loading_from_pretrained.py b/src/vit_prisma/prisma/loading_from_pretrained.py index 2bc857fb..a03039ca 100644 --- a/src/vit_prisma/prisma/loading_from_pretrained.py +++ b/src/vit_prisma/prisma/loading_from_pretrained.py @@ -19,7 +19,7 @@ import einops -def convert_timm_weigthts( +def convert_timm_weights( old_state_dict, cfg: HookedViTConfig, ): @@ -124,7 +124,7 @@ def get_pretrained_state_dict( param.requires_grad = False # state_dict = None # Conversion of state dict to HookedTransformer format - state_dict = convert_timm_weigthts(hf_model.state_dict(), cfg) + state_dict = convert_timm_weights(hf_model.state_dict(), cfg) return state_dict @@ -188,7 +188,7 @@ def convert_pretrained_model_config(model_name: str, is_timm: bool = True) -> Ho 'n_params' : sum(p.numel() for p in model.parameters() if p.requires_grad) if is_timm else None, } - # Currently a bug getting configs, only this model works and still requires modification of eps + # Currently a bug getting configs, only this model confirmed to work and even it requires modification of eps if is_timm and model_name == "vit_base_patch16_224": pretrained_config.update({ "eps": 1e-6,