Skip to content

Commit

Permalink
[gccjit] alloc op lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Nov 7, 2024
1 parent ca71fa0 commit a85b8b9
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/Conversion/ConvertMemrefToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ Value AllocationLowering<OpType>::getAlignment(
if (auto alignmentAttr = op.getAlignment()) {
Type indexType = this->getIndexType();
alignment =
createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
this->createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
} else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
alignment =
this->getAlignInBytes(loc, memRefType.getElementType(), rewriter);
Expand Down Expand Up @@ -603,6 +603,16 @@ struct AllocaOpLowering : public AllocationLowering<memref::AllocaOp> {
}
};

struct AllocOpLowering : public AllocationLowering<memref::AllocOp> {
std::tuple<Value, Value>
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
Value sizeBytes, memref::AllocOp op) const override final {
return allocateBufferManuallyAlign(rewriter, loc, sizeBytes, op,
getAlignment(rewriter, loc, op));
}
using AllocationLowering<memref::AllocOp>::AllocationLowering;
};

void ConvertMemrefToGCCJITPass::runOnOperation() {
auto moduleOp = getOperation();
auto typeConverter = GCCJITTypeConverter();
Expand All @@ -618,8 +628,8 @@ void ConvertMemrefToGCCJITPass::runOnOperation() {
typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
mlir::RewritePatternSet patterns(&getContext());
patterns.insert<LoadOpLowering, StoreOpLowering, AllocaOpLowering>(
typeConverter, &getContext());
patterns.insert<LoadOpLowering, StoreOpLowering, AllocaOpLowering,
AllocOpLowering>(typeConverter, &getContext());
mlir::ConversionTarget target(getContext());
target.addLegalDialect<gccjit::GCCJITDialect>();
target.addIllegalDialect<memref::MemRefDialect>();
Expand Down

0 comments on commit a85b8b9

Please sign in to comment.