diff --git a/model.py b/model.py index 4825c96..b89a19a 100644 --- a/model.py +++ b/model.py @@ -44,14 +44,14 @@ def from_name(cls, name: str): if name in transformer_configs: return cls(**transformer_configs[name]) # fuzzy search - config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] + config = [config for config in transformer_configs if config.lower() in str(name).lower()] # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, # take longer name (as it have more symbols matched) if len(config) > 1: config.sort(key=len, reverse=True) assert len(config[0]) != len(config[1]), name # make sure only one 'best' match - + return cls(**transformer_configs[config[0]]) @@ -65,7 +65,9 @@ def from_name(cls, name: str): "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), "stories15M": dict(n_layer=6, n_head=6, dim=288), "stories110M": dict(n_layer=12, n_head=12, dim=768), - "Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000), + + "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000), + "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000), } class KVCache(nn.Module): diff --git a/requirements.txt b/requirements.txt index 04f828c..9be5ad2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ torch sentencepiece tiktoken +blobfile +safetensors diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 8a22106..d3a64d9 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -9,7 +9,7 @@ import sys from pathlib import Path from typing import Optional - +from safetensors.torch import load_file as load_safetensors_file import torch # support running without installing as a package @@ -28,62 +28,49 @@ def convert_hf_checkpoint( if model_name is None: model_name = checkpoint_dir.name - # Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files - # need to be copied into model.pth. - # Llama 3 70B can't be easily merged into one model.pth file, though, since names of the - # weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not - # currently supported. - # Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken - is_llama3 = "Llama-3" in model_name - if is_llama3: - # Check if we have multiple original/consolidated.NN.pth files and report error - # if we do for Llama 3. - original_dir = checkpoint_dir / "original" - pattern = re.compile(r"^consolidated\.\d{2}\.pth$") - bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)] - if len(bin_files) > 1: - raise ValueError( - f"Multiple consolidated.NN.pth files found in {original_dir}. " - "Merging them into one model.pth file is not supported for Llama 3.") - - config = ModelArgs.from_name(model_name) print(f"Model config {config.__dict__}") # Load the json file containing weight mapping - if not is_llama3: - model_map_json = checkpoint_dir / "pytorch_model.bin.index.json" - - assert model_map_json.is_file() - - with open(model_map_json) as json_map: - bin_index = json.load(json_map) - - weight_map = { - "model.embed_tokens.weight": "tok_embeddings.weight", - "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", - "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", - "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", - "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, - 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', - "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", - "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", - "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", - "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", - "model.norm.weight": "norm.weight", - "lm_head.weight": "output.weight", - } - bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} - else: - # There is no separate pytorch_model.bin.index.json file for llama3. - # Instead, we will just use all original/consolidated.NN.pth files. - # so, we use model.safetensors.index.json - weight_map = None - original_dir = checkpoint_dir / "original" - pattern = re.compile(r"^consolidated\.\d{2}\.pth$") - bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)} - + model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json' + model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json" + model_map_json = None + + try: + assert model_map_json_safetensors.is_file() + model_map_json = model_map_json_safetensors + print(f"Found safetensors index at {model_map_json_safetensors}") + except AssertionError: + print(f"{model_map_json_safetensors} not found") + if model_map_json is None: + try: + assert model_map_json_pytorch.is_file() + model_map_json = model_map_json_pytorch + print(f"Found pytorch index at {model_map_json_pytorch}") + except AssertionError: + print(f"{model_map_json_pytorch} not found") + + if model_map_json is None: raise Exception("No model map found!") + + with open(model_map_json) as json_map: + bin_index = json.load(json_map) + + weight_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, + 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} def permute(w, n_head): dim = config.dim @@ -95,39 +82,40 @@ def permute(w, n_head): merged_result = {} for file in sorted(bin_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) - merged_result.update(state_dict) + if "safetensors" in str(file): + state_dict = load_safetensors_file(str(file), device="cpu") + merged_result.update(state_dict) + else: + state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + merged_result.update(state_dict) final_result = {} - if weight_map is not None: - for key, value in merged_result.items(): - if "layers" in key: - abstract_key = re.sub(r'(\d+)', '{}', key) - layer_num = re.search(r'\d+', key).group(0) - new_key = weight_map[abstract_key] - if new_key is None: - continue - new_key = new_key.format(layer_num) - else: - new_key = weight_map[key] - - final_result[new_key] = value - - for key in tuple(final_result.keys()): - if "wq" in key: - q = final_result[key] - k = final_result[key.replace("wq", "wk")] - v = final_result[key.replace("wq", "wv")] - q = permute(q, config.n_head) - k = permute(k, config.n_local_heads) - final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) - del final_result[key] - del final_result[key.replace("wq", "wk")] - del final_result[key.replace("wq", "wv")] - else: - final_result = merged_result + for key, value in merged_result.items(): + if "layers" in key: + abstract_key = re.sub(r'(\d+)', '{}', key) + layer_num = re.search(r'\d+', key).group(0) + new_key = weight_map[abstract_key] + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = weight_map[key] + + final_result[new_key] = value + + for key in tuple(final_result.keys()): + if "wq" in key: + q = final_result[key] + k = final_result[key.replace("wq", "wk")] + v = final_result[key.replace("wq", "wv")] + q = permute(q, config.n_head) + k = permute(k, config.n_local_heads) + final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) + del final_result[key] + del final_result[key.replace("wq", "wk")] + del final_result[key.replace("wq", "wv")] print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") torch.save(final_result, checkpoint_dir / "model.pth") - if is_llama3: + if 'llama-3' in model_name.lower(): original_dir = checkpoint_dir / "original" tokenizer_model = original_dir / "tokenizer.model" tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" diff --git a/tokenizer.py b/tokenizer.py index c62a0c5..f60b3c1 100644 --- a/tokenizer.py +++ b/tokenizer.py @@ -105,7 +105,8 @@ def get_tokenizer(tokenizer_model_path, model_name): Returns: - TokenizerInterface: An instance of a tokenizer. """ - if "Llama-3" in str(model_name): + + if "llama-3" in str(model_name).lower(): return TiktokenWrapper(tokenizer_model_path) else: return SentencePieceWrapper(tokenizer_model_path)