From 6e8d64283398fb127a202827e2a7fceeae5b07ab Mon Sep 17 00:00:00 2001 From: Schrodinger ZHU Yifan Date: Mon, 4 Nov 2024 19:31:33 -0500 Subject: [PATCH] lowering store/load --- include/mlir-gccjit/IR/GCCJITOps.td | 2 +- src/Conversion/ConvertMemrefToGCCJIT.cpp | 39 +++++++++++++++++++----- test/lowering/gemm.mlir | 2 +- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/include/mlir-gccjit/IR/GCCJITOps.td b/include/mlir-gccjit/IR/GCCJITOps.td index 34f1e68..111ae9f 100644 --- a/include/mlir-gccjit/IR/GCCJITOps.td +++ b/include/mlir-gccjit/IR/GCCJITOps.td @@ -779,7 +779,7 @@ def AccessFieldOp : GCCJIT_Op<"access_field"> { ``` }]; let arguments = (ins AnyType:$composite, IndexAttr:$field); - let results = (outs GCCJIT_LValueType:$result); + let results = (outs AnyType:$result); let assemblyFormat = [{ $composite `[` $field `]` `:` functional-type(operands, results) attr-dict }]; diff --git a/src/Conversion/ConvertMemrefToGCCJIT.cpp b/src/Conversion/ConvertMemrefToGCCJIT.cpp index 71d88b5..7816318 100644 --- a/src/Conversion/ConvertMemrefToGCCJIT.cpp +++ b/src/Conversion/ConvertMemrefToGCCJIT.cpp @@ -73,9 +73,9 @@ Value getMemRefDiscriptorAlignedPtr(OpBuilder &builder, Value descriptor, const GCCJITTypeConverter &converter, Location loc, MemRefType type) { auto elementType = converter.convertType(type.getElementType()); - return builder.create( - loc, PointerType::get(builder.getContext(), elementType), descriptor, - builder.getIndexAttr(1)); + auto ptrTy = PointerType::get(builder.getContext(), elementType); + return builder.create(loc, ptrTy, descriptor, + builder.getIndexAttr(1)); } Value getMemRefDescriptorBufferPtr(OpBuilder &builder, Location loc, @@ -130,8 +130,9 @@ Value getStridedElementLValue(Location loc, MemRefType type, Value descriptor, auto descriptorTy = cast(descriptor.getType()); auto fieldTy = cast( cast(descriptorTy.getRecordFields()[4]).getType()); + auto fieldLValueTy = LValueType::get(rewriter.getContext(), fieldTy); auto strideField = rewriter.create( - loc, fieldTy, descriptor, rewriter.getIndexAttr(4)); + loc, fieldLValueTy, materializedMemref, rewriter.getIndexAttr(4)); auto ptrToStrideArray = rewriter.create( loc, PointerType::get(rewriter.getContext(), fieldTy), strideField); ptrToStrideField = rewriter.create( @@ -174,7 +175,6 @@ class LoadOpLowering : public GCCJITLoweringPattern { mlir::LogicalResult matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - OpBuilder::InsertionGuard guard(rewriter); auto type = op.getMemRefType(); auto retTy = typeConverter->convertType(op.getResult().getType()); auto exprBundle = rewriter.replaceOpWithNewOp(op, retTy); @@ -189,6 +189,30 @@ class LoadOpLowering : public GCCJITLoweringPattern { } }; +class StoreOpLowering : public GCCJITLoweringPattern { +public: + using GCCJITLoweringPattern::GCCJITLoweringPattern; + mlir::LogicalResult + matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto type = op.getMemRefType(); + auto elemTy = typeConverter->convertType(type.getElementType()); + auto elemLValueTy = LValueType::get(rewriter.getContext(), elemTy); + auto expr = rewriter.create(op->getLoc(), elemLValueTy, true); + auto *block = rewriter.createBlock(&expr.getBody()); + { + rewriter.setInsertionPointToStart(block); + Value dataLValue = getStridedElementLValue( + op.getLoc(), type, adaptor.getMemref(), expr, adaptor.getIndices(), + *getTypeConverter(), rewriter); + rewriter.create(op.getLoc(), dataLValue); + } + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp(op, adaptor.getValue(), expr); + return success(); + } +}; + void ConvertMemrefToGCCJITPass::runOnOperation() { auto moduleOp = getOperation(); auto typeConverter = GCCJITTypeConverter(); @@ -204,10 +228,11 @@ void ConvertMemrefToGCCJITPass::runOnOperation() { typeConverter.addTargetMaterialization(materializeAsUnrealizedCast); typeConverter.addSourceMaterialization(materializeAsUnrealizedCast); mlir::RewritePatternSet patterns(&getContext()); - patterns.insert(typeConverter, &getContext()); + patterns.insert(typeConverter, + &getContext()); mlir::ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalDialect(); llvm::SmallVector ops; for (auto func : moduleOp.getOps()) ops.push_back(func); diff --git a/test/lowering/gemm.mlir b/test/lowering/gemm.mlir index e645f8a..73aaf64 100644 --- a/test/lowering/gemm.mlir +++ b/test/lowering/gemm.mlir @@ -1,4 +1,4 @@ -// RUN: %gccjit-opt %s -lower-affine -convert-scf-to-cf -convert-arith-to-gccjit -convert-func-to-gccjit -reconcile-unrealized-casts | %filecheck %s +// RUN: %gccjit-opt %s -lower-affine -convert-scf-to-cf -convert-arith-to-gccjit -convert-memref-to-gccjit -convert-func-to-gccjit -reconcile-unrealized-casts | %filecheck %s module { // CHECK-NOT: func.func // CHECK-NOT: func.return