Skip to content

Commit

Permalink
#2022: Added data parallel and model parallel tests for all multichip…
Browse files Browse the repository at this point in the history
… configurations
  • Loading branch information
tapspatel committed Feb 28, 2025
1 parent 8242514 commit 50c2b11
Show file tree
Hide file tree
Showing 8 changed files with 610 additions and 0 deletions.

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions test/ttmlir/Silicon/TTNN/llmbox/ccl/ccl_2x4.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,30 @@ func.func @reduce_scatter_cluster1(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1
// CHECK: "ttnn.mesh_shard"
return %5 : tensor<1x1x8192x128xf32>
}

func.func public @jit_data_tensor_parallel_t3000(%arg0: tensor<64x1x1024x2048xf32>, %arg1: tensor<1x1x2048x512xf32>) -> (tensor<64x1x1024x512xf32> {jax.result_info = ""}) {
%0 = tensor.empty() : tensor<32x1x1024x512xf32>
%1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array<i64: 0, 3>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 2, 1, 1, 4>, shard_type = #tt.shard_type<devices>}> : (tensor<64x1x1024x2048xf32>, tensor<32x1x1024x512xf32>) -> tensor<32x1x1024x512xf32>
// CHECK: "ttnn.mesh_shard"
%2 = tensor.empty() : tensor<1x1x512x512xf32>
%3 = "ttir.mesh_shard"(%arg1, %2) <{shard_dims = array<i64: -1, 2>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 1, 1, 4, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<1x1x2048x512xf32>, tensor<1x1x512x512xf32>) -> tensor<1x1x512x512xf32>
// CHECK: "ttnn.mesh_shard"
%4 = tensor.empty() : tensor<32x1024x512xf32>
%5 = "ttir.reshape"(%1, %4) <{shape = [32 : i32, 1024 : i32, 512 : i32]}> : (tensor<32x1x1024x512xf32>, tensor<32x1024x512xf32>) -> tensor<32x1024x512xf32>
// CHECK: = "ttnn.reshape"
%6 = tensor.empty() : tensor<1x512x512xf32>
%7 = "ttir.reshape"(%3, %6) <{shape = [1 : i32, 512 : i32, 512 : i32]}> : (tensor<1x1x512x512xf32>, tensor<1x512x512xf32>) -> tensor<1x512x512xf32>
// CHECK: = "ttnn.reshape"
%8 = "ttir.dot_general"(%5, %7) <{batch_dims_lhs = array<i64>, batch_dims_rhs = array<i64>, contract_dims_lhs = array<i64: 2>, contract_dims_rhs = array<i64: 1>}> : (tensor<32x1024x512xf32>, tensor<1x512x512xf32>) -> tensor<32x1024x1x512xf32>
// CHECK: "ttir.matmul"
%9 = tensor.empty() : tensor<32x1x1024x512xf32>
%10 = "ttir.permute"(%8, %9) <{permutation = array<i64: 0, 2, 1, 3>}> : (tensor<32x1024x1x512xf32>, tensor<32x1x1024x512xf32>) -> tensor<32x1x1024x512xf32>
// CHECK: "ttnn.permute"
%11 = tensor.empty() : tensor<32x1x256x512xf32>
%12 = "ttir.reduce_scatter"(%10, %11) <{cluster_axis = 1 : ui32, reduce_type = #tt.reduce_type<sum>, scatter_dim = 2 : si32}> : (tensor<32x1x1024x512xf32>, tensor<32x1x256x512xf32>) -> tensor<32x1x256x512xf32>
// CHECK: "ttnn.reduce_scatter"
%13 = tensor.empty() : tensor<64x1x1024x512xf32>
%14 = "ttir.mesh_shard"(%12, %13) <{shard_dims = array<i64: 0, 2>, shard_direction = #tt.shard_direction<shard_to_full>, shard_shape = array<i64: 2, 1, 4, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<32x1x256x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32>
// CHECK: "ttnn.mesh_shard"
return %14 : tensor<64x1x1024x512xf32>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% mesh-shape=1,8" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn

func.func public @jit_data_parallel_t3000(%arg0: tensor<64x1x1024x2048xf32>, %arg1: tensor<1x1x2048x512xf32>) -> (tensor<64x1x1024x512xf32> {jax.result_info = ""}) {
%0 = tensor.empty() : tensor<8x1x1024x2048xf32>
%1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array<i64: -1, 0>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 8, 1, 1, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<64x1x1024x2048xf32>, tensor<8x1x1024x2048xf32>) -> tensor<8x1x1024x2048xf32>
// CHECK: "ttnn.mesh_shard"
%2 = tensor.empty() : tensor<1x1x2048x512xf32>
%3 = "ttir.mesh_shard"(%arg1, %2) <{shard_dims = array<i64: -1>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 1>, shard_type = #tt.shard_type<replicate>}> : (tensor<1x1x2048x512xf32>, tensor<1x1x2048x512xf32>) -> tensor<1x1x2048x512xf32>
// CHECK: "ttnn.mesh_shard"
%4 = tensor.empty() : tensor<8x1024x2048xf32>
%5 = "ttir.reshape"(%1, %4) <{shape = [8 : i32, 1024 : i32, 2048 : i32]}> : (tensor<8x1x1024x2048xf32>, tensor<8x1024x2048xf32>) -> tensor<8x1024x2048xf32>
// CHECK: = "ttnn.reshape"
%6 = tensor.empty() : tensor<1x2048x512xf32>
%7 = "ttir.reshape"(%3, %6) <{shape = [1 : i32, 2048 : i32, 512 : i32]}> : (tensor<1x1x2048x512xf32>, tensor<1x2048x512xf32>) -> tensor<1x2048x512xf32>
// CHECK: = "ttnn.reshape"
%8 = "ttir.dot_general"(%5, %7) <{batch_dims_lhs = array<i64>, batch_dims_rhs = array<i64>, contract_dims_lhs = array<i64: 2>, contract_dims_rhs = array<i64: 1>}> : (tensor<8x1024x2048xf32>, tensor<1x2048x512xf32>) -> tensor<8x1024x1x512xf32>
// CHECK: "ttir.matmul"
%9 = tensor.empty() : tensor<8x1x1024x512xf32>
%10 = "ttir.permute"(%8, %9) <{permutation = array<i64: 0, 2, 1, 3>}> : (tensor<8x1024x1x512xf32>, tensor<8x1x1024x512xf32>) -> tensor<8x1x1024x512xf32>
// CHECK: "ttnn.permute"
%11 = tensor.empty() : tensor<64x1x1024x512xf32>
%12 = "ttir.mesh_shard"(%10, %11) <{shard_dims = array<i64: -1, 0>, shard_direction = #tt.shard_direction<shard_to_full>, shard_shape = array<i64: 8, 1, 1, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<8x1x1024x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32>
// CHECK: "ttnn.mesh_shard"
return %12 : tensor<64x1x1024x512xf32>
}

func.func public @jit_tensor_parallel_t3000(%arg0: tensor<64x1x1024x2048xf32>, %arg1: tensor<1x1x2048x512xf32>) -> (tensor<64x1x1024x512xf32> {jax.result_info = ""}) {
%0 = tensor.empty() : tensor<64x1x1024x256xf32>
%1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array<i64: -1, 3>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 1, 1, 1, 8>, shard_type = #tt.shard_type<devices>}> : (tensor<64x1x1024x2048xf32>, tensor<64x1x1024x256xf32>) -> tensor<64x1x1024x256xf32>
// CHECK: "ttnn.mesh_shard"
%2 = tensor.empty() : tensor<1x1x256x512xf32>
%3 = "ttir.mesh_shard"(%arg1, %2) <{shard_dims = array<i64: -1, 2>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 1, 1, 8, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<1x1x2048x512xf32>, tensor<1x1x256x512xf32>) -> tensor<1x1x256x512xf32>
// CHECK: "ttnn.mesh_shard"
%4 = tensor.empty() : tensor<64x1024x256xf32>
%5 = "ttir.reshape"(%1, %4) <{shape = [64 : i32, 1024 : i32, 256 : i32]}> : (tensor<64x1x1024x256xf32>, tensor<64x1024x256xf32>) -> tensor<64x1024x256xf32>
// CHECK: = "ttnn.reshape"
%6 = tensor.empty() : tensor<1x256x512xf32>
%7 = "ttir.reshape"(%3, %6) <{shape = [1 : i32, 256 : i32, 512 : i32]}> : (tensor<1x1x256x512xf32>, tensor<1x256x512xf32>) -> tensor<1x256x512xf32>
// CHECK: = "ttnn.reshape"
%8 = "ttir.dot_general"(%5, %7) <{batch_dims_lhs = array<i64>, batch_dims_rhs = array<i64>, contract_dims_lhs = array<i64: 2>, contract_dims_rhs = array<i64: 1>}> : (tensor<64x1024x256xf32>, tensor<1x256x512xf32>) -> tensor<64x1024x1x512xf32>
// CHECK: "ttir.matmul"
%9 = tensor.empty() : tensor<64x1x1024x512xf32>
%10 = "ttir.permute"(%8, %9) <{permutation = array<i64: 0, 2, 1, 3>}> : (tensor<64x1024x1x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32>
// CHECK: "ttnn.permute"
%11 = tensor.empty() : tensor<64x1x128x512xf32>
%12 = "ttir.reduce_scatter"(%10, %11) <{cluster_axis = 1 : ui32, reduce_type = #tt.reduce_type<sum>, scatter_dim = 2 : si32}> : (tensor<64x1x1024x512xf32>, tensor<64x1x128x512xf32>) -> tensor<64x1x128x512xf32>
// CHECK: "ttnn.reduce_scatter"
%13 = tensor.empty() : tensor<64x1x1024x512xf32>
%14 = "ttir.mesh_shard"(%12, %13) <{shard_dims = array<i64: -1, 2>, shard_direction = #tt.shard_direction<shard_to_full>, shard_shape = array<i64: 1, 1, 8, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<64x1x128x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32>
// CHECK: "ttnn.mesh_shard"
return %14 : tensor<64x1x1024x512xf32>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% mesh-shape=2,4" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn

func.func public @jit_data_tensor_parallel_t3000(%arg0: tensor<64x1x1024x2048xf32>, %arg1: tensor<1x1x2048x512xf32>) -> (tensor<64x1x1024x512xf32> {jax.result_info = ""}) {
%0 = tensor.empty() : tensor<32x1x1024x512xf32>
%1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array<i64: 0, 3>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 2, 1, 1, 4>, shard_type = #tt.shard_type<devices>}> : (tensor<64x1x1024x2048xf32>, tensor<32x1x1024x512xf32>) -> tensor<32x1x1024x512xf32>
// CHECK: "ttnn.mesh_shard"
%2 = tensor.empty() : tensor<1x1x512x512xf32>
%3 = "ttir.mesh_shard"(%arg1, %2) <{shard_dims = array<i64: -1, 2>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 1, 1, 4, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<1x1x2048x512xf32>, tensor<1x1x512x512xf32>) -> tensor<1x1x512x512xf32>
// CHECK: "ttnn.mesh_shard"
%4 = tensor.empty() : tensor<32x1024x512xf32>
%5 = "ttir.reshape"(%1, %4) <{shape = [32 : i32, 1024 : i32, 512 : i32]}> : (tensor<32x1x1024x512xf32>, tensor<32x1024x512xf32>) -> tensor<32x1024x512xf32>
// CHECK: = "ttnn.reshape"
%6 = tensor.empty() : tensor<1x512x512xf32>
%7 = "ttir.reshape"(%3, %6) <{shape = [1 : i32, 512 : i32, 512 : i32]}> : (tensor<1x1x512x512xf32>, tensor<1x512x512xf32>) -> tensor<1x512x512xf32>
// CHECK: = "ttnn.reshape"
%8 = "ttir.dot_general"(%5, %7) <{batch_dims_lhs = array<i64>, batch_dims_rhs = array<i64>, contract_dims_lhs = array<i64: 2>, contract_dims_rhs = array<i64: 1>}> : (tensor<32x1024x512xf32>, tensor<1x512x512xf32>) -> tensor<32x1024x1x512xf32>
// CHECK: "ttir.matmul"
%9 = tensor.empty() : tensor<32x1x1024x512xf32>
%10 = "ttir.permute"(%8, %9) <{permutation = array<i64: 0, 2, 1, 3>}> : (tensor<32x1024x1x512xf32>, tensor<32x1x1024x512xf32>) -> tensor<32x1x1024x512xf32>
// CHECK: "ttnn.permute"
%11 = tensor.empty() : tensor<32x1x256x512xf32>
%12 = "ttir.reduce_scatter"(%10, %11) <{cluster_axis = 1 : ui32, reduce_type = #tt.reduce_type<sum>, scatter_dim = 2 : si32}> : (tensor<32x1x1024x512xf32>, tensor<32x1x256x512xf32>) -> tensor<32x1x256x512xf32>
// CHECK: "ttnn.reduce_scatter"
%13 = tensor.empty() : tensor<64x1x1024x512xf32>
%14 = "ttir.mesh_shard"(%12, %13) <{shard_dims = array<i64: 0, 2>, shard_direction = #tt.shard_direction<shard_to_full>, shard_shape = array<i64: 2, 1, 4, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<32x1x256x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32>
// CHECK: "ttnn.mesh_shard"
return %14 : tensor<64x1x1024x512xf32>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% mesh-shape=1,2" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn

func.func public @jit_data_parallel_n300(%arg0: tensor<64x1x1024x2048xf32>, %arg1: tensor<1x1x2048x512xf32>) -> (tensor<64x1x1024x512xf32> {jax.result_info = ""}) {
%0 = tensor.empty() : tensor<32x1x1024x2048xf32>
%1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array<i64: -1, 0>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 2, 1, 1, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<64x1x1024x2048xf32>, tensor<32x1x1024x2048xf32>) -> tensor<32x1x1024x2048xf32>
// CHECK: "ttnn.mesh_shard"
%2 = tensor.empty() : tensor<1x1x2048x512xf32>
%3 = "ttir.mesh_shard"(%arg1, %2) <{shard_dims = array<i64: -1>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 1>, shard_type = #tt.shard_type<replicate>}> : (tensor<1x1x2048x512xf32>, tensor<1x1x2048x512xf32>) -> tensor<1x1x2048x512xf32>
// CHECK: "ttnn.mesh_shard"
%4 = tensor.empty() : tensor<32x1024x2048xf32>
%5 = "ttir.reshape"(%1, %4) <{shape = [32 : i32, 1024 : i32, 2048 : i32]}> : (tensor<32x1x1024x2048xf32>, tensor<32x1024x2048xf32>) -> tensor<32x1024x2048xf32>
// CHECK: = "ttnn.reshape"
%6 = tensor.empty() : tensor<1x2048x512xf32>
%7 = "ttir.reshape"(%3, %6) <{shape = [1 : i32, 2048 : i32, 512 : i32]}> : (tensor<1x1x2048x512xf32>, tensor<1x2048x512xf32>) -> tensor<1x2048x512xf32>
// CHECK: = "ttnn.reshape"
%8 = "ttir.dot_general"(%5, %7) <{batch_dims_lhs = array<i64>, batch_dims_rhs = array<i64>, contract_dims_lhs = array<i64: 2>, contract_dims_rhs = array<i64: 1>}> : (tensor<32x1024x2048xf32>, tensor<1x2048x512xf32>) -> tensor<32x1024x1x512xf32>
// CHECK: "ttir.matmul"
%9 = tensor.empty() : tensor<32x1x1024x512xf32>
%10 = "ttir.permute"(%8, %9) <{permutation = array<i64: 0, 2, 1, 3>}> : (tensor<32x1024x1x512xf32>, tensor<32x1x1024x512xf32>) -> tensor<32x1x1024x512xf32>
// CHECK: "ttnn.permute"
%11 = tensor.empty() : tensor<64x1x1024x512xf32>
%12 = "ttir.mesh_shard"(%10, %11) <{shard_dims = array<i64: -1, 0>, shard_direction = #tt.shard_direction<shard_to_full>, shard_shape = array<i64: 2, 1, 1, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<32x1x1024x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32>
// CHECK: "ttnn.mesh_shard"
return %12 : tensor<64x1x1024x512xf32>
}

func.func public @jit_tensor_parallel_n300(%arg0: tensor<64x1x1024x2048xf32>, %arg1: tensor<1x1x2048x512xf32>) -> (tensor<64x1x1024x512xf32> {jax.result_info = ""}) {
%0 = tensor.empty() : tensor<64x1x1024x1024xf32>
%1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array<i64: -1, 3>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 1, 1, 1, 2>, shard_type = #tt.shard_type<devices>}> : (tensor<64x1x1024x2048xf32>, tensor<64x1x1024x1024xf32>) -> tensor<64x1x1024x1024xf32>
// CHECK: "ttnn.mesh_shard"
%2 = tensor.empty() : tensor<1x1x1024x512xf32>
%3 = "ttir.mesh_shard"(%arg1, %2) <{shard_dims = array<i64: -1, 2>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 1, 1, 2, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<1x1x2048x512xf32>, tensor<1x1x1024x512xf32>) -> tensor<1x1x1024x512xf32>
// CHECK: "ttnn.mesh_shard"
%4 = tensor.empty() : tensor<64x1024x1024xf32>
%5 = "ttir.reshape"(%1, %4) <{shape = [64 : i32, 1024 : i32, 1024 : i32]}> : (tensor<64x1x1024x1024xf32>, tensor<64x1024x1024xf32>) -> tensor<64x1024x1024xf32>
// CHECK: = "ttnn.reshape"
%6 = tensor.empty() : tensor<1x1024x512xf32>
%7 = "ttir.reshape"(%3, %6) <{shape = [1 : i32, 1024 : i32, 512 : i32]}> : (tensor<1x1x1024x512xf32>, tensor<1x1024x512xf32>) -> tensor<1x1024x512xf32>
// CHECK: = "ttnn.reshape"
%8 = "ttir.dot_general"(%5, %7) <{batch_dims_lhs = array<i64>, batch_dims_rhs = array<i64>, contract_dims_lhs = array<i64: 2>, contract_dims_rhs = array<i64: 1>}> : (tensor<64x1024x1024xf32>, tensor<1x1024x512xf32>) -> tensor<64x1024x1x512xf32>
// CHECK: "ttir.matmul"
%9 = tensor.empty() : tensor<64x1x1024x512xf32>
%10 = "ttir.permute"(%8, %9) <{permutation = array<i64: 0, 2, 1, 3>}> : (tensor<64x1024x1x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32>
// CHECK: "ttnn.permute"
%11 = tensor.empty() : tensor<64x1x512x512xf32>
%12 = "ttir.reduce_scatter"(%10, %11) <{cluster_axis = 1 : ui32, reduce_type = #tt.reduce_type<sum>, scatter_dim = 2 : si32}> : (tensor<64x1x1024x512xf32>, tensor<64x1x512x512xf32>) -> tensor<64x1x512x512xf32>
// CHECK: "ttnn.reduce_scatter"
%13 = tensor.empty() : tensor<64x1x1024x512xf32>
%14 = "ttir.mesh_shard"(%12, %13) <{shard_dims = array<i64: -1, 2>, shard_direction = #tt.shard_direction<shard_to_full>, shard_shape = array<i64: 1, 1, 2, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<64x1x512x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32>
// CHECK: "ttnn.mesh_shard"
return %14 : tensor<64x1x1024x512xf32>
}
27 changes: 27 additions & 0 deletions test/ttmlir/Silicon/TTNN/tg/ccl/ccl_8x4.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,30 @@ func.func @reduce_scatter_cluster1(%arg0: tensor<1x1x8192x512xf32>) -> (tensor<1
// CHECK: "ttnn.mesh_shard"
return %5 : tensor<1x1x8192x128xf32>
}

func.func public @jit_data_tensor_parallel_tg(%arg0: tensor<64x1x1024x2048xf32>, %arg1: tensor<1x1x2048x512xf32>) -> (tensor<64x1x1024x512xf32> {jax.result_info = ""}) {
%0 = tensor.empty() : tensor<8x1x1024x512xf32>
%1 = "ttir.mesh_shard"(%arg0, %0) <{shard_dims = array<i64: 0, 3>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 8, 1, 1, 4>, shard_type = #tt.shard_type<devices>}> : (tensor<64x1x1024x2048xf32>, tensor<8x1x1024x512xf32>) -> tensor<8x1x1024x512xf32>
// CHECK: "ttnn.mesh_shard"
%2 = tensor.empty() : tensor<1x1x512x512xf32>
%3 = "ttir.mesh_shard"(%arg1, %2) <{shard_dims = array<i64: -1, 2>, shard_direction = #tt.shard_direction<full_to_shard>, shard_shape = array<i64: 1, 1, 4, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<1x1x2048x512xf32>, tensor<1x1x512x512xf32>) -> tensor<1x1x512x512xf32>
// CHECK: "ttnn.mesh_shard"
%4 = tensor.empty() : tensor<8x1024x512xf32>
%5 = "ttir.reshape"(%1, %4) <{shape = [8 : i32, 1024 : i32, 512 : i32]}> : (tensor<8x1x1024x512xf32>, tensor<8x1024x512xf32>) -> tensor<8x1024x512xf32>
// CHECK: = "ttnn.reshape"
%6 = tensor.empty() : tensor<1x512x512xf32>
%7 = "ttir.reshape"(%3, %6) <{shape = [1 : i32, 512 : i32, 512 : i32]}> : (tensor<1x1x512x512xf32>, tensor<1x512x512xf32>) -> tensor<1x512x512xf32>
// CHECK: = "ttnn.reshape"
%8 = "ttir.dot_general"(%5, %7) <{batch_dims_lhs = array<i64>, batch_dims_rhs = array<i64>, contract_dims_lhs = array<i64: 2>, contract_dims_rhs = array<i64: 1>}> : (tensor<8x1024x512xf32>, tensor<1x512x512xf32>) -> tensor<8x1024x1x512xf32>
// CHECK: "ttir.matmul"
%9 = tensor.empty() : tensor<8x1x1024x512xf32>
%10 = "ttir.permute"(%8, %9) <{permutation = array<i64: 0, 2, 1, 3>}> : (tensor<8x1024x1x512xf32>, tensor<8x1x1024x512xf32>) -> tensor<8x1x1024x512xf32>
// CHECK: "ttnn.permute"
%11 = tensor.empty() : tensor<8x1x256x512xf32>
%12 = "ttir.reduce_scatter"(%10, %11) <{cluster_axis = 1 : ui32, reduce_type = #tt.reduce_type<sum>, scatter_dim = 2 : si32}> : (tensor<8x1x1024x512xf32>, tensor<8x1x256x512xf32>) -> tensor<8x1x256x512xf32>
// CHECK: "ttnn.reduce_scatter"
%13 = tensor.empty() : tensor<64x1x1024x512xf32>
%14 = "ttir.mesh_shard"(%12, %13) <{shard_dims = array<i64: 0, 2>, shard_direction = #tt.shard_direction<shard_to_full>, shard_shape = array<i64: 8, 1, 4, 1>, shard_type = #tt.shard_type<devices>}> : (tensor<8x1x256x512xf32>, tensor<64x1x1024x512xf32>) -> tensor<64x1x1024x512xf32>
// CHECK: "ttnn.mesh_shard"
return %14 : tensor<64x1x1024x512xf32>
}
Loading

0 comments on commit 50c2b11

Please sign in to comment.