Skip to content

Commit

Permalink
stage
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Nov 6, 2024
1 parent d9f55b1 commit 96125ce
Showing 1 changed file with 280 additions and 0 deletions.
280 changes: 280 additions & 0 deletions src/Conversion/ConvertMemrefToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <llvm-20/llvm/Support/LogicalResult.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/ErrorHandling.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
Expand All @@ -31,6 +32,7 @@
#include "mlir-gccjit/IR/GCCJITOpsEnums.h"
#include "mlir-gccjit/IR/GCCJITTypes.h"
#include "mlir-gccjit/Passes.h"
#include "mlir/IR/Types.h"

using namespace mlir;
using namespace mlir::gccjit;
Expand All @@ -50,8 +52,24 @@ class GCCJITLoweringPattern : public mlir::OpConversionPattern<T> {
}

IntType getIndexType() const;
PointerType getVoidPtrType() const {
return PointerType::get(this->getContext(),
VoidType::get(this->getContext()));
}
Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) const;
Value getSizeInBytes(Location loc, Type type,
ConversionPatternRewriter &rewriter) const;
Value getAlignInBytes(Location loc, Type type,
ConversionPatternRewriter &rewriter) const;
PointerType getElementPtrType(MemRefType type) const;

void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
ValueRange dynamicSizes,
ConversionPatternRewriter &rewriter,
SmallVectorImpl<Value> &sizes,
SmallVectorImpl<Value> &strides, Value &size,
bool sizeInBytes) const;

class MemRefDescriptor {
private:
Expand Down Expand Up @@ -86,6 +104,47 @@ class GCCJITLoweringPattern : public mlir::OpConversionPattern<T> {
using OpConversionPattern<T>::OpConversionPattern;
};

template <typename OpType>
class AllocationLowering : public GCCJITLoweringPattern<OpType> {
protected:
/// Computes the aligned value for 'input' as follows:
/// bumped = input + alignement - 1
/// aligned = bumped - bumped % alignment
Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
Value input, Value alignment) const;

MemRefType getMemRefResultType(OpType op) const;

Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
OpType op) const;

int64_t alignedAllocationGetAlignment(ConversionPatternRewriter &rewriter,
Location loc, OpType op) const;

std::tuple<Value, Value>
allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc,
Value sizeBytes, OpType op,
Value alignment) const;

/// Allocates a memory buffer using an aligned allocation method.
Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter,
Location loc, Value sizeBytes, OpType op,
int64_t alignment) const;

virtual std::tuple<Value, Value>
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size,
Operation *op) const = 0;

private:
static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;

public:
LogicalResult
matchAndRewrite(OpType op,
typename OpConversionPattern<OpType>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final;
};

class LoadOpLowering : public GCCJITLoweringPattern<memref::LoadOp> {
public:
using GCCJITLoweringPattern::GCCJITLoweringPattern;
Expand Down Expand Up @@ -297,6 +356,227 @@ GCCJITLoweringPattern<T>::getMemRefDescriptor(
return {descriptor, type, rewriter, *this};
}

template <typename T>
Value GCCJITLoweringPattern<T>::getSizeInBytes(
Location loc, Type type, ConversionPatternRewriter &rewriter) const {
Type gccjitType = getTypeConverter()->convertType(type);
auto indexType = getIndexType();
return rewriter.create<gccjit::SizeOfOp>(loc, indexType, gccjitType);
}

template <typename T>
Value GCCJITLoweringPattern<T>::getAlignInBytes(
Location loc, Type type, ConversionPatternRewriter &rewriter) const {
Type gccjitType = getTypeConverter()->convertType(type);
auto indexType = getIndexType();
return rewriter.create<gccjit::AlignOfOp>(loc, indexType, gccjitType);
}

template <typename T>
PointerType GCCJITLoweringPattern<T>::getElementPtrType(MemRefType type) const {
auto eltTy = getTypeConverter()->convertType(type.getElementType());
return PointerType::get(this->getContext(), eltTy);
}

template <typename OpType>
MemRefType AllocationLowering<OpType>::getMemRefResultType(OpType op) const {
return cast<MemRefType>(op->getResult(0).getType());
}

template <typename OpType>
Value AllocationLowering<OpType>::getAlignment(
ConversionPatternRewriter &rewriter, Location loc, OpType op) const {
MemRefType memRefType = op.getType();
Value alignment;
if (auto alignmentAttr = op.getAlignment()) {
Type indexType = this->getIndexType();
alignment =
createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
} else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
alignment =
this->getAlignInBytes(loc, memRefType.getElementType(), rewriter);
}
return alignment;
}

template <typename OpType>
Value AllocationLowering<OpType>::createAligned(
ConversionPatternRewriter &rewriter, Location loc, Value input,
Value alignment) const {
Value one =
this->createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
Value bump = rewriter.create<gccjit::BinaryOp>(loc, alignment.getType(),
BOp::Minus, alignment, one);
Value bumped = rewriter.create<gccjit::BinaryOp>(loc, alignment.getType(),
BOp::Plus, input, bump);
Value mod = rewriter.create<gccjit::BinaryOp>(loc, alignment.getType(),
BOp::Modulo, bumped, alignment);
return rewriter.create<gccjit::BinaryOp>(loc, alignment.getType(), BOp::Minus,
bumped, mod);
}

template <typename OpType>
std::tuple<Value, Value>
AllocationLowering<OpType>::allocateBufferManuallyAlign(
ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
OpType op, Value alignment) const {
if (alignment) {
// Adjust the allocation size to consider alignment.
sizeBytes = rewriter.create<gccjit::BinaryOp>(
loc, sizeBytes.getType(), BOp::Plus, sizeBytes, alignment);
}

MemRefType memRefType = getMemRefResultType(op);
// Allocate the underlying buffer.
Type elementPtrType = this->getElementPtrType(memRefType);
Value allocatedPtr = rewriter.create<gccjit::CallOp>(
loc, this->getVoidPtrType(),
SymbolRefAttr::get(this->getContext(), "malloc"), ValueRange{sizeBytes},
/* tailcall */ nullptr, /* builtin */ rewriter.getUnitAttr());

if (!allocatedPtr)
return std::make_tuple(Value(), Value());
Value alignedPtr = allocatedPtr;
if (alignment) {
// Compute the aligned pointer.
Value allocatedInt = rewriter.create<gccjit::BitCastOp>(
loc, this->getIndexType(), allocatedPtr);
Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
alignedPtr =
rewriter.create<gccjit::BitCastOp>(loc, elementPtrType, alignmentInt);
} else {
alignedPtr =
rewriter.create<gccjit::BitCastOp>(loc, elementPtrType, allocatedPtr);
}

return std::make_tuple(allocatedPtr, alignedPtr);
}

template <typename OpType>
Value AllocationLowering<OpType>::allocateBufferAutoAlign(
ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
OpType op, int64_t alignment) const {
Value allocAlignment =
createIndexAttrConstant(rewriter, loc, this->getIndexType(), alignment);

MemRefType memRefType = getMemRefResultType(op);
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);

Type elementPtrType = this->getElementPtrType(memRefType);
auto result = rewriter.create<gccjit::CallOp>(
loc, this->getVoidPtrType(),
SymbolRefAttr::get(this->getContext(), "aligned_alloc"),
ValueRange{allocAlignment, sizeBytes},
/* tailcall */ nullptr, /* builtin */ rewriter.getUnitAttr());

return rewriter.create<gccjit::BitCastOp>(loc, elementPtrType, result);
}

bool isConvertibleAndHasIdentityMaps(MemRefType type,
const GCCJITTypeConverter &typeConverter) {
if (!typeConverter.convertType(type.getElementType()))
return false;
return type.getLayout().isIdentity();
}

template <typename OpType>
void GCCJITLoweringPattern<OpType>::getMemRefDescriptorSizes(
Location loc, MemRefType memRefType, ValueRange dynamicSizes,
ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
assert(
isConvertibleAndHasIdentityMaps(memRefType, this->getTypeConverter()) &&
"layout maps must have been normalized away");
assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
static_cast<ssize_t>(dynamicSizes.size()) &&
"dynamicSizes size doesn't match dynamic sizes count in memref shape");

sizes.reserve(memRefType.getRank());
unsigned dynamicIndex = 0;
Type indexType = getIndexType();
for (int64_t size : memRefType.getShape()) {
sizes.push_back(
size == ShapedType::kDynamic
? dynamicSizes[dynamicIndex++]
: createIndexAttrConstant(rewriter, loc, indexType, size));
}

// Strides: iterate sizes in reverse order and multiply.
int64_t stride = 1;
Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
strides.resize(memRefType.getRank());
for (auto i = memRefType.getRank(); i-- > 0;) {
strides[i] = runningStride;

int64_t staticSize = memRefType.getShape()[i];
bool useSizeAsStride = stride == 1;
if (staticSize == ShapedType::kDynamic)
stride = ShapedType::kDynamic;
if (stride != ShapedType::kDynamic)
stride *= staticSize;

if (useSizeAsStride)
runningStride = sizes[i];
else if (stride == ShapedType::kDynamic)
runningStride = rewriter.create<gccjit::BinaryOp>(
loc, indexType, BOp::Mult, runningStride, sizes[i]);
else
runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
}
if (sizeInBytes) {
// Buffer size in bytes.
Type elementType =
this->getTypeConverter()->convertType(memRefType.getElementType());
size = rewriter.create<gccjit::SizeOfOp>(loc, indexType, elementType);
} else {
size = runningStride;
}
}

template <typename OpType>
LogicalResult AllocationLowering<OpType>::matchAndRewrite(
OpType op, typename OpConversionPattern<OpType>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType memRefType = getMemRefResultType(op);
if (!isConvertibleAndHasIdentityMaps(memRefType, *this->getTypeConverter()))
return rewriter.notifyMatchFailure(op, "incompatible memref type");
auto loc = op->getLoc();
auto convertedType = this->getTypeConverter()->convertType(memRefType);
auto exprBundle = rewriter.replaceOpWithNewOp<ExprOp>(op, convertedType);
auto *block = rewriter.createBlock(&exprBundle.getBody());
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(block);
// Get actual sizes of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands. In case of
// zero-dimensional memref, assume a scalar (size 1).
SmallVector<Value, 4> sizes;
SmallVector<Value, 4> strides;
Value size;

this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
rewriter, sizes, strides, size, true);

// Allocate the underlying buffer.
auto [allocatedPtr, alignedPtr] =
this->allocateBuffer(rewriter, loc, size, op);

if (!allocatedPtr || !alignedPtr)
return rewriter.notifyMatchFailure(loc,
"underlying buffer allocation failed");

// Create the MemRef descriptor.
auto memRefDescriptor = rewriter.create<gccjit::NewStructOp>(
loc, convertedType, ArrayRef<int32_t>{0, 1, 2, 3, 4},
ValueRange{alignedPtr, allocatedPtr, size});

// Return the final value of the descriptor.
rewriter.create<ReturnOp>(loc, memRefDescriptor);
}
// Return the final value of the descriptor.
rewriter.replaceOp(op, exprBundle);
return success();
}
} // namespace

std::unique_ptr<Pass> mlir::gccjit::createConvertMemrefToGCCJITPass() {
Expand Down

0 comments on commit 96125ce

Please sign in to comment.