From d2f0b15273732d0c987b9cb83cfca4673aa096af Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Wed, 12 Feb 2025 04:25:53 +0100 Subject: [PATCH] [TT-Transformer] Add HF_MODEL to load models directly from huggingface Co-authored-by: mtairum --- models/demos/llama3/PERF.md | 16 +++---- .../demos/llama3/tests/test_llama_accuracy.py | 29 ++++++------- models/demos/llama3/tt/llama_attention.py | 6 +-- models/demos/llama3/tt/load_checkpoints.py | 11 +++-- models/demos/llama3/tt/model_config.py | 42 +++++++++++++++---- 5 files changed, 67 insertions(+), 37 deletions(-) diff --git a/models/demos/llama3/PERF.md b/models/demos/llama3/PERF.md index 2aefa56be3c..8fb3be2baf7 100644 --- a/models/demos/llama3/PERF.md +++ b/models/demos/llama3/PERF.md @@ -11,16 +11,16 @@ This configuration uses bfp4 MLP FF1+FF3 for all models. | Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | |----------------|--------|-----------|-----------|---------------| | Llama3.2-1B | N150 | 89 | 98 | 86.9 | -| Llama3.2-1B | N300 | 91 | 98 | 104.3 | -| Llama3.2-1B | T3K | 91 | 98 | 118.5 | +| Llama3.2-1B | N300 | 90 | 98 | 104.3 | +| Llama3.2-1B | T3K | 87 | 98 | 118.5 | | Llama3.2-1B | TG | | | 72.3 | -| Llama3.2-3B | N150 | 92 | 96 | 53.3 | +| Llama3.2-3B | N150 | 91 | 96 | 53.3 | | Llama3.2-3B | N300 | 91 | 96 | 66.1 | | Llama3.2-3B | T3K | 91 | 96 | 66.9 | | Llama3.2-3B | TG | | | 48.5 | | Llama3.1-8B | N150 | 87 | 99 | 27.9 | | Llama3.1-8B | N300 | 88 | 99 | 43.7 | -| Llama3.1-8B | T3K | 88 | 100 | 64.2 | +| Llama3.1-8B | T3K | 88 | 99 | 64.2 | | Llama3.1-8B | TG | | | 41.0 | | Llama3.2-11B | N300 | 89 | 99 | 43.5 | | Llama3.2-11B | T3K | 88 | 99 | 63.4 | @@ -37,12 +37,12 @@ This configuration uses bfp4 MLP FF1+FF3 only for the Llama-3.1-70B model and th | Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | |----------------|--------|-----------|-----------|---------------| | Llama3.2-1B | N150 | 88 | 98 | 86.8 | -| Llama3.2-1B | N300 | 90 | 98 | 98.1 | -| Llama3.2-1B | T3K | 90 | 98 | 97.5 | +| Llama3.2-1B | N300 | 88 | 98 | 98.1 | +| Llama3.2-1B | T3K | 89 | 99 | 97.5 | | Llama3.2-1B | TG | 87 | 98 | 51.3 | -| Llama3.2-3B | N150 | 93 | 99 | 44.2 | +| Llama3.2-3B | N150 | 92 | 99 | 44.2 | | Llama3.2-3B | N300 | 92 | 98 | 54.2 | -| Llama3.2-3B | T3K | 93 | 98 | 55.6 | +| Llama3.2-3B | T3K | 91 | 100 | 55.6 | | Llama3.2-3B | TG | 91 | 98 | 33.6 | | Llama3.1-8B | N150 | 93 | 100 | 23.6 | | Llama3.1-8B | N300 | 93 | 100 | 34.5 | diff --git a/models/demos/llama3/tests/test_llama_accuracy.py b/models/demos/llama3/tests/test_llama_accuracy.py index d0fd2d2a15b..5a40dec57ac 100644 --- a/models/demos/llama3/tests/test_llama_accuracy.py +++ b/models/demos/llama3/tests/test_llama_accuracy.py @@ -157,7 +157,7 @@ def test_tt_model_acc( text = f.read() # Encode text to tokens - encoded_tokens = tokenizer.encode(text, bos=True, eos=False) + encoded_tokens = model_args.encode_prompt(text, system_prompt_text=None, instruct=False) total_length = prefill_len + decode_len + 1 reference_tokens = torch.tensor(encoded_tokens[:total_length]).unsqueeze(0) top5_tokens = None # Will be computed during inference @@ -439,17 +439,18 @@ def test_tt_model_acc( true_word = sanitize(tokenizer.decode([true_token])) logger.info(f"{error['position']}: {context}[{incorrect}] != [{expected}], true: [{true_word}]") - # Get accuracy thresholds from PERF.md - min_top1_acc, min_top5_acc = get_accuracy_thresholds( - model_args.base_model_name, - model_args.device_name, - optimizations, - ) + if use_reference_file: + # Get accuracy thresholds from PERF.md + min_top1_acc, min_top5_acc = get_accuracy_thresholds( + model_args.base_model_name, + model_args.device_name, + optimizations, + ) - logger.info(f"Top-1: {total_top1_acc:.0f}% | Top-5: {total_top5_acc:.0f}%") - assert ( - total_top1_acc >= min_top1_acc - ), f"Top-1 accuracy {total_top1_acc:.1f}% is too low (expected >={min_top1_acc}%)" - assert ( - total_top5_acc >= min_top5_acc - ), f"Top-5 accuracy {total_top5_acc:.1f}% is too low (expected >={min_top5_acc}%)" + logger.info(f"Top-1: {total_top1_acc:.0f}% | Top-5: {total_top5_acc:.0f}%") + assert ( + total_top1_acc >= min_top1_acc + ), f"Top-1 accuracy {total_top1_acc:.1f}% is too low (expected >={min_top1_acc}%)" + assert ( + total_top5_acc >= min_top5_acc + ), f"Top-5 accuracy {total_top5_acc:.1f}% is too low (expected >={min_top5_acc}%)" diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index ac67c80f1c2..a8c8581dc98 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -8,8 +8,6 @@ import ttnn from models.common.lightweightmodule import LightweightModule from models.demos.llama3.tt.llama_ccl import tt_all_reduce, tt_all_gather -from models.demos.llama3.tt.llama_common import first_five -from models.demos.llama3.tt.load_checkpoints import permute class TtLlamaAttention(LightweightModule): @@ -138,7 +136,9 @@ def __init__( ) # as_tensor returns (32, dim) which is incorrect, this reshape updates the padded size to the correct size self.wqkv_bias_prefill = ttnn.reshape( - self.wqkv_bias_prefill, ttnn.Shape([1, 1, 1, self.wqkv_bias_prefill.shape[-1]]) + self.wqkv_bias_prefill, + (1, 1, 1, self.wqkv_bias_prefill.shape[-1]), + (1, 1, self.wqkv_bias_prefill.shape[-2], self.wqkv_bias_prefill.shape[-1]), ) # Broadcasting does not seem to be supported inside execute_trace so expand to the whole batch size diff --git a/models/demos/llama3/tt/load_checkpoints.py b/models/demos/llama3/tt/load_checkpoints.py index 7e330a2e18d..f85788ee1e3 100644 --- a/models/demos/llama3/tt/load_checkpoints.py +++ b/models/demos/llama3/tt/load_checkpoints.py @@ -37,13 +37,16 @@ def load_hf_state_dict(ckpt_dir): raise FileNotFoundError(f"Neither model.safetensors.index.json nor model.safetensors found in {ckpt_dir}") loaded_weights = safetensors_load_file(safetensor_path) - if not "lm_head.weight" in loaded_weights: - # Assume tied to the embeddings if not present - loaded_weights["lm_head.weight"] = loaded_weights["model.embed_tokens.weight"] - return loaded_weights +def standardize_hf_keys(state_dict): + if not "lm_head.weight" in state_dict: + # Assume tied to the embeddings if not present + state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"] + return state_dict + + def convert_hf_to_meta(state_dict, head_dim): state_dict = convert_hf_qkv_to_meta_format(state_dict, head_dim) state_dict = map_hf_to_meta_keys(state_dict) diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index db7b9e207c5..c58ea0a9eaa 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -31,6 +31,7 @@ convert_hf_to_meta, convert_meta_to_hf, reverse_permute, + standardize_hf_keys, ) @@ -114,8 +115,10 @@ def __init__( self.max_batch_size = max_batch_size self.tile_size = 32 self.is_70b = False + self.from_hf_url = False # updated below if true LLAMA_DIR = os.getenv("LLAMA_DIR") + HF_MODEL = os.getenv("HF_MODEL") if LLAMA_DIR: if any([os.getenv("LLAMA_CKPT_DIR"), os.getenv("LLAMA_TOKENIZER_PATH"), os.getenv("LLAMA_CACHE_PATH")]): logger.warning( @@ -125,10 +128,18 @@ def __init__( self.DEFAULT_TOKENIZER_PATH = LLAMA_DIR self.DEFAULT_CACHE_PATH = os.path.join(LLAMA_DIR, self.device_name) self.model_name = os.path.basename(LLAMA_DIR) # May be overridden by config + elif HF_MODEL: + self.DEFAULT_CKPT_DIR = HF_MODEL + self.DEFAULT_TOKENIZER_PATH = HF_MODEL + self.DEFAULT_CACHE_PATH = os.getenv("LLAMA_CACHE_PATH") + if not self.DEFAULT_CACHE_PATH: + self.DEFAULT_CACHE_PATH = os.path.join("model_cache", HF_MODEL, self.device_name) + self.model_name = HF_MODEL # May be overridden by config + self.from_hf_url = True else: assert "Please set $LLAMA_DIR to a valid checkpoint directory" - if not dummy_weights: + if not dummy_weights and not HF_MODEL: # Assert if all folders and files exist assert os.path.exists( self.DEFAULT_CKPT_DIR @@ -157,7 +168,10 @@ def __init__( self.instruct = True # Load model params - if not dummy_weights: + if HF_MODEL: + self.checkpoint_type = CheckpointType.HuggingFace + self._set_hf_params(self.DEFAULT_CKPT_DIR) + elif not dummy_weights: self.checkpoint_type = self.detect_checkpoint_type() self._set_model_params(self.DEFAULT_CKPT_DIR) else: # With Dummy weights, set the params from the local copy inside the model folder. This is required for CI pipeline that doesn't mount the external folders. @@ -1107,10 +1121,15 @@ def _set_llama_params(self, checkpoint_dir): self.orig_context_len = 8192 def _set_hf_params(self, checkpoint_dir): - config_file = os.path.join(checkpoint_dir, "config.json") - assert os.path.exists(config_file), f"config.json file not found at {config_file}" - with open(config_file, "r") as f: - config = json.load(f) + if self.from_hf_url: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(self.model_name).to_dict() + else: + config_file = os.path.join(checkpoint_dir, "config.json") + assert os.path.exists(config_file), f"config.json file not found at {config_file}" + with open(config_file, "r") as f: + config = json.load(f) self._set_params_from_dict(config) def __repr__(self): @@ -1172,7 +1191,14 @@ def load_state_dict(self): state_dict = load_meta_state_dict(self.DEFAULT_CKPT_DIR, self.n_layers) else: assert self.checkpoint_type == CheckpointType.HuggingFace - state_dict = load_hf_state_dict(self.DEFAULT_CKPT_DIR) + if self.from_hf_url: + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained(self.DEFAULT_CKPT_DIR) + state_dict = model.state_dict() + else: + state_dict = load_hf_state_dict(self.DEFAULT_CKPT_DIR) + state_dict = standardize_hf_keys(state_dict) state_dict = convert_hf_to_meta(state_dict, self.head_dim) keys_dict = list(state_dict.keys())[:] remv = [f"layers.{i}." for i in list(range(self.n_layers, self.full_model_n_layers))] @@ -1210,7 +1236,7 @@ def matmul_config( ) # TODO: Needed for TG hang workaround if in0_block_w is None: - in0_block_w = min(4, max(1, k // (self.tile_size * grid_size[0]))) + in0_block_w = self.find_largest_divisor(k // (self.tile_size * grid_size[1])) return ttnn.MatmulMultiCoreReuseMultiCastProgramConfig( compute_with_storage_grid_size=grid_size,