diff --git a/include/mlir-gccjit/IR/GCCJITAttrs.td b/include/mlir-gccjit/IR/GCCJITAttrs.td index a11360c..f073c71 100644 --- a/include/mlir-gccjit/IR/GCCJITAttrs.td +++ b/include/mlir-gccjit/IR/GCCJITAttrs.td @@ -149,7 +149,7 @@ def NullAttr : GCCJIT_Attr<"Null", "null", [TypedAttrInterface]> { def OptLevel_0 : I32EnumAttrCase<"O0", 0, "O0">; def OptLevel_1 : I32EnumAttrCase<"O1", 1, "O1">; def OptLevel_2 : I32EnumAttrCase<"O2", 2, "O2">; -def OptLevel_3 : I32EnumAttrCase<"O3", 3, "O2">; +def OptLevel_3 : I32EnumAttrCase<"O3", 3, "O3">; def OptLevelEnum : I32EnumAttr<"OptLevelEnum", "Optimization level", [OptLevel_0, OptLevel_1, OptLevel_2, OptLevel_3]> { let cppNamespace = "mlir::gccjit"; diff --git a/include/mlir-gccjit/IR/GCCJITOps.td b/include/mlir-gccjit/IR/GCCJITOps.td index fa8d6e7..bc0cce2 100644 --- a/include/mlir-gccjit/IR/GCCJITOps.td +++ b/include/mlir-gccjit/IR/GCCJITOps.td @@ -569,7 +569,7 @@ def CallOp : GCCJIT_Op<"call"> { UnitAttr:$tail, UnitAttr:$builtin ); - let results = (outs Optional:$result); + let results = (outs AnyType:$result); let assemblyFormat = [{ custom($tail) custom($builtin) @@ -590,7 +590,7 @@ def PtrCallOp : GCCJIT_Op<"ptr_call"> { ``` }]; let arguments = (ins GCCJIT_PointerType:$callee, Variadic:$args, UnitAttr:$tail); - let results = (outs Optional:$result); + let results = (outs AnyType:$result); let assemblyFormat = [{ custom($tail) $callee `(` $args `)` `:` functional-type(operands, results) attr-dict diff --git a/include/mlir-gccjit/Translation/TranslateToGCCJIT.h b/include/mlir-gccjit/Translation/TranslateToGCCJIT.h index 4ade0b3..9a8b77c 100644 --- a/include/mlir-gccjit/Translation/TranslateToGCCJIT.h +++ b/include/mlir-gccjit/Translation/TranslateToGCCJIT.h @@ -26,6 +26,10 @@ namespace mlir::gccjit { void registerToGCCJITGimpleTranslation(); void registerToGCCJITReproducerTranslation(); +void registerToGCCJITAssemblyTranslation(); +void registerToGCCJITObjectTranslation(); +void registerToGCCJITExecutableTranslation(); +void registerToGCCJITDylibTranslation(); struct GCCJITContextDeleter { void operator()(gcc_jit_context *ctxt) const; @@ -79,6 +83,7 @@ class GCCJITTranslation { void populateGCCJITModuleOptions(); void declareAllFunctionAndGlobals(); void translateGlobalInitializers(); + void translateFunctions(); }; llvm::Expected translateModuleToGCCJIT(ModuleOp op); diff --git a/src/Translation/Registration.cpp b/src/Translation/Registration.cpp index 315656b..4239141 100644 --- a/src/Translation/Registration.cpp +++ b/src/Translation/Registration.cpp @@ -25,15 +25,44 @@ namespace mlir::gccjit { namespace { +enum class OutputType { + Gimple, + Reproducer, + Assembly, + Object, + Executable, + Dylib +}; + llvm::Expected -dumpContextToTempfile(gcc_jit_context *ctxt, bool reproducer) { +dumpContextToTempfile(gcc_jit_context *ctxt, OutputType type) { auto file = llvm::sys::fs::TempFile::create("mlir-gccjit-%%%%%%%"); if (!file) return file.takeError(); - if (reproducer) - gcc_jit_context_dump_reproducer_to_file(ctxt, file->TmpName.c_str()); - else + switch (type) { + case OutputType::Gimple: gcc_jit_context_dump_to_file(ctxt, file->TmpName.c_str(), false); + break; + case OutputType::Reproducer: + gcc_jit_context_dump_reproducer_to_file(ctxt, file->TmpName.c_str()); + break; + case OutputType::Assembly: + gcc_jit_context_compile_to_file(ctxt, GCC_JIT_OUTPUT_KIND_ASSEMBLER, + file->TmpName.c_str()); + break; + case OutputType::Object: + gcc_jit_context_compile_to_file(ctxt, GCC_JIT_OUTPUT_KIND_OBJECT_FILE, + file->TmpName.c_str()); + break; + case OutputType::Executable: + gcc_jit_context_compile_to_file(ctxt, GCC_JIT_OUTPUT_KIND_EXECUTABLE, + file->TmpName.c_str()); + break; + case OutputType::Dylib: + gcc_jit_context_compile_to_file(ctxt, GCC_JIT_OUTPUT_KIND_DYNAMIC_LIBRARY, + file->TmpName.c_str()); + break; + } return file; } @@ -49,10 +78,10 @@ LogicalResult copyFileToStream(llvm::sys::fs::TempFile file, } void registerTranslation(llvm::StringRef name, llvm::StringRef desc, - bool reproducer) { + OutputType type) { TranslateFromMLIRRegistration registration( name, desc, - [reproducer](Operation *op, raw_ostream &output) { + [type](Operation *op, raw_ostream &output) { auto module = dyn_cast(op); if (!module) { op->emitError("expected 'module' operation"); @@ -63,7 +92,7 @@ void registerTranslation(llvm::StringRef name, llvm::StringRef desc, op->emitError("failed to translate to GCCJIT context"); return failure(); } - auto file = dumpContextToTempfile(context.get().get(), reproducer); + auto file = dumpContextToTempfile(context.get().get(), type); if (!file) { op->emitError("failed to dump GCCJIT context to tempfile"); return failure(); @@ -79,11 +108,36 @@ void registerTranslation(llvm::StringRef name, llvm::StringRef desc, void registerToGCCJITGimpleTranslation() { registerTranslation("mlir-to-gccjit-gimple", - "Translate MLIR to GCCJIT's GIMPLE format", false); + "Translate MLIR to GCCJIT's GIMPLE format", + OutputType::Gimple); } void registerToGCCJITReproducerTranslation() { registerTranslation("mlir-to-gccjit-reproducer", - "Translate MLIR to GCCJIT's reproducer format", true); + "Translate MLIR to GCCJIT's reproducer format", + OutputType::Reproducer); +} + +void registerToGCCJITAssemblyTranslation() { + registerTranslation("mlir-to-gccjit-assembly", + "Translate MLIR to GCCJIT's assembly format", + OutputType::Assembly); +} + +void registerToGCCJITObjectTranslation() { + registerTranslation("mlir-to-gccjit-object", + "Translate MLIR to GCCJIT's object file format", + OutputType::Object); +} + +void registerToGCCJITExecutableTranslation() { + registerTranslation("mlir-to-gccjit-executable", + "Translate MLIR to GCCJIT's executable format", + OutputType::Executable); +} +void registerToGCCJITDylibTranslation() { + registerTranslation("mlir-to-gccjit-dylib", + "Translate MLIR to GCCJIT's dynamic library format", + OutputType::Dylib); } } // namespace mlir::gccjit diff --git a/src/Translation/TranslateToGCCJIT.cpp b/src/Translation/TranslateToGCCJIT.cpp index 2b44c00..10e74fa 100644 --- a/src/Translation/TranslateToGCCJIT.cpp +++ b/src/Translation/TranslateToGCCJIT.cpp @@ -64,19 +64,17 @@ class RegionVisitor { GCCJITTranslation &translator; Region ®ion [[maybe_unused]]; llvm::DenseMap exprCache; - llvm::DenseMap variables; llvm::DenseMap blocks; public: RegionVisitor(GCCJITTranslation &translator, Region ®ion); - gcc_jit_lvalue *queryVariable(Value value); GCCJITTranslation &getTranslator() const; gcc_jit_context *getContext() const; MLIRContext *getMLIRContext() const; void translateIntoContext(); private: - Expr visitExpr(Operation *op); + Expr visitExpr(Value value); void visitExprAsRValue(ValueRange operands, llvm::SmallVectorImpl &result); gcc_jit_rvalue *visitExprWithoutCache(ConstantOp op); @@ -94,6 +92,14 @@ class RegionVisitor { gcc_jit_rvalue *visitExprWithoutCache(AddrOp op); gcc_jit_rvalue *visitExprWithoutCache(FnAddrOp op); gcc_jit_lvalue *visitExprWithoutCache(GetGlobalOp op); + + /// The following operations are entrypoints for real codegen. + void visitAssignOp(gcc_jit_block *blk, AssignOp op); + void visitUpdateOp(gcc_jit_block *blk, UpdateOp op); + void visitReturnOp(gcc_jit_block *blk, ReturnOp op); + void visitSwitchOp(gcc_jit_block *blk, SwitchOp op); + void visitJumpOp(gcc_jit_block *blk, JumpOp op); + void visitEvalOp(gcc_jit_block *blk, EvalOp op); }; } // namespace @@ -175,6 +181,7 @@ void GCCJITTranslation::translateModuleToGCCJIT(ModuleOp op) { populateGCCJITModuleOptions(); declareAllFunctionAndGlobals(); translateGlobalInitializers(); + translateFunctions(); } gcc_jit_location *GCCJITTranslation::getLocation(LocationAttr loc) { @@ -329,7 +336,7 @@ static gcc_jit_tls_model convertTLSModel(TLSModelEnum model) { } void GCCJITTranslation::declareAllFunctionAndGlobals() { - moduleOp.walk([&](gccjit::FuncOp func) { + for (auto func : moduleOp.getOps()) { auto type = func.getFunctionType(); llvm::SmallVector paramTypes; llvm::SmallVector params; @@ -353,8 +360,8 @@ void GCCJITTranslation::declareAllFunctionAndGlobals() { processFunctionAttrs(func, funcHandle); SymbolRefAttr symRef = SymbolRefAttr::get(getMLIRContext(), name); functionMap[symRef] = {funcHandle, std::move(params)}; - }); - moduleOp.walk([&](gccjit::GlobalOp global) { + } + for (auto global : moduleOp.getOps()) { auto type = global.getType(); auto *typeHandle = convertType(type); auto name = global.getSymName().str(); @@ -394,16 +401,16 @@ void GCCJITTranslation::declareAllFunctionAndGlobals() { .Default([](Attribute) { llvm_unreachable("unknown initializer"); }); } // if the global has body, we translate them in the next pass - }); + } } void GCCJITTranslation::translateGlobalInitializers() { - moduleOp.walk([&](gccjit::GlobalOp global) { + for (auto global : moduleOp.getOps()) { if (global.getBody().empty()) return; RegionVisitor visitor(*this, global.getBody()); visitor.translateIntoContext(); - }); + } } ///===----------------------------------------------------------------------===/// @@ -419,7 +426,7 @@ RegionVisitor::RegionVisitor(GCCJITTranslation &translator, Region ®ion) auto *function = translator.getFunction(symName); for (auto arg : region.getArguments()) { auto *lvalue = gcc_jit_function_get_param(function, arg.getArgNumber()); - variables[arg] = gcc_jit_param_as_lvalue(lvalue); + exprCache[arg] = gcc_jit_param_as_lvalue(lvalue); } region.walk([&](LocalOp local) { auto *type = translator.convertType(local.getType()); @@ -427,15 +434,18 @@ RegionVisitor::RegionVisitor(GCCJITTranslation &translator, Region ®ion) auto name = llvm::Twine("var").concat(llvm::Twine(variableCount++)).str(); auto *lvalue = gcc_jit_function_new_local(function, loc, type, name.c_str()); - variables[local] = lvalue; + exprCache[local] = lvalue; }); + size_t blockCount = 0; + for (auto &block : region) { + auto *blk = gcc_jit_function_new_block( + function, + llvm::Twine("bb").concat(llvm::Twine(blockCount++)).str().c_str()); + blocks[&block] = blk; + } } } -gcc_jit_lvalue *RegionVisitor::queryVariable(Value value) { - return variables.lookup(value); -} - GCCJITTranslation &RegionVisitor::getTranslator() const { return translator; } gcc_jit_context *RegionVisitor::getContext() const { @@ -449,8 +459,20 @@ MLIRContext *RegionVisitor::getMLIRContext() const { void RegionVisitor::translateIntoContext() { auto *parent = region.getParentOp(); if (auto funcOp = dyn_cast(parent)) { - (void)funcOp; - llvm_unreachable("NYI"); + for (auto [mlirBlk, gccBlk] : blocks) { + auto *blk = gccBlk; + mlirBlk->walk([&](Operation *op) { + llvm::TypeSwitch(op) + .Case([&](AssignOp op) { visitAssignOp(blk, op); }) + .Case([&](UpdateOp op) { visitUpdateOp(blk, op); }) + .Case([&](ReturnOp op) { visitReturnOp(blk, op); }) + .Case([&](SwitchOp op) { visitSwitchOp(blk, op); }) + .Case([&](JumpOp op) { visitJumpOp(blk, op); }) + .Case([&](EvalOp op) { visitEvalOp(blk, op); }) + .Default([](auto) { return; }); + }); + } + return; } if (auto globalOp = dyn_cast(parent)) { assert(region.getBlocks().size() == 1 && @@ -458,7 +480,7 @@ void RegionVisitor::translateIntoContext() { Block &block = region.getBlocks().front(); auto terminator = cast(block.getTerminator()); auto value = terminator->getOperand(0); - auto rvalue = visitExpr(value.getDefiningOp()); + auto rvalue = visitExpr(value); auto symName = SymbolRefAttr::get(getMLIRContext(), globalOp.getSymName()); auto *lvalue = getTranslator().getGlobalLValue(symName); gcc_jit_global_set_initializer_rvalue(lvalue, rvalue); @@ -467,12 +489,14 @@ void RegionVisitor::translateIntoContext() { llvm_unreachable("unknown region parent"); } -Expr RegionVisitor::visitExpr(Operation *op) { - if (op->getNumResults() != 1) - llvm_unreachable("expected single result operation"); +Expr RegionVisitor::visitExpr(Value value) { + auto &cached = exprCache[value]; + + if (!cached) { + auto *op = value.getDefiningOp(); + if (op->getNumResults() != 1) + llvm_unreachable("expected single result operation"); - auto &cached = exprCache[op->getResult(0)]; - if (!cached) cached = llvm::TypeSwitch(op) .Case([&](ConstantOp op) { return visitExprWithoutCache(op); }) @@ -493,13 +517,15 @@ Expr RegionVisitor::visitExpr(Operation *op) { .Default([](Operation *) -> Expr { llvm_unreachable("unknown expression type"); }); + } + return cached; } void RegionVisitor::visitExprAsRValue( ValueRange operands, llvm::SmallVectorImpl &result) { for (auto operand : operands) - result.push_back(visitExpr(operand.getDefiningOp())); + result.push_back(visitExpr(operand)); } gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(ConstantOp op) { @@ -547,7 +573,7 @@ gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(AlignOfOp op) { } gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(AsRValueOp op) { - auto lvalue = visitExpr(op.getLvalue().getDefiningOp()); + auto lvalue = visitExpr(op.getLvalue()); return gcc_jit_lvalue_as_rvalue(lvalue); } @@ -583,8 +609,8 @@ static gcc_jit_binary_op convertBinaryOp(BOp kind) { // RValue always has a defining operation gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(BinaryOp op) { - auto lhs = visitExpr(op.getLhs().getDefiningOp()); - auto rhs = visitExpr(op.getRhs().getDefiningOp()); + auto lhs = visitExpr(op.getLhs()); + auto rhs = visitExpr(op.getRhs()); auto kind = convertBinaryOp(op.getOp()); auto *loc = getTranslator().getLocation(op.getLoc()); auto *ctxt = getContext(); @@ -607,7 +633,7 @@ static gcc_jit_unary_op convertUnaryOp(UOp kind) { } gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(UnaryOp op) { - auto operand = visitExpr(op.getOperand().getDefiningOp()); + auto operand = visitExpr(op.getOperand()); auto kind = convertUnaryOp(op.getOp()); auto *loc = getTranslator().getLocation(op.getLoc()); auto *ctxt = getContext(); @@ -634,8 +660,8 @@ static gcc_jit_comparison convertCompareOp(CmpOp kind) { } gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(CompareOp op) { - auto lhs = visitExpr(op.getLhs().getDefiningOp()); - auto rhs = visitExpr(op.getRhs().getDefiningOp()); + auto lhs = visitExpr(op.getLhs()); + auto rhs = visitExpr(op.getRhs()); auto kind = convertCompareOp(op.getOp()); auto *loc = getTranslator().getLocation(op.getLoc()); auto *ctxt = getContext(); @@ -657,12 +683,13 @@ gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(CallOp op) { auto *ctxt = getContext(); auto *call = gcc_jit_context_new_call(ctxt, loc, callee, args.size(), args.data()); - gcc_jit_rvalue_set_bool_require_tail_call(call, op.getTail()); + if (op.getTail()) + gcc_jit_rvalue_set_bool_require_tail_call(call, true); return call; } gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(CastOp op) { - auto operand = visitExpr(op.getOperand().getDefiningOp()); + auto operand = visitExpr(op.getOperand()); auto *loc = getTranslator().getLocation(op.getLoc()); auto *ctxt = getContext(); auto *type = getTranslator().convertType(op.getType()); @@ -670,7 +697,7 @@ gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(CastOp op) { } gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(BitCastOp op) { - auto operand = visitExpr(op.getOperand().getDefiningOp()); + auto operand = visitExpr(op.getOperand()); auto *loc = getTranslator().getLocation(op.getLoc()); auto *ctxt = getContext(); auto *type = getTranslator().convertType(op.getType()); @@ -678,19 +705,20 @@ gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(BitCastOp op) { } gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(PtrCallOp op) { - auto callee = visitExpr(op.getCallee().getDefiningOp()); + auto callee = visitExpr(op.getCallee()); llvm::SmallVector args; visitExprAsRValue(op.getArgs(), args); auto *loc = getTranslator().getLocation(op.getLoc()); auto *ctxt = getContext(); auto *call = gcc_jit_context_new_call_through_ptr(ctxt, loc, callee, args.size(), args.data()); - gcc_jit_rvalue_set_bool_require_tail_call(call, op.getTail()); + if (op.getTail()) + gcc_jit_rvalue_set_bool_require_tail_call(call, true); return call; } gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(AddrOp op) { - auto lvalue = visitExpr(op.getOperand().getDefiningOp()); + auto lvalue = visitExpr(op.getOperand()); auto *loc = getTranslator().getLocation(op.getLoc()); return gcc_jit_lvalue_get_address(lvalue, loc); } @@ -708,6 +736,73 @@ gcc_jit_lvalue *RegionVisitor::visitExprWithoutCache(GetGlobalOp op) { return lvalue; } +void RegionVisitor::visitAssignOp(gcc_jit_block *blk, AssignOp op) { + auto lvalue = visitExpr(op.getLvalue()); + auto rvalue = visitExpr(op.getRvalue()); + auto *loc = getTranslator().getLocation(op.getLoc()); + gcc_jit_block_add_assignment(blk, loc, lvalue, rvalue); +} + +void RegionVisitor::visitUpdateOp(gcc_jit_block *blk, UpdateOp op) { + auto lvalue = visitExpr(op.getLvalue()); + auto rvalue = visitExpr(op.getRvalue()); + auto *loc = getTranslator().getLocation(op.getLoc()); + auto kind = convertBinaryOp(op.getOp()); + gcc_jit_block_add_assignment_op(blk, loc, lvalue, kind, rvalue); +} + +void RegionVisitor::visitReturnOp(gcc_jit_block *blk, ReturnOp op) { + if (op->getNumOperands()) + gcc_jit_block_end_with_return(blk, getTranslator().getLocation(op.getLoc()), + visitExpr(op.getOperand(0))); + else + gcc_jit_block_end_with_void_return( + blk, getTranslator().getLocation(op.getLoc())); +} + +void RegionVisitor::visitSwitchOp(gcc_jit_block *blk, SwitchOp op) { + auto value = visitExpr(op.getValue()); + auto *loc = getTranslator().getLocation(op.getLoc()); + llvm::SmallVector cases; + for (auto [lb, ub, dst] : + llvm::zip(op.getCaseLowerbound(), op.getCaseUpperbound(), + op.getCaseDestinations())) { + // TODO: handle signedness + // TODO: generalize switch statement to support rvalue expressions + // (constant) + auto intLb = cast(lb).getValue().getZExtValue(); + auto intUb = cast(ub).getValue().getZExtValue(); + auto *dstBlk = blocks[dst]; + auto *lbv = + gcc_jit_context_new_rvalue_from_long(getContext(), nullptr, intLb); + auto *ubv = + gcc_jit_context_new_rvalue_from_long(getContext(), nullptr, intUb); + cases.push_back(gcc_jit_context_new_case(getContext(), lbv, ubv, dstBlk)); + } + auto *defaultBlk = blocks[op.getDefaultDestination()]; + gcc_jit_block_end_with_switch(blk, loc, value, defaultBlk, cases.size(), + cases.data()); +} + +void RegionVisitor::visitJumpOp(gcc_jit_block *blk, JumpOp op) { + auto *dst = blocks[op.getDest()]; + gcc_jit_block_end_with_jump(blk, getTranslator().getLocation(op.getLoc()), + dst); +} + +void RegionVisitor::visitEvalOp(gcc_jit_block *blk, EvalOp op) { + auto value = visitExpr(op.getOperand()); + gcc_jit_block_add_eval(blk, getTranslator().getLocation(op.getLoc()), value); +} + +void GCCJITTranslation::translateFunctions() { + for (auto func : moduleOp.getOps()) { + auto ®ion = func.getBody(); + RegionVisitor visitor(*this, region); + visitor.translateIntoContext(); + } +} + //===----------------------------------------------------------------------===// // TranslateModuleToGCCJIT //===----------------------------------------------------------------------===// diff --git a/test/syntax/hello_world.mlir b/test/syntax/hello_world.mlir new file mode 100644 index 0000000..df35de6 --- /dev/null +++ b/test/syntax/hello_world.mlir @@ -0,0 +1,20 @@ +!i32 = !gccjit.int +!char = !gccjit.int +!const_char = !gccjit.qualified +!str = !gccjit.ptr +module @test attributes { + gccjit.opt_level = #gccjit.opt_level, + gccjit.prog_name = "test", + gccjit.allow_unreachable = false, + gccjit.debug_info = true +} { + gccjit.func imported @puts(!str) -> !i32 + gccjit.func exported @main() -> !i32 { + %1 = gccjit.literal <"hello, world!\n"> : !str + %2 = gccjit.call @puts(%1) : (!str) -> !i32 + gccjit.eval (%2 : !i32) + + %0 = gccjit.const #gccjit.zero : !i32 + gccjit.return %0 : !i32 + } +} diff --git a/tools/gccjit-translate/main.cpp b/tools/gccjit-translate/main.cpp index bdf9f54..85d2808 100644 --- a/tools/gccjit-translate/main.cpp +++ b/tools/gccjit-translate/main.cpp @@ -23,6 +23,10 @@ int main(int argc, char **argv) { registerAllTranslations(); mlir::gccjit::registerToGCCJITGimpleTranslation(); mlir::gccjit::registerToGCCJITReproducerTranslation(); + mlir::gccjit::registerToGCCJITAssemblyTranslation(); + mlir::gccjit::registerToGCCJITObjectTranslation(); + mlir::gccjit::registerToGCCJITExecutableTranslation(); + mlir::gccjit::registerToGCCJITDylibTranslation(); return failed( mlirTranslateMain(argc, argv, "GCCJIT Translation Testing Tool")); }