From 96125ceb5fd7a42e6ff723262b144551da1f6cc2 Mon Sep 17 00:00:00 2001 From: Schrodinger ZHU Yifan Date: Wed, 6 Nov 2024 17:08:27 -0500 Subject: [PATCH] stage --- src/Conversion/ConvertMemrefToGCCJIT.cpp | 280 +++++++++++++++++++++++ 1 file changed, 280 insertions(+) diff --git a/src/Conversion/ConvertMemrefToGCCJIT.cpp b/src/Conversion/ConvertMemrefToGCCJIT.cpp index ffad290..4e9ebda 100644 --- a/src/Conversion/ConvertMemrefToGCCJIT.cpp +++ b/src/Conversion/ConvertMemrefToGCCJIT.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -31,6 +32,7 @@ #include "mlir-gccjit/IR/GCCJITOpsEnums.h" #include "mlir-gccjit/IR/GCCJITTypes.h" #include "mlir-gccjit/Passes.h" +#include "mlir/IR/Types.h" using namespace mlir; using namespace mlir::gccjit; @@ -50,8 +52,24 @@ class GCCJITLoweringPattern : public mlir::OpConversionPattern { } IntType getIndexType() const; + PointerType getVoidPtrType() const { + return PointerType::get(this->getContext(), + VoidType::get(this->getContext())); + } Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value) const; + Value getSizeInBytes(Location loc, Type type, + ConversionPatternRewriter &rewriter) const; + Value getAlignInBytes(Location loc, Type type, + ConversionPatternRewriter &rewriter) const; + PointerType getElementPtrType(MemRefType type) const; + + void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, + ValueRange dynamicSizes, + ConversionPatternRewriter &rewriter, + SmallVectorImpl &sizes, + SmallVectorImpl &strides, Value &size, + bool sizeInBytes) const; class MemRefDescriptor { private: @@ -86,6 +104,47 @@ class GCCJITLoweringPattern : public mlir::OpConversionPattern { using OpConversionPattern::OpConversionPattern; }; +template +class AllocationLowering : public GCCJITLoweringPattern { +protected: + /// Computes the aligned value for 'input' as follows: + /// bumped = input + alignement - 1 + /// aligned = bumped - bumped % alignment + Value createAligned(ConversionPatternRewriter &rewriter, Location loc, + Value input, Value alignment) const; + + MemRefType getMemRefResultType(OpType op) const; + + Value getAlignment(ConversionPatternRewriter &rewriter, Location loc, + OpType op) const; + + int64_t alignedAllocationGetAlignment(ConversionPatternRewriter &rewriter, + Location loc, OpType op) const; + + std::tuple + allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc, + Value sizeBytes, OpType op, + Value alignment) const; + + /// Allocates a memory buffer using an aligned allocation method. + Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter, + Location loc, Value sizeBytes, OpType op, + int64_t alignment) const; + + virtual std::tuple + allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size, + Operation *op) const = 0; + +private: + static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; + +public: + LogicalResult + matchAndRewrite(OpType op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final; +}; + class LoadOpLowering : public GCCJITLoweringPattern { public: using GCCJITLoweringPattern::GCCJITLoweringPattern; @@ -297,6 +356,227 @@ GCCJITLoweringPattern::getMemRefDescriptor( return {descriptor, type, rewriter, *this}; } +template +Value GCCJITLoweringPattern::getSizeInBytes( + Location loc, Type type, ConversionPatternRewriter &rewriter) const { + Type gccjitType = getTypeConverter()->convertType(type); + auto indexType = getIndexType(); + return rewriter.create(loc, indexType, gccjitType); +} + +template +Value GCCJITLoweringPattern::getAlignInBytes( + Location loc, Type type, ConversionPatternRewriter &rewriter) const { + Type gccjitType = getTypeConverter()->convertType(type); + auto indexType = getIndexType(); + return rewriter.create(loc, indexType, gccjitType); +} + +template +PointerType GCCJITLoweringPattern::getElementPtrType(MemRefType type) const { + auto eltTy = getTypeConverter()->convertType(type.getElementType()); + return PointerType::get(this->getContext(), eltTy); +} + +template +MemRefType AllocationLowering::getMemRefResultType(OpType op) const { + return cast(op->getResult(0).getType()); +} + +template +Value AllocationLowering::getAlignment( + ConversionPatternRewriter &rewriter, Location loc, OpType op) const { + MemRefType memRefType = op.getType(); + Value alignment; + if (auto alignmentAttr = op.getAlignment()) { + Type indexType = this->getIndexType(); + alignment = + createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr); + } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { + alignment = + this->getAlignInBytes(loc, memRefType.getElementType(), rewriter); + } + return alignment; +} + +template +Value AllocationLowering::createAligned( + ConversionPatternRewriter &rewriter, Location loc, Value input, + Value alignment) const { + Value one = + this->createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); + Value bump = rewriter.create(loc, alignment.getType(), + BOp::Minus, alignment, one); + Value bumped = rewriter.create(loc, alignment.getType(), + BOp::Plus, input, bump); + Value mod = rewriter.create(loc, alignment.getType(), + BOp::Modulo, bumped, alignment); + return rewriter.create(loc, alignment.getType(), BOp::Minus, + bumped, mod); +} + +template +std::tuple +AllocationLowering::allocateBufferManuallyAlign( + ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, + OpType op, Value alignment) const { + if (alignment) { + // Adjust the allocation size to consider alignment. + sizeBytes = rewriter.create( + loc, sizeBytes.getType(), BOp::Plus, sizeBytes, alignment); + } + + MemRefType memRefType = getMemRefResultType(op); + // Allocate the underlying buffer. + Type elementPtrType = this->getElementPtrType(memRefType); + Value allocatedPtr = rewriter.create( + loc, this->getVoidPtrType(), + SymbolRefAttr::get(this->getContext(), "malloc"), ValueRange{sizeBytes}, + /* tailcall */ nullptr, /* builtin */ rewriter.getUnitAttr()); + + if (!allocatedPtr) + return std::make_tuple(Value(), Value()); + Value alignedPtr = allocatedPtr; + if (alignment) { + // Compute the aligned pointer. + Value allocatedInt = rewriter.create( + loc, this->getIndexType(), allocatedPtr); + Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); + alignedPtr = + rewriter.create(loc, elementPtrType, alignmentInt); + } else { + alignedPtr = + rewriter.create(loc, elementPtrType, allocatedPtr); + } + + return std::make_tuple(allocatedPtr, alignedPtr); +} + +template +Value AllocationLowering::allocateBufferAutoAlign( + ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, + OpType op, int64_t alignment) const { + Value allocAlignment = + createIndexAttrConstant(rewriter, loc, this->getIndexType(), alignment); + + MemRefType memRefType = getMemRefResultType(op); + sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); + + Type elementPtrType = this->getElementPtrType(memRefType); + auto result = rewriter.create( + loc, this->getVoidPtrType(), + SymbolRefAttr::get(this->getContext(), "aligned_alloc"), + ValueRange{allocAlignment, sizeBytes}, + /* tailcall */ nullptr, /* builtin */ rewriter.getUnitAttr()); + + return rewriter.create(loc, elementPtrType, result); +} + +bool isConvertibleAndHasIdentityMaps(MemRefType type, + const GCCJITTypeConverter &typeConverter) { + if (!typeConverter.convertType(type.getElementType())) + return false; + return type.getLayout().isIdentity(); +} + +template +void GCCJITLoweringPattern::getMemRefDescriptorSizes( + Location loc, MemRefType memRefType, ValueRange dynamicSizes, + ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, + SmallVectorImpl &strides, Value &size, bool sizeInBytes) const { + assert( + isConvertibleAndHasIdentityMaps(memRefType, this->getTypeConverter()) && + "layout maps must have been normalized away"); + assert(count(memRefType.getShape(), ShapedType::kDynamic) == + static_cast(dynamicSizes.size()) && + "dynamicSizes size doesn't match dynamic sizes count in memref shape"); + + sizes.reserve(memRefType.getRank()); + unsigned dynamicIndex = 0; + Type indexType = getIndexType(); + for (int64_t size : memRefType.getShape()) { + sizes.push_back( + size == ShapedType::kDynamic + ? dynamicSizes[dynamicIndex++] + : createIndexAttrConstant(rewriter, loc, indexType, size)); + } + + // Strides: iterate sizes in reverse order and multiply. + int64_t stride = 1; + Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1); + strides.resize(memRefType.getRank()); + for (auto i = memRefType.getRank(); i-- > 0;) { + strides[i] = runningStride; + + int64_t staticSize = memRefType.getShape()[i]; + bool useSizeAsStride = stride == 1; + if (staticSize == ShapedType::kDynamic) + stride = ShapedType::kDynamic; + if (stride != ShapedType::kDynamic) + stride *= staticSize; + + if (useSizeAsStride) + runningStride = sizes[i]; + else if (stride == ShapedType::kDynamic) + runningStride = rewriter.create( + loc, indexType, BOp::Mult, runningStride, sizes[i]); + else + runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride); + } + if (sizeInBytes) { + // Buffer size in bytes. + Type elementType = + this->getTypeConverter()->convertType(memRefType.getElementType()); + size = rewriter.create(loc, indexType, elementType); + } else { + size = runningStride; + } +} + +template +LogicalResult AllocationLowering::matchAndRewrite( + OpType op, typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MemRefType memRefType = getMemRefResultType(op); + if (!isConvertibleAndHasIdentityMaps(memRefType, *this->getTypeConverter())) + return rewriter.notifyMatchFailure(op, "incompatible memref type"); + auto loc = op->getLoc(); + auto convertedType = this->getTypeConverter()->convertType(memRefType); + auto exprBundle = rewriter.replaceOpWithNewOp(op, convertedType); + auto *block = rewriter.createBlock(&exprBundle.getBody()); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(block); + // Get actual sizes of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. In case of + // zero-dimensional memref, assume a scalar (size 1). + SmallVector sizes; + SmallVector strides; + Value size; + + this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(), + rewriter, sizes, strides, size, true); + + // Allocate the underlying buffer. + auto [allocatedPtr, alignedPtr] = + this->allocateBuffer(rewriter, loc, size, op); + + if (!allocatedPtr || !alignedPtr) + return rewriter.notifyMatchFailure(loc, + "underlying buffer allocation failed"); + + // Create the MemRef descriptor. + auto memRefDescriptor = rewriter.create( + loc, convertedType, ArrayRef{0, 1, 2, 3, 4}, + ValueRange{alignedPtr, allocatedPtr, size}); + + // Return the final value of the descriptor. + rewriter.create(loc, memRefDescriptor); + } + // Return the final value of the descriptor. + rewriter.replaceOp(op, exprBundle); + return success(); +} } // namespace std::unique_ptr mlir::gccjit::createConvertMemrefToGCCJITPass() {