Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#2005: Added representation of reduce scatter in ttir and ttnn dialect #2127

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,24 @@ def TTIR_AllReduceOp : TTIR_NamedOp<"all_reduce"> {
let hasVerifier = 1;
}

def TTIR_ReduceScatterOp : TTIR_DPSOp<"reduce_scatter"> {
let summary = "Reduce scatter operation.";
let description = [{
Reduce scatter op.
}];
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
TT_ReduceTypeAttr:$reduce_type,
SI32Attr:$scatter_dim,
UI32Attr:$cluster_axis);

let results = (outs AnyRankedTensor:$result);
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
let hasVerifier = 1;
}

def TTIR_MeshShardOp : TTIR_NamedOp<"mesh_shard"> {
let summary = "Mesh shard operation.";
let description = [{
Expand Down
60 changes: 59 additions & 1 deletion lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,9 @@ class StableHLOToTTIRLogicalAndBitwiseOpConversionPattern

template <typename SrcOpTy>
LogicalResult getReduceType(SrcOpTy &srcOp, ReduceType &reduceType) {
if constexpr (!std::is_same<SrcOpTy, mlir::stablehlo::AllReduceOp>::value) {
if constexpr (!std::is_same<SrcOpTy, mlir::stablehlo::AllReduceOp>::value &&
!std::is_same<SrcOpTy,
mlir::stablehlo::ReduceScatterOp>::value) {
return failure();
}
// Check operations in the first block and determine reduce type for now
Expand Down Expand Up @@ -1638,6 +1640,60 @@ class StableHLOToTTIRAllReduceOpConversionPattern
};
} // namespace

namespace {
class StableHLOToTTIRReduceScatterOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::ReduceScatterOp> {
using OpConversionPattern<
mlir::stablehlo::ReduceScatterOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::ReduceScatterOp srcOp,
mlir::stablehlo::ReduceScatterOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Create the output tensor type based on inputs
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));

// Create an empty output tensor with the computed shape
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

if (auto srcChannelHandleAttr = adaptor.getChannelHandleAttr()) {
// channelType is supposed to be DEVICE_TO_DEVICE or Invalid for CCL ops.
// Currently, we ensure if it is DEVICE_TO_DEVICE commmuincaiton.
// Consider preserving this information in the future if the attribute
// is non-DEVICE_TO_DEVICE values.
auto channelType = static_cast<int32_t>(srcChannelHandleAttr.getType());
if (channelType != StableHLOChannelType::kChannelTypeDeviceToDevice &&
channelType != StableHLOChannelType::kChannelTypeInvalid) {
return failure();
}
}

// Determine cluster axis based on replica groups
uint32_t clusterAxis;
if (failed(determineClusterAxis(adaptor.getReplicaGroups(), clusterAxis))) {
return rewriter.notifyMatchFailure(
srcOp, "ReduceScatterOp cannot specify cluster axis.");
}

// Convert reduceType shlo attribute into ttir attribute
ReduceType reduceType;
if (failed(getReduceType(srcOp, reduceType))) {
return rewriter.notifyMatchFailure(
srcOp, "ReduceScatterOp cannot specify reduce type.");
}

rewriter.replaceOpWithNewOp<mlir::tt::ttir::ReduceScatterOp>(
srcOp, outputType, adaptor.getOperands()[0], outputTensor, reduceType,
adaptor.getScatterDimension(), clusterAxis);

return success();
}
};
} // namespace

namespace {
class StableHLOToTTIRAllGatherOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::AllGatherOp> {
Expand Down Expand Up @@ -2357,6 +2413,8 @@ static void addCCLOpsConversionPattern(MLIRContext *ctx,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRAllReduceOpConversionPattern>(typeConverter, ctx);
patterns.add<StableHLOToTTIRAllGatherOpConversionPattern>(typeConverter, ctx);
patterns.add<StableHLOToTTIRReduceScatterOpConversionPattern>(typeConverter,
ctx);
patterns.add<StableHLOToTTIRCustomCallOpConversionPattern>(typeConverter,
ctx);
}
Expand Down
22 changes: 22 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,27 @@ class AllReduceOpConversionPattern
};
} // namespace

namespace {
class ReduceScatterOpConversionPattern
: public OpConversionPattern<ttir::ReduceScatterOp> {
public:
using OpConversionPattern<ttir::ReduceScatterOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ReduceScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);

rewriter.replaceOpWithNewOp<ttnn::ReduceScatterOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), device, adaptor.getReduceType(),
adaptor.getScatterDim(), adaptor.getClusterAxis());

return success();
}
};
} // namespace

namespace {
class MeshShardOpConversionPattern
: public OpConversionPattern<ttir::MeshShardOp> {
Expand Down Expand Up @@ -1574,6 +1595,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
MeshShardOpConversionPattern,
AllReduceOpConversionPattern,
AllGatherOpConversionPattern,
ReduceScatterOpConversionPattern,
ArangeOpConversionPattern,
UpdateCacheOpConversionPattern,
FillCacheOpConversionPattern,
Expand Down
29 changes: 29 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2257,6 +2257,35 @@ ::mlir::LogicalResult mlir::tt::ttir::AllReduceOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ReduceScatterOp
//===----------------------------------------------------------------------===//

// ReduceScatterOp verification
::mlir::LogicalResult mlir::tt::ttir::ReduceScatterOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::tt::ReduceType reduceType = getReduceType();
int32_t scatterDim = getScatterDim();

// Currently TTIR only supports the following reduce types.
if (reduceType != ::mlir::tt::ReduceType::Sum &&
reduceType != ::mlir::tt::ReduceType::Max &&
reduceType != ::mlir::tt::ReduceType::Min) {
return emitOpError("Invalid reduction op for reduce scatter op.");
}

if (scatterDim >= inputType.getRank() || scatterDim < -inputType.getRank()) {
return emitOpError(
"Invalid dimension for reduce scatter op. Scatter dimension "
"must be "
">= to "
"input tensor rank or < -input tensor rank, got scatter_dim = ")
<< scatterDim;
}

return success();
}

//===----------------------------------------------------------------------===//
// MeshShardOp
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 5 additions & 4 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1557,9 +1557,10 @@ ::mlir::LogicalResult ReduceScatterOp::verify() {
::mlir::tt::ReduceType reduceType = getReduceType();

if (scatterDim >= inputType.getRank() || scatterDim < -inputType.getRank()) {
return emitOpError("Invalid scatter dimension for all reduce op. Scatter "
"dimension must be >= to input tensor rank or < -input "
"tensor rank, got scatter_dim = ")
return emitOpError(
"Invalid scatter dimension for reduce scatter op. Scatter "
"dimension must be >= to input tensor rank or < -input "
"tensor rank, got scatter_dim = ")
<< scatterDim;
}

Expand All @@ -1569,7 +1570,7 @@ ::mlir::LogicalResult ReduceScatterOp::verify() {
if (reduceType != ::mlir::tt::ReduceType::Sum &&
reduceType != ::mlir::tt::ReduceType::Max &&
reduceType != ::mlir::tt::ReduceType::Min) {
return emitOpError("Invalid reduction op for all reduce op.");
return emitOpError("Invalid reduction op for reduce scatter op.");
}

return success();
Expand Down
Loading
Loading