Skip to content

Commit

Permalink
#0: Work around bad PCC in dram-sharded matmul in vision test by ensu…
Browse files Browse the repository at this point in the history
…ring N is an even number of tiles
  • Loading branch information
yieldthought committed Nov 22, 2024
1 parent e9191c8 commit 4e3ff75
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ def find_largest_divisor(n, max_divisor=8):
)

# Width sharded
mlp_core_grid = self.dram_shard_core_grid_for_k(self.dim) # , self.hidden_dim // self.num_devices)
# mlp_core_grid = self.dram_shard_core_grid_for_k(self.dim)
mlp_core_grid = self.dram_shard_core_grid_for_k_and_n(self.dim, self.hidden_dim // self.num_devices)
self.model_config["SHARDED_MLP_INPUT_MEMCFG"] = ttnn.create_sharded_memory_config(
(
self.tile_padded_batch_rows,
Expand All @@ -457,7 +458,8 @@ def find_largest_divisor(n, max_divisor=8):
num_cores=mlp_core_grid.num_cores,
)

mlp2_core_grid = self.dram_shard_core_grid_for_k(self.hidden_dim // self.num_devices) # , self.dim)
# mlp2_core_grid = self.dram_shard_core_grid_for_k(self.hidden_dim // self.num_devices)
mlp2_core_grid = self.dram_shard_core_grid_for_k_and_n(self.hidden_dim // self.num_devices, self.dim)
self.model_config["SHARDED_MLP2_INPUT_MEMCFG"] = ttnn.create_sharded_memory_config(
(
self.tile_padded_batch_rows,
Expand Down Expand Up @@ -968,8 +970,8 @@ def dram_matmul_config(
) -> ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig:
# in0_block_w must evenly divide k and be no larger than tile_size * num_cores
if num_cores is None:
num_cores = self.dram_shard_core_grid_for_k(k).num_cores
# num_cores = self.dram_shard_core_grid_for_k_and_n(k, n).num_cores
# num_cores = self.dram_shard_core_grid_for_k_and_n(k).num_cores
num_cores = self.dram_shard_core_grid_for_k_and_n(k, n).num_cores
assert (
k % (self.tile_size * num_cores) == 0
), f"k must be divisible by tile_size * num_cores: {k} % {self.tile_size * num_cores} != 0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def forward(

h = self.norm(h, mode=mode)

# TODO: Switch to using dram-sharded LM haed and remove this
# TODO: Switch to using dram-sharded LM head and remove this
# Note: workaround for sharded_to_interleaved memory corruption (#15113)
h = ttnn.to_memory_config(h, ttnn.DRAM_MEMORY_CONFIG)

Expand Down

0 comments on commit 4e3ff75

Please sign in to comment.