Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make modeling compatible with Nanotron + few optims #23

Closed
Closed
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7384efb
implement GPT Neo's rope
NouamaneTazi Oct 25, 2023
7ebc9ea
fix imports
NouamaneTazi Oct 28, 2023
2c28b04
output logits
NouamaneTazi Oct 28, 2023
253da5b
attn mask.all()
NouamaneTazi Oct 30, 2023
d3c15ba
fix caching in rope
NouamaneTazi Oct 31, 2023
1e74664
GQA generation without cache
NouamaneTazi Nov 1, 2023
1f424bb
fix use_cache for GQA
NouamaneTazi Nov 1, 2023
39a3483
reshapes fixes for num_heads=2
NouamaneTazi Nov 1, 2023
1c79ecd
.
NouamaneTazi Nov 2, 2023
19cf153
add flash_attn_with_kvcache to GQA
NouamaneTazi Dec 7, 2023
b493268
add merging word embedding checkpoints
xrsrke Dec 29, 2023
4446fe0
add merging quite a bit
xrsrke Dec 31, 2023
1d949b2
add reference starcoder model
xrsrke Dec 31, 2023
a58a947
merged most of the checkpoints
xrsrke Dec 31, 2023
ac559a1
add merged checkpoints
xrsrke Jan 1, 2024
78114b7
add mapping to target state dict
xrsrke Jan 1, 2024
7d50b80
refactor converting scrip
xrsrke Jan 2, 2024
21ee689
refactor
xrsrke Jan 3, 2024
210311b
add inference script
xrsrke Jan 3, 2024
09c086a
refactor
xrsrke Jan 3, 2024
ae54653
refactor all functions
xrsrke Jan 3, 2024
594099c
save some files before cleaning it all
xrsrke Jan 3, 2024
fb8a86b
delete uncessary files
xrsrke Jan 3, 2024
c26472c
add rope_theta to config
NouamaneTazi Jan 5, 2024
9c9cfbb
fix config.attn_pdrop for flash attn
NouamaneTazi Jan 8, 2024
6bdf78a
Merge pull request #1 from xrsrke/sc2-rope
NouamaneTazi Jan 8, 2024
1507798
Refactor GPTBigCode model conversion code
NouamaneTazi Jan 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor all functions
xrsrke committed Jan 3, 2024
commit ae54653bff3222a11f9df92c8226503157f7c91f
188 changes: 84 additions & 104 deletions src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import re
from tqdm import tqdm
from pathlib import Path

import numpy as np
import torch
import yaml
from safetensors import safe_open
from collections import defaultdict

@@ -47,119 +44,102 @@ def get_safetensor_checkpoint_paths(checkpoint_dir: Path):

return safetensor_files

def merge_checkpoint(checkpoint_dir: Path):
"""Load a checkpoint from the BRRR format and merge tensor parallel shards."""
checkpoint_paths = get_safetensor_checkpoint_paths(checkpoint_dir)

def transform_paths(paths):
# Convert to Path objects and find common prefix
path_objs = [Path(p) for p in paths]
common_path_prefix = Path(commonprefix(path_objs)).parent

# Initialize the final paths dictionary
final_paths = {}
def transform_paths(paths):
path_objs = [Path(p) for p in paths]
common_path_prefix = Path(commonprefix(path_objs)).parent

for path in path_objs:
# Relative path
relative_path = str(path.relative_to(common_path_prefix))
final_paths = {}
for path in path_objs:
relative_path = str(path.relative_to(common_path_prefix))
dot_path = relative_path.replace('/', '.')

weight_replaced = re.sub(r'model_weight_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'weight.\1', dot_path)
bias_replaced = re.sub(r'model_bias_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'bias.\1', weight_replaced)
cleaned_path = bias_replaced.replace('.safetensors', '')

# Convert slashes to dots
dot_path = relative_path.replace('/', '.')
final_paths[cleaned_path] = str(path)

# Replace patterns for model weights and biases
weight_replaced = re.sub(r'model_weight_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'weight.\1', dot_path)
bias_replaced = re.sub(r'model_bias_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'bias.\1', weight_replaced)
return final_paths

# Remove '.safetensors' extension
cleaned_path = bias_replaced.replace('.safetensors', '')
def group_and_sort_paths(paths):
grouped_paths = defaultdict(list)

# Add to final dictionary
final_paths[cleaned_path] = str(path)
for key, path in paths.items():
try:
module_name, shard_number = key.rsplit('.', 1)
grouped_paths[module_name].append((int(shard_number), path))
except ValueError:
# NOTE: these are layer norm's weight and biases
# so it don't have shard number
grouped_paths[key].append(path)

return final_paths
# Remove any entries with empty lists
grouped_paths = {k: v for k, v in grouped_paths.items() if v}

paths = transform_paths(checkpoint_paths)

def group_and_sort_paths(paths):
grouped_paths = defaultdict(list)

for key, path in paths.items():
try:
module_name, shard_number = key.rsplit('.', 1)
grouped_paths[module_name].append((int(shard_number), path))
except ValueError:
# Handle cases where the key does not split into two parts
print(f"skipped {key}, {path}")
grouped_paths[key].append(path)

# Remove any entries with empty lists
grouped_paths = {k: v for k, v in grouped_paths.items() if v}

# Sort paths in each group
sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0])
for module, paths in grouped_paths.items()}

return sorted_grouped_paths

paths = group_and_sort_paths(paths)

def merge_checkpoints(paths):
def find_corresponding_dim(name):
for key, value in MERGE_DIM_MAPPING.items():
if key in name:
return value
return None

model_states = {}
for state_key, path in paths.items():
model_states[state_key] = {}
for shard_id, _path in enumerate(path):
checkpoint_path = _path[1] if isinstance(_path, tuple) else _path
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
for key in f.keys():
data = f.get_tensor(key)
model_states[state_key][shard_id] = data

tensor_list = [tensor for _, tensor in sorted(model_states[state_key].items())]
merge_dim = find_corresponding_dim(state_key)
print(f"trying to merge: {state_key}")

if len(tensor_list) > 1:
try:
model_states[state_key] = torch.cat(tensor_list, dim=merge_dim)
except:
print(f"skipped {state_key}, {[x.shape for x in tensor_list]}")
else:
# NOTE: these are biases
model_states[state_key] = tensor_list[0]
return model_states

model_states = merge_checkpoints(paths)
# NOTE: Sort paths in each group
# module: [(4, path), (0, path), (3, path) ...] -> module: [(0, path), (1, path), (2, path) ...]
sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0])
for module, paths in grouped_paths.items()}

def remap_keys(target_dict):
new_dict = {}
for key, value in target_dict.items():
parts = key.split('.')

if 'model.decoder' in key and 'pp_block' in key:
block_number = parts[2]
component_parts = parts[4:]
component = '.'.join(component_parts)

new_component = BRRR_TFMS_NAME_MAPPING.get(component, component)
new_key = f"transformer.h.{block_number}.{new_component}"
new_dict[new_key] = value
elif key == 'model.final_layer_norm.pp_block.model_weight':
new_dict['transformer.ln_f.weight'] = value
elif key == 'model.final_layer_norm.pp_block.model_bias':
new_dict['transformer.ln_f.bias'] = value
return sorted_grouped_paths

def merge_checkpoints(paths):
def find_corresponding_dim(name):
for key, value in MERGE_DIM_MAPPING.items():
if key in name:
return value
return None

model_states = {}
for state_key, path in paths.items():
model_states[state_key] = {}
for shard_id, _path in enumerate(path):
checkpoint_path = _path[1] if isinstance(_path, tuple) else _path
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
for key in f.keys():
data = f.get_tensor(key)
model_states[state_key][shard_id] = data

tensor_list = [tensor for _, tensor in sorted(model_states[state_key].items())]
merge_dim = find_corresponding_dim(state_key)

if len(tensor_list) > 1:
model_states[state_key] = torch.cat(tensor_list, dim=merge_dim)
else:
# NOTE: these are biases
model_states[state_key] = tensor_list[0]
return model_states

elif key == 'model.token_embeddings.pp_block.token_embedding.weight':
new_dict['transformer.wte.weight'] = value
def remap_keys(target_dict):
key_mapping = {
'model.final_layer_norm.pp_block.model_weight': 'transformer.ln_f.weight',
'model.final_layer_norm.pp_block.model_bias': 'transformer.ln_f.bias',
'model.token_embeddings.pp_block.token_embedding.weight': 'transformer.wte.weight'
}

new_dict["lm_head.weight"] = new_dict["transformer.wte.weight"]
return new_dict
def get_new_key(key):
if 'model.decoder' in key and 'pp_block' in key:
parts = key.split('.')
block_number = parts[2]
component_parts = parts[4:]
component = '.'.join(component_parts)
new_component = BRRR_TFMS_NAME_MAPPING.get(component, component)
return f"transformer.h.{block_number}.{new_component}"
else:
return key_mapping.get(key, key)

new_dict = {get_new_key(key): value for key, value in target_dict.items()}
new_dict["lm_head.weight"] = new_dict.get("transformer.wte.weight", new_dict.get("lm_head.weight"))

return new_dict

def merge_checkpoint(checkpoint_dir: Path):
"""Load a checkpoint from the BRRR format and merge tensor parallel shards."""
checkpoint_paths = get_safetensor_checkpoint_paths(checkpoint_dir)
paths = transform_paths(checkpoint_paths)
paths = group_and_sort_paths(paths)
model_states = merge_checkpoints(paths)
model_states = remap_keys(model_states)

return model_states