Skip to content

Commit

Permalink
[gccjit] lower memref (#22)
Browse files Browse the repository at this point in the history
* [gccjit] stage

* lowering store/load

* fixes

* enable opt
  • Loading branch information
SchrodingerZhu authored Nov 5, 2024
1 parent 51e5908 commit 3c8b1ea
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 28 deletions.
7 changes: 4 additions & 3 deletions include/mlir-gccjit/IR/GCCJITAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,21 @@ def FieldAttr : GCCJIT_Attr<"Field", "field"> {

let builders = [
AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type), [{
return get($_ctxt, name, type, 0, std::nullopt);
return get($_ctxt, name, type, std::nullopt, std::nullopt);
}]>,
AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type,
"unsigned":$bitWidth), [{
return get($_ctxt, name, type, bitWidth, std::nullopt);
}]>,
AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type,
"mlir::gccjit::SourceLocAttr":$loc), [{
return get($_ctxt, name, type, 0, loc);
return get($_ctxt, name, type, std::nullopt, loc);
}]>,
];

// attribute can eat up the `:` separator, so we need to move the name to the front
let assemblyFormat = [{
`<` $type $name (`:` $bitWidth^)? ($loc^)? `>`
`<` $name $type (`:` $bitWidth^)? ($loc^)? `>`
}];
}

Expand Down
2 changes: 1 addition & 1 deletion include/mlir-gccjit/IR/GCCJITOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def AccessFieldOp : GCCJIT_Op<"access_field"> {
```
}];
let arguments = (ins AnyType:$composite, IndexAttr:$field);
let results = (outs GCCJIT_LValueType:$result);
let results = (outs AnyType:$result);
let assemblyFormat = [{
$composite `[` $field `]` `:` functional-type(operands, results) attr-dict
}];
Expand Down
199 changes: 198 additions & 1 deletion src/Conversion/ConvertMemrefToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <llvm/Support/Casting.h>
#include <llvm/Support/ErrorHandling.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinTypeInterfaces.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Value.h>

#include "libgccjit.h"
#include "mlir-gccjit/Conversion/Conversions.h"
#include "mlir-gccjit/Conversion/TypeConverter.h"
#include "mlir-gccjit/IR/GCCJITAttrs.h"
#include "mlir-gccjit/IR/GCCJITOps.h"
#include "mlir-gccjit/IR/GCCJITOpsEnums.h"
#include "mlir-gccjit/IR/GCCJITTypes.h"
#include "mlir-gccjit/Passes.h"

using namespace mlir;
Expand All @@ -29,13 +42,197 @@ struct ConvertMemrefToGCCJITPass
void runOnOperation() override final;
};

template <typename T>
class GCCJITLoweringPattern : public mlir::OpConversionPattern<T> {
protected:
const GCCJITTypeConverter *getTypeConverter() const {
return static_cast<const GCCJITTypeConverter *>(this->typeConverter);
}

public:
using OpConversionPattern<T>::OpConversionPattern;
};

Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType,
int64_t value) {

auto indexTy = IntType::get(builder.getContext(), GCC_JIT_TYPE_SIZE_T);
auto intAttr = IntAttr::get(builder.getContext(), indexTy,
{64, static_cast<uint64_t>(value)});
return builder.create<gccjit::ConstantOp>(loc, resultType, intAttr);
}

Value getMemRefDescriptorOffset(OpBuilder &builder, Value descriptor,
Location loc) {
auto indexTy = IntType::get(builder.getContext(), GCC_JIT_TYPE_SIZE_T);
return builder.create<gccjit::AccessFieldOp>(loc, indexTy, descriptor,
builder.getIndexAttr(2));
}

Value getMemRefDiscriptorAlignedPtr(OpBuilder &builder, Value descriptor,
const GCCJITTypeConverter &converter,
Location loc, MemRefType type) {
auto elementType = converter.convertType(type.getElementType());
auto ptrTy = PointerType::get(builder.getContext(), elementType);
return builder.create<gccjit::AccessFieldOp>(loc, ptrTy, descriptor,
builder.getIndexAttr(1));
}

Value getMemRefDescriptorBufferPtr(OpBuilder &builder, Location loc,
Value descriptor,
const GCCJITTypeConverter &converter,
MemRefType type) {
auto [strides, offsetCst] = getStridesAndOffset(type);
auto alignedPtr =
getMemRefDiscriptorAlignedPtr(builder, descriptor, converter, loc, type);

// For zero offsets, we already have the base pointer.
if (offsetCst == 0)
return alignedPtr;

// Otherwise add the offset to the aligned base.
Type indexType = IntType::get(builder.getContext(), GCC_JIT_TYPE_SIZE_T);
Value offsetVal =
ShapedType::isDynamic(offsetCst)
? getMemRefDescriptorOffset(builder, descriptor, loc)
: createIndexAttrConstant(builder, loc, indexType, offsetCst);
Type elementType = converter.convertType(type.getElementType());
auto lvalueTy = LValueType::get(builder.getContext(), elementType);
auto lvalue =
builder.create<gccjit::DerefOp>(loc, lvalueTy, alignedPtr, offsetVal);
return builder.create<gccjit::AddrOp>(
loc, PointerType::get(builder.getContext(), elementType), lvalue);
}

Value getStridedElementLValue(Location loc, MemRefType type, Value descriptor,
ExprOp parent, ValueRange indices,
const GCCJITTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
Value materializedMemref = nullptr;
Value ptrToStrideField = nullptr;
auto [strides, offset] = getStridesAndOffset(type);
auto indexTy = IntType::get(rewriter.getContext(), GCC_JIT_TYPE_SIZE_T);
auto elementType = typeConverter.convertType(type.getElementType());
auto doMaterialization = [&]() {
if (materializedMemref)
return;
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(parent);
auto lvalueTy =
LValueType::get(rewriter.getContext(), descriptor.getType());
materializedMemref = rewriter.create<gccjit::LocalOp>(
loc, lvalueTy, nullptr, nullptr, nullptr);
rewriter.create<gccjit::AssignOp>(loc, descriptor, materializedMemref);
};
auto generateStride = [&](size_t i) -> Value {
doMaterialization();
if (!ptrToStrideField) {
auto descriptorTy = cast<StructType>(descriptor.getType());
auto fieldTy = cast<ArrayType>(
cast<FieldAttr>(descriptorTy.getRecordFields()[4]).getType());
auto fieldLValueTy = LValueType::get(rewriter.getContext(), fieldTy);
auto strideField = rewriter.create<gccjit::AccessFieldOp>(
loc, fieldLValueTy, materializedMemref, rewriter.getIndexAttr(4));
auto ptrToStrideArray = rewriter.create<gccjit::AddrOp>(
loc, PointerType::get(rewriter.getContext(), fieldTy), strideField);
ptrToStrideField = rewriter.create<gccjit::BitCastOp>(
loc, PointerType::get(rewriter.getContext(), indexTy),
ptrToStrideArray);
}
auto offset = rewriter.create<gccjit::AccessFieldOp>(
loc, indexTy, ptrToStrideField, rewriter.getIndexAttr(i));
auto strideLValue = rewriter.create<gccjit::DerefOp>(
loc, LValueType::get(rewriter.getContext(), indexTy), ptrToStrideField,
offset);
return rewriter.create<gccjit::AsRValueOp>(loc, indexTy, strideLValue);
};

Value base = getMemRefDescriptorBufferPtr(rewriter, loc, descriptor,
typeConverter, type);
Value index;
for (int i = 0, e = indices.size(); i < e; ++i) {
Value increment = indices[i];
if (strides[i] != 1) { // Skip if stride is 1.
Value stride =
ShapedType::isDynamic(strides[i])
? generateStride(i)
: createIndexAttrConstant(rewriter, loc, indexTy, strides[i]);
increment = rewriter.create<gccjit::BinaryOp>(loc, indexTy, BOp::Mult,
increment, stride);
}
index = index ? rewriter.create<gccjit::BinaryOp>(loc, indexTy, BOp::Plus,
index, increment)
: increment;
}

return rewriter.create<gccjit::DerefOp>(
loc, LValueType::get(rewriter.getContext(), elementType), base, index);
}

class LoadOpLowering : public GCCJITLoweringPattern<memref::LoadOp> {
public:
using GCCJITLoweringPattern::GCCJITLoweringPattern;
mlir::LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto type = op.getMemRefType();
auto retTy = typeConverter->convertType(op.getResult().getType());
auto exprBundle = rewriter.replaceOpWithNewOp<ExprOp>(op, retTy);
auto *block = rewriter.createBlock(&exprBundle.getBody());
rewriter.setInsertionPointToStart(block);
Value dataLValue = getStridedElementLValue(
op.getLoc(), type, adaptor.getMemref(), exprBundle,
adaptor.getIndices(), *getTypeConverter(), rewriter);
auto rvalue = rewriter.create<AsRValueOp>(op.getLoc(), retTy, dataLValue);
rewriter.create<ReturnOp>(op.getLoc(), rvalue);
return success();
}
};

class StoreOpLowering : public GCCJITLoweringPattern<memref::StoreOp> {
public:
using GCCJITLoweringPattern::GCCJITLoweringPattern;
mlir::LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto type = op.getMemRefType();
auto elemTy = typeConverter->convertType(type.getElementType());
auto elemLValueTy = LValueType::get(rewriter.getContext(), elemTy);
auto expr = rewriter.create<ExprOp>(op->getLoc(), elemLValueTy, true);
auto *block = rewriter.createBlock(&expr.getBody());
{
rewriter.setInsertionPointToStart(block);
Value dataLValue = getStridedElementLValue(
op.getLoc(), type, adaptor.getMemref(), expr, adaptor.getIndices(),
*getTypeConverter(), rewriter);
rewriter.create<ReturnOp>(op.getLoc(), dataLValue);
}
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<AssignOp>(op, adaptor.getValue(), expr);
return success();
}
};

void ConvertMemrefToGCCJITPass::runOnOperation() {
auto moduleOp = getOperation();
auto typeConverter = GCCJITTypeConverter();
auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
};
typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
mlir::RewritePatternSet patterns(&getContext());
patterns.insert<LoadOpLowering, StoreOpLowering>(typeConverter,
&getContext());
mlir::ConversionTarget target(getContext());
target.addLegalDialect<gccjit::GCCJITDialect>();
target.addIllegalDialect<mlir::memref::MemRefDialect>();
target.addIllegalDialect<memref::MemRefDialect>();
llvm::SmallVector<Operation *> ops;
for (auto func : moduleOp.getOps<func::FuncOp>())
ops.push_back(func);
Expand Down
6 changes: 3 additions & 3 deletions src/Conversion/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ GCCJITTypeConverter::getMemrefDescriptorType(mlir::MemRefType type) const {
llvm::enumerate(ArrayRef<Type>{elementPtrType, elementPtrType, indexType,
dimOrStrideType, dimOrStrideType})) {
auto nameAttr = StringAttr::get(type.getContext(), names[idx]);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field, 0));
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field));
}
auto fieldsAttr = ArrayAttr::get(type.getContext(), fields);
return StructType::get(type.getContext(), nameAttr, fieldsAttr);
Expand All @@ -199,7 +199,7 @@ gccjit::StructType GCCJITTypeConverter::getUnrankedMemrefDescriptorType(
llvm::enumerate(ArrayRef<Type>{indexType, opaquePtrType})) {
auto name = Twine("__field_").concat(Twine(idx)).str();
auto nameAttr = StringAttr::get(type.getContext(), name);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field, 0));
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field));
}
auto fieldsAttr = ArrayAttr::get(type.getContext(), fields);
return StructType::get(type.getContext(), nameAttr, fieldsAttr);
Expand All @@ -220,7 +220,7 @@ Type GCCJITTypeConverter::convertAndPackTypesIfNonSingleton(
for (auto [idx, type] : llvm::enumerate(types)) {
auto name = Twine("__field_").concat(Twine(idx)).str();
auto nameAttr = StringAttr::get(func.getContext(), name);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, type, 0));
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, type));
}
auto nameAttr = StringAttr::get(func.getContext(), name);
auto fieldsAttr = ArrayAttr::get(func.getContext(), fields);
Expand Down
31 changes: 23 additions & 8 deletions src/Translation/TranslateToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ class RegionVisitor {
gcc_jit_rvalue *visitExprWithoutCache(AddrOp op);
gcc_jit_rvalue *visitExprWithoutCache(FnAddrOp op);
gcc_jit_lvalue *visitExprWithoutCache(GetGlobalOp op);
gcc_jit_rvalue *visitExprWithoutCache(ExprOp op);
Expr visitExprWithoutCache(ExprOp op);
gcc_jit_lvalue *visitExprWithoutCache(DerefOp op);
Expr visitExprWithoutCache(AccessFieldOp op);

/// The following operations are entrypoints for real codegen.
void visitAssignOp(gcc_jit_block *blk, AssignOp op);
Expand Down Expand Up @@ -515,18 +516,17 @@ Expr RegionVisitor::translateIntoContext() {
Block &block = region.getBlocks().front();
auto terminator = cast<gccjit::ReturnOp>(block.getTerminator());
auto value = terminator->getOperand(0);
auto rvalue = visitExpr(value, true);
auto expr = visitExpr(value, true);

if (auto globalOp = dyn_cast<gccjit::GlobalOp>(parent)) {
auto symName = SymbolRefAttr::get(getMLIRContext(), globalOp.getSymName());
auto *lvalue = getTranslator().getGlobalLValue(symName);
gcc_jit_global_set_initializer_rvalue(lvalue, rvalue);
gcc_jit_global_set_initializer_rvalue(lvalue, expr);
return {};
}

if (auto exprOp = dyn_cast<ExprOp>(parent)) {
return rvalue;
}
if (auto exprOp = dyn_cast<ExprOp>(parent))
return expr;

llvm_unreachable("unknown region parent");
}
Expand Down Expand Up @@ -558,8 +558,8 @@ Expr RegionVisitor::visitExpr(Value value, bool toplevel) {
.Case([&](GetGlobalOp op) { return visitExprWithoutCache(op); })
.Case([&](ExprOp op) { return visitExprWithoutCache(op); })
.Case([&](DerefOp op) { return visitExprWithoutCache(op); })
.Case([&](AccessFieldOp op) { return visitExprWithoutCache(op); })
.Default([](Operation *op) -> Expr {
op->dump();
llvm::report_fatal_error("unknown expression type");
});

Expand All @@ -579,7 +579,22 @@ gcc_jit_lvalue *RegionVisitor::visitExprWithoutCache(DerefOp op) {
return gcc_jit_context_new_array_access(getContext(), loc, ptr, offset);
}

gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(ExprOp op) {
Expr RegionVisitor::visitExprWithoutCache(AccessFieldOp op) {
auto composite = visitExpr(op.getComposite());
auto *loc = getTranslator().getLocation(op.getLoc());
auto *compositeTy = getTranslator().convertType(op.getComposite().getType());
auto index = op.getField().getZExtValue();
// TODO: support union and query from cache instead
auto *structure = gcc_jit_type_is_struct(compositeTy);
if (!structure)
llvm_unreachable("expected struct type");
auto *field = gcc_jit_struct_get_field(structure, index);
if (isa<LValueType>(op.getType()))
return gcc_jit_lvalue_access_field(composite, loc, field);
return gcc_jit_rvalue_access_field(composite, loc, field);
}

Expr RegionVisitor::visitExprWithoutCache(ExprOp op) {
RegionVisitor visitor(getTranslator(), op.getRegion(), this);
return visitor.translateIntoContext();
}
Expand Down
20 changes: 18 additions & 2 deletions test/lowering/gemm.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
// RUN: %gccjit-opt %s -lower-affine -convert-scf-to-cf -convert-arith-to-gccjit -convert-func-to-gccjit -reconcile-unrealized-casts | %filecheck %s
module {
// RUN: %gccjit-opt %s \
// RUN: -lower-affine \
// RUN: -convert-scf-to-cf \
// RUN: -convert-arith-to-gccjit \
// RUN: -convert-memref-to-gccjit \
// RUN: -convert-func-to-gccjit \
// RUN: -reconcile-unrealized-casts -mlir-print-debuginfo -o %t.mlir
// RUN: %filecheck --input-file=%t.mlir %s
// RUN: %gccjit-translate %t.mlir -mlir-to-gccjit-gimple | %filecheck %s --check-prefix=CHECK-GIMPLE
module @test attributes {
gccjit.opt_level = #gccjit.opt_level<O3>
}
{
// CHECK-NOT: func.func
// CHECK-NOT: func.return
// CHECK-NOT: cf.cond_br
Expand All @@ -15,11 +26,15 @@ module {
%acc0 = arith.constant 0.0 : f32
%sum = affine.for %k = 0 to 100 iter_args(%acc = %acc0) -> f32 {
// Load values from A and B
// CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})]
%a_val = affine.load %A[%i, %k] : memref<100x100xf32>
// CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})]
%b_val = affine.load %B[%k, %j] : memref<100x100xf32>

// Multiply and accumulate
// CHECK-GIMPLE: %[[V:[0-9]+]] = %{{[0-9]+}} * %{{[0-9]+}}
%prod = arith.mulf %a_val, %b_val : f32
// CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9]+}} + %[[V]]
%new_acc = arith.addf %acc, %prod : f32

// Yield the new accumulated value
Expand All @@ -33,6 +48,7 @@ module {
%final_val = arith.addf %c_scaled, %result : f32

// Store the final result back to matrix C
// CHECK-GIMPLE: %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})] = %{{[0-9]+}}
affine.store %final_val, %C[%i, %j] : memref<100x100xf32>
}
}
Expand Down
Loading

0 comments on commit 3c8b1ea

Please sign in to comment.