From b7d82aa9791c933f9a12c596f6e7966b7d53e67b Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 26 Dec 2024 19:39:05 +0300 Subject: [PATCH] Align the code with non-interleaved placement of QKV --- litgpt/adapter_v2.py | 2 +- litgpt/model.py | 5 +-- litgpt/scripts/convert_hf_checkpoint.py | 40 +++++++++------------ litgpt/scripts/convert_lit_checkpoint.py | 45 ++++++++++++------------ 4 files changed, 44 insertions(+), 48 deletions(-) diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index 6885f628aa..9b975260f0 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -164,7 +164,7 @@ def __init__(self, config: Config, block_idx: int) -> None: nn.Module.__init__(self) shape = (config.n_head + 2 * config.n_query_groups) * config.head_size # key, query, value projections for all heads, but in a batch - self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) + self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) # output projection # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) diff --git a/litgpt/model.py b/litgpt/model.py index 5fbd0a8c24..cbdf2a4bdd 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -347,8 +347,9 @@ def forward( # NOTE: flash attention requires it in training mode. # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): - k = k.expand(*q.shape) # (B, nh_q, T, hs) - v = v.expand(*q.shape) # (B, nh_q, T, hs) + q_per_kv = self.config.n_head // self.config.n_query_groups + k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) + v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) if self.apply_sliding_window_attention: """ diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 2c0dbb6aad..fbcfa871a6 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -408,20 +408,17 @@ def copy_weights_qwen_2_5( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) - for name, param in hf_weights.items(): - if "model.layers" in name: - from_name, l = layer_template(name, 2) - qkv = qkv_weights.setdefault(l, defaultdict(dict)) - if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): - weight_name, weight_type = from_name.split(".")[-2:] - qkv[weight_type][weight_name] = param - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype, verbose=debug_mode) + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -436,22 +433,19 @@ def copy_weights_qwen_2_5( for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: - # split across different .bin files + # qkv is splitted across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] + if progress_per_file is not None: pbar.update(progress_per_file) + def qkv_reassemble( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -483,7 +477,7 @@ def layer_template(layer_name: str, num_matches: int = 1) -> Tuple[str, int]: def load_param( - param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose=False + param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose: bool =False ) -> torch.Tensor: if hasattr(param, "_load_tensor"): # support tensors loaded via `lazy_load()` diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index 5bb08ea4f6..f276e3ae31 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -310,34 +310,35 @@ def copy_weights_qwen_2_5( "lm_head.weight": "lm_head.weight", } - for name, param in lit_weights.items(): - if name == "lm_head.weight" and untie_weights: + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: continue - if name.endswith((".attn.attn.weight", ".attn.attn.bias")): - from_name, l_idx = layer_template(name, 2) - qkv = load_param(param, name, None) - qp, kp, vp = qkv_split(qkv, config) - - weight_type = name.split(".")[-1] # weight or bias - q = f"model.layers.{l_idx}.self_attn.q_proj.{weight_type}" - k = f"model.layers.{l_idx}.self_attn.k_proj.{weight_type}" - v = f"model.layers.{l_idx}.self_attn.v_proj.{weight_type}" - for to_name, param in zip((q, k, v), (qp, kp, vp)): - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): + weight_type = from_name.split(".")[-1] # weight or bias + to_names = ( + "model.layers.{}.self_attn.q_proj.{}".format(*ids, weight_type), + "model.layers.{}.self_attn.k_proj.{}".format(*ids, weight_type), + "model.layers.{}.self_attn.v_proj.{}".format(*ids, weight_type), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) else: - if "transformer.h" in name: - from_name, l_idx = layer_template(name, 2) - to_name = weight_map[from_name] - to_name = to_name.format(l_idx) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param + def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: """Reassemble from a normal to an interleaved placement in a QKV matrix. [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...]