Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For visibility: Gqa megatron rope #25

Open
wants to merge 8 commits into
base: gqa
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/transformers/models/gpt_bigcode/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## Conversion to `transformers`

To convert a model from Megatron-LM to transformers use:
```bash
source ~/.bashrc
export PYTHONPATH=Megatron-LM
export PYTHONPATH=transformers/src:$PYTHONPATH

cd transformers/src/transformers/models

python gpt_bigcode/convert_megatron_checkpoint.py \
--path_to_checkpoint /fsx/bigcode/experiments/pretraining/starcoder2-1B/checkpoints/iter_0200000/mp_rank_00/model_optim_rng.pt \
--save_dir /fsx/bigcode/experiments/pretraining/starcoder2-1B/checkpoints/conversions \
--test_generation \
--tokenizer_path /fsx/loubna/data/tokenizer/starcoder2-smol-internal-1
```

For `fast-llm` use `convert_fast_llm_checkpoint.py`. For cloning and pushing models from existng iterations directly to HF hub check `push_checkpoints.py`.
304 changes: 304 additions & 0 deletions src/transformers/models/gpt_bigcode/convert_megatron_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
####################################################################################################

# Copyright (c) 2021-, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

####################################################################################################

#
# Note: If when running this conversion script you're getting an exception:
# ModuleNotFoundError: No module named 'megatron.model.enums'
# you need to tell python where to find the clone of Megatron-LM, e.g.:
#
# cd /tmp
# git clone https://github.com/NVIDIA/Megatron-LM
# PYTHONPATH=/tmp/Megatron-LM python src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py ...
#
# if you already have it cloned elsewhere, simply adjust the path to the existing path
#
# If the training was done using a Megatron-LM fork, e.g.,
# https://github.com/microsoft/Megatron-DeepSpeed/ then chances are that you need to have that one
# in your path, i.e., /path/to/Megatron-DeepSpeed/
#

import argparse
import os
import re

import torch
import math

from transformers.models.gpt_bigcode import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel


# The simple map of names for "automated" rules.
NAME_MAP = {
"attention.dense": ".attn.c_proj.",
"self_attention.dense": ".attn.c_proj.",
"mlp.dense_h_to_4h": ".mlp.c_fc.",
"mlp.dense_4h_to_h": ".mlp.c_proj.",
"self_attention.query_key_value": ".attn.c_attn.",
"self_attention.query": ".attn.q_attn.",
"self_attention.key_value": ".attn.kv_attn.",
}


def recursive_print(name, val, spaces=0):
# Format the message.
if name is None:
msg = None
else:
fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
msg = fmt.format(name)

# Print and recurse (if needed).
if isinstance(val, dict):
if msg is not None:
print(msg)
for k in val.keys():
recursive_print(k, val[k], spaces + 2)
elif isinstance(val, torch.Tensor):
print(msg, ":", val.size())
else:
print(msg, ":", val)


def convert_megatron_checkpoint(input_state_dict, merge_qkv):
# The converted output model.
output_state_dict = {}
ds_args = input_state_dict["args"]

if ds_args is not None:
# @loubnabnl fastllm uses gelu?
if ds_args.bias_gelu_fusion:
activation_function = "gelu_pytorch_tanh"
elif ds_args.openai_gelu:
activation_function = "gelu_new"
else:
activation_function = "gelu"
else:
# in the very early days this used to be "gelu_new"
activation_function = "gelu_new"

if ds_args.attention_head_type == "multihead":
multi_query = False
else:
assert ds_args.attention_head_type == "multiquery"
# @loubnabnl we don't use the no-merge-kv anymore?
# attention_type = 2 if merge_qkv else 3
multi_query = True

attention_softmax_in_fp32 = ds_args.attention_softmax_in_fp32 or ds_args.apply_query_key_layer_scaling

# Spell out all parameters in case the defaults change.
config = GPTBigCodeConfig(
architectures=["GPTBigCodeLMHeadModel"],
vocab_size=ds_args.padded_vocab_size,
n_positions=ds_args.max_position_embeddings,
n_embd=ds_args.hidden_size,
n_layer=ds_args.num_layers,
n_head=ds_args.num_attention_heads,
n_inner=ds_args.ffn_hidden_size,
activation_function=activation_function,
multi_query=multi_query,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
use_cache=True,
bos_token_id=0,
eos_token_id=0,
attention_softmax_in_fp32=attention_softmax_in_fp32,
scale_attention_softmax_in_fp32=True,
use_rotary_embeddings=ds_args.use_rotary_position_embeddings,
rotary_embedding_scale=-math.log(ds_args.rotary_theta),
use_position_embeddings=ds_args.add_position_embedding,
)

from pprint import pprint
pprint(vars(ds_args))
pprint(config)

# Megatron-LM checkpoint version
checkpoint_version = input_state_dict["checkpoint_version"]
if checkpoint_version < 2.0:
raise NotImplementedError(f"Checkpoint version {checkpoint_version} not supported.")

# The model.
model = input_state_dict["model"]["language_model"]

# The word embeddings, truncated to to vocab_size rows.
word_embeddings = model["embedding"]["word_embeddings"]["weight"][: config.vocab_size, :]
output_state_dict["transformer.wte.weight"] = word_embeddings

# The position embeddings.
output_state_dict["transformer.wpe.weight"] = model["embedding"]["position_embeddings"]["weight"]

# The transformer.
transformer = model["transformer"] if "transformer" in model else model["encoder"]

# The regex to extract layer names.
layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")

# Extract the layers.
for key, val in transformer.items():
# Match the name.
m = layer_re.match(key)

# Stop if that's not a layer
if m is None:
break

# The index of the layer.
layer_idx = int(m.group(1))
# The name of the operation.
op_name = m.group(2)
# Is it a weight or a bias?
weight_or_bias = m.group(3)

# The name of the layer.
layer_name = f"transformer.h.{layer_idx}"

# For layernorm(s), simply store the layer norm.
if op_name.endswith("layernorm"):

ln_name = "ln_1" if op_name.startswith("input") else "ln_2"
output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val

# Concatenate QKV matrix.
elif merge_qkv and (op_name == "self_attention.key_value"):
# Query is before key_value in the dict.
query = output_state_dict.pop(layer_name + ".attn.q_attn." + weight_or_bias)
out_val = torch.cat([query, val], dim=0)
output_state_dict[layer_name + ".attn.c_attn." + weight_or_bias] = out_val

# Copy the parameters.
else:
output_state_dict[layer_name + NAME_MAP[op_name] + weight_or_bias] = val

# DEBUG.
assert config.n_layer == layer_idx + 1

# The final layernorm.
output_state_dict["transformer.ln_f.weight"] = transformer["final_layernorm.weight"]
output_state_dict["transformer.ln_f.bias"] = transformer["final_layernorm.bias"]

# For LM head, transformers' wants the matrix to weight embeddings.
output_state_dict["lm_head.weight"] = word_embeddings

# It should be done!
return config, output_state_dict


def test_conversion(checkpoint_path, tokenizer_path, device="cpu", prompt=None):
from transformers import AutoTokenizer
from transformers.models.gpt_bigcode import GPTBigCodeForCausalLM
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = GPTBigCodeForCausalLM.from_pretrained(checkpoint_path, torch_dtype=torch.bfloat16, device_map=device)
prompt_1 = 'def separate_paren_groups(paren_string: str) -> List[str]:\n """ Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups(\'( ) (( )) (( )( ))\')\n [\'()\', \'(())\', \'(()())\']\n """'
prompt_2 = 'def fibonnaci(n'
prompts = [prompt_1, prompt_2]
for text in prompts:
inputs = tokenizer(text, return_tensors="pt").to(device)
print(f"Testing generation with prompt '{text}'")
print(f"Input ids: {inputs['input_ids']}")
output = model.generate(**inputs, max_new_tokens=128, do_sample=False)
print(tokenizer.decode(output[0]))


def main(argv=None):
# Create the argument parser.
parser = argparse.ArgumentParser()
parser.add_argument("--print-checkpoint-structure", action="store_true")
parser.add_argument(
"--path_to_checkpoint",
type=str,
help="Path to the checkpoint file (.zip archive or direct .pt file)",
)
parser.add_argument(
"--no_merge_qkv",
dest="merge_qkv",
action="store_false",
help="Do not merge the query and key_value tensors (MQA).",
)
parser.add_argument(
"--custom_model",
action="store_true",
help="Save as custom model so it can be used with huggingface transformers.",
)
parser.add_argument(
"--save_dir", help="Path where the converted model is saved. Will use the checkpoint directory if not provided"
)
parser.add_argument(
"--tokenizer_path",
type=str,
help="Path to the tokenizer or repo name on the HF hub for testing",
)
parser.add_argument(
"--test_generation",
action="store_true",
help="Test generation with the converted model",
)
args = parser.parse_args(argv)

# Extract the basename.
basename = args.save_dir or os.path.dirname(args.path_to_checkpoint)

# Load the model.
print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}")
input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")

# Convert.
print("Converting")
config, output_state_dict = convert_megatron_checkpoint(input_state_dict, args.merge_qkv)

# Print the structure of converted state dict.
if args.print_checkpoint_structure:
recursive_print(None, output_state_dict)

if args.custom_model:
# Save custom model
GPTBigCodeConfig.register_for_auto_class()
GPTBigCodeModel.register_for_auto_class("AutoModelForCausalLM")
hf_model = GPTBigCodeForCausalLM(config)
hf_model.load_state_dict(output_state_dict)
hf_model.save_pretrained(basename)

else:
# Store the config to file.
print("Saving config")
config.save_pretrained(basename)

# Store the state_dict to file.
output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")
print(f'Saving checkpoint to "{output_checkpoint_file}"')
torch.save(output_state_dict, output_checkpoint_file)

# test model
if args.test_generation:
print(f"Testing converted model at {args.save_dir}")
if args.tokenizer_path is None:
raise ValueError("Please provide a tokenizer path for testing")
test_conversion(checkpoint_path=args.save_dir, tokenizer_path=args.tokenizer_path)


if __name__ == "__main__":
main()
32 changes: 30 additions & 2 deletions src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _apply_rotary_embeddings(
* Convert back tho the input format.
# TODO: Full precision only needed for bfloat16? (Doesn't support complex numbers)
"""
complex_tensor = torch.view_as_complex(tensor.float().view(*tensor.shape[:-1], -1, rope_frequencies.size(-1), 2))
complex_tensor = torch.view_as_complex(tensor.float().view(*tensor.shape[:-1], -1, 2, rope_frequencies.size(-1)).transpose(-2, -1).contiguous())
return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor)


Expand Down Expand Up @@ -273,7 +273,35 @@ def forward(
# .split((self.head_dim, 2 * self.head_dim), dim=3)
# )

query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
# query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)

# Split the KV tensors based on Megatron-LM's way
c_states = self.c_attn(hidden_states)
new_tensor_shape = hidden_states.size()[:-1] + (self.kv_heads, ((self.num_heads // self.kv_heads + 2)* self.head_dim),)
c_states= c_states.view(*new_tensor_shape)
(query, key, value) = torch.split(
c_states,
[
(
self.num_heads
// self.kv_heads
* self.head_dim
),
self.head_dim,
self.head_dim,
],
dim=3,
)

query = query.reshape(query.size()[:-2] + (-1,))
key = key.reshape(key.size()[:-2] + (-1,))
value = value.reshape(value.size()[:-2] + (-1,))
key_value = torch.cat([key, value], dim=-1)
if layer_past is not None:
key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None

key, value = key_value.split((self.kv_heads * self.head_dim), dim=-1)
# key_value: (batch, sequence, 2 * kv_heads * head_dim)

if layer_past is not None:
Expand Down