From 4e3ff75110fc0fad7e3c4ab8f3238226aa23dba2 Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Fri, 22 Nov 2024 15:05:23 +0000 Subject: [PATCH] #0: Work around bad PCC in dram-sharded matmul in vision test by ensuring N is an even number of tiles --- models/demos/llama3/tt/model_config.py | 10 ++++++---- .../llama_cross_attention_transformer_text.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index a79c2649e5df..aaf0352c8094 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -439,7 +439,8 @@ def find_largest_divisor(n, max_divisor=8): ) # Width sharded - mlp_core_grid = self.dram_shard_core_grid_for_k(self.dim) # , self.hidden_dim // self.num_devices) + # mlp_core_grid = self.dram_shard_core_grid_for_k(self.dim) + mlp_core_grid = self.dram_shard_core_grid_for_k_and_n(self.dim, self.hidden_dim // self.num_devices) self.model_config["SHARDED_MLP_INPUT_MEMCFG"] = ttnn.create_sharded_memory_config( ( self.tile_padded_batch_rows, @@ -457,7 +458,8 @@ def find_largest_divisor(n, max_divisor=8): num_cores=mlp_core_grid.num_cores, ) - mlp2_core_grid = self.dram_shard_core_grid_for_k(self.hidden_dim // self.num_devices) # , self.dim) + # mlp2_core_grid = self.dram_shard_core_grid_for_k(self.hidden_dim // self.num_devices) + mlp2_core_grid = self.dram_shard_core_grid_for_k_and_n(self.hidden_dim // self.num_devices, self.dim) self.model_config["SHARDED_MLP2_INPUT_MEMCFG"] = ttnn.create_sharded_memory_config( ( self.tile_padded_batch_rows, @@ -968,8 +970,8 @@ def dram_matmul_config( ) -> ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig: # in0_block_w must evenly divide k and be no larger than tile_size * num_cores if num_cores is None: - num_cores = self.dram_shard_core_grid_for_k(k).num_cores - # num_cores = self.dram_shard_core_grid_for_k_and_n(k, n).num_cores + # num_cores = self.dram_shard_core_grid_for_k_and_n(k).num_cores + num_cores = self.dram_shard_core_grid_for_k_and_n(k, n).num_cores assert ( k % (self.tile_size * num_cores) == 0 ), f"k must be divisible by tile_size * num_cores: {k} % {self.tile_size * num_cores} != 0" diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index 8abf0588cee4..f657abb86726 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -283,7 +283,7 @@ def forward( h = self.norm(h, mode=mode) - # TODO: Switch to using dram-sharded LM haed and remove this + # TODO: Switch to using dram-sharded LM head and remove this # Note: workaround for sharded_to_interleaved memory corruption (#15113) h = ttnn.to_memory_config(h, ttnn.DRAM_MEMORY_CONFIG)