Skip to content

Commit

Permalink
[TT-Transformer] Add HF_MODEL to load models directly from huggingface
Browse files Browse the repository at this point in the history
Co-authored-by: mtairum <[email protected]>
  • Loading branch information
yieldthought and mtairum authored Feb 12, 2025
1 parent 441142f commit d2f0b15
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 37 deletions.
16 changes: 8 additions & 8 deletions models/demos/llama3/PERF.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ This configuration uses bfp4 MLP FF1+FF3 for all models.
| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |
|----------------|--------|-----------|-----------|---------------|
| Llama3.2-1B | N150 | 89 | 98 | 86.9 |
| Llama3.2-1B | N300 | 91 | 98 | 104.3 |
| Llama3.2-1B | T3K | 91 | 98 | 118.5 |
| Llama3.2-1B | N300 | 90 | 98 | 104.3 |
| Llama3.2-1B | T3K | 87 | 98 | 118.5 |
| Llama3.2-1B | TG | | | 72.3 |
| Llama3.2-3B | N150 | 92 | 96 | 53.3 |
| Llama3.2-3B | N150 | 91 | 96 | 53.3 |
| Llama3.2-3B | N300 | 91 | 96 | 66.1 |
| Llama3.2-3B | T3K | 91 | 96 | 66.9 |
| Llama3.2-3B | TG | | | 48.5 |
| Llama3.1-8B | N150 | 87 | 99 | 27.9 |
| Llama3.1-8B | N300 | 88 | 99 | 43.7 |
| Llama3.1-8B | T3K | 88 | 100 | 64.2 |
| Llama3.1-8B | T3K | 88 | 99 | 64.2 |
| Llama3.1-8B | TG | | | 41.0 |
| Llama3.2-11B | N300 | 89 | 99 | 43.5 |
| Llama3.2-11B | T3K | 88 | 99 | 63.4 |
Expand All @@ -37,12 +37,12 @@ This configuration uses bfp4 MLP FF1+FF3 only for the Llama-3.1-70B model and th
| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |
|----------------|--------|-----------|-----------|---------------|
| Llama3.2-1B | N150 | 88 | 98 | 86.8 |
| Llama3.2-1B | N300 | 90 | 98 | 98.1 |
| Llama3.2-1B | T3K | 90 | 98 | 97.5 |
| Llama3.2-1B | N300 | 88 | 98 | 98.1 |
| Llama3.2-1B | T3K | 89 | 99 | 97.5 |
| Llama3.2-1B | TG | 87 | 98 | 51.3 |
| Llama3.2-3B | N150 | 93 | 99 | 44.2 |
| Llama3.2-3B | N150 | 92 | 99 | 44.2 |
| Llama3.2-3B | N300 | 92 | 98 | 54.2 |
| Llama3.2-3B | T3K | 93 | 98 | 55.6 |
| Llama3.2-3B | T3K | 91 | 100 | 55.6 |
| Llama3.2-3B | TG | 91 | 98 | 33.6 |
| Llama3.1-8B | N150 | 93 | 100 | 23.6 |
| Llama3.1-8B | N300 | 93 | 100 | 34.5 |
Expand Down
29 changes: 15 additions & 14 deletions models/demos/llama3/tests/test_llama_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_tt_model_acc(
text = f.read()

# Encode text to tokens
encoded_tokens = tokenizer.encode(text, bos=True, eos=False)
encoded_tokens = model_args.encode_prompt(text, system_prompt_text=None, instruct=False)
total_length = prefill_len + decode_len + 1
reference_tokens = torch.tensor(encoded_tokens[:total_length]).unsqueeze(0)
top5_tokens = None # Will be computed during inference
Expand Down Expand Up @@ -439,17 +439,18 @@ def test_tt_model_acc(
true_word = sanitize(tokenizer.decode([true_token]))
logger.info(f"{error['position']}: {context}[{incorrect}] != [{expected}], true: [{true_word}]")

# Get accuracy thresholds from PERF.md
min_top1_acc, min_top5_acc = get_accuracy_thresholds(
model_args.base_model_name,
model_args.device_name,
optimizations,
)
if use_reference_file:
# Get accuracy thresholds from PERF.md
min_top1_acc, min_top5_acc = get_accuracy_thresholds(
model_args.base_model_name,
model_args.device_name,
optimizations,
)

logger.info(f"Top-1: {total_top1_acc:.0f}% | Top-5: {total_top5_acc:.0f}%")
assert (
total_top1_acc >= min_top1_acc
), f"Top-1 accuracy {total_top1_acc:.1f}% is too low (expected >={min_top1_acc}%)"
assert (
total_top5_acc >= min_top5_acc
), f"Top-5 accuracy {total_top5_acc:.1f}% is too low (expected >={min_top5_acc}%)"
logger.info(f"Top-1: {total_top1_acc:.0f}% | Top-5: {total_top5_acc:.0f}%")
assert (
total_top1_acc >= min_top1_acc
), f"Top-1 accuracy {total_top1_acc:.1f}% is too low (expected >={min_top1_acc}%)"
assert (
total_top5_acc >= min_top5_acc
), f"Top-5 accuracy {total_top5_acc:.1f}% is too low (expected >={min_top5_acc}%)"
6 changes: 3 additions & 3 deletions models/demos/llama3/tt/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import ttnn
from models.common.lightweightmodule import LightweightModule
from models.demos.llama3.tt.llama_ccl import tt_all_reduce, tt_all_gather
from models.demos.llama3.tt.llama_common import first_five
from models.demos.llama3.tt.load_checkpoints import permute


class TtLlamaAttention(LightweightModule):
Expand Down Expand Up @@ -138,7 +136,9 @@ def __init__(
)
# as_tensor returns (32, dim) which is incorrect, this reshape updates the padded size to the correct size
self.wqkv_bias_prefill = ttnn.reshape(
self.wqkv_bias_prefill, ttnn.Shape([1, 1, 1, self.wqkv_bias_prefill.shape[-1]])
self.wqkv_bias_prefill,
(1, 1, 1, self.wqkv_bias_prefill.shape[-1]),
(1, 1, self.wqkv_bias_prefill.shape[-2], self.wqkv_bias_prefill.shape[-1]),
)

# Broadcasting does not seem to be supported inside execute_trace so expand to the whole batch size
Expand Down
11 changes: 7 additions & 4 deletions models/demos/llama3/tt/load_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ def load_hf_state_dict(ckpt_dir):
raise FileNotFoundError(f"Neither model.safetensors.index.json nor model.safetensors found in {ckpt_dir}")
loaded_weights = safetensors_load_file(safetensor_path)

if not "lm_head.weight" in loaded_weights:
# Assume tied to the embeddings if not present
loaded_weights["lm_head.weight"] = loaded_weights["model.embed_tokens.weight"]

return loaded_weights


def standardize_hf_keys(state_dict):
if not "lm_head.weight" in state_dict:
# Assume tied to the embeddings if not present
state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"]
return state_dict


def convert_hf_to_meta(state_dict, head_dim):
state_dict = convert_hf_qkv_to_meta_format(state_dict, head_dim)
state_dict = map_hf_to_meta_keys(state_dict)
Expand Down
42 changes: 34 additions & 8 deletions models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
convert_hf_to_meta,
convert_meta_to_hf,
reverse_permute,
standardize_hf_keys,
)


Expand Down Expand Up @@ -114,8 +115,10 @@ def __init__(
self.max_batch_size = max_batch_size
self.tile_size = 32
self.is_70b = False
self.from_hf_url = False # updated below if true

LLAMA_DIR = os.getenv("LLAMA_DIR")
HF_MODEL = os.getenv("HF_MODEL")
if LLAMA_DIR:
if any([os.getenv("LLAMA_CKPT_DIR"), os.getenv("LLAMA_TOKENIZER_PATH"), os.getenv("LLAMA_CACHE_PATH")]):
logger.warning(
Expand All @@ -125,10 +128,18 @@ def __init__(
self.DEFAULT_TOKENIZER_PATH = LLAMA_DIR
self.DEFAULT_CACHE_PATH = os.path.join(LLAMA_DIR, self.device_name)
self.model_name = os.path.basename(LLAMA_DIR) # May be overridden by config
elif HF_MODEL:
self.DEFAULT_CKPT_DIR = HF_MODEL
self.DEFAULT_TOKENIZER_PATH = HF_MODEL
self.DEFAULT_CACHE_PATH = os.getenv("LLAMA_CACHE_PATH")
if not self.DEFAULT_CACHE_PATH:
self.DEFAULT_CACHE_PATH = os.path.join("model_cache", HF_MODEL, self.device_name)
self.model_name = HF_MODEL # May be overridden by config
self.from_hf_url = True
else:
assert "Please set $LLAMA_DIR to a valid checkpoint directory"

if not dummy_weights:
if not dummy_weights and not HF_MODEL:
# Assert if all folders and files exist
assert os.path.exists(
self.DEFAULT_CKPT_DIR
Expand Down Expand Up @@ -157,7 +168,10 @@ def __init__(
self.instruct = True

# Load model params
if not dummy_weights:
if HF_MODEL:
self.checkpoint_type = CheckpointType.HuggingFace
self._set_hf_params(self.DEFAULT_CKPT_DIR)
elif not dummy_weights:
self.checkpoint_type = self.detect_checkpoint_type()
self._set_model_params(self.DEFAULT_CKPT_DIR)
else: # With Dummy weights, set the params from the local copy inside the model folder. This is required for CI pipeline that doesn't mount the external folders.
Expand Down Expand Up @@ -1107,10 +1121,15 @@ def _set_llama_params(self, checkpoint_dir):
self.orig_context_len = 8192

def _set_hf_params(self, checkpoint_dir):
config_file = os.path.join(checkpoint_dir, "config.json")
assert os.path.exists(config_file), f"config.json file not found at {config_file}"
with open(config_file, "r") as f:
config = json.load(f)
if self.from_hf_url:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(self.model_name).to_dict()
else:
config_file = os.path.join(checkpoint_dir, "config.json")
assert os.path.exists(config_file), f"config.json file not found at {config_file}"
with open(config_file, "r") as f:
config = json.load(f)
self._set_params_from_dict(config)

def __repr__(self):
Expand Down Expand Up @@ -1172,7 +1191,14 @@ def load_state_dict(self):
state_dict = load_meta_state_dict(self.DEFAULT_CKPT_DIR, self.n_layers)
else:
assert self.checkpoint_type == CheckpointType.HuggingFace
state_dict = load_hf_state_dict(self.DEFAULT_CKPT_DIR)
if self.from_hf_url:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(self.DEFAULT_CKPT_DIR)
state_dict = model.state_dict()
else:
state_dict = load_hf_state_dict(self.DEFAULT_CKPT_DIR)
state_dict = standardize_hf_keys(state_dict)
state_dict = convert_hf_to_meta(state_dict, self.head_dim)
keys_dict = list(state_dict.keys())[:]
remv = [f"layers.{i}." for i in list(range(self.n_layers, self.full_model_n_layers))]
Expand Down Expand Up @@ -1210,7 +1236,7 @@ def matmul_config(
) # TODO: Needed for TG hang workaround

if in0_block_w is None:
in0_block_w = min(4, max(1, k // (self.tile_size * grid_size[0])))
in0_block_w = self.find_largest_divisor(k // (self.tile_size * grid_size[1]))

return ttnn.MatmulMultiCoreReuseMultiCastProgramConfig(
compute_with_storage_grid_size=grid_size,
Expand Down

0 comments on commit d2f0b15

Please sign in to comment.