Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make modeling compatible with Nanotron + few optims #23

Closed
Closed
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7384efb
implement GPT Neo's rope
NouamaneTazi Oct 25, 2023
7ebc9ea
fix imports
NouamaneTazi Oct 28, 2023
2c28b04
output logits
NouamaneTazi Oct 28, 2023
253da5b
attn mask.all()
NouamaneTazi Oct 30, 2023
d3c15ba
fix caching in rope
NouamaneTazi Oct 31, 2023
1e74664
GQA generation without cache
NouamaneTazi Nov 1, 2023
1f424bb
fix use_cache for GQA
NouamaneTazi Nov 1, 2023
39a3483
reshapes fixes for num_heads=2
NouamaneTazi Nov 1, 2023
1c79ecd
.
NouamaneTazi Nov 2, 2023
19cf153
add flash_attn_with_kvcache to GQA
NouamaneTazi Dec 7, 2023
b493268
add merging word embedding checkpoints
xrsrke Dec 29, 2023
4446fe0
add merging quite a bit
xrsrke Dec 31, 2023
1d949b2
add reference starcoder model
xrsrke Dec 31, 2023
a58a947
merged most of the checkpoints
xrsrke Dec 31, 2023
ac559a1
add merged checkpoints
xrsrke Jan 1, 2024
78114b7
add mapping to target state dict
xrsrke Jan 1, 2024
7d50b80
refactor converting scrip
xrsrke Jan 2, 2024
21ee689
refactor
xrsrke Jan 3, 2024
210311b
add inference script
xrsrke Jan 3, 2024
09c086a
refactor
xrsrke Jan 3, 2024
ae54653
refactor all functions
xrsrke Jan 3, 2024
594099c
save some files before cleaning it all
xrsrke Jan 3, 2024
fb8a86b
delete uncessary files
xrsrke Jan 3, 2024
c26472c
add rope_theta to config
NouamaneTazi Jan 5, 2024
9c9cfbb
fix config.attn_pdrop for flash attn
NouamaneTazi Jan 8, 2024
6bdf78a
Merge pull request #1 from xrsrke/sc2-rope
NouamaneTazi Jan 8, 2024
1507798
Refactor GPTBigCode model conversion code
NouamaneTazi Jan 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Refactor GPTBigCode model conversion code
NouamaneTazi committed Jan 8, 2024
commit 15077983d17f926bbf61cff744bb4a3221b566f7
4 changes: 1 addition & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
@@ -2469,9 +2469,7 @@ def greedy_search(
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_tokens_scores,) if outputs.logits.shape[1] == 1 else (
outputs.logits,
)
scores += (next_tokens_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
141 changes: 109 additions & 32 deletions src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,138 @@
import argparse
import os
from pathlib import Path
import re

import torch
from transformers.models.gpt_bigcode.merge_fast_llm_checkpoint import merge_checkpoint
from transformers.models.gpt_bigcode import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel


# The simple map of names for "automated" rules.
NAME_MAP = {
"_mlp._layer_1": "mlp.c_fc",
"_mlp._layer_2": "mlp.c_proj",
"layer_norm_1": "ln_1",
"layer_norm_2": "ln_2",
# "attention.dense": "attn.c_proj",
"self_attn.dense": "attn.c_proj",
# "self_attention.query_key_value": "attn.c_attn",
}


def convert_fast_llm_checkpoint(state_dict, config):
# The converted output model.
output_state_dict = {}
if "window_size" in config:
attention_window_size = config["window_size"]
else:
attention_window_size = config.get("attention_window_size", None)

config = GPTBigCodeConfig(
architectures=["GPTBigCodeLMHeadModel"],
vocab_size=config["vocab_size"],
n_positions=config["max_position_embeddings"],
n_embd=config["hidden_size"],
n_layer=config["num_layers"],
n_head=config["num_attention_heads"],
n_inner=config["ffn_hidden_size"],
activation_function="gelu", # TODO
multi_query=True, # TODO
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
use_cache=True,
bos_token_id=0, # TODO: can we remove these?
eos_token_id=0,
attention_softmax_in_fp32=True,
scale_attention_softmax_in_fp32=True,
use_rotary_embeddings=config["use_rotary_embeddings"],
rotary_embedding_scale=config["rotary_embedding_scale"],
use_position_embeddings=config["use_position_embeddings"],
attention_window_size=attention_window_size
)

# Truncate the word embeddings to the vocab-size
word_embeddings = state_dict.pop("_layers.0._word_embeddings_weight")[:config.vocab_size, :]
output_state_dict["transformer.wte.weight"] = word_embeddings
if config.use_position_embeddings:
output_state_dict["transformer.wpe.weight"] = state_dict.pop("_layers.0._position_embeddings_weight")

# Layer-0 is the word/position embeddings
# Layers 1 to n_layer need to be re-mapped from 0 to n_layer-1.
# _layers.{layer_index}.{op}.{w/b}

# Concatenate QKV matrix
for layer_index in range(1, config.n_layer + 1):
for weight_or_bias in ["weight", "bias"]:
query = state_dict.pop(f"_layers.{layer_index}.self_attn.query.{weight_or_bias}")
key_value = state_dict.pop(f"_layers.{layer_index}.self_attn.key_value.{weight_or_bias}")
output_state_dict[f"transformer.h.{layer_index - 1}.attn.c_attn.{weight_or_bias}"] = torch.cat([query, key_value], dim=0)

# Extract the other ops
layer_re = re.compile("_layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
for name, value in state_dict.items():
m = layer_re.match(name)
assert m is not None, f"Invalid layer name: {name}"

# The index of the layer.
layer_index = int(m.group(1))
# The name of the operation.
op_name = m.group(2)
# Is it a weight or a bias?
weight_or_bias = m.group(3)

# Final layernorm
if op_name == "final_layernorm":
assert layer_index == config.n_layer + 1
output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value
else:
output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = value

# For LM head, transformers' wants the matrix to weight embeddings.
output_state_dict["lm_head.weight"] = word_embeddings

return output_state_dict, config


def main(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_dir",
type=Path,
# default="/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d",
help="Path where the converted model is saved"
help="Path to the experiment directory",
)
parser.add_argument(
"--save_dir",
type=Path,
# default="./",
help="Path where the converted model is saved"
)
args = parser.parse_args(argv)

print("start")

# TODO(xrsrke): auto convert checkpoint_dir to Path
# checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d"
# checkpoint_dir = Path(checkpoint_dir)

state_dict = merge_checkpoint(args.checkpoint_dir)
state_dict, config = merge_checkpoint(
args.checkpoint_dir,
dummy_experiment_dir=None
)

output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config)

print("Saving config")
save_dir = args.save_dir or args.checkpoint_dir / "converted"
output_config.save_pretrained(save_dir)

# Store the state_dict to file.
output_checkpoint_file = os.path.join(save_dir, "pytorch_model.bin")

print(f'Saving checkpoint to "{output_checkpoint_file}"')
torch.save(state_dict, output_checkpoint_file)
torch.save(output_state_dict, output_checkpoint_file)
print(f'Done!')

# # Compare
# def compare_state_dicts(dict1, dict2):
# # Compare keys
# if set(dict1.keys()) != set(dict2.keys()):
# return "Different keys"

# # Compare shapes and values
# for key in dict1:
# if dict1[key].shape != dict2[key].shape:
# return f"Different shape for key: {key}"
# if not torch.allclose(dict1[key], dict2[key]):
# return f"Different values for key: {key}"

# return "State dictionaries are identical"

# ref_state_dict = torch.load("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/merged_checkpoint.pth")
# result = compare_state_dicts(state_dict, ref_state_dict)
# print(result)



if __name__ == "__main__":
Loading