From 3c8b1eae32e36431fdd9397c1a0e1010d9838ed6 Mon Sep 17 00:00:00 2001 From: Schrodinger ZHU Yifan Date: Mon, 4 Nov 2024 20:36:25 -0500 Subject: [PATCH] [gccjit] lower memref (#22) * [gccjit] stage * lowering store/load * fixes * enable opt --- include/mlir-gccjit/IR/GCCJITAttrs.td | 7 +- include/mlir-gccjit/IR/GCCJITOps.td | 2 +- src/Conversion/ConvertMemrefToGCCJIT.cpp | 199 ++++++++++++++++++++++- src/Conversion/TypeConverter.cpp | 6 +- src/Translation/TranslateToGCCJIT.cpp | 31 +++- test/lowering/gemm.mlir | 20 ++- test/syntax/record.mlir | 20 +-- 7 files changed, 257 insertions(+), 28 deletions(-) diff --git a/include/mlir-gccjit/IR/GCCJITAttrs.td b/include/mlir-gccjit/IR/GCCJITAttrs.td index 9b4c16e..edaeed6 100644 --- a/include/mlir-gccjit/IR/GCCJITAttrs.td +++ b/include/mlir-gccjit/IR/GCCJITAttrs.td @@ -180,7 +180,7 @@ def FieldAttr : GCCJIT_Attr<"Field", "field"> { let builders = [ AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type), [{ - return get($_ctxt, name, type, 0, std::nullopt); + return get($_ctxt, name, type, std::nullopt, std::nullopt); }]>, AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type, "unsigned":$bitWidth), [{ @@ -188,12 +188,13 @@ def FieldAttr : GCCJIT_Attr<"Field", "field"> { }]>, AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type, "mlir::gccjit::SourceLocAttr":$loc), [{ - return get($_ctxt, name, type, 0, loc); + return get($_ctxt, name, type, std::nullopt, loc); }]>, ]; + // attribute can eat up the `:` separator, so we need to move the name to the front let assemblyFormat = [{ - `<` $type $name (`:` $bitWidth^)? ($loc^)? `>` + `<` $name $type (`:` $bitWidth^)? ($loc^)? `>` }]; } diff --git a/include/mlir-gccjit/IR/GCCJITOps.td b/include/mlir-gccjit/IR/GCCJITOps.td index 34f1e68..111ae9f 100644 --- a/include/mlir-gccjit/IR/GCCJITOps.td +++ b/include/mlir-gccjit/IR/GCCJITOps.td @@ -779,7 +779,7 @@ def AccessFieldOp : GCCJIT_Op<"access_field"> { ``` }]; let arguments = (ins AnyType:$composite, IndexAttr:$field); - let results = (outs GCCJIT_LValueType:$result); + let results = (outs AnyType:$result); let assemblyFormat = [{ $composite `[` $field `]` `:` functional-type(operands, results) attr-dict }]; diff --git a/src/Conversion/ConvertMemrefToGCCJIT.cpp b/src/Conversion/ConvertMemrefToGCCJIT.cpp index 540e3cc..389ea6d 100644 --- a/src/Conversion/ConvertMemrefToGCCJIT.cpp +++ b/src/Conversion/ConvertMemrefToGCCJIT.cpp @@ -12,11 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include #include +#include +#include +#include +#include +#include +#include +#include "libgccjit.h" #include "mlir-gccjit/Conversion/Conversions.h" #include "mlir-gccjit/Conversion/TypeConverter.h" +#include "mlir-gccjit/IR/GCCJITAttrs.h" +#include "mlir-gccjit/IR/GCCJITOps.h" +#include "mlir-gccjit/IR/GCCJITOpsEnums.h" +#include "mlir-gccjit/IR/GCCJITTypes.h" #include "mlir-gccjit/Passes.h" using namespace mlir; @@ -29,13 +42,197 @@ struct ConvertMemrefToGCCJITPass void runOnOperation() override final; }; +template +class GCCJITLoweringPattern : public mlir::OpConversionPattern { +protected: + const GCCJITTypeConverter *getTypeConverter() const { + return static_cast(this->typeConverter); + } + +public: + using OpConversionPattern::OpConversionPattern; +}; + +Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, + int64_t value) { + + auto indexTy = IntType::get(builder.getContext(), GCC_JIT_TYPE_SIZE_T); + auto intAttr = IntAttr::get(builder.getContext(), indexTy, + {64, static_cast(value)}); + return builder.create(loc, resultType, intAttr); +} + +Value getMemRefDescriptorOffset(OpBuilder &builder, Value descriptor, + Location loc) { + auto indexTy = IntType::get(builder.getContext(), GCC_JIT_TYPE_SIZE_T); + return builder.create(loc, indexTy, descriptor, + builder.getIndexAttr(2)); +} + +Value getMemRefDiscriptorAlignedPtr(OpBuilder &builder, Value descriptor, + const GCCJITTypeConverter &converter, + Location loc, MemRefType type) { + auto elementType = converter.convertType(type.getElementType()); + auto ptrTy = PointerType::get(builder.getContext(), elementType); + return builder.create(loc, ptrTy, descriptor, + builder.getIndexAttr(1)); +} + +Value getMemRefDescriptorBufferPtr(OpBuilder &builder, Location loc, + Value descriptor, + const GCCJITTypeConverter &converter, + MemRefType type) { + auto [strides, offsetCst] = getStridesAndOffset(type); + auto alignedPtr = + getMemRefDiscriptorAlignedPtr(builder, descriptor, converter, loc, type); + + // For zero offsets, we already have the base pointer. + if (offsetCst == 0) + return alignedPtr; + + // Otherwise add the offset to the aligned base. + Type indexType = IntType::get(builder.getContext(), GCC_JIT_TYPE_SIZE_T); + Value offsetVal = + ShapedType::isDynamic(offsetCst) + ? getMemRefDescriptorOffset(builder, descriptor, loc) + : createIndexAttrConstant(builder, loc, indexType, offsetCst); + Type elementType = converter.convertType(type.getElementType()); + auto lvalueTy = LValueType::get(builder.getContext(), elementType); + auto lvalue = + builder.create(loc, lvalueTy, alignedPtr, offsetVal); + return builder.create( + loc, PointerType::get(builder.getContext(), elementType), lvalue); +} + +Value getStridedElementLValue(Location loc, MemRefType type, Value descriptor, + ExprOp parent, ValueRange indices, + const GCCJITTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { + Value materializedMemref = nullptr; + Value ptrToStrideField = nullptr; + auto [strides, offset] = getStridesAndOffset(type); + auto indexTy = IntType::get(rewriter.getContext(), GCC_JIT_TYPE_SIZE_T); + auto elementType = typeConverter.convertType(type.getElementType()); + auto doMaterialization = [&]() { + if (materializedMemref) + return; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(parent); + auto lvalueTy = + LValueType::get(rewriter.getContext(), descriptor.getType()); + materializedMemref = rewriter.create( + loc, lvalueTy, nullptr, nullptr, nullptr); + rewriter.create(loc, descriptor, materializedMemref); + }; + auto generateStride = [&](size_t i) -> Value { + doMaterialization(); + if (!ptrToStrideField) { + auto descriptorTy = cast(descriptor.getType()); + auto fieldTy = cast( + cast(descriptorTy.getRecordFields()[4]).getType()); + auto fieldLValueTy = LValueType::get(rewriter.getContext(), fieldTy); + auto strideField = rewriter.create( + loc, fieldLValueTy, materializedMemref, rewriter.getIndexAttr(4)); + auto ptrToStrideArray = rewriter.create( + loc, PointerType::get(rewriter.getContext(), fieldTy), strideField); + ptrToStrideField = rewriter.create( + loc, PointerType::get(rewriter.getContext(), indexTy), + ptrToStrideArray); + } + auto offset = rewriter.create( + loc, indexTy, ptrToStrideField, rewriter.getIndexAttr(i)); + auto strideLValue = rewriter.create( + loc, LValueType::get(rewriter.getContext(), indexTy), ptrToStrideField, + offset); + return rewriter.create(loc, indexTy, strideLValue); + }; + + Value base = getMemRefDescriptorBufferPtr(rewriter, loc, descriptor, + typeConverter, type); + Value index; + for (int i = 0, e = indices.size(); i < e; ++i) { + Value increment = indices[i]; + if (strides[i] != 1) { // Skip if stride is 1. + Value stride = + ShapedType::isDynamic(strides[i]) + ? generateStride(i) + : createIndexAttrConstant(rewriter, loc, indexTy, strides[i]); + increment = rewriter.create(loc, indexTy, BOp::Mult, + increment, stride); + } + index = index ? rewriter.create(loc, indexTy, BOp::Plus, + index, increment) + : increment; + } + + return rewriter.create( + loc, LValueType::get(rewriter.getContext(), elementType), base, index); +} + +class LoadOpLowering : public GCCJITLoweringPattern { +public: + using GCCJITLoweringPattern::GCCJITLoweringPattern; + mlir::LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto type = op.getMemRefType(); + auto retTy = typeConverter->convertType(op.getResult().getType()); + auto exprBundle = rewriter.replaceOpWithNewOp(op, retTy); + auto *block = rewriter.createBlock(&exprBundle.getBody()); + rewriter.setInsertionPointToStart(block); + Value dataLValue = getStridedElementLValue( + op.getLoc(), type, adaptor.getMemref(), exprBundle, + adaptor.getIndices(), *getTypeConverter(), rewriter); + auto rvalue = rewriter.create(op.getLoc(), retTy, dataLValue); + rewriter.create(op.getLoc(), rvalue); + return success(); + } +}; + +class StoreOpLowering : public GCCJITLoweringPattern { +public: + using GCCJITLoweringPattern::GCCJITLoweringPattern; + mlir::LogicalResult + matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto type = op.getMemRefType(); + auto elemTy = typeConverter->convertType(type.getElementType()); + auto elemLValueTy = LValueType::get(rewriter.getContext(), elemTy); + auto expr = rewriter.create(op->getLoc(), elemLValueTy, true); + auto *block = rewriter.createBlock(&expr.getBody()); + { + rewriter.setInsertionPointToStart(block); + Value dataLValue = getStridedElementLValue( + op.getLoc(), type, adaptor.getMemref(), expr, adaptor.getIndices(), + *getTypeConverter(), rewriter); + rewriter.create(op.getLoc(), dataLValue); + } + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp(op, adaptor.getValue(), expr); + return success(); + } +}; + void ConvertMemrefToGCCJITPass::runOnOperation() { auto moduleOp = getOperation(); auto typeConverter = GCCJITTypeConverter(); + auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return Value(); + + return builder.create(loc, resultType, inputs) + .getResult(0); + }; + typeConverter.addTargetMaterialization(materializeAsUnrealizedCast); + typeConverter.addSourceMaterialization(materializeAsUnrealizedCast); mlir::RewritePatternSet patterns(&getContext()); + patterns.insert(typeConverter, + &getContext()); mlir::ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalDialect(); + target.addIllegalDialect(); llvm::SmallVector ops; for (auto func : moduleOp.getOps()) ops.push_back(func); diff --git a/src/Conversion/TypeConverter.cpp b/src/Conversion/TypeConverter.cpp index 17a7cb0..3332ed8 100644 --- a/src/Conversion/TypeConverter.cpp +++ b/src/Conversion/TypeConverter.cpp @@ -177,7 +177,7 @@ GCCJITTypeConverter::getMemrefDescriptorType(mlir::MemRefType type) const { llvm::enumerate(ArrayRef{elementPtrType, elementPtrType, indexType, dimOrStrideType, dimOrStrideType})) { auto nameAttr = StringAttr::get(type.getContext(), names[idx]); - fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field, 0)); + fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field)); } auto fieldsAttr = ArrayAttr::get(type.getContext(), fields); return StructType::get(type.getContext(), nameAttr, fieldsAttr); @@ -199,7 +199,7 @@ gccjit::StructType GCCJITTypeConverter::getUnrankedMemrefDescriptorType( llvm::enumerate(ArrayRef{indexType, opaquePtrType})) { auto name = Twine("__field_").concat(Twine(idx)).str(); auto nameAttr = StringAttr::get(type.getContext(), name); - fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field, 0)); + fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field)); } auto fieldsAttr = ArrayAttr::get(type.getContext(), fields); return StructType::get(type.getContext(), nameAttr, fieldsAttr); @@ -220,7 +220,7 @@ Type GCCJITTypeConverter::convertAndPackTypesIfNonSingleton( for (auto [idx, type] : llvm::enumerate(types)) { auto name = Twine("__field_").concat(Twine(idx)).str(); auto nameAttr = StringAttr::get(func.getContext(), name); - fields.push_back(FieldAttr::get(type.getContext(), nameAttr, type, 0)); + fields.push_back(FieldAttr::get(type.getContext(), nameAttr, type)); } auto nameAttr = StringAttr::get(func.getContext(), name); auto fieldsAttr = ArrayAttr::get(func.getContext(), fields); diff --git a/src/Translation/TranslateToGCCJIT.cpp b/src/Translation/TranslateToGCCJIT.cpp index c19ced3..c71bede 100644 --- a/src/Translation/TranslateToGCCJIT.cpp +++ b/src/Translation/TranslateToGCCJIT.cpp @@ -107,8 +107,9 @@ class RegionVisitor { gcc_jit_rvalue *visitExprWithoutCache(AddrOp op); gcc_jit_rvalue *visitExprWithoutCache(FnAddrOp op); gcc_jit_lvalue *visitExprWithoutCache(GetGlobalOp op); - gcc_jit_rvalue *visitExprWithoutCache(ExprOp op); + Expr visitExprWithoutCache(ExprOp op); gcc_jit_lvalue *visitExprWithoutCache(DerefOp op); + Expr visitExprWithoutCache(AccessFieldOp op); /// The following operations are entrypoints for real codegen. void visitAssignOp(gcc_jit_block *blk, AssignOp op); @@ -515,18 +516,17 @@ Expr RegionVisitor::translateIntoContext() { Block &block = region.getBlocks().front(); auto terminator = cast(block.getTerminator()); auto value = terminator->getOperand(0); - auto rvalue = visitExpr(value, true); + auto expr = visitExpr(value, true); if (auto globalOp = dyn_cast(parent)) { auto symName = SymbolRefAttr::get(getMLIRContext(), globalOp.getSymName()); auto *lvalue = getTranslator().getGlobalLValue(symName); - gcc_jit_global_set_initializer_rvalue(lvalue, rvalue); + gcc_jit_global_set_initializer_rvalue(lvalue, expr); return {}; } - if (auto exprOp = dyn_cast(parent)) { - return rvalue; - } + if (auto exprOp = dyn_cast(parent)) + return expr; llvm_unreachable("unknown region parent"); } @@ -558,8 +558,8 @@ Expr RegionVisitor::visitExpr(Value value, bool toplevel) { .Case([&](GetGlobalOp op) { return visitExprWithoutCache(op); }) .Case([&](ExprOp op) { return visitExprWithoutCache(op); }) .Case([&](DerefOp op) { return visitExprWithoutCache(op); }) + .Case([&](AccessFieldOp op) { return visitExprWithoutCache(op); }) .Default([](Operation *op) -> Expr { - op->dump(); llvm::report_fatal_error("unknown expression type"); }); @@ -579,7 +579,22 @@ gcc_jit_lvalue *RegionVisitor::visitExprWithoutCache(DerefOp op) { return gcc_jit_context_new_array_access(getContext(), loc, ptr, offset); } -gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(ExprOp op) { +Expr RegionVisitor::visitExprWithoutCache(AccessFieldOp op) { + auto composite = visitExpr(op.getComposite()); + auto *loc = getTranslator().getLocation(op.getLoc()); + auto *compositeTy = getTranslator().convertType(op.getComposite().getType()); + auto index = op.getField().getZExtValue(); + // TODO: support union and query from cache instead + auto *structure = gcc_jit_type_is_struct(compositeTy); + if (!structure) + llvm_unreachable("expected struct type"); + auto *field = gcc_jit_struct_get_field(structure, index); + if (isa(op.getType())) + return gcc_jit_lvalue_access_field(composite, loc, field); + return gcc_jit_rvalue_access_field(composite, loc, field); +} + +Expr RegionVisitor::visitExprWithoutCache(ExprOp op) { RegionVisitor visitor(getTranslator(), op.getRegion(), this); return visitor.translateIntoContext(); } diff --git a/test/lowering/gemm.mlir b/test/lowering/gemm.mlir index e645f8a..099fdf4 100644 --- a/test/lowering/gemm.mlir +++ b/test/lowering/gemm.mlir @@ -1,5 +1,16 @@ -// RUN: %gccjit-opt %s -lower-affine -convert-scf-to-cf -convert-arith-to-gccjit -convert-func-to-gccjit -reconcile-unrealized-casts | %filecheck %s -module { +// RUN: %gccjit-opt %s \ +// RUN: -lower-affine \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-arith-to-gccjit \ +// RUN: -convert-memref-to-gccjit \ +// RUN: -convert-func-to-gccjit \ +// RUN: -reconcile-unrealized-casts -mlir-print-debuginfo -o %t.mlir +// RUN: %filecheck --input-file=%t.mlir %s +// RUN: %gccjit-translate %t.mlir -mlir-to-gccjit-gimple | %filecheck %s --check-prefix=CHECK-GIMPLE +module @test attributes { + gccjit.opt_level = #gccjit.opt_level +} +{ // CHECK-NOT: func.func // CHECK-NOT: func.return // CHECK-NOT: cf.cond_br @@ -15,11 +26,15 @@ module { %acc0 = arith.constant 0.0 : f32 %sum = affine.for %k = 0 to 100 iter_args(%acc = %acc0) -> f32 { // Load values from A and B + // CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})] %a_val = affine.load %A[%i, %k] : memref<100x100xf32> + // CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})] %b_val = affine.load %B[%k, %j] : memref<100x100xf32> // Multiply and accumulate + // CHECK-GIMPLE: %[[V:[0-9]+]] = %{{[0-9]+}} * %{{[0-9]+}} %prod = arith.mulf %a_val, %b_val : f32 + // CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9]+}} + %[[V]] %new_acc = arith.addf %acc, %prod : f32 // Yield the new accumulated value @@ -33,6 +48,7 @@ module { %final_val = arith.addf %c_scaled, %result : f32 // Store the final result back to matrix C + // CHECK-GIMPLE: %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})] = %{{[0-9]+}} affine.store %final_val, %C[%i, %j] : memref<100x100xf32> } } diff --git a/test/syntax/record.mlir b/test/syntax/record.mlir index f98f701..71f29a7 100644 --- a/test/syntax/record.mlir +++ b/test/syntax/record.mlir @@ -4,19 +4,19 @@ module @test { gccjit.func imported @gemm ( !gccjit.struct<"__memref_188510220862752" { - #gccjit.field> "base">, - #gccjit.field> "aligned">, - #gccjit.field "offset">, - #gccjit.field, 2> "sizes">, - #gccjit.field, 2> "strides"> + #gccjit.field<"base" !gccjit.ptr>>, + #gccjit.field<"aligned" !gccjit.ptr>>, + #gccjit.field<"offset" !gccjit.int : 32>, + #gccjit.field<"sizes" !gccjit.array, 2>>, + #gccjit.field<"strides" !gccjit.array, 2>> }> ) // CHECK: @gemm // CHECK-SAME: !gccjit.struct<"__memref_188510220862752" { - // CHECK-SAME: #gccjit.field> "base"> - // CHECK-SAME: #gccjit.field> "aligned"> - // CHECK-SAME: #gccjit.field "offset"> - // CHECK-SAME: #gccjit.field, 2> "sizes"> - // CHECK-SAME: #gccjit.field, 2> "strides"> + // CHECK-SAME: #gccjit.field<"base" !gccjit.ptr>> + // CHECK-SAME: #gccjit.field<"aligned" !gccjit.ptr>> + // CHECK-SAME: #gccjit.field<"offset" !gccjit.int : 32> + // CHECK-SAME: #gccjit.field<"sizes" !gccjit.array, 2>> + // CHECK-SAME: #gccjit.field<"strides" !gccjit.array, 2>> // CHECK-SAME: } }