Skip to content

Commit

Permalink
[gccjit] implement expression fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Nov 2, 2024
1 parent 03e55f4 commit 8e454c1
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 2 deletions.
5 changes: 4 additions & 1 deletion include/mlir-gccjit/IR/GCCJITOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,15 @@ def ExprOp : GCCJIT_Op<"expr"> {
%2 = gccjit.binary plus %0, %1 : !gccjit.int<32>, !gccjit.int<32>
gccjit.return %2 : !gccjit.int<32>
} : !gccjit.int<32>
The expr operation can be marked with a lazy attribute. If such an attribute exists,
the expr will not be materialized until it is used.
```
}];
let arguments = (ins UnitAttr:$lazy);
let results = (outs AnyType:$result);
let regions = (region AnyRegion:$body);
let assemblyFormat = [{
$body `:` type($result) attr-dict
custom<LazyAttribute>($lazy) $body `:` type($result) attr-dict
}];
}

Expand Down
2 changes: 2 additions & 0 deletions src/GCCJITOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ constexpr ParseNamedUnitAttr parseAsmInlineAttr{"inline"};
constexpr PrintNamedUnitAttr printAsmInlineAttr{"inline"};
constexpr ParseNamedUnitAttr parseAsmVolatileAttr{"volatile"};
constexpr PrintNamedUnitAttr printAsmVolatileAttr{"volatile"};
constexpr ParseNamedUnitAttr parseLazyAttribute{"lazy"};
constexpr PrintNamedUnitAttr printLazyAttribute{"lazy"};

ParseResult
parseAsmOperands(OpAsmParser &parser, ArrayAttr &constrains, ArrayAttr &symbols,
Expand Down
10 changes: 9 additions & 1 deletion src/Translation/TranslateToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ void GCCJITTranslation::declareAllFunctionAndGlobals() {
auto index = pair.index();
auto type = pair.value();
auto name =
llvm::Twine("__arg").concat(llvm::Twine(index)).str();
llvm::Twine("%arg").concat(llvm::Twine(index)).str();
return gcc_jit_context_new_param(
ctxt, /*todo: location*/ nullptr, type, name.c_str());
});
Expand Down Expand Up @@ -415,6 +415,11 @@ RegionVisitor::RegionVisitor(GCCJITTranslation &translator, Region &region,
if (isa<LValueType>(op->getResult(0).getType()) && !isa<LocalOp>(op))
return WalkResult::skip();

// skip lazy evaluated expressions
if (auto exprOp = dyn_cast<ExprOp>(op))
if (exprOp.getLazy())
return WalkResult::skip();

auto *type = translator.convertType(res.getType());
auto *loc = translator.getLocation(res.getLoc());
std::string name;
Expand Down Expand Up @@ -479,6 +484,9 @@ Expr RegionVisitor::translateIntoContext() {
auto *loc = translator.getLocation(op->getLoc());
if (op->getNumResults() == 1 &&
!isa<LValueType>(op->getResult(0).getType())) {
if (auto exprOp = dyn_cast<ExprOp>(op))
if (exprOp.getLazy())
return;
auto result = op->getResult(0);
auto rvalue = visitExpr(result, true);
auto lvalue = lookupExpr(result);
Expand Down
47 changes: 47 additions & 0 deletions test/compile/lazy_evaluation.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// RUN: %gccjit-translate -o %t.gimple %s -mlir-to-gccjit-gimple
// RUN: %filecheck --input-file=%t.gimple %s --check-prefix=CHECK-GIMPLE

!i32 = !gccjit.int<int32_t>
!i1 = !gccjit.int<bool>
!var32 = !gccjit.lvalue<!i32>
!char = !gccjit.int<char>
!const_char = !gccjit.qualified<!char, const>
!str = !gccjit.ptr<!const_char>
module @test attributes {
gccjit.opt_level = #gccjit.opt_level<O3>,
gccjit.prog_name = "test",
gccjit.allow_unreachable = false,
gccjit.debug_info = true
} {
// fuse expr into return
gccjit.func exported @add(!i32, !i32) -> !i32 {
^body(%arg0: !var32, %arg1: !var32):
%res = gccjit.expr lazy {
%0 = gccjit.as_rvalue %arg0 : !var32 to !i32
%1 = gccjit.as_rvalue %arg1 : !var32 to !i32
%2 = gccjit.binary plus (%0 : !i32, %1 : !i32) : !i32
gccjit.return %2 : !i32
} : !i32
// CHECK-GIMPLE: return %arg0 + %arg1;
gccjit.return %res : !i32
}

// fuse expr into conditional
gccjit.func exported @max(!i32, !i32) -> !i32 {
^body(%arg0: !var32, %arg1: !var32):
%0 = gccjit.as_rvalue %arg0 : !var32 to !i32
%1 = gccjit.as_rvalue %arg1 : !var32 to !i32
%3 = gccjit.expr lazy {
%2 = gccjit.compare gt (%0 : !i32, %1 : !i32) : !i1
gccjit.return %2 : !i1
} : !i1
// CHECK-GIMPLE: if (%0 > %1) goto bb1; else goto bb2;
gccjit.conditional (%3 : !i1), ^bb1, ^bb2
^bb1:
gccjit.return %0 : !i32
^bb2:
gccjit.return %1 : !i32
}


}

0 comments on commit 8e454c1

Please sign in to comment.