Skip to content

Commit

Permalink
Align the code with non-interleaved placement of QKV
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov committed Dec 26, 2024
1 parent ac310a9 commit b7d82aa
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 48 deletions.
2 changes: 1 addition & 1 deletion litgpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(self, config: Config, block_idx: int) -> None:
nn.Module.__init__(self)
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)
self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)
# output projection
# if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
Expand Down
5 changes: 3 additions & 2 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,9 @@ def forward(
# NOTE: flash attention requires it in training mode.
# Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting.
if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1):
k = k.expand(*q.shape) # (B, nh_q, T, hs)
v = v.expand(*q.shape) # (B, nh_q, T, hs)
q_per_kv = self.config.n_head // self.config.n_query_groups
k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs)
v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs)

if self.apply_sliding_window_attention:
"""
Expand Down
40 changes: 17 additions & 23 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,20 +408,17 @@ def copy_weights_qwen_2_5(
if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))

for name, param in hf_weights.items():
if "model.layers" in name:
from_name, l = layer_template(name, 2)
qkv = qkv_weights.setdefault(l, defaultdict(dict))
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
to_name = weight_map[from_name]
if to_name is None:
continue
to_name = to_name.format(l)
else:
to_name = weight_map[name]
param = load_param(param, name, dtype, verbose=debug_mode)
for from_name, param in hf_weights.items():
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map[name_template]
param = load_param(param, from_name, dtype, verbose=debug_mode)
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
Expand All @@ -436,22 +433,19 @@ def copy_weights_qwen_2_5(
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# split across different .bin files
# qkv is splitted across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
q_per_kv = config.n_head // config.n_query_groups
qs = torch.split(q, config.head_size * q_per_kv)
ks = torch.split(k, config.head_size)
vs = torch.split(v, config.head_size)
cycled = [t for group in zip(qs, ks, vs) for t in group]
qkv = torch.cat(cycled)
state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]

if progress_per_file is not None:
pbar.update(progress_per_file)


def qkv_reassemble(
param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -483,7 +477,7 @@ def layer_template(layer_name: str, num_matches: int = 1) -> Tuple[str, int]:


def load_param(
param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose=False
param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose: bool =False
) -> torch.Tensor:
if hasattr(param, "_load_tensor"):
# support tensors loaded via `lazy_load()`
Expand Down
45 changes: 23 additions & 22 deletions litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,34 +310,35 @@ def copy_weights_qwen_2_5(
"lm_head.weight": "lm_head.weight",
}

for name, param in lit_weights.items():
if name == "lm_head.weight" and untie_weights:
for from_name, param in lit_weights.items():
if from_name == "lm_head.weight" and untie_weights:
continue
if name.endswith((".attn.attn.weight", ".attn.attn.bias")):
from_name, l_idx = layer_template(name, 2)
qkv = load_param(param, name, None)
qp, kp, vp = qkv_split(qkv, config)

weight_type = name.split(".")[-1] # weight or bias
q = f"model.layers.{l_idx}.self_attn.q_proj.{weight_type}"
k = f"model.layers.{l_idx}.self_attn.k_proj.{weight_type}"
v = f"model.layers.{l_idx}.self_attn.v_proj.{weight_type}"
for to_name, param in zip((q, k, v), (qp, kp, vp)):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
name_template, *ids = layer_template(from_name, num_matches=2)
param = load_param(param, from_name, None)
if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")):
weight_type = from_name.split(".")[-1] # weight or bias
to_names = (
"model.layers.{}.self_attn.q_proj.{}".format(*ids, weight_type),
"model.layers.{}.self_attn.k_proj.{}".format(*ids, weight_type),
"model.layers.{}.self_attn.v_proj.{}".format(*ids, weight_type),
)
params = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
else:
if "transformer.h" in name:
from_name, l_idx = layer_template(name, 2)
to_name = weight_map[from_name]
to_name = to_name.format(l_idx)
else:
to_name = weight_map[name]
param = load_param(param, name, None)
to_names = (weight_map[name_template].format(*ids),)
params = (param,)

for to_name, param in zip(to_names, params):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param


def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor:
"""Reassemble from a normal to an interleaved placement in a QKV matrix.
[Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...]
Expand Down

0 comments on commit b7d82aa

Please sign in to comment.