Skip to content

Commit

Permalink
Merge pull request #500 from Xilinx/bump_to_b6f04fa3
Browse files Browse the repository at this point in the history
[AutoBump] Merge with fixes of b6f04fa (Nov 07) (111)
  • Loading branch information
mgehre-amd authored Feb 4, 2025
2 parents 6acb6b6 + 283047b commit 70ddf37
Show file tree
Hide file tree
Showing 24 changed files with 2,081 additions and 736 deletions.
50 changes: 50 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16329,6 +16329,31 @@ def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
}];
}

def Torch_AtenEqBoolOp : Torch_Op<"aten.eq.bool", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::eq.bool : (bool, bool) -> (bool)`";
let arguments = (ins
Torch_BoolType:$a,
Torch_BoolType:$b
);
let results = (outs
Torch_BoolType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenEqBoolOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenEqBoolOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenNeBoolOp : Torch_Op<"aten.ne.bool", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -16476,6 +16501,31 @@ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [
let hasCanonicalizer = 1;
}

def Torch_AtenMulLeftTOp : Torch_Op<"aten.mul.left_t", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::mul.left_t : (t[], int) -> (t[])`";
let arguments = (ins
AnyTorchListType:$l,
Torch_IntType:$n
);
let results = (outs
AnyTorchListType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMulLeftTOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenMulLeftTOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_Aten__Getitem__TOp : Torch_Op<"aten.__getitem__.t", [
AllowsTypeRefinement,
ReadOnly
Expand Down
64 changes: 29 additions & 35 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/DialectResourceBlobManager.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/FormatVariadic.h"
#include <numeric>

using namespace mlir;
Expand Down Expand Up @@ -1292,6 +1290,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
});
patterns.onOp(
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Torch::ValueTensorType resultType;
Value input, weight;
int64_t group;
Expand All @@ -1316,14 +1315,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op,
"unsupported conversion: kernel_shape list size should have "
"number of values equal to weight_rank - 2");
} else {
for (unsigned i = 0; i < kernelShape.size(); i++) {
if (weightShape[i + 2] != kernelShape[i]) {
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: kernel_shape value "
"should be equal to the weight tensor shape");
}
}
}
}

Expand Down Expand Up @@ -1380,6 +1371,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
padding.resize_for_overwrite(2 * spatialRank);
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
if (weightShape[dimIdx + 2] == Torch::kUnknownSize ||
inputShape[dimIdx + 2] == Torch::kUnknownSize)
return rewriter.notifyMatchFailure(
binder.op,
"expected weight and input tensor to have static shape");
const int64_t dilatedKernelSize =
dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1;
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
Expand All @@ -1405,10 +1401,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (padding.size() != 2 * (rank - 2)) {
for (int64_t i : padding) {
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
loc, rewriter.getI64IntegerAttr(i)));
}
paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
cstPadding);
Expand All @@ -1431,10 +1427,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (matchedPads) {
for (unsigned i = 0; i < padding.size() / 2; i++) {
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
loc, rewriter.getI64IntegerAttr(padding[i])));
}
paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
cstPadding);
Expand All @@ -1443,40 +1439,40 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
SmallVector<Value> inputPaddingList;
for (uint32_t i = 0; i < padding.size() / 2; i++) {
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
padding[padding.size() / 2 - i - 1])));
loc, rewriter.getI64IntegerAttr(
padding[padding.size() / 2 - i - 1])));
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
loc,
rewriter.getI64IntegerAttr(padding[padding.size() - i - 1])));
inputPaddingList.emplace_back(
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
loc, rewriter.getI64IntegerAttr(0)));
}
// The conv op itself will have no padding since the actual padding
// is performed using the torch.pad preceding it.
paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
inputPaddingList);
Value padsSizeList =
rewriter
.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(
rewriter.getType<Torch::IntType>()),
padsRearrange)
.getResult();
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
binder.getLoc(), rewriter.getStringAttr("constant"));
loc, rewriter.getStringAttr("constant"));
Value constantValue;

if (isa<IntegerType>(inputTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
loc, rewriter.getI64IntegerAttr(0));
if (isa<FloatType>(inputTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(0.0f));
loc, rewriter.getF64FloatAttr(0.0f));
// Pad output shape must be computed explicitly from the pad values
SmallVector<int64_t> newInputShape(inputTensorType.getSizes());
for (uint32_t i = 0; i < padding.size() / 2; i++) {
Expand All @@ -1486,46 +1482,44 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
auto padTy = rewriter.getType<Torch::ValueTensorType>(
newInputShape, inputTensorType.getDtype());
paddedInput = rewriter.create<Torch::AtenPadOp>(
binder.getLoc(), padTy, input, padsSizeList, modeVal,
constantValue);
loc, padTy, input, padsSizeList, modeVal, constantValue);
}
}
for (int64_t i : dilations) {
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
loc, rewriter.getI64IntegerAttr(i)));
}
for (int64_t i : strides) {
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
loc, rewriter.getI64IntegerAttr(i)));
}
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
loc, rewriter.getI64IntegerAttr(0));
cstOutputPadding = {cstZero, cstZero};

Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstDilations);
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstStrides);
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
loc,
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstOutputPadding);
Value transposed =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value transposed = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value bias;
if (binder.op->getNumOperands() == 3) {
if (binder.tensorOperandAtIndex(bias, 2)) {
return failure();
}
} else {
bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
bias = rewriter.create<Torch::ConstantNoneOp>(loc);
}
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(group));
loc, rewriter.getI64IntegerAttr(group));

rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
binder.op, resultType, paddedInput, weight, bias, stridesList,
Expand Down
Loading

0 comments on commit 70ddf37

Please sign in to comment.