Skip to content

Commit

Permalink
[gccjit] refactor memref lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Nov 6, 2024
1 parent 451f7cf commit d9f55b1
Showing 1 changed file with 172 additions and 114 deletions.
286 changes: 172 additions & 114 deletions src/Conversion/ConvertMemrefToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,125 +49,42 @@ class GCCJITLoweringPattern : public mlir::OpConversionPattern<T> {
return static_cast<const GCCJITTypeConverter *>(this->typeConverter);
}

public:
using OpConversionPattern<T>::OpConversionPattern;
};
IntType getIndexType() const;
Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) const;

Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType,
int64_t value) {
class MemRefDescriptor {
private:
Value descriptor;
MemRefType type;

auto indexTy = IntType::get(builder.getContext(), GCC_JIT_TYPE_SIZE_T);
auto intAttr = IntAttr::get(builder.getContext(), indexTy,
{64, static_cast<uint64_t>(value)});
return builder.create<gccjit::ConstantOp>(loc, resultType, intAttr);
}
ConversionPatternRewriter &rewriter;
const GCCJITLoweringPattern<T> &pattern;

Value getMemRefDescriptorOffset(OpBuilder &builder, Value descriptor,
Location loc) {
auto indexTy = IntType::get(builder.getContext(), GCC_JIT_TYPE_SIZE_T);
return builder.create<gccjit::AccessFieldOp>(loc, indexTy, descriptor,
builder.getIndexAttr(2));
}
MemRefDescriptor(Value descriptor, MemRefType type,
ConversionPatternRewriter &rewriter,
const GCCJITLoweringPattern<T> &pattern);

Value getMemRefDiscriptorAlignedPtr(OpBuilder &builder, Value descriptor,
const GCCJITTypeConverter &converter,
Location loc, MemRefType type) {
auto elementType = converter.convertType(type.getElementType());
auto ptrTy = PointerType::get(builder.getContext(), elementType);
return builder.create<gccjit::AccessFieldOp>(loc, ptrTy, descriptor,
builder.getIndexAttr(1));
}
public:
friend class GCCJITLoweringPattern<T>;

Value getMemRefDescriptorBufferPtr(OpBuilder &builder, Location loc,
Value descriptor,
const GCCJITTypeConverter &converter,
MemRefType type) {
auto [strides, offsetCst] = getStridesAndOffset(type);
auto alignedPtr =
getMemRefDiscriptorAlignedPtr(builder, descriptor, converter, loc, type);
Value getOffset(Location loc) const;

// For zero offsets, we already have the base pointer.
if (offsetCst == 0)
return alignedPtr;
Value getAlignedPtr(Location loc) const;

// Otherwise add the offset to the aligned base.
Type indexType = IntType::get(builder.getContext(), GCC_JIT_TYPE_SIZE_T);
Value offsetVal =
ShapedType::isDynamic(offsetCst)
? getMemRefDescriptorOffset(builder, descriptor, loc)
: createIndexAttrConstant(builder, loc, indexType, offsetCst);
Type elementType = converter.convertType(type.getElementType());
auto lvalueTy = LValueType::get(builder.getContext(), elementType);
auto lvalue =
builder.create<gccjit::DerefOp>(loc, lvalueTy, alignedPtr, offsetVal);
return builder.create<gccjit::AddrOp>(
loc, PointerType::get(builder.getContext(), elementType), lvalue);
}
Value getMemRefDescriptorBufferPtr(Location loc) const;

Value getStridedElementLValue(Location loc, MemRefType type, Value descriptor,
ExprOp parent, ValueRange indices,
const GCCJITTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
Value materializedMemref = nullptr;
Value ptrToStrideField = nullptr;
auto [strides, offset] = getStridesAndOffset(type);
auto indexTy = IntType::get(rewriter.getContext(), GCC_JIT_TYPE_SIZE_T);
auto elementType = typeConverter.convertType(type.getElementType());
auto doMaterialization = [&]() {
if (materializedMemref)
return;
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(parent);
auto lvalueTy =
LValueType::get(rewriter.getContext(), descriptor.getType());
materializedMemref = rewriter.create<gccjit::LocalOp>(
loc, lvalueTy, nullptr, nullptr, nullptr);
rewriter.create<gccjit::AssignOp>(loc, descriptor, materializedMemref);
};
auto generateStride = [&](size_t i) -> Value {
doMaterialization();
if (!ptrToStrideField) {
auto descriptorTy = cast<StructType>(descriptor.getType());
auto fieldTy = cast<ArrayType>(
cast<FieldAttr>(descriptorTy.getRecordFields()[4]).getType());
auto fieldLValueTy = LValueType::get(rewriter.getContext(), fieldTy);
auto strideField = rewriter.create<gccjit::AccessFieldOp>(
loc, fieldLValueTy, materializedMemref, rewriter.getIndexAttr(4));
auto ptrToStrideArray = rewriter.create<gccjit::AddrOp>(
loc, PointerType::get(rewriter.getContext(), fieldTy), strideField);
ptrToStrideField = rewriter.create<gccjit::BitCastOp>(
loc, PointerType::get(rewriter.getContext(), indexTy),
ptrToStrideArray);
}
auto offset = rewriter.create<gccjit::AccessFieldOp>(
loc, indexTy, ptrToStrideField, rewriter.getIndexAttr(i));
auto strideLValue = rewriter.create<gccjit::DerefOp>(
loc, LValueType::get(rewriter.getContext(), indexTy), ptrToStrideField,
offset);
return rewriter.create<gccjit::AsRValueOp>(loc, indexTy, strideLValue);
Value getStridedElementLValue(Location loc, Operation *materializationPoint,
ValueRange indices) const;
};

Value base = getMemRefDescriptorBufferPtr(rewriter, loc, descriptor,
typeConverter, type);
Value index;
for (int i = 0, e = indices.size(); i < e; ++i) {
Value increment = indices[i];
if (strides[i] != 1) { // Skip if stride is 1.
Value stride =
ShapedType::isDynamic(strides[i])
? generateStride(i)
: createIndexAttrConstant(rewriter, loc, indexTy, strides[i]);
increment = rewriter.create<gccjit::BinaryOp>(loc, indexTy, BOp::Mult,
increment, stride);
}
index = index ? rewriter.create<gccjit::BinaryOp>(loc, indexTy, BOp::Plus,
index, increment)
: increment;
}
MemRefDescriptor
getMemRefDescriptor(Value descriptor, MemRefType type,
ConversionPatternRewriter &rewriter) const;

return rewriter.create<gccjit::DerefOp>(
loc, LValueType::get(rewriter.getContext(), elementType), base, index);
}
public:
using OpConversionPattern<T>::OpConversionPattern;
};

class LoadOpLowering : public GCCJITLoweringPattern<memref::LoadOp> {
public:
Expand All @@ -180,9 +97,10 @@ class LoadOpLowering : public GCCJITLoweringPattern<memref::LoadOp> {
auto exprBundle = rewriter.replaceOpWithNewOp<ExprOp>(op, retTy);
auto *block = rewriter.createBlock(&exprBundle.getBody());
rewriter.setInsertionPointToStart(block);
Value dataLValue = getStridedElementLValue(
op.getLoc(), type, adaptor.getMemref(), exprBundle,
adaptor.getIndices(), *getTypeConverter(), rewriter);
MemRefDescriptor descriptor =
getMemRefDescriptor(adaptor.getMemref(), type, rewriter);
Value dataLValue = descriptor.getStridedElementLValue(op.getLoc(), op,
adaptor.getIndices());
auto rvalue = rewriter.create<AsRValueOp>(op.getLoc(), retTy, dataLValue);
rewriter.create<ReturnOp>(op.getLoc(), rvalue);
return success();
Expand All @@ -202,9 +120,10 @@ class StoreOpLowering : public GCCJITLoweringPattern<memref::StoreOp> {
auto *block = rewriter.createBlock(&expr.getBody());
{
rewriter.setInsertionPointToStart(block);
Value dataLValue = getStridedElementLValue(
op.getLoc(), type, adaptor.getMemref(), expr, adaptor.getIndices(),
*getTypeConverter(), rewriter);
MemRefDescriptor descriptor =
getMemRefDescriptor(adaptor.getMemref(), type, rewriter);
Value dataLValue = descriptor.getStridedElementLValue(
op->getLoc(), op, adaptor.getIndices());
rewriter.create<ReturnOp>(op.getLoc(), dataLValue);
}
rewriter.setInsertionPoint(op);
Expand Down Expand Up @@ -239,6 +158,145 @@ void ConvertMemrefToGCCJITPass::runOnOperation() {
if (failed(applyPartialConversion(ops, target, std::move(patterns))))
signalPassFailure();
}

template <typename T> IntType GCCJITLoweringPattern<T>::getIndexType() const {
return IntType::get(this->getContext(), GCC_JIT_TYPE_SIZE_T);
}

template <typename T>
Value GCCJITLoweringPattern<T>::createIndexAttrConstant(OpBuilder &builder,
Location loc,
Type resultType,
int64_t value) const {
auto indexTy = getIndexType();
auto intAttr = IntAttr::get(this->getContext(), indexTy,
{64, static_cast<uint64_t>(value)});
return builder.create<gccjit::ConstantOp>(loc, resultType, intAttr);
}

template <typename T>
GCCJITLoweringPattern<T>::MemRefDescriptor::MemRefDescriptor(
Value descriptor, MemRefType type, ConversionPatternRewriter &rewriter,
const GCCJITLoweringPattern<T> &pattern)
: descriptor(descriptor), type(type), rewriter(rewriter), pattern(pattern) {
}
template <typename T>
Value GCCJITLoweringPattern<T>::MemRefDescriptor::getOffset(
Location loc) const {
auto indexTy = pattern.getIndexType();
return rewriter.create<gccjit::AccessFieldOp>(loc, indexTy, descriptor,
rewriter.getIndexAttr(2));
}

template <typename T>
Value GCCJITLoweringPattern<T>::MemRefDescriptor::getAlignedPtr(
Location loc) const {
auto elementType =
pattern.getTypeConverter()->convertType(type.getElementType());
auto ptrTy = PointerType::get(pattern.getContext(), elementType);
return rewriter.create<gccjit::AccessFieldOp>(loc, ptrTy, descriptor,
rewriter.getIndexAttr(1));
}

template <typename T>
Value GCCJITLoweringPattern<T>::MemRefDescriptor::getMemRefDescriptorBufferPtr(
Location loc) const {
auto [strides, offsetCst] = getStridesAndOffset(type);
auto alignedPtr = getAlignedPtr(loc);
// For zero offsets, we already have the base pointer.
if (offsetCst == 0)
return alignedPtr;

// Otherwise add the offset to the aligned base.
Type indexType = pattern.getIndexType();
Value offsetVal = ShapedType::isDynamic(offsetCst)
? getOffset(loc)
: pattern.createIndexAttrConstant(rewriter, loc,
indexType, offsetCst);
Type elementType =
pattern.getTypeConverter()->convertType(type.getElementType());
auto lvalueTy = LValueType::get(rewriter.getContext(), elementType);
auto lvalue =
rewriter.create<gccjit::DerefOp>(loc, lvalueTy, alignedPtr, offsetVal);
return rewriter.create<gccjit::AddrOp>(
loc, PointerType::get(rewriter.getContext(), elementType), lvalue);
}

template <typename T>
Value GCCJITLoweringPattern<T>::MemRefDescriptor::getStridedElementLValue(
Location loc, Operation *materializationPoint, ValueRange indices) const {
Value materializedMemref = nullptr;
Value ptrToStrideField = nullptr;
auto [strides, offset] = getStridesAndOffset(type);
auto indexTy = IntType::get(rewriter.getContext(), GCC_JIT_TYPE_SIZE_T);
auto elementType =
pattern.getTypeConverter()->convertType(type.getElementType());
auto doMaterialization = [&]() {
if (materializedMemref)
return;
OpBuilder::InsertionGuard guard(rewriter);
if (materializationPoint)
rewriter.setInsertionPoint(materializationPoint);
else
rewriter.setInsertionPointAfter(descriptor.getDefiningOp());
auto lvalueTy =
LValueType::get(rewriter.getContext(), descriptor.getType());
materializedMemref = rewriter.create<gccjit::LocalOp>(
loc, lvalueTy, nullptr, nullptr, nullptr);
rewriter.create<gccjit::AssignOp>(loc, descriptor, materializedMemref);
};
auto generateStride = [&](size_t i) -> Value {
doMaterialization();
if (!ptrToStrideField) {
auto descriptorTy = cast<StructType>(descriptor.getType());
auto fieldTy = cast<ArrayType>(
cast<FieldAttr>(descriptorTy.getRecordFields()[4]).getType());
auto fieldLValueTy = LValueType::get(rewriter.getContext(), fieldTy);
auto strideField = rewriter.create<gccjit::AccessFieldOp>(
loc, fieldLValueTy, materializedMemref, rewriter.getIndexAttr(4));
auto ptrToStrideArray = rewriter.create<gccjit::AddrOp>(
loc, PointerType::get(rewriter.getContext(), fieldTy), strideField);
ptrToStrideField = rewriter.create<gccjit::BitCastOp>(
loc, PointerType::get(rewriter.getContext(), indexTy),
ptrToStrideArray);
}
auto offset = rewriter.create<gccjit::AccessFieldOp>(
loc, indexTy, ptrToStrideField, rewriter.getIndexAttr(i));
auto strideLValue = rewriter.create<gccjit::DerefOp>(
loc, LValueType::get(rewriter.getContext(), indexTy), ptrToStrideField,
offset);
return rewriter.create<gccjit::AsRValueOp>(loc, indexTy, strideLValue);
};

Value base = getMemRefDescriptorBufferPtr(loc);
Value index;
for (int i = 0, e = indices.size(); i < e; ++i) {
Value increment = indices[i];
if (strides[i] != 1) { // Skip if stride is 1.
Value stride = ShapedType::isDynamic(strides[i])
? generateStride(i)
: pattern.createIndexAttrConstant(rewriter, loc,
indexTy, strides[i]);
increment = rewriter.create<gccjit::BinaryOp>(loc, indexTy, BOp::Mult,
increment, stride);
}
index = index ? rewriter.create<gccjit::BinaryOp>(loc, indexTy, BOp::Plus,
index, increment)
: increment;
}

return rewriter.create<gccjit::DerefOp>(
loc, LValueType::get(rewriter.getContext(), elementType), base, index);
}

template <typename T>
typename GCCJITLoweringPattern<T>::MemRefDescriptor
GCCJITLoweringPattern<T>::getMemRefDescriptor(
Value descriptor, MemRefType type,
ConversionPatternRewriter &rewriter) const {
return {descriptor, type, rewriter, *this};
}

} // namespace

std::unique_ptr<Pass> mlir::gccjit::createConvertMemrefToGCCJITPass() {
Expand Down

0 comments on commit d9f55b1

Please sign in to comment.