Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check and update return sharding #2298

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions include/ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir::tt::sharding_utils {

#if TTMLIR_ENABLE_STABLEHLO
#if defined(TTMLIR_ENABLE_STABLEHLO) && (TTMLIR_ENABLE_STABLEHLO != 0)

class MeshSharding {
public:
Expand All @@ -32,9 +34,47 @@ class MeshSharding {
mlir::sdy::MeshAttr mesh,
mlir::tt::MeshShardDirection direction);

// Force dummy sharding op by setting shard_type to manual. The mesh_shard op
// will be ignored at runtime by simply copying input tensor to output.
void setDummyShardingOp() { shardType = mlir::tt::MeshShardType::Manual; }
// Check and update function arg sharding
template <typename AttrType>
void checkAndUpdateFuncArgSharding(mlir::PatternRewriter &rewriter,
mlir::func::FuncOp funcOp, uint64_t argNum,
AttrType shardingAttr,
llvm::StringRef argShardingStrRef) {
if (auto argShardingAttr =
funcOp.getArgAttrOfType<AttrType>(argNum, argShardingStrRef)) {
if (argShardingAttr == shardingAttr) {
setDummyShardingOp();
rewriter.modifyOpInPlace(
funcOp, [&]() { funcOp.removeArgAttr(argNum, argShardingStrRef); });
} else {
llvm_unreachable(
"MeshSharding operation and function argument shardings "
"are different.");
}
}
}

// Check and update function ret sharding
template <typename AttrType>
void checkAndUpdateFuncReturnSharding(mlir::PatternRewriter &rewriter,
mlir::func::FuncOp funcOp,
uint64_t retNum, AttrType shardingAttr,
llvm::StringRef retShardingStrRef) {
if (auto retShardingAttr =
funcOp.getResultAttrOfType<AttrType>(retNum, retShardingStrRef)) {
if (retShardingAttr == shardingAttr) {
setDummyShardingOp();
rewriter.modifyOpInPlace(funcOp, [&]() {
funcOp.removeResultAttr(
retNum,
mlir::StringAttr::get(rewriter.getContext(), retShardingStrRef));
});
} else {
llvm_unreachable("MeshSharding operation and function return shardings "
"are different.");
}
}
}

// Getter functions.
mlir::tt::MeshShardDirection getShardDirection() const {
Expand Down Expand Up @@ -63,6 +103,10 @@ class MeshSharding {
meshShape = llvm::SmallVector<int64_t>{-1};
}

// Force dummy sharding op by setting shard_type to manual. The mesh_shard op
// will be ignored at runtime by simply copying input tensor to output.
void setDummyShardingOp() { shardType = mlir::tt::MeshShardType::Manual; }

private:
mlir::tt::MeshShardDirection shardDirection =
mlir::tt::MeshShardDirection::ShardToFull;
Expand Down
32 changes: 14 additions & 18 deletions lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,12 @@ class ShardyToTTIRManualComputationOpConversionPattern
// JAX automatic sharding pre-shards input tensors and provides multiple
// buffers. Thus, mesh sharding operations should not shard the tensors
// twice if they are function arguments and pre-sharded by frontend.
// Runtime ignores mesh sharding operation if it is set as manual
// sharding.
if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(globalOperand)) {
auto argNum = blockArg.getArgNumber();
if (mlir::sdy::TensorShardingAttr argShardingAttr =
funcOp.getArgAttrOfType<mlir::sdy::TensorShardingAttr>(
argNum, mlir::sdy::kShardingAttr)) {
if (argShardingAttr == argSharding) {
meshSharding.setDummyShardingOp();
rewriter.modifyOpInPlace(funcOp, [&]() {
funcOp.removeArgAttr(argNum, mlir::sdy::kShardingAttr);
});
} else {
llvm_unreachable("Manual computation op and function argument "
"shardings are different.");
}
}
meshSharding
.checkAndUpdateFuncArgSharding<mlir::sdy::TensorShardingAttr>(
rewriter, funcOp, argNum, argSharding,
mlir::sdy::kShardingAttr);
}

auto outputType = mlir::cast<mlir::RankedTensorType>(
Expand All @@ -142,17 +131,24 @@ class ShardyToTTIRManualComputationOpConversionPattern
// Add mesh_shard (ShardToFullShape) for outputs.
rewriter.setInsertionPointAfter(srcOp);
mlir::Operation *sdyReturn = getBodyTerminator(srcOp);
for (auto [returnOperand, outSharding, opResult] : llvm::zip_equal(
for (auto [retNum, args] : llvm::enumerate(llvm::zip_equal(
sdyReturn->getOpOperands(), srcOp.getOutShardings().getShardings(),
srcOp.getResults())) {

srcOp.getResults()))) {
auto [returnOperand, outSharding, opResult] = args;
mlir::tt::sharding_utils::MeshSharding meshSharding;
auto error = meshSharding.convertSdyShardingToMeshSharding(
outSharding, targetMesh, mlir::tt::MeshShardDirection::ShardToFull);
if (auto e = error.takeError()) {
return rewriter.notifyMatchFailure(srcOp, llvm::toString(std::move(e)));
}

// JAX automatic sharding may expect pre-sharded output tensors. Thus,
// mesh sharding operations should not concat the tensors twice if
// frontent expects pre-sharded tensor.
meshSharding
.checkAndUpdateFuncReturnSharding<mlir::sdy::TensorShardingAttr>(
rewriter, funcOp, retNum, outSharding, mlir::sdy::kShardingAttr);

auto inputOperand = returnOperand.get();
auto inputType = mlir::cast<mlir::RankedTensorType>(
getTypeConverter()->convertType(inputOperand.getType()));
Expand Down
37 changes: 21 additions & 16 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,7 @@ class StableHLOToTTIRCustomCallOpConversionPattern
meshShape);
}

auto funcOp = srcOp->getParentOfType<mlir::func::FuncOp>();
if (callTargetName ==
mlir::tt::sharding_utils::kSPMDFullToShardShapeCallTargetName) {
// @Sharding => @SPMDFullToShardShape pattern
Expand All @@ -1767,6 +1768,23 @@ class StableHLOToTTIRCustomCallOpConversionPattern
srcOp, "Requires operand to be defined by prior Sharding op.");
}

// JAX automatic sharding may expect pre-sharded output tensors. Thus,
// mesh sharding operations should not concat the tensors twice if
// frontent expects pre-sharded tensor.
if (auto *funcReturnOp = funcOp.getBody().front().getTerminator()) {
auto returnOperands = funcReturnOp->getOperands();
auto returnOperandIt =
llvm::find_if(returnOperands, [&](Value operand) {
return operand == srcOp->getResult(0);
});
if (returnOperandIt != returnOperands.end()) {
auto retNum = std::distance(returnOperands.begin(), returnOperandIt);
meshSharding.checkAndUpdateFuncReturnSharding<mlir::StringAttr>(
rewriter, funcOp, retNum, shardingAttr,
mlir::tt::sharding_utils::kXlaShardingAttr);
}
}

auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp->getResult(0).getType()));

Expand Down Expand Up @@ -1794,25 +1812,12 @@ class StableHLOToTTIRCustomCallOpConversionPattern
// JAX automatic sharding pre-shards input tensors and provides multiple
// buffers. Thus, mesh sharding operations should not shard the tensors
// twice if they are function arguments and pre-sharded by frontend.
// Runtime ignores mesh sharding operation if it is set as manual
// sharding.
auto inputOperand = adaptor.getInputs().front();
auto funcOp = srcOp->getParentOfType<mlir::func::FuncOp>();
if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(inputOperand)) {
auto argNum = blockArg.getArgNumber();
if (auto argShardingAttr = funcOp.getArgAttrOfType<mlir::StringAttr>(
argNum, mlir::tt::sharding_utils::kXlaShardingAttr)) {
if (argShardingAttr == shardingAttr) {
meshSharding.setDummyShardingOp();
rewriter.modifyOpInPlace(funcOp, [&]() {
funcOp.removeArgAttr(
argNum, mlir::tt::sharding_utils::kXlaShardingAttr);
});
} else {
llvm_unreachable("GSPMD customCallOp and function argument "
"shardings are different.");
}
}
meshSharding.checkAndUpdateFuncArgSharding<mlir::StringAttr>(
rewriter, funcOp, argNum, shardingAttr,
mlir::tt::sharding_utils::kXlaShardingAttr);
}

auto outputType =
Expand Down
2 changes: 2 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,8 @@ createOp(FlatbufferObjectCache &cache, MeshShardOp op) {
meshShardType = ::tt::target::ttnn::MeshShardType::Replicate;
} else if (shardType == mlir::tt::MeshShardType::Devices) {
meshShardType = ::tt::target::ttnn::MeshShardType::Devices;
} else if (shardType == mlir::tt::MeshShardType::Manual) {
meshShardType = ::tt::target::ttnn::MeshShardType::Manual;
} else {
llvm_unreachable("unhandled mesh_shard type");
}
Expand Down
9 changes: 5 additions & 4 deletions runtime/lib/ttnn/operations/ccl/mesh_shard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ void run(const ::tt::target::ttnn::MeshShardOp *op, ProgramContext &context) {
// pre-sharded by frontend. Thus, no sharding is required, but need to makes
// sure if the tensor is multi-device host tensor.
if (shardType == ::tt::target::ttnn::MeshShardType::Manual) {
LOG_ASSERT(
input.storage_type() == ::tt::tt_metal::StorageType::MULTI_DEVICE,
"Input of mesh_shard with manual sharding must be MULTIDEVICE. id:",
op->in()->global_id());
LOG_ASSERT(input.storage_type() ==
::tt::tt_metal::StorageType::MULTI_DEVICE_HOST,
"Input of mesh_shard with manual sharding must be MULTI DEVICE "
"HOST Storage. id:",
op->in()->global_id());
tensorPool.insert_or_assign(op->out()->global_id(), input);
return;
}
Expand Down
33 changes: 33 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_gspmd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1474,3 +1474,36 @@ module @jit_neg_basic7 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_repli
return %0 : tensor<1x128x128x1024xf32>
}
}

// -----

// jax/pjrt automatic input/output sharding tests
module @jit_negative_basic attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<256x256xf32> {mhlo.sharding = "{devices=[1,2]<=[2]}"}) -> (tensor<256x128xf32> {jax.result_info = "", mhlo.sharding = "{replicated}"}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.sharding = "{devices=[1,2]<=[2]}"} : (tensor<256x256xf32>) -> tensor<256x256xf32>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{manual}"} : (tensor<256x256xf32>) -> tensor<256x128xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 1>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 1, 2>
// CHECK-SAME: shard_type = #tt.shard_type<manual>
%2 = call @shmap_body(%1) : (tensor<256x128xf32>) -> tensor<256x128xf32>
%3 = stablehlo.custom_call @Sharding(%2) {mhlo.sharding = "{manual}"} : (tensor<256x128xf32>) -> tensor<256x128xf32>
%4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {mhlo.sharding = "{replicated}"} : (tensor<256x128xf32>) -> tensor<256x128xf32>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1>
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 1>
// CHECK-SAME: shard_type = #tt.shard_type<manual>
return %4 : tensor<256x128xf32>
}
func.func private @shmap_body(%arg0: tensor<256x128xf32>) -> (tensor<256x128xf32> {jax.result_info = "[None, None]"}) {
%0 = stablehlo.negate %arg0 : tensor<256x128xf32>
%1 = "stablehlo.all_reduce"(%0) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, use_global_device_ids}> ({
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = stablehlo.add %arg1, %arg2 : tensor<f32>
stablehlo.return %2 : tensor<f32>
}) : (tensor<256x128xf32>) -> tensor<256x128xf32>
return %1 : tensor<256x128xf32>
}
}
35 changes: 35 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_shardy.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,38 @@ module @jit_matmul_shardy_automatic attributes {mhlo.num_partitions = 8 : i32, m
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 2, 1>
// CHECK-SAME: shard_type = #tt.shard_type<devices>

// -----

// jax/pjrt automatic input/output sharding tests
module @jit_matmul_shardy1 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <["x"=2, "y"=4]>
func.func public @main(%arg0: tensor<8192x784xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<784x16384xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<8192x16384xf32> {jax.result_info = "", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) {
%0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{"y"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x", "y"} (%arg2: tensor<4096x196xf32>, %arg3: tensor<196x16384xf32>) {
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4096x196xf32>, tensor<196x16384xf32>) -> tensor<4096x16384xf32>
%2 = "stablehlo.all_reduce"(%1) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> ({
^bb0(%arg4: tensor<f32>, %arg5: tensor<f32>):
%3 = stablehlo.add %arg4, %arg5 : tensor<f32>
stablehlo.return %3 : tensor<f32>
}) : (tensor<4096x16384xf32>) -> tensor<4096x16384xf32>
sdy.return %2 : tensor<4096x16384xf32>
} : (tensor<8192x784xf32>, tensor<784x16384xf32>) -> tensor<8192x16384xf32>
return %0 : tensor<8192x16384xf32>
}
}
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: 0, 1>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 2, 4>
// CHECK-SAME: shard_type = #tt.shard_type<manual>
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: -1, 0>
// CHECK-SAME: shard_direction = #tt.shard_direction<full_to_shard>
// CHECK-SAME: shard_shape = array<i64: 4, 1>
// CHECK-SAME: shard_type = #tt.shard_type<manual>
// CHECK: = "ttir.all_reduce"
// CHECK: "ttir.mesh_shard"
// CHECK-SAME: shard_dims = array<i64: 0, -1>
// CHECK-SAME: shard_direction = #tt.shard_direction<shard_to_full>
// CHECK-SAME: shard_shape = array<i64: 2, 1>
// CHECK-SAME: shard_type = #tt.shard_type<manual>
Loading