From a85b8b98890b06b745d6667552935fb1cc89a454 Mon Sep 17 00:00:00 2001 From: Schrodinger ZHU Yifan <i@zhuyi.fan> Date: Wed, 6 Nov 2024 22:27:19 -0500 Subject: [PATCH] [gccjit] alloc op lowering --- src/Conversion/ConvertMemrefToGCCJIT.cpp | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/Conversion/ConvertMemrefToGCCJIT.cpp b/src/Conversion/ConvertMemrefToGCCJIT.cpp index ee9a812..e1a0955 100644 --- a/src/Conversion/ConvertMemrefToGCCJIT.cpp +++ b/src/Conversion/ConvertMemrefToGCCJIT.cpp @@ -365,7 +365,7 @@ Value AllocationLowering<OpType>::getAlignment( if (auto alignmentAttr = op.getAlignment()) { Type indexType = this->getIndexType(); alignment = - createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr); + this->createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr); } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { alignment = this->getAlignInBytes(loc, memRefType.getElementType(), rewriter); @@ -603,6 +603,16 @@ struct AllocaOpLowering : public AllocationLowering<memref::AllocaOp> { } }; +struct AllocOpLowering : public AllocationLowering<memref::AllocOp> { + std::tuple<Value, Value> + allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, + Value sizeBytes, memref::AllocOp op) const override final { + return allocateBufferManuallyAlign(rewriter, loc, sizeBytes, op, + getAlignment(rewriter, loc, op)); + } + using AllocationLowering<memref::AllocOp>::AllocationLowering; +}; + void ConvertMemrefToGCCJITPass::runOnOperation() { auto moduleOp = getOperation(); auto typeConverter = GCCJITTypeConverter(); @@ -618,8 +628,8 @@ void ConvertMemrefToGCCJITPass::runOnOperation() { typeConverter.addTargetMaterialization(materializeAsUnrealizedCast); typeConverter.addSourceMaterialization(materializeAsUnrealizedCast); mlir::RewritePatternSet patterns(&getContext()); - patterns.insert<LoadOpLowering, StoreOpLowering, AllocaOpLowering>( - typeConverter, &getContext()); + patterns.insert<LoadOpLowering, StoreOpLowering, AllocaOpLowering, + AllocOpLowering>(typeConverter, &getContext()); mlir::ConversionTarget target(getContext()); target.addLegalDialect<gccjit::GCCJITDialect>(); target.addIllegalDialect<memref::MemRefDialect>();