From a85b8b98890b06b745d6667552935fb1cc89a454 Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <i@zhuyi.fan>
Date: Wed, 6 Nov 2024 22:27:19 -0500
Subject: [PATCH] [gccjit] alloc op lowering

---
 src/Conversion/ConvertMemrefToGCCJIT.cpp | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/src/Conversion/ConvertMemrefToGCCJIT.cpp b/src/Conversion/ConvertMemrefToGCCJIT.cpp
index ee9a812..e1a0955 100644
--- a/src/Conversion/ConvertMemrefToGCCJIT.cpp
+++ b/src/Conversion/ConvertMemrefToGCCJIT.cpp
@@ -365,7 +365,7 @@ Value AllocationLowering<OpType>::getAlignment(
   if (auto alignmentAttr = op.getAlignment()) {
     Type indexType = this->getIndexType();
     alignment =
-        createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
+        this->createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
   } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
     alignment =
         this->getAlignInBytes(loc, memRefType.getElementType(), rewriter);
@@ -603,6 +603,16 @@ struct AllocaOpLowering : public AllocationLowering<memref::AllocaOp> {
   }
 };
 
+struct AllocOpLowering : public AllocationLowering<memref::AllocOp> {
+  std::tuple<Value, Value>
+  allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
+                 Value sizeBytes, memref::AllocOp op) const override final {
+    return allocateBufferManuallyAlign(rewriter, loc, sizeBytes, op,
+                                       getAlignment(rewriter, loc, op));
+  }
+  using AllocationLowering<memref::AllocOp>::AllocationLowering;
+};
+
 void ConvertMemrefToGCCJITPass::runOnOperation() {
   auto moduleOp = getOperation();
   auto typeConverter = GCCJITTypeConverter();
@@ -618,8 +628,8 @@ void ConvertMemrefToGCCJITPass::runOnOperation() {
   typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
   typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
   mlir::RewritePatternSet patterns(&getContext());
-  patterns.insert<LoadOpLowering, StoreOpLowering, AllocaOpLowering>(
-      typeConverter, &getContext());
+  patterns.insert<LoadOpLowering, StoreOpLowering, AllocaOpLowering,
+                  AllocOpLowering>(typeConverter, &getContext());
   mlir::ConversionTarget target(getContext());
   target.addLegalDialect<gccjit::GCCJITDialect>();
   target.addIllegalDialect<memref::MemRefDialect>();