Skip to content

Commit

Permalink
Merge pull request #104 from Xilinx/matthias.bump_llvm_green-2b4807
Browse files Browse the repository at this point in the history
Update to LLVM green commit 2b4807
  • Loading branch information
mgehre-amd authored Jun 16, 2023
2 parents 8ca7252 + 4e2282d commit 198c510
Show file tree
Hide file tree
Showing 12 changed files with 252 additions and 26 deletions.
6 changes: 6 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@
"IndexSelectNegativeDimModule_basic",
"IndexSelectStaticModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorModule3dInputStatic_basic",
"IndexTensorMultiIndexStaticModule_basic",
"LayerNormLastDimModule_basic",
"LayerNormModule_basic",
Expand Down Expand Up @@ -986,6 +987,7 @@
"ReduceAmaxKeepDim_basic",
"NativeLayerNormModule4D_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"Permute0RankModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"ElementwiseLog2Module_basic",
Expand Down Expand Up @@ -1036,6 +1038,7 @@
"ViewNoChangeStaticModule_basic",
"UnsafeViewExpandModule_basic",
"ReshapeCollapseModule_basic",
"ElementwiseErfModule_basic",
"ElementwiseGeluModule_basic",
"GeluBackwardModule_basic",
"ElementwiseNeIntScalarModule_basic",
Expand All @@ -1053,6 +1056,7 @@
"BaddbmmWithBetaModule_basic",
"BaddbmmBroadcast1DInputModule_basic",
"BaddbmmBroadcast2DInputModule_basic",
"NumpyTRank0Module_basic",
"NumpyTRank1Module_basic",
"NumpyTRank2Module_basic",
"NumpyTRankNStaticModule_basic",
Expand Down Expand Up @@ -1089,6 +1093,7 @@
"IndexPutImpl1DIntNonAccumulateModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorModule3dInputStatic_basic",
"ElementwiseWhereScalarModule_basic",
"FullLikeModuleFloat3DStatic_basic",
"FullModuleDefaultDtype_basic",
Expand Down Expand Up @@ -1358,6 +1363,7 @@
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"IndexTensorModule3dInput_basic",
"IndexTensorModule3dInputStatic_basic",
"IndexTensorModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
Expand Down
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 11612 files
2 changes: 1 addition & 1 deletion externals/mlir-hlo
Submodule mlir-hlo updated 7942 files
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ TypedValue<RankedTensorType> transposeBy(Location loc,
PatternRewriter &rewriter, Value val,
ArrayRef<int32_t> permutation);

// Get accumulator type for AvgPool2dOp.
LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
TypeAttr &accType);

} // namespace tosa
} // namespace mlir

Expand Down
5 changes: 1 addition & 4 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
initIndex = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
}

DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
RankedTensorType::get({}, rewriter.getI64Type()), dim);

auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
auto indexTensor = rewriter.create<stablehlo::DynamicIotaOp>(
Expand All @@ -115,7 +112,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
initValue,
initIndex,
},
dimensions);
rewriter.getI64TensorAttr(dim));

Block &block = stablehloReduceOp.getBody().emplaceBlock();

Expand Down
183 changes: 179 additions & 4 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3993,6 +3993,11 @@ class ConvertAtenIndexTensorOpNone
op.getLoc(), "unimplemented: index must be ranked tensor");
}

if (indices.getType().getRank() != 1) {
return rewriter.notifyMatchFailure(
op.getLoc(), "unimplemented: index must be 1d tensor");
}

auto input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy || !inputTy.hasStaticShape())
Expand Down Expand Up @@ -4724,10 +4729,25 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "Failed to process inputs for pooling");

auto pooledOutput =
rewriter
.create<TosaOpT>(op->getLoc(), outputTy, input, kernel, stride, pad)
.getResult();
Value pooledOutput;
static_assert(std::is_same<TosaOpT, tosa::MaxPool2dOp>::value ||
std::is_same<TosaOpT, tosa::AvgPool2dOp>::value,
"Expected either tosa::MaxPool2dOp or tosa::AvgPool2dOp");
if constexpr (std::is_same<TosaOpT, tosa::MaxPool2dOp>::value) {
pooledOutput = rewriter
.create<TosaOpT>(op->getLoc(), outputTy, input, kernel,
stride, pad)
.getResult();
} else if constexpr (std::is_same<TosaOpT, tosa::AvgPool2dOp>::value) {
TypeAttr accType;
if (failed(tosa::getAvgPool2dAccType(rewriter, input, accType)))
return rewriter.notifyMatchFailure(
op, "Failed to get accumulator type for pooling");
pooledOutput = rewriter
.create<TosaOpT>(op->getLoc(), outputTy, input, kernel,
stride, pad, accType)
.getResult();
}

auto transposedOutput =
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingOutputToChw(
Expand Down Expand Up @@ -5499,7 +5519,159 @@ class ConvertAtenOpToTosaCustomOp : public OpConversionPattern<AtenOpT> {
std::string implementedWithOpAttr;
};

class SimplifyAtenIndexTensorWithSliceIndex
: public OpRewritePattern<AtenIndexTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;


LogicalResult matchAndRewrite(AtenIndexTensorOp op,
PatternRewriter &rewriter) const override {
auto outTy = dyn_cast<BaseTensorType>(op.getType());
if (!outTy) {
return rewriter.notifyMatchFailure(op, "requires tensor type");
}

SmallVector<Value> indices;
if (!getListConstructElements(op.getIndices(), indices))
return failure();

TypedValue<BaseTensorType> input =
dyn_cast<TypedValue<BaseTensorType>>(op.getSelf());
if (!input) {
return rewriter.notifyMatchFailure(op, "requires tensor type");
}

if (llvm::count_if(indices, [](Value v) {
return !isa<Torch::NoneType>(v.getType());
}) == 1) {
return rewriter.notifyMatchFailure(op, "nothing to do");
}

auto loc = op->getLoc();

for (size_t i = 0; i < indices.size(); ++i) {
if (isa<Torch::NoneType>(indices[i].getType()))
continue;

auto indicesTy = dyn_cast<BaseTensorType>(indices[i].getType());
if (!indicesTy || !indicesTy.areAllSizesKnown()) {
return rewriter.notifyMatchFailure(
op, "requires indices with static shape");
}
int64_t numIndices = std::accumulate(
indicesTy.getSizes().begin(), indicesTy.getSizes().end(), 1,
[&](int64_t a, int64_t b) { return a * b; });
if (numIndices != 1)
continue;

auto inputTy = input.getType();
SmallVector<int64_t> slicedShape{inputTy.getSizes()};
slicedShape[i] = 1;
auto slicedType =
inputTy.getWithSizesAndDtype(slicedShape, inputTy.getDtype());

auto none = rewriter.create<Torch::ConstantNoneOp>(op->getLoc());
SmallVector<Value> sliceIndices{inputTy.getSizes().size(), none};
sliceIndices[i] = reshapeTo(loc, rewriter, indices[i], {1});

Value sliceIndicesV = rewriter.create<PrimListConstructOp>(
loc, op.getIndices().getType(), sliceIndices);
auto slicedInput = rewriter.create<AtenIndexTensorOp>(
loc, slicedType, input, sliceIndicesV);

SmallVector<int64_t> reshapedShape = slicedShape;
reshapedShape.erase(reshapedShape.begin() + i);

auto reshaped = reshapeTo(loc, rewriter, slicedInput, reshapedShape);

SmallVector<Value> newIndicesList{indices};
newIndicesList.erase(newIndicesList.begin() + i);

Value newIndicesListV = rewriter.create<PrimListConstructOp>(
loc, op.getIndices().getType(), newIndicesList);

rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(op, op.getType(), reshaped,
newIndicesListV);
return success();
}
return failure();
}
};
class SimplifyAtenIndexTensorWithNdIndex
: public OpRewritePattern<AtenIndexTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(AtenIndexTensorOp op,
PatternRewriter &rewriter) const override {
auto outTy = dyn_cast<BaseTensorType>(op.getType());
if (!outTy) {
return rewriter.notifyMatchFailure(op, "requires tensor type");
}

SmallVector<Value> indices;
if (!getListConstructElements(op.getIndices(), indices))
return failure();

TypedValue<BaseTensorType> input =
dyn_cast<TypedValue<BaseTensorType>>(op.getSelf());
if (!input) {
return rewriter.notifyMatchFailure(op, "requires tensor type");
}
auto loc = op->getLoc();

if (llvm::count_if(indices, [](Value v) {
return !isa<Torch::NoneType>(v.getType());
}) != 1) {
return rewriter.notifyMatchFailure(op, "can only handle single None");
}

for (size_t i = 0; i < indices.size(); ++i) {
if (isa<Torch::NoneType>(indices[i].getType()))
continue;

auto indicesTy = dyn_cast<BaseTensorType>(indices[i].getType());
if (!indicesTy || !indicesTy.areAllSizesKnown()) {
return rewriter.notifyMatchFailure(
op, "requires indices with static shape");
}
if (indicesTy.getSizes().size() == 1) {
continue;
}

// flatten indices
int64_t numIndices = std::accumulate(
indicesTy.getSizes().begin(), indicesTy.getSizes().end(), 1,
[&](int64_t a, int64_t b) { return a * b; });

auto newIndices =
reshapeTo(op.getLoc(), rewriter, indices[i], {numIndices});

SmallVector<Value> newIndicesList{indices};
newIndicesList[i] = newIndices;

Value newIndicesListV = rewriter.create<PrimListConstructOp>(
loc, op.getIndices().getType(), newIndicesList);

SmallVector<int64_t> indexOpShape{outTy.getSizes()};
indexOpShape.erase(indexOpShape.begin() + i,
indexOpShape.begin() + i + indicesTy.getSizes().size());
indexOpShape.insert(indexOpShape.begin() + i, numIndices);

auto indexOpType =
outTy.getWithSizesAndDtype(indexOpShape, outTy.getOptionalDtype());
auto indexed = rewriter.create<AtenIndexTensorOp>(
loc, indexOpType, input, newIndicesListV);

auto reshaped =
reshapeTo(loc, rewriter, indexed, outTy.getSizes());
rewriter.replaceOp(op, reshaped);
return success();
}
return failure();
}
};
} // namespace

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -5542,6 +5714,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {

patterns.add<SimplifyAten_IndexPutImplOp>(context);
patterns.add<SimplifyAten_IndexPutImplOpNone>(context);
patterns.add<SimplifyAtenIndexTensorWithSliceIndex>(context);
patterns.add<SimplifyAtenIndexTensorWithNdIndex>(context);
patterns.add<ConvertAtenIndexTensorOpNone>(typeConverter, context);

#define INSERT_SIMPLIFY_OP_PATTERN(AtenOp) \
Expand All @@ -5567,6 +5741,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp)
INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp)
INSERT_UNARY_PATTERN(AtenErfOp, tosa::ErfOp)
#undef INSERT_UNARY_PATTERN

#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
Expand Down
23 changes: 22 additions & 1 deletion lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
.getResult();
}


// Templated function to create a constant op for given type and shape.
// T: storage C type.
// Default template creates a constant tensor in T.
Expand Down Expand Up @@ -439,5 +438,27 @@ template std::optional<Value> getConstTensor<int64_t>(PatternRewriter &,
ArrayRef<int64_t> vec,
ArrayRef<int64_t> shape,
std::optional<Type> dtype);

LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
TypeAttr &accType) {
auto inputTy = llvm::dyn_cast<ShapedType>(input.getType());
if (!inputTy)
return failure();
auto inputETy = inputTy.getElementType();

if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
inputETy = quantType.getStorageType();

// Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time
// FP16 is supported, the accumulator type can be selected based on trade-off
// between performance and accuracy. Set to FP32 by default.
accType = inputETy.isa<FloatType>()
? mlir::TypeAttr::get(rewriter.getF32Type())
: mlir::TypeAttr::get(rewriter.getIntegerType(32));

return success();
}

} // namespace tosa
} // namespace mlir
5 changes: 3 additions & 2 deletions lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
Expand Down Expand Up @@ -239,13 +240,13 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
[](Torch::TupleType type,
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
SmallVectorImpl<Type> &types) -> LogicalResult {
llvm::append_range(types, type.getContainedTypes());
return success();
});
typeConverter.addConversion(
[](Torch::NoneType type,
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
SmallVectorImpl<Type> &types) -> LogicalResult {
return success();
});

Expand Down
Loading

0 comments on commit 198c510

Please sign in to comment.