Skip to content

Commit

Permalink
[gccjit] implemented extended asm translation
Browse files Browse the repository at this point in the history
SchrodingerZhu committed Oct 27, 2024
1 parent fe29f6e commit 5dc6d97
Showing 1 changed file with 76 additions and 11 deletions.
87 changes: 76 additions & 11 deletions src/Translation/TranslateToGCCJIT.cpp
Original file line number Diff line number Diff line change
@@ -34,32 +34,36 @@
#include "llvm/Support/ErrorHandling.h"
#include <algorithm>
#include <utility>
#include <variant>

namespace mlir::gccjit {

namespace {

class Expr {
std::variant<gcc_jit_lvalue *, gcc_jit_rvalue *> value;
static constexpr intptr_t rvalueFlag = std::numeric_limits<intptr_t>::min();
union {
gcc_jit_lvalue *lvalue;
gcc_jit_rvalue *rvalue;
intptr_t dummy;
};

public:
Expr() : value(static_cast<gcc_jit_lvalue *>(nullptr)) {}
Expr(gcc_jit_lvalue *value) : value(value) {}
Expr(gcc_jit_rvalue *value) : value(value) {}
operator bool() const { return isRValue() || std::get<0>(value) != nullptr; }
Expr() : dummy(0) {}
Expr(gcc_jit_lvalue *value) : lvalue(value) {}
Expr(gcc_jit_rvalue *value) : rvalue(value) { dummy |= rvalueFlag; }
operator bool() const { return dummy != 0; }
operator gcc_jit_lvalue *() const {
if (isLValue())
return std::get<0>(value);
return lvalue;
llvm_unreachable("not an lvalue");
}
operator gcc_jit_rvalue *() const {
if (isRValue())
return std::get<1>(value);
return gcc_jit_lvalue_as_rvalue(std::get<0>(value));
return reinterpret_cast<gcc_jit_rvalue *>(dummy & ~rvalueFlag);
return gcc_jit_lvalue_as_rvalue(lvalue);
}
bool isLValue() const { return value.index() == 0; }
bool isRValue() const { return value.index() == 1; }
bool isLValue() const { return dummy > 0; }
bool isRValue() const { return dummy < 0; }
};

class RegionVisitor {
@@ -77,6 +81,7 @@ class RegionVisitor {

private:
Expr visitExpr(Value value);
void visitExprs(ValueRange values, llvm::SmallVectorImpl<Expr> &result);
void visitExprAsRValue(ValueRange operands,
llvm::SmallVectorImpl<gcc_jit_rvalue *> &result);
gcc_jit_rvalue *visitExprWithoutCache(ConstantOp op);
@@ -101,6 +106,11 @@ class RegionVisitor {
void visitReturnOp(gcc_jit_block *blk, ReturnOp op);
void visitSwitchOp(gcc_jit_block *blk, SwitchOp op);
void visitJumpOp(gcc_jit_block *blk, JumpOp op);
void visitAsmOp(gcc_jit_block *blk, AsmOp op);
void visitAsmGotoOp(gcc_jit_block *blk, AsmGotoOp op);

template <typename OpTy>
void populateExtendedAsm(gcc_jit_extended_asm *extAsm, OpTy op);
};

} // namespace
@@ -492,6 +502,8 @@ void RegionVisitor::translateIntoContext() {
.Case([&](ReturnOp op) { visitReturnOp(blk, op); })
.Case([&](SwitchOp op) { visitSwitchOp(blk, op); })
.Case([&](JumpOp op) { visitJumpOp(blk, op); })
.Case([&](AsmOp op) { visitAsmOp(blk, op); })
.Case([&](AsmGotoOp op) { visitAsmGotoOp(blk, op); })
.Default([&](Operation *op) {
if (op->hasAttr("gccjit.eval")) {
auto *loc = translator.getLocation(op->getLoc());
@@ -560,6 +572,12 @@ Expr RegionVisitor::visitExpr(Value value) {
return cached;
}

void RegionVisitor::visitExprs(ValueRange values,
llvm::SmallVectorImpl<Expr> &result) {
for (auto value : values)
result.push_back(visitExpr(value));
}

void RegionVisitor::visitExprAsRValue(
ValueRange operands, llvm::SmallVectorImpl<gcc_jit_rvalue *> &result) {
for (auto operand : operands)
@@ -828,6 +846,53 @@ void RegionVisitor::visitJumpOp(gcc_jit_block *blk, JumpOp op) {
dst);
}

template <typename OpTy>
void RegionVisitor::populateExtendedAsm(gcc_jit_extended_asm *extAsm, OpTy op) {
auto asmStr = op.getTemplateCode().str();
for (auto [output, constraint, symbol] : llvm::zip(
op.getOutputs(), op.getOutputConstraints(), op.getOutputSymbols())) {
auto constraintStr = cast<StringAttr>(constraint).getValue().str();
auto symbolStr = cast<StringAttr>(symbol).getValue().str();
auto lvalue = visitExpr(output);
assert(lvalue.isLValue() && "expected lvalue");
gcc_jit_extended_asm_add_output_operand(
extAsm, symbolStr.empty() ? nullptr : symbolStr.c_str(),
constraintStr.c_str(), lvalue);
}
for (auto [input, constraint, symbol] : llvm::zip(
op.getInputs(), op.getInputConstraints(), op.getInputSymbols())) {
auto constraintStr = cast<StringAttr>(constraint).getValue().str();
auto symbolStr = cast<StringAttr>(symbol).getValue().str();
auto rvalue = visitExpr(input);
gcc_jit_extended_asm_add_input_operand(
extAsm, symbolStr.empty() ? nullptr : symbolStr.c_str(),
constraintStr.c_str(), rvalue);
}
for (auto clobber : op.getClobbers()) {
auto clobberStr = cast<StringAttr>(clobber).getValue().str();
gcc_jit_extended_asm_add_clobber(extAsm, clobberStr.c_str());
}
}

void RegionVisitor::visitAsmOp(gcc_jit_block *blk, AsmOp op) {
auto *loc = getTranslator().getLocation(op.getLoc());
auto asmStr = op.getTemplateCode().str();
auto *extendedAsm = gcc_jit_block_add_extended_asm(blk, loc, asmStr.c_str());
populateExtendedAsm(extendedAsm, op);
}

void RegionVisitor::visitAsmGotoOp(gcc_jit_block *blk, AsmGotoOp op) {
auto *loc = getTranslator().getLocation(op.getLoc());
auto asmStr = op.getTemplateCode().str();
llvm::SmallVector<gcc_jit_block *> targets;
for (auto *target : op.getLabels())
targets.push_back(blocks.at(target));
gcc_jit_block *fallthrough = blocks.at(op.getFallthrough());
auto *extendedAsm = gcc_jit_block_end_with_extended_asm_goto(
blk, loc, asmStr.c_str(), targets.size(), targets.data(), fallthrough);
populateExtendedAsm(extendedAsm, op);
}

void GCCJITTranslation::translateFunctions() {
for (auto func : moduleOp.getOps<gccjit::FuncOp>()) {
auto &region = func.getBody();

0 comments on commit 5dc6d97

Please sign in to comment.