Skip to content

Commit

Permalink
Remove dps interface from ttnn ops
Browse files Browse the repository at this point in the history
  • Loading branch information
mtopalovicTT committed Feb 28, 2025
1 parent 54a459c commit 92b1ad9
Show file tree
Hide file tree
Showing 43 changed files with 196 additions and 454 deletions.
61 changes: 13 additions & 48 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -753,20 +753,18 @@ def TTNN_ProdOp : TTNN_Op<"prod"> {
let hasVerifier = 1;
}

def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {
def TTNN_EmbeddingOp : TTNN_Op<"embedding"> {
let summary = "Embedding op.";
let description = [{
Embedding operation.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
AnyRankedTensor:$output);
AnyRankedTensor:$weight);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds();
}
Expand Down Expand Up @@ -809,7 +807,7 @@ def TTNN_FillCacheOp : TTNN_InplaceOp<"fill_cache"> {
let hasVerifier = 1;
}

def TTNN_EmbeddingBackwardOp : TTNN_NamedDPSOp<"embedding_bw"> {
def TTNN_EmbeddingBackwardOp : TTNN_Op<"embedding_bw"> {
let summary = "Embedding backward op.";
let description = [{
Embedding backward operation. Generates the gradient of the embedding operation with respect to the input.
Expand All @@ -819,13 +817,11 @@ def TTNN_EmbeddingBackwardOp : TTNN_NamedDPSOp<"embedding_bw"> {
AnyRankedTensor:$weight,
AnyRankedTensor:$in_gradient,
OptionalAttr<TT_DataTypeAttr>:$dtype,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config,
AnyRankedTensor:$output);
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkaroundsFactory::createEmbeddingBackwardOpOperandsWorkarounds();
}
Expand All @@ -834,7 +830,7 @@ def TTNN_EmbeddingBackwardOp : TTNN_NamedDPSOp<"embedding_bw"> {
let hasVerifier = 1;
}

def TTNN_MorehCumSumOp : TTNN_NamedDPSOp<"moreh_cumsum"> {
def TTNN_MorehCumSumOp : TTNN_Op<"moreh_cumsum"> {
let summary = "Moreh cummulative sum op.";
let description = [{
Computes the cumulative sum of elements of a tensor along specified dimension.
Expand All @@ -854,14 +850,11 @@ def TTNN_MorehCumSumOp : TTNN_NamedDPSOp<"moreh_cumsum"> {

let arguments = (ins AnyRankedTensor:$input,
I64Attr:$dim,
AnyRankedTensor:$output,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }

wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
RankedTensorType inputType = getInput().getType();
return wa::TTNNOperandsWorkaroundsFactory::createCumSumOpOperandsWorkarounds(inputType);
Expand Down Expand Up @@ -924,22 +917,19 @@ def TTNN_RepeatInterleaveOp : TTNN_Op<"repeat_interleave"> {
let hasVerifier = 1;
}

def TTNN_ConcatOp : TTNN_NamedDPSOp<"concat", [HasMemoryConfigTrait]> {
def TTNN_ConcatOp : TTNN_Op<"concat", [HasMemoryConfigTrait]> {
let summary = "Concat op.";
let description = [{
Concat tensors along a given dimension.
}];

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
AnyRankedTensor:$output,
SI32Attr:$dim,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }

wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
::mlir::Operation::operand_range inputs = getInputs();
int64_t numOperands = getOperands().size();
Expand Down Expand Up @@ -1016,27 +1006,24 @@ def TTNN_PadOp: TTNN_Op<"pad"> {
let hasVerifier = 1;
}

def TTNN_SliceOp: TTNN_NamedDPSOp<"slice"> {
def TTNN_SliceOp: TTNN_Op<"slice"> {
let summary = "Slice op.";
let description = [{
Extract a portion of a tensor based on the specified start (`begins`), stop (`ends`), and step
indices for each dimension.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I32ArrayAttr:$begins,
I32ArrayAttr:$ends,
I32ArrayAttr:$step);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }

wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
ttnn::TTNNLayoutAttr layoutAttr = mlir::cast<ttnn::TTNNLayoutAttr>(
getOutput().getType().getEncoding());
getResult().getType().getEncoding());
::mlir::ArrayAttr begins = getBegins();
::mlir::ArrayAttr step = getStep();
return wa::TTNNOperandsWorkaroundsFactory::
Expand All @@ -1047,7 +1034,7 @@ def TTNN_SliceOp: TTNN_NamedDPSOp<"slice"> {
let hasVerifier = 1;
}

def TTNN_LinearOp : TTNN_NamedDPSOp<"linear"> {
def TTNN_LinearOp : TTNN_Op<"linear"> {
let summary = "Linear transformation of inputs.";

let description = [{
Expand All @@ -1064,41 +1051,31 @@ def TTNN_LinearOp : TTNN_NamedDPSOp<"linear"> {
let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
DefaultValuedAttr<BoolAttr, "false">:$transpose_b);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}


// ANCHOR: adding_an_op_matmul_ttnn
def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul",
def TTNN_MatmulOp : TTNN_Op<"matmul",
[DeclareOpInterfaceMethods<TTNN_OpModelInterface, ["getOpConstraints", "getOpRuntime"]>]
> {
let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
AnyRankedTensor:$output,
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
DefaultValuedAttr<BoolAttr, "false">:$transpose_b);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}
// ANCHOR_END: adding_an_op_matmul_ttnn

def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> {
def TTNN_Conv2dOp : TTNN_Op<"conv2d"> {
let summary = "Conv2d operation.";
let description = [{
Applies a 2D convolution over an input image composed of several input planes.
Expand All @@ -1107,7 +1084,6 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> {
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
TT_Device:$device,
I32Attr:$in_channels,
I32Attr:$out_channels,
Expand All @@ -1123,14 +1099,10 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> {

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> {
def TTNN_ConvTranspose2dOp : TTNN_Op<"conv_transpose2d"> {
let summary = "ConvTranspose2d operation.";
let description = [{
Applies a 2D transposed convolution operator over an input image composed of several input planes.
Expand Down Expand Up @@ -1191,7 +1163,6 @@ def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> {
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
TT_Device:$device,
I32Attr:$in_channels,
I32Attr:$out_channels,
Expand All @@ -1207,21 +1178,16 @@ def TTNN_ConvTranspose2dOp : TTNN_NamedDPSOp<"conv_transpose2d"> {

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {
def TTNN_MaxPool2dOp : TTNN_Op<"max_pool2d"> {
let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
let description = [{
Applies a 2D max pooling over an input signal composed of several input planes.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
TT_Device:$device,
SI32Attr:$batch_size,
SI32Attr:$input_height,
Expand All @@ -1240,7 +1206,6 @@ def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {
let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds();
}
Expand Down
59 changes: 24 additions & 35 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ class EmbeddingOpConversionPattern
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::EmbeddingOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getWeight(), adaptor.getOutput());
adaptor.getInput(), adaptor.getWeight());

return success();
}
Expand Down Expand Up @@ -455,7 +455,7 @@ class EmbeddingBackwardOpConversionPattern
rewriter.replaceOpWithNewOp<ttnn::EmbeddingBackwardOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getWeight(), reshapedGrad, dTypeAttr,
memoryConfigAttr, adaptor.getOutput());
memoryConfigAttr);
return success();
}
};
Expand All @@ -471,7 +471,7 @@ class CumSumOpConversionPattern : public OpConversionPattern<ttir::CumSumOp> {
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::MorehCumSumOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getDim(), adaptor.getOutput(), nullptr);
adaptor.getInput(), adaptor.getDim(), nullptr);
return success();
}
};
Expand Down Expand Up @@ -654,7 +654,7 @@ class ConcatOpConversionPattern : public OpConversionPattern<ttir::ConcatOp> {
}
rewriter.replaceOpWithNewOp<ttnn::ConcatOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInputs(), adaptor.getOutput(), dim,
adaptor.getInputs(), dim,
/* memory_config */ nullptr);
return success();
}
Expand Down Expand Up @@ -687,8 +687,8 @@ class SliceOpConversionPattern : public OpConversionPattern<ttir::SliceOp> {
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::SliceOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getOutput(), adaptor.getBegins(),
adaptor.getEnds(), adaptor.getStep());
adaptor.getInput(), adaptor.getBegins(), adaptor.getEnds(),
adaptor.getStep());
return success();
}
};
Expand Down Expand Up @@ -932,8 +932,8 @@ class LinearOpConversionPattern : public OpConversionPattern<ttir::LinearOp> {
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::LinearOp>(
op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(),
adaptor.getB(), adaptor.getBias(), adaptor.getOutput(),
adaptor.getTransposeA(), adaptor.getTransposeB());
adaptor.getB(), adaptor.getBias(), adaptor.getTransposeA(),
adaptor.getTransposeB());
return success();
}
};
Expand All @@ -950,8 +950,7 @@ class MatmulOpConversionPattern : public OpConversionPattern<ttir::MatmulOp> {
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::MatmulOp>(
op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(),
adaptor.getB(), adaptor.getOutput(), adaptor.getTransposeA(),
adaptor.getTransposeB());
adaptor.getB(), adaptor.getTransposeA(), adaptor.getTransposeB());
return success();
}
};
Expand All @@ -971,7 +970,7 @@ class Conv2dOpConversionPattern : public OpConversionPattern<ttir::Conv2dOp> {

auto inputTy = mlir::cast<RankedTensorType>(adaptor.getInput().getType());
auto kernelTy = mlir::cast<RankedTensorType>(adaptor.getWeight().getType());
auto outputTy = mlir::cast<RankedTensorType>(adaptor.getOutput().getType());
auto outputTy = op.getResult().getType();

auto batchSizeAttr = rewriter.getI32IntegerAttr(inputTy.getDimSize(0));
auto inputHeightAttr = rewriter.getI32IntegerAttr(inputTy.getDimSize(1));
Expand Down Expand Up @@ -1028,12 +1027,11 @@ class Conv2dOpConversionPattern : public OpConversionPattern<ttir::Conv2dOp> {
outputTy.getElementType(),
outputTy.getEncoding());

ttnn::Conv2dOp newConv = ttmlir::utils::createDPSOp<ttnn::Conv2dOp>(
rewriter, op.getLoc(), outputTy, adaptor.getInput(),
adaptor.getWeight(), adaptor.getBias(), device, inChannelsAttr,
outChannelsAttr, batchSizeAttr, inputHeightAttr, inputWidthAttr,
kernelSizeAttr, *strideAttr, reducedPaddingAttr, *dilationAttr,
groupsAttr, nullptr);
ttnn::Conv2dOp newConv = rewriter.create<ttnn::Conv2dOp>(
op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(),
adaptor.getBias(), device, inChannelsAttr, outChannelsAttr,
batchSizeAttr, inputHeightAttr, inputWidthAttr, kernelSizeAttr,
*strideAttr, reducedPaddingAttr, *dilationAttr, groupsAttr, nullptr);

Value output =
ttir_to_ttnn::utils::generateReshape(newConv, outputShape, rewriter);
Expand Down Expand Up @@ -1091,7 +1089,7 @@ class ConvTranspose2dOpConversionPattern

auto inputTy = mlir::cast<RankedTensorType>(adaptor.getInput().getType());
auto kernelTy = mlir::cast<RankedTensorType>(adaptor.getWeight().getType());
auto outputTy = mlir::cast<RankedTensorType>(adaptor.getOutput().getType());
auto outputTy = op.getResult().getType();

auto batchSizeAttr = rewriter.getI32IntegerAttr(inputTy.getDimSize(0));
auto inputHeightAttr = rewriter.getI32IntegerAttr(inputTy.getDimSize(1));
Expand Down Expand Up @@ -1151,21 +1149,12 @@ class ConvTranspose2dOpConversionPattern
outputTy = mlir::cast<RankedTensorType>(getTypeConverter()->convertType(
outputTy.cloneWith(flattenedOutputShape, outputTy.getElementType())));

// Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the
// attribute determination
auto convDPSOutput = rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
adaptor.getOutput().getDefiningOp(), flattenedOutputShape,
outputTy.getElementType());

// Must set the type to the output type to maintain the layout attributes
convDPSOutput.getResult().setType(outputTy);

ttnn::ConvTranspose2dOp new_conv = rewriter.create<ttnn::ConvTranspose2dOp>(
op.getLoc(), outputTy, adaptor.getInput(), adaptor.getWeight(),
adaptor.getBias(), convDPSOutput, device, inChannelsAttr,
outChannelsAttr, batchSizeAttr, inputHeightAttr, inputWidthAttr,
kernelSizeAttr, *strideAttr, reducedPaddingAttr, *outputPaddingAttr,
*dilationAttr, groupsAttr);
adaptor.getBias(), device, inChannelsAttr, outChannelsAttr,
batchSizeAttr, inputHeightAttr, inputWidthAttr, kernelSizeAttr,
*strideAttr, reducedPaddingAttr, *outputPaddingAttr, *dilationAttr,
groupsAttr);

// Restore the normal shape (N x H x W x C)
Value output =
Expand Down Expand Up @@ -1239,8 +1228,7 @@ class MaxPool2dOpConversionPattern
mlir::cast<mlir::TypedValue<RankedTensorType>>(adaptor.getInput()),
rewriter);

auto outputType =
mlir::cast<RankedTensorType>(adaptor.getOutput().getType());
auto outputType = op.getResult().getType();
llvm::ArrayRef<std::int64_t> outputShape = outputType.getShape();

llvm::SmallVector<int64_t> flattenedOutputShape{
Expand All @@ -1250,8 +1238,9 @@ class MaxPool2dOpConversionPattern
outputType.getElementType(),
outputType.getEncoding());

auto newPool = ttmlir::utils::createDPSOp<ttnn::MaxPool2dOp>(
rewriter, op.getLoc(), outputType, flattenedInput, device, batchSize,
auto newPool = rewriter.create<ttnn::MaxPool2dOp>(
op.getLoc(), this->getTypeConverter()->convertType(outputType),
flattenedInput, device, batchSize,
static_cast<int32_t>(inputShape[inputShape.size() - 3]),
static_cast<int32_t>(inputShape[inputShape.size() - 2]), channels,
adaptor.getKernelHeight(), adaptor.getKernelWidth(),
Expand Down
Loading

0 comments on commit 92b1ad9

Please sign in to comment.