diff --git a/src/Conversion/ConvertMemrefToGCCJIT.cpp b/src/Conversion/ConvertMemrefToGCCJIT.cpp index ee9a812..e1a0955 100644 --- a/src/Conversion/ConvertMemrefToGCCJIT.cpp +++ b/src/Conversion/ConvertMemrefToGCCJIT.cpp @@ -365,7 +365,7 @@ Value AllocationLowering::getAlignment( if (auto alignmentAttr = op.getAlignment()) { Type indexType = this->getIndexType(); alignment = - createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr); + this->createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr); } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { alignment = this->getAlignInBytes(loc, memRefType.getElementType(), rewriter); @@ -603,6 +603,16 @@ struct AllocaOpLowering : public AllocationLowering { } }; +struct AllocOpLowering : public AllocationLowering { + std::tuple + allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, + Value sizeBytes, memref::AllocOp op) const override final { + return allocateBufferManuallyAlign(rewriter, loc, sizeBytes, op, + getAlignment(rewriter, loc, op)); + } + using AllocationLowering::AllocationLowering; +}; + void ConvertMemrefToGCCJITPass::runOnOperation() { auto moduleOp = getOperation(); auto typeConverter = GCCJITTypeConverter(); @@ -618,8 +628,8 @@ void ConvertMemrefToGCCJITPass::runOnOperation() { typeConverter.addTargetMaterialization(materializeAsUnrealizedCast); typeConverter.addSourceMaterialization(materializeAsUnrealizedCast); mlir::RewritePatternSet patterns(&getContext()); - patterns.insert( - typeConverter, &getContext()); + patterns.insert(typeConverter, &getContext()); mlir::ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalDialect();