Skip to content

Commit

Permalink
lowering store/load
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Nov 5, 2024
1 parent 27bf788 commit 6e8d642
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
2 changes: 1 addition & 1 deletion include/mlir-gccjit/IR/GCCJITOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def AccessFieldOp : GCCJIT_Op<"access_field"> {
```
}];
let arguments = (ins AnyType:$composite, IndexAttr:$field);
let results = (outs GCCJIT_LValueType:$result);
let results = (outs AnyType:$result);
let assemblyFormat = [{
$composite `[` $field `]` `:` functional-type(operands, results) attr-dict
}];
Expand Down
39 changes: 32 additions & 7 deletions src/Conversion/ConvertMemrefToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ Value getMemRefDiscriptorAlignedPtr(OpBuilder &builder, Value descriptor,
const GCCJITTypeConverter &converter,
Location loc, MemRefType type) {
auto elementType = converter.convertType(type.getElementType());
return builder.create<gccjit::AccessFieldOp>(
loc, PointerType::get(builder.getContext(), elementType), descriptor,
builder.getIndexAttr(1));
auto ptrTy = PointerType::get(builder.getContext(), elementType);
return builder.create<gccjit::AccessFieldOp>(loc, ptrTy, descriptor,
builder.getIndexAttr(1));
}

Value getMemRefDescriptorBufferPtr(OpBuilder &builder, Location loc,
Expand Down Expand Up @@ -130,8 +130,9 @@ Value getStridedElementLValue(Location loc, MemRefType type, Value descriptor,
auto descriptorTy = cast<StructType>(descriptor.getType());
auto fieldTy = cast<ArrayType>(
cast<FieldAttr>(descriptorTy.getRecordFields()[4]).getType());
auto fieldLValueTy = LValueType::get(rewriter.getContext(), fieldTy);
auto strideField = rewriter.create<gccjit::AccessFieldOp>(
loc, fieldTy, descriptor, rewriter.getIndexAttr(4));
loc, fieldLValueTy, materializedMemref, rewriter.getIndexAttr(4));
auto ptrToStrideArray = rewriter.create<gccjit::AddrOp>(
loc, PointerType::get(rewriter.getContext(), fieldTy), strideField);
ptrToStrideField = rewriter.create<gccjit::BitCastOp>(
Expand Down Expand Up @@ -174,7 +175,6 @@ class LoadOpLowering : public GCCJITLoweringPattern<memref::LoadOp> {
mlir::LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
OpBuilder::InsertionGuard guard(rewriter);
auto type = op.getMemRefType();
auto retTy = typeConverter->convertType(op.getResult().getType());
auto exprBundle = rewriter.replaceOpWithNewOp<ExprOp>(op, retTy);
Expand All @@ -189,6 +189,30 @@ class LoadOpLowering : public GCCJITLoweringPattern<memref::LoadOp> {
}
};

class StoreOpLowering : public GCCJITLoweringPattern<memref::StoreOp> {
public:
using GCCJITLoweringPattern::GCCJITLoweringPattern;
mlir::LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto type = op.getMemRefType();
auto elemTy = typeConverter->convertType(type.getElementType());
auto elemLValueTy = LValueType::get(rewriter.getContext(), elemTy);
auto expr = rewriter.create<ExprOp>(op->getLoc(), elemLValueTy, true);
auto *block = rewriter.createBlock(&expr.getBody());
{
rewriter.setInsertionPointToStart(block);
Value dataLValue = getStridedElementLValue(
op.getLoc(), type, adaptor.getMemref(), expr, adaptor.getIndices(),
*getTypeConverter(), rewriter);
rewriter.create<ReturnOp>(op.getLoc(), dataLValue);
}
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<AssignOp>(op, adaptor.getValue(), expr);
return success();
}
};

void ConvertMemrefToGCCJITPass::runOnOperation() {
auto moduleOp = getOperation();
auto typeConverter = GCCJITTypeConverter();
Expand All @@ -204,10 +228,11 @@ void ConvertMemrefToGCCJITPass::runOnOperation() {
typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
mlir::RewritePatternSet patterns(&getContext());
patterns.insert<LoadOpLowering>(typeConverter, &getContext());
patterns.insert<LoadOpLowering, StoreOpLowering>(typeConverter,
&getContext());
mlir::ConversionTarget target(getContext());
target.addLegalDialect<gccjit::GCCJITDialect>();
target.addIllegalOp<memref::LoadOp>();
target.addIllegalDialect<memref::MemRefDialect>();
llvm::SmallVector<Operation *> ops;
for (auto func : moduleOp.getOps<func::FuncOp>())
ops.push_back(func);
Expand Down
2 changes: 1 addition & 1 deletion test/lowering/gemm.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %gccjit-opt %s -lower-affine -convert-scf-to-cf -convert-arith-to-gccjit -convert-func-to-gccjit -reconcile-unrealized-casts | %filecheck %s
// RUN: %gccjit-opt %s -lower-affine -convert-scf-to-cf -convert-arith-to-gccjit -convert-memref-to-gccjit -convert-func-to-gccjit -reconcile-unrealized-casts | %filecheck %s
module {
// CHECK-NOT: func.func
// CHECK-NOT: func.return
Expand Down

0 comments on commit 6e8d642

Please sign in to comment.