diff --git a/models/demos/t3000/falcon40b/demo/demo.py b/models/demos/t3000/falcon40b/demo/demo.py index 75d615d0f8e..2ecf3a8d0c9 100644 --- a/models/demos/t3000/falcon40b/demo/demo.py +++ b/models/demos/t3000/falcon40b/demo/demo.py @@ -33,6 +33,22 @@ SPACE = 204 +# Used for debugging non-deterministic outputs of prefill stage +def save_kv_cache_to_file(device_mesh, kv_cache, kv_cache_path): + # generate tensor of 60 layers and key and value tensors for each layer where there is 60 layers, key and value and tensor shape (32, 1, 128, 64) + final_tensor = torch.zeros(60, 2, 32, 1, 128, 512) + for layer in range(60): + for type in range(len(kv_cache[layer])): + # get key tensor from device + tensor = ttnn.to_torch( + kv_cache[layer][type], device=device_mesh, mesh_composer=ttnn.ConcatMeshToTensor(device_mesh, dim=-1) + ) + # save tensor to file + final_tensor[layer][type] = tensor + + torch.save(final_tensor, kv_cache_path) + + # load from jason, return as a list def load_inputs(user_input, batch): if isinstance(user_input, str): diff --git a/models/demos/t3000/falcon40b/demo/expected_output_data.json b/models/demos/t3000/falcon40b/demo/expected_output_data.json index 47622d1ff3c..b936a880e7c 100644 --- a/models/demos/t3000/falcon40b/demo/expected_output_data.json +++ b/models/demos/t3000/falcon40b/demo/expected_output_data.json @@ -1 +1 @@ -["List the first 5 prime numbers \nThe first 5 prime numbers are 2, 3, 5, 7, and 11. ", "Give a brief history of the internet \nThe internet was invented in the late 1960s by computer scientists at the University of California, Los Angeles (UCLA) and the University of California, Santa Barbara (UCSB). The first message was sent in 1969 between two computers at UCLA and UCSB. The internet has since grown to become a global network of interconnected computers and devices, allowing people to communicate and share information across vast distances. ", "Describe to me some good coding practices \nSome good coding practices include: \n\n1. Properly commenting code to make it easier to understand and maintain. \n2. Using descriptive variable names to make code more readable. \n3. Keeping code organized and structured to make it easier to navigate and debug. \n4. Testing code thoroughly to ensure it works as intended. \n5. Using version control to track changes and revert mistakes. \n6. Avoiding unnecessary complexity and keeping code simple and concise. \n7. Using best practices for coding standards and conventions. \n8. Continuously learning and improving", "write a short poem about Paris in English\nParis is a city of love and romance,\nWhere the Eiffel Tower stands tall and proud,\nThe Seine River flows through the heart of the city,\nAnd the streets are filled with art and culture.\nThe city is alive with energy and excitement,\nAnd the people are warm and welcoming.\nParis is a city that never sleeps,\nAnd it's easy to fall in love with its charm and beauty. ", "Who is the inventor of the telephone?\nAlexander Graham Bell is credited with inventing the telephone in 1876. ", "write a short poem about Istanbul in English\nIstanbul is a city of contrasts,\nWhere East meets West,\nWhere ancient meets modern,\nWhere old meets new,\nWhere past meets present,\nWhere history meets future,\nWhere culture meets civilization,\nWhere religion meets secularism,\nWhere tradition meets innovation,\nWhere art meets architecture,\nWhere beauty meets chaos,\nWhere diversity meets unity,\nWhere the old city meets the new city,\nWhere the past meets the future,\nWhere the East meets the West,\nWhere the old meets the new,\nWhere the ancient meets the modern,\nWhere the", "What are the tourist attractions in Paris?\nParis is home to many famous tourist attractions such as the Eiffel Tower, Notre-Dame Cathedral, the Louvre Museum, the Champs-\u00c9lys\u00e9es, the Palace of Versailles, and the Seine River. Other popular attractions include the Arc de Triomphe, Montmartre, and the Parisian parks such as Jardin des Tuileries and Parc de la Villette. ", "How many countries are in Africa? \nThere are 54 countries in Africa. ", "what is the capital of USA? \nThe capital of USA is Washington D.C. ", "what is the capital of Canada? \nThe capital of Canada is Ottawa. ", "what is the capital of UK? \nThe capital of UK is London. ", "what is the capital of Germany? \nThe capital of Germany is Berlin. ", "what is the capital of France? \nThe capital of France is Paris. ", "what is the capital of Japan? \nThe capital of Japan is Tokyo. ", "what is the capital of India? \nThe capital of India is New Delhi. ", "what is the capital of China? \nThe capital of China is Beijing. ", "what is the currency of Cuba? \nThe currency of Cuba is the Cuban peso (CUP). ", "what is the currency of Lebanon? \nThe currency of Lebanon is the Lebanese pound (LBP). ", "what is the currency of Brazil? \nThe currency of Brazil is the Brazilian Real (BRL). ", "what is the currency of Australia? \nThe currency of Australia is the Australian dollar (AUD). ", "what is the currency of Jamaica? \nThe currency of Jamaica is the Jamaican dollar (JMD). ", "what is the currency of Egypt? \nThe currency of Egypt is the Egyptian pound (EGP). ", "what is the currency of Uzbekistan? \nThe currency of Uzbekistan is the Uzbekistani som (UZS). ", "what is the currency of Argentina? \nThe currency of Argentina is the Argentine peso. ", "describe the geographic location of London in UK\nLondon is located in the southeast of England, on the River Thames. It is the capital city of the United Kingdom and the largest city in Europe. ", "describe the geographic location of Toronto in Canada\nToronto is located in the province of Ontario, Canada. It is situated on the northwestern shore of Lake Ontario, and is the largest city in Canada. Toronto is a multicultural city with a diverse population and a thriving economy. ", "describe the geographic location of Madrid in Spain\nMadrid is located in the center of Spain, in the region of Madrid. It is the capital city of Spain and the largest city in the country. Madrid is situated on a plateau at an elevation of 2,180 feet (660 meters) above sea level. ", "describe the geographic location of Paris in France\nParis is located in the north-central part of France, on the River Seine. It is the capital city of France and the largest city in the country. ", "describe the geographic location of Rome in Italy\nRome is located in central Italy, on the Tiber River. It is the capital city of Italy and the largest city in the country. ", "describe the geographic location of Istanbul in Turkey\nIstanbul is located in the northwest corner of Turkey, on the Bosphorus Strait, which connects the Black Sea to the Sea of Marmara. It is the largest city in Turkey and the fifth largest city in the world. ", "describe the geographic location of Shanghai in China\nShanghai is located in eastern China, on the Yangtze River Delta. It is the largest city in China and one of the largest cities in the world. ", "describe the geographic location of Lagos in Nigeria\nLagos is located in the southwestern part of Nigeria, on the Gulf of Guinea. It is the largest city in Nigeria and the second largest city in Africa. Lagos is also the economic and cultural center of Nigeria, with a population of over 20 million people. "] +["List the first 5 prime numbers \nThe first 5 prime numbers are 2, 3, 5, 7, and 11. ", "Give a brief history of the internet \nThe internet was invented in the late 1960s by a group of researchers at the University of California, Los Angeles (UCLA). It was originally called ARPANET and was designed to allow researchers to share information and communicate with each other. Over time, the internet grew and evolved into the global network we know today, with billions of users worldwide. ", "Describe to me some good coding practices \nSome good coding practices include: \n\n1. Use descriptive variable names \n2. Write clear and concise code \n3. Use comments to explain complex code \n4. Test your code thoroughly \n5. Use version control \n6. Keep your code organized and structured \n7. Avoid duplicating code \n8. Use proper indentation and formatting \n9. Use appropriate data types and avoid unnecessary complexity \n10. Keep your code up-to-date and maintainable. ", "write a short poem about Paris in English\nParis is a city of love and romance,\nWhere the streets are lined with trees and flowers,\nThe air is filled with the scent of fresh pastries,\nAnd the sound of laughter and chatter.\nThe city is alive with energy and excitement,\nAnd the people are friendly and welcoming.\nParis is a city of art and culture,\nWhere museums and galleries abound,\nAnd the streets are filled with artists and performers,\nExpressing their creativity and passion.\nParis is a city of history and architecture,\nWhere ancient buildings stand tall and proud,\nAnd the streets are", "Who is the inventor of the telephone?\nAlexander Graham Bell is credited with inventing the telephone in 1876. ", "write a short poem about Istanbul in English\nIstanbul is a city of contrasts,\nWhere East meets West,\nWhere ancient meets modern,\nWhere old meets new,\nWhere past meets present,\nWhere history meets future,\nWhere tradition meets innovation,\nWhere culture meets commerce,\nWhere religion meets secularism,\nWhere art meets architecture,\nWhere beauty meets chaos,\nWhere diversity meets unity,\nWhere the old city meets the new city,\nWhere the past meets the future,\nWhere the East meets the West,\nWhere the old meets the new,\nWhere the ancient meets the modern,\nWhere the", "What are the tourist attractions in Paris?\nParis is home to many famous landmarks and attractions such as the Eiffel Tower, Notre-Dame Cathedral, the Louvre Museum, the Champs-\u00c9lys\u00e9es, the Palace of Versailles, and the Seine River. Other popular attractions include the Montmartre district, the Arc de Triomphe, and the Parisian parks such as Jardin des Tuileries and Parc de la Villette. ", "How many countries are in Africa? \nThere are 54 countries in Africa. ", "what is the capital of USA? \nThe capital of USA is Washington D.C. ", "what is the capital of Canada? \nThe capital of Canada is Ottawa. ", "what is the capital of UK? \nThe capital of UK is London. ", "what is the capital of Germany? \nThe capital of Germany is Berlin. ", "what is the capital of France? \nThe capital of France is Paris. ", "what is the capital of Japan? \nThe capital of Japan is Tokyo. ", "what is the capital of India? \nThe capital of India is New Delhi. ", "what is the capital of China? \nThe capital of China is Beijing. ", "what is the currency of Cuba? \nThe currency of Cuba is the Cuban peso (CUP). ", "what is the currency of Lebanon? \nThe currency of Lebanon is the Lebanese pound (LBP). ", "what is the currency of Brazil? \nThe currency of Brazil is the Brazilian Real (BRL). ", "what is the currency of Australia? \nThe currency of Australia is the Australian dollar (AUD). ", "what is the currency of Jamaica? \nThe currency of Jamaica is the Jamaican dollar (JMD). ", "what is the currency of Egypt? \nThe currency of Egypt is the Egyptian pound (EGP). ", "what is the currency of Uzbekistan? \nThe currency of Uzbekistan is the Uzbekistani som (UZS). ", "what is the currency of Argentina? \nThe currency of Argentina is the Argentine peso. ", "describe the geographic location of London in UK\nLondon is located in the southeast of England, on the River Thames. It is the capital city of the United Kingdom and the largest city in Europe. ", "describe the geographic location of Toronto in Canada\nToronto is located in the province of Ontario, Canada. It is situated on the northwestern shore of Lake Ontario, and is the largest city in Canada with a population of over 2.8 million people. ", "describe the geographic location of Madrid in Spain\nMadrid is located in the center of Spain, in the region of Madrid. It is the capital city of Spain and is situated on the Manzanares River. Madrid is surrounded by mountains and has a continental climate with hot summers and cold winters. ", "describe the geographic location of Paris in France\nParis is located in the north-central part of France, on the Seine River. It is the capital city of France and the largest city in the country. ", "describe the geographic location of Rome in Italy\nRome is located in central Italy, on the Tiber River. It is the capital city of Italy and the largest city in the country. ", "describe the geographic location of Istanbul in Turkey\nIstanbul is located in the northwest corner of Turkey, on the Bosphorus Strait, which connects the Black Sea to the Sea of Marmara. It is the largest city in Turkey and the fifth largest city in the world. ", "describe the geographic location of Shanghai in China\nShanghai is located in eastern China, on the Yangtze River Delta. It is the largest city in China and one of the largest cities in the world. ", "describe the geographic location of Lagos in Nigeria\nLagos is located in the southwestern part of Nigeria, on the Gulf of Guinea. It is the largest city in Nigeria and the fifth largest city in Africa. Lagos is a coastal city with a tropical climate, and it is the economic and cultural center of Nigeria. "] diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_attention.py b/models/demos/t3000/falcon40b/tests/test_falcon_attention.py index 7318c562006..7f516c5a076 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_attention.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_attention.py @@ -101,12 +101,18 @@ def run_test_FalconAttention_inference( tt_attention_mask = ttnn.as_tensor( tensor=attention_mask_bool, - dtype=model_config["ATTN_MASK_DTYPE"], - layout=ttnn.TILE_LAYOUT, + dtype=model_config["BFLOAT16_DTYPE"], + layout=ttnn.ROW_MAJOR_LAYOUT, device=device_mesh, memory_config=attention_mask_memconfig, mesh_mapper=ReplicateTensorToMesh(device_mesh), - preprocess=lambda x: x * (-1e5), + preprocess=lambda x: (x * (-1e5)).expand(1, 1, -1, -1), + ) + + tt_attention_mask = ttnn.tilize( + tt_attention_mask, + memory_config=model_config["DRAM_MEMCFG"], + dtype=model_config["ATTN_MASK_DTYPE"], ) tt_k_cache_host = torch.zeros(batch, configuration.num_kv_heads, max_position_embeddings, head_dim) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_causallm.py b/models/demos/t3000/falcon40b/tests/test_falcon_causallm.py index cd65593291e..90b9e761944 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_causallm.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_causallm.py @@ -79,10 +79,10 @@ def run_test_FalconCausalLM_inference( use_global_cos_sin_cache = True if 1: - model_input = torch.arange(seq_len * batch).reshape(batch, seq_len) + model_input = torch.randint(0, seq_len * batch, (batch, seq_len)) else: # batch identical sequences for debugging - model_input = torch.stack([torch.arange(seq_len)] * batch).reshape(batch, seq_len) + model_input = torch.stack([torch.randint(0, seq_len)] * batch).reshape(batch, seq_len) # Generate dummy kv_cache -------------------------------------------------------------- if llm_mode == "prefill": diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py b/models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py index d8c801462e4..eed55bbf83f 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py @@ -73,10 +73,10 @@ def run_test_FalconCausalLM_end_to_end( use_global_cos_sin_cache = True if 1: - model_input = torch.arange(seq_len * batch).reshape(batch, seq_len) + model_input = torch.randint(0, seq_len * batch, (batch, seq_len)) else: # batch identical sequences for debugging - model_input = torch.stack([torch.arange(seq_len)] * batch).reshape(batch, seq_len) + model_input = torch.stack([torch.randint(0, seq_len)] * batch).reshape(batch, seq_len) # Generate dummy kv_cache -------------------------------------------------------------- if llm_mode == "prefill": diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_model.py b/models/demos/t3000/falcon40b/tests/test_falcon_model.py index b01b48c9e4d..61e428be886 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_model.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_model.py @@ -73,10 +73,10 @@ def run_test_FalconModel_inference( use_global_cos_sin_cache = True if 1: - model_input = torch.arange(seq_len * batch).reshape(batch, seq_len) + model_input = torch.randint(0, seq_len * batch, (batch, seq_len)) else: # batch identical sequences for debugging - model_input = torch.stack([torch.arange(seq_len)] * batch).reshape(batch, seq_len) + model_input = torch.stack([torch.randint(0, seq_len)] * batch).reshape(batch, seq_len) # Generate dummy kv_cache -------------------------------------------------------------- if llm_mode == "prefill": diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py b/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py index 43b1fd1c439..22f1a8b59d4 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py @@ -99,7 +99,7 @@ def run_test_falcon_prefill_end_to_end_determinism( logger.info("Done loading TT Falcon Model") # Prepare inputs ----------------------------------------------------------------------- - model_input = torch.arange(seq_len * batch).reshape(batch, seq_len) + model_input = torch.randint(0, seq_len * batch, (batch, seq_len)) model_inputs = torch.split(model_input, 1) # First run to get reference output ---------------------------------------------------- diff --git a/models/demos/t3000/falcon40b/tests/test_perf_falcon.py b/models/demos/t3000/falcon40b/tests/test_perf_falcon.py index 18d9f047988..a6655eb460b 100644 --- a/models/demos/t3000/falcon40b/tests/test_perf_falcon.py +++ b/models/demos/t3000/falcon40b/tests/test_perf_falcon.py @@ -73,10 +73,10 @@ def run_test_FalconCausalLM_end_to_end( use_global_cos_sin_cache = True if True: - model_input = torch.arange(seq_len * batch).reshape(batch, seq_len) + model_input = torch.randint(0, seq_len * batch, (batch, seq_len)) else: # batch identical sequences for debugging - model_input = torch.stack([torch.arange(seq_len)] * batch).reshape(batch, seq_len) + model_input = torch.stack([torch.randint(0, seq_len)] * batch).reshape(batch, seq_len) # Generate dummy kv_cache -------------------------------------------------------------- if llm_mode == "prefill": diff --git a/models/demos/t3000/falcon40b/tt/falcon_attention.py b/models/demos/t3000/falcon40b/tt/falcon_attention.py index 9539431c163..969091fa41d 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_attention.py +++ b/models/demos/t3000/falcon40b/tt/falcon_attention.py @@ -10,15 +10,8 @@ import ttnn from ttnn import ShardTensorToMesh, ReplicateTensorToMesh -from models.utility_functions import ( - torch2tt_tensor, - tt2torch_tensor, - pad_by_zero, - nearest_32, -) -from models.demos.t3000.falcon40b.tt.model_utils import ( - convert_to_layout, -) +from models.utility_functions import nearest_32 +from models.demos.t3000.falcon40b.tt.model_utils import convert_to_layout from models.demos.t3000.falcon40b.tt.model_utils import falcon_prefill_matmul, determine_tensor_deallocation @@ -205,7 +198,7 @@ def __init__( # self.scalar = pad_by_zero(torch.Tensor([1 / math.sqrt(self.head_dim)]), self.device)[0] self.scalar = 1 / math.sqrt(self.head_dim) - self.init_preprocessing(self.model_config["LLM_MODE"], max_position_embeddings) + # self.init_preprocessing(self.model_config["LLM_MODE"], max_position_embeddings) self.layer_past = None def initialize_kvcache(self): @@ -252,22 +245,6 @@ def initialize_kvcache(self): def set_model_config(self, model_config): self.model_config = model_config - self.init_preprocessing(self.model_config["LLM_MODE"], self.max_position_embeddings) - - def init_preprocessing(self, llm_mode, max_sequence_size): - if llm_mode == "prefill": - self.attn_output = ttnn.as_tensor( - torch.zeros([1, self.num_heads_per_device, max_sequence_size, self.head_dim]), - dtype=self.model_config["POST_SOFTMAX_MM_OUTPUT_DTYPE"], - layout=ttnn.TILE_LAYOUT, - device=self.device_mesh, - memory_config=self.model_config["DRAM_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.device_mesh), - ) - - def online_preprocessing(self, llm_mode, sequence_size): - if llm_mode == "prefill": - self.sliced_attn_output = self.attn_output[:, :, :sequence_size, :] def __call__( self, @@ -369,72 +346,18 @@ def fwd_prefill( ttnn.experimental.tensor.typecast(value_layer, self.model_config["KV_CACHE_DTYPE"]), user_id, ) - key_layer_transposed = ttnn.transpose( + attn_output = ttnn.experimental.operations.primary.transformers.scaled_dot_product_attention( + query_layer, key_layer, - -2, - -1, - memory_config=self.model_config["K_TRANSPOSED_OUTPUT_MEMCFG"], + value_layer, + attention_mask, + is_causal=True, + scale=self.scalar, + program_config=self.model_config["SDPA_PROGCFG"], ) - key_layer.deallocate(True) - - slice_size = self.model_config["attention_params"]["attention_slice_size"] - num_slices = self.model_config["attention_params"]["attention_num_slices"] - - if num_slices > 1: - if not hasattr(self, "sliced_attn_output"): - self.online_preprocessing(llm_mode, q_len) - attn_output_tensor = self.sliced_attn_output - - for slice_i in range(num_slices): - # Partially slice and convert activations to sharded - q_slices = ttnn.experimental.tensor.interleaved_to_sharded_partial( - query_layer, - (8, 8), - [slice_size * 16 // 64, self.head_dim], # each slice is [1,16,128,64], we use 64 cores - num_slices, - slice_i, - ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, - ) - attn_output_slice = self.scaled_dot_product_attention( - q_slices, - key_layer_transposed, - attention_mask, - value_layer, - q_len, - ) - ttnn.experimental.tensor.sharded_to_interleaved_partial( - attn_output_slice, - attn_output_tensor, - num_slices, - slice_i, - self.model_config["DRAM_MEMCFG"], - ) - attn_output_slice.deallocate(True) - attn_output = attn_output_tensor - q_slices.deallocate(True) - else: - query_layer = convert_to_layout( - query_layer, - self.model_config["DRAM_MEMCFG"], - self.model_config["QUERY_HEIGHT_SHARDED_MEMCFG"], - ) - attn_output = self.scaled_dot_product_attention( - query_layer, - key_layer_transposed, - attention_mask, - value_layer, - q_len, - ) - attn_output = convert_to_layout( - attn_output, - self.model_config["ATTN_OUTPUT_HEIGHT_SHARDED_MEMCFG"], - self.model_config["DRAM_MEMCFG"], - ) - # Deallocate query, key, value query_layer.deallocate(True) - key_layer_transposed.deallocate(True) + key_layer.deallocate(True) value_layer.deallocate(True) # Output projection @@ -467,36 +390,6 @@ def fwd_prefill( layer_present = layer_past if use_cache else None return attn_output, layer_present - def scaled_dot_product_attention(self, q_slices, key_layer_transposed, attn_mask_slices, value_layer, q_len): - # Q * KˆT - attn_weights = ttnn.matmul( - q_slices, - key_layer_transposed, - compute_kernel_config=self.model_config["COMPUTE_KERNEL_FP16_ACC_CONFIG"], - memory_config=self.model_config["HEIGHT_SHARDED_MEMCFG"], - program_config=self.model_config["ATTENTION_MM_PROGCFG"], - dtype=self.model_config["ATTENTION_DTYPE"], - ) - # Softmax - attn_weights = ttnn.scale_causal_mask_hw_dims_softmax_in_place( - attn_weights, - self.scalar, - attn_mask_slices, - program_config=self.model_config["SOFTMAX_PROGCFG"], - ) - # Attention score * V - attn_output_slice = ttnn.matmul( - attn_weights, - value_layer, - compute_kernel_config=self.model_config["COMPUTE_KERNEL_FP16_ACC_CONFIG"], - memory_config=self.model_config["HEIGHT_SHARDED_MEMCFG"], - program_config=self.model_config["ATTENTION_MM_2_PROGCFG"], - dtype=self.model_config["ATTENTION_OUT_DTYPE"], - ) - attn_weights.deallocate(True) - - return attn_output_slice - def fwd_decode( self, hidden_states: ttnn.experimental.tensor.Tensor, diff --git a/models/demos/t3000/falcon40b/tt/falcon_decoder.py b/models/demos/t3000/falcon40b/tt/falcon_decoder.py index 6d6693a83fd..1c6bfd2cb1e 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_decoder.py +++ b/models/demos/t3000/falcon40b/tt/falcon_decoder.py @@ -131,19 +131,11 @@ def pad_ln_params(x): self.layernorm_eps = config.layer_norm_epsilon - self.init_preprocessing("prefill", 1, max_position_embeddings) - def set_model_config(self, model_config): self.model_config = model_config self.self_attn.set_model_config(model_config) self.mlp.set_model_config(model_config) - def init_preprocessing(self, llm_mode, batch_size, max_sequence_size): - self.self_attn.init_preprocessing(llm_mode, max_sequence_size) - - def online_preprocessing(self, llm_mode, sequence_size): - self.self_attn.online_preprocessing(llm_mode, sequence_size) - def __call__( self, hidden_states: ttnn.experimental.tensor.Tensor, diff --git a/models/demos/t3000/falcon40b/tt/falcon_model.py b/models/demos/t3000/falcon40b/tt/falcon_model.py index 7eb9d5f4728..9017781dadb 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_model.py +++ b/models/demos/t3000/falcon40b/tt/falcon_model.py @@ -222,11 +222,6 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token ), dim=-1, ) - attention_mask_memconfig = self.model_config["ATTN_MASK_MEMCFG"] - if attention_mask_memconfig.is_sharded(): - attn_mask_shard_shape = attention_mask_memconfig.shard_spec.shape - attn_mask_shard_shape[-1] = num_max_tokens - attention_mask_memconfig.shard_spec.shape = attn_mask_shard_shape # Push attention mask to device in row major order and then tilize on device (faster than tilizing on CPU) tt_attention_mask = ttnn.as_tensor( @@ -234,23 +229,20 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token dtype=self.model_config["BFLOAT16_DTYPE"], # subsequent tilize op expects bfloat16 inputs layout=ttnn.ROW_MAJOR_LAYOUT, device=self.device_mesh, - memory_config=attention_mask_memconfig, - mesh_mapper=ShardTensorToMesh(self.device_mesh, dim=1), - preprocess=lambda x: (x.transpose(0, 2) * -1e5).expand(-1, self.config.num_attention_heads, -1, -1), + memory_config=self.model_config["DEFAULT_MEMCFG"], + mesh_mapper=ReplicateTensorToMesh(self.device_mesh), + preprocess=lambda x: (x.transpose(0, 2) * -1e5).expand(1, 1, -1, -1), ) tt_attention_mask = ttnn.tilize( tt_attention_mask, - memory_config=attention_mask_memconfig, + memory_config=self.model_config["DEFAULT_MEMCFG"], dtype=self.model_config["ATTN_MASK_DTYPE"], ) else: raise NotImplementedError(f"Llm mode {llm_mode} is not supported! Must be one of prefill or decode.") - for layer in self.layers: - layer.online_preprocessing(llm_mode, sequence_size) - return tt_inputs, tt_attention_mask @abstractmethod diff --git a/models/demos/t3000/falcon40b/tt/model_config.py b/models/demos/t3000/falcon40b/tt/model_config.py index 25450830a04..13b3f5d6c32 100644 --- a/models/demos/t3000/falcon40b/tt/model_config.py +++ b/models/demos/t3000/falcon40b/tt/model_config.py @@ -727,8 +727,6 @@ def get_prefill_model_config(model_config_str, input_shape, num_devices): model_config["KV_CACHE_MEMCFG"] = DRAM_MEMCFG model_config["KV_CACHE_DTYPE"] = BFP8_DTYPE - model_config["ATTN_MASK_DTYPE"] = BFP4_DTYPE - model_config["WORD_EMBEDDING_OUTPUT_DTYPE"] = BFLOAT16_DTYPE # embeddings output and the residual stream # Set input df for AllGathers to bfp8 to save data bandwidth @@ -736,11 +734,7 @@ def get_prefill_model_config(model_config_str, input_shape, num_devices): model_config["ATTENTION_OUT_DTYPE"] = BFP8_DTYPE # Attention AllGather model_config["SELFOUT_MM_OUTPUT_DTYPE"] = BFP8_DTYPE # AllGather at start of the decoder layer and final AllGather - head_dim = 64 hidden_size = model_config_entries["hidden_size"] - vocab_size = model_config_entries["vocab_size"] - num_attention_heads = model_config_entries["num_attention_heads"] - num_kv_heads = model_config_entries["num_kv_heads"] batch_size, seq_len = input_shape[0], input_shape[1] row_height = seq_len @@ -784,16 +778,11 @@ def get_grid_size_and_core_range_based_on_num_cores(num_cores): ) return attention_mm_grid_size, attn_core_range_set - # Attetnion in slices: determine number of cores and shard spec + # Attention in slices: determine number of cores and shard spec attention_max_slice_size = 1024 attention_slice_size = min(attention_max_slice_size, row_height) assert row_height % attention_slice_size == 0 - attention_num_slices = row_height // attention_slice_size - attention_num_cores = min(attention_slice_size * 16 // 32, 64) - - attention_mm_grid_size, attn_core_range_set = get_grid_size_and_core_range_based_on_num_cores(attention_num_cores) - # MLP in slices: determine number of cores and shard spec use_mm_2d_start = 512 model_config["MM_USE_MM2D_START"] = use_mm_2d_start @@ -839,81 +828,6 @@ def get_grid_size_and_core_range_based_on_num_cores(num_cores): model_config["layernorm_params"] = layernorm_params - model_config["attention_params"] = { - "attention_slice_size": attention_slice_size, - "attention_max_slice_size": attention_max_slice_size, - "attention_num_slices": attention_num_slices, - } - - # Specify program configs - attetnion_mm_M = ( - attention_slice_size * 16 // attention_num_cores // 32 - ) # attetnion_slice_size * 16 qheads // attention_num_cores // TILE_SIZE - - # Attention - model_config["ATTENTION_MM_PROGCFG"] = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( - compute_with_storage_grid_size=attention_mm_grid_size, - in0_block_w=head_dim // 32, - out_subblock_h=1, - out_subblock_w=1, # use 8 for S=2k when hang is fixed - per_core_M=attetnion_mm_M, - per_core_N=row_height // 32, - fuse_batch=True, - fused_activation=None, - mcast_in0=False, - ) - model_config["SOFTMAX_PROGCFG"] = ttnn.SoftmaxShardedMultiCoreProgramConfig( - compute_with_storage_grid_size=attention_mm_grid_size, - subblock_w=1, - block_h=attetnion_mm_M, - block_w=row_height // 32, - ) - model_config["ATTENTION_MM_2_PROGCFG"] = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( - compute_with_storage_grid_size=attention_mm_grid_size, - in0_block_w=row_height // 32, - out_subblock_h=1, # use 4 for S=2k when hang is fixed - out_subblock_w=1, # use 2 for S=2k when hang is fixed - per_core_M=attetnion_mm_M, - per_core_N=head_dim // 32, - fuse_batch=True, - fused_activation=None, - mcast_in0=False, - ) - model_config["ATTENTION_DTYPE"] = dtype - - model_config["QUERY_HEIGHT_SHARDED_MEMCFG"] = ttnn.experimental.tensor.MemoryConfig( - ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.experimental.tensor.BufferType.L1, - ttnn.experimental.tensor.ShardSpec( - attn_core_range_set, - [16 * attention_slice_size // attention_num_cores, head_dim], - ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, - False, - ), - ) - - model_config["SOFTMAX_HEIGHT_SHARDED_MEMCFG"] = ttnn.experimental.tensor.MemoryConfig( - ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.experimental.tensor.BufferType.L1, - ttnn.experimental.tensor.ShardSpec( - attn_core_range_set, - [16 * attention_slice_size // attention_num_cores, row_height], - ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, - False, - ), - ) - - model_config["ATTN_OUTPUT_HEIGHT_SHARDED_MEMCFG"] = ttnn.experimental.tensor.MemoryConfig( - ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.experimental.tensor.BufferType.L1, - ttnn.experimental.tensor.ShardSpec( - attn_core_range_set, - [16 * attention_slice_size // attention_num_cores, head_dim], - ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, - False, - ), - ) - # MLP sharding specs if mlp_slice_size > use_mm_2d_start: @@ -985,8 +899,15 @@ def get_grid_size_and_core_range_based_on_num_cores(num_cores): ), ) - # uncomment if need to see all the configs - # logger.debug(f"Falcon model config: \n{pretty_print_model_config(model_config)}") + # Attention parameters + q_chunk_size = min(seq_len, 256) + k_chunk_size = min(seq_len, 256) + + model_config["SDPA_PROGCFG"] = ttnn.experimental.operations.primary.transformers.SDPAMultiCoreProgramConfig( + compute_with_storage_grid_size=[8, 7], + q_chunk_size=q_chunk_size, + k_chunk_size=k_chunk_size, + ) return model_config