Skip to content

Commit

Permalink
[gccjit] wrap the type of OptionalParameter in std::optional
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancern committed Nov 3, 2024
1 parent 8e454c1 commit 8babb40
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 35 deletions.
18 changes: 10 additions & 8 deletions include/mlir-gccjit/IR/GCCJITAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,19 @@ def FieldAttr : GCCJIT_Attr<"Field", "field"> {
bitfield. The `bitWidth` parameter gives the width of the bitfield.
}];

let parameters = (ins "mlir::StringAttr":$name, "mlir::Type":$type,
OptionalParameter<"unsigned">:$bitWidth,
OptionalParameter<"mlir::gccjit::SourceLocAttr">:$loc);
let parameters = (ins
"mlir::StringAttr":$name, "mlir::Type":$type,
OptionalParameter<"std::optional<unsigned>">:$bitWidth,
OptionalParameter<"std::optional<mlir::gccjit::SourceLocAttr>">:$loc
);

let builders = [
AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type), [{
return get($_ctxt, name, type, 0, mlir::gccjit::SourceLocAttr{});
return get($_ctxt, name, type, 0, std::nullopt);
}]>,
AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type,
"unsigned":$bitWidth), [{
return get($_ctxt, name, type, bitWidth, mlir::gccjit::SourceLocAttr{});
return get($_ctxt, name, type, bitWidth, std::nullopt);
}]>,
AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type,
"mlir::gccjit::SourceLocAttr":$loc), [{
Expand All @@ -191,7 +193,7 @@ def FieldAttr : GCCJIT_Attr<"Field", "field"> {
];

let assemblyFormat = [{
`<` $type $name (`:` $bitWidth^)? `>` ($loc^)?
`<` $type $name (`:` $bitWidth^)? ($loc^)? `>`
}];
}

Expand Down Expand Up @@ -259,8 +261,8 @@ def FnAttr : GCCJIT_Attr<"Function", "fn_attr"> {

let parameters = (ins
"FnAttrEnumAttr":$attr,
OptionalParameter<"::mlir::StringAttr">:$strValue,
OptionalParameter<"::mlir::DenseI32ArrayAttr">:$intArrayValue
OptionalParameter<"std::optional<mlir::StringAttr>">:$strValue,
OptionalParameter<"std::optional<mlir::DenseI32ArrayAttr>">:$intArrayValue
);

let hasCustomAssemblyFormat = 1;
Expand Down
8 changes: 5 additions & 3 deletions include/mlir-gccjit/IR/GCCJITTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,11 @@ class GCCJIT_RecordType<string name, string typeMnemonic>
: GCCJIT_Type<name, typeMnemonic, [
DeclareTypeInterfaceMethods<GCCJITRecordTypeInterface>
]> {
let parameters = (ins "mlir::StringAttr":$name,
"mlir::ArrayAttr":$fields,
OptionalParameter<"mlir::gccjit::SourceLocAttr">:$loc);
let parameters = (ins
"mlir::StringAttr":$name,
"mlir::ArrayAttr":$fields,
OptionalParameter<"std::optional<mlir::gccjit::SourceLocAttr>">:$loc
);
let assemblyFormat = [{
`<` custom<RecordBody>($name, $fields, $loc) `>`
}];
Expand Down
4 changes: 2 additions & 2 deletions src/GCCJITAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,10 @@ void FunctionAttr::print(AsmPrinter &printer) const {
printer << "<" << getAttr().getValue();
if (isStringFnAttr(getAttr().getValue())) {
printer << ", ";
printer.printAttribute(getStrValue());
printer.printAttribute(getStrValue().value());
} else if (isIntArrayFnAttr(getAttr().getValue())) {
printer << ", ";
printer.printAttribute(getIntArrayValue());
printer.printAttribute(getIntArrayValue().value());
}
printer << ">";
}
Expand Down
2 changes: 1 addition & 1 deletion src/GCCJITOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ FlatSymbolRefAttr FuncOp::getAliasee() {
for (auto attr : getGccjitFnAttrs()) {
auto fnAttr = cast<FunctionAttr>(attr);
if (fnAttr.getAttr().getValue() == FnAttrEnum::Alias) {
res = FlatSymbolRefAttr::get(getContext(), fnAttr.getStrValue());
res = FlatSymbolRefAttr::get(getContext(), fnAttr.getStrValue().value());
break;
}
}
Expand Down
25 changes: 17 additions & 8 deletions src/GCCJITTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include <optional>

#include "mlir-gccjit/IR/GCCJITAttrs.h"
#include "mlir-gccjit/IR/GCCJITDialect.h"
Expand All @@ -44,7 +45,8 @@ using namespace mlir::gccjit;
//===----------------------------------------------------------------------===//

static LogicalResult parseRecordBody(AsmParser &parser, StringAttr &name,
ArrayAttr &fields, SourceLocAttr &loc) {
ArrayAttr &fields,
std::optional<SourceLocAttr> &loc) {
if (parser.parseAttribute(name))
return failure();

Expand All @@ -61,26 +63,33 @@ static LogicalResult parseRecordBody(AsmParser &parser, StringAttr &name,
};
if (parser.parseCommaSeparatedList(fieldParser))
return failure();
fields = ArrayAttr::get(parser.getContext(), fieldAttrs);

if (parser.parseRBrace())
return failure();

OptionalParseResult parseLocResult = parser.parseOptionalAttribute(loc);
SourceLocAttr locAttr;
OptionalParseResult parseLocResult = parser.parseOptionalAttribute(locAttr);
if (parseLocResult.has_value() && parseLocResult.value())
return failure();
if (locAttr)
loc.emplace(locAttr);
else
loc.reset();

return success();
}

static void printRecordBody(AsmPrinter &printer, StringAttr name,
ArrayAttr fields, SourceLocAttr loc) {
ArrayAttr fields,
std::optional<SourceLocAttr> loc) {
printer << name << " {";
llvm::interleaveComma(fields, printer, [&printer](mlir::Attribute elem) {
printer << cast<FieldAttr>(elem);
});
printer << "}";
if (loc)
printer << " " << loc;
printer << " " << *loc;
}

#define GET_TYPEDEF_CLASSES
Expand Down Expand Up @@ -488,7 +497,7 @@ verifyRecordFields(llvm::function_ref<InFlightDiagnostic()> emitError,

LogicalResult mlir::gccjit::StructType::verify(
llvm::function_ref<InFlightDiagnostic()> emitError, StringAttr name,
ArrayAttr fields, SourceLocAttr loc) {
ArrayAttr fields, std::optional<SourceLocAttr> loc) {
return verifyRecordFields(emitError, fields);
}

Expand All @@ -501,12 +510,12 @@ mlir::ArrayAttr mlir::gccjit::StructType::getRecordFields() const {
}

mlir::gccjit::SourceLocAttr mlir::gccjit::StructType::getRecordLoc() const {
return getLoc();
return getLoc().value_or(mlir::gccjit::SourceLocAttr{});
}

LogicalResult mlir::gccjit::UnionType::verify(
llvm::function_ref<InFlightDiagnostic()> emitError, StringAttr name,
ArrayAttr fields, SourceLocAttr loc) {
ArrayAttr fields, std::optional<SourceLocAttr> loc) {
return verifyRecordFields(emitError, fields);
}

Expand All @@ -519,7 +528,7 @@ mlir::ArrayAttr mlir::gccjit::UnionType::getRecordFields() const {
}

mlir::gccjit::SourceLocAttr mlir::gccjit::UnionType::getRecordLoc() const {
return getLoc();
return getLoc().value_or(mlir::gccjit::SourceLocAttr{});
}

bool mlir::gccjit::UnionType::isUnion() const { return true; }
Expand Down
20 changes: 11 additions & 9 deletions src/Translation/TranslateToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,9 @@ static void processFunctionAttrs(gccjit::FuncOp func,
auto fnAttr = cast<FunctionAttr>(attr);
switch (fnAttr.getAttr().getValue()) {
case FnAttrEnum::Alias:
gcc_jit_function_add_string_attribute(handle, GCC_JIT_FN_ATTRIBUTE_ALIAS,
fnAttr.getStrValue().str().c_str());
gcc_jit_function_add_string_attribute(
handle, GCC_JIT_FN_ATTRIBUTE_ALIAS,
fnAttr.getStrValue().value().str().c_str());
break;
case FnAttrEnum::AlwaysInline:
gcc_jit_function_add_attribute(handle,
Expand All @@ -242,16 +243,17 @@ static void processFunctionAttrs(gccjit::FuncOp func,
gcc_jit_function_add_attribute(handle, GCC_JIT_FN_ATTRIBUTE_NOINLINE);
break;
case FnAttrEnum::Target:
gcc_jit_function_add_string_attribute(handle, GCC_JIT_FN_ATTRIBUTE_TARGET,
fnAttr.getStrValue().str().c_str());
gcc_jit_function_add_string_attribute(
handle, GCC_JIT_FN_ATTRIBUTE_TARGET,
fnAttr.getStrValue().value().str().c_str());
break;
case FnAttrEnum::Used:
gcc_jit_function_add_attribute(handle, GCC_JIT_FN_ATTRIBUTE_USED);
break;
case FnAttrEnum::Visibility:
gcc_jit_function_add_string_attribute(handle,
GCC_JIT_FN_ATTRIBUTE_VISIBILITY,
fnAttr.getStrValue().str().c_str());
gcc_jit_function_add_string_attribute(
handle, GCC_JIT_FN_ATTRIBUTE_VISIBILITY,
fnAttr.getStrValue().value().str().c_str());
break;
case FnAttrEnum::Cold:
gcc_jit_function_add_attribute(handle, GCC_JIT_FN_ATTRIBUTE_COLD);
Expand All @@ -273,8 +275,8 @@ static void processFunctionAttrs(gccjit::FuncOp func,
gcc_jit_function_add_integer_array_attribute(
handle, GCC_JIT_FN_ATTRIBUTE_NONNULL,
reinterpret_cast<const int *>(
fnAttr.getIntArrayValue().asArrayRef().data()),
fnAttr.getIntArrayValue().size());
fnAttr.getIntArrayValue().value().asArrayRef().data()),
fnAttr.getIntArrayValue().value().size());
break;
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/Translation/TypeTranslation.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "libgccjit.h"
#include "mlir-gccjit/Translation/TranslateToGCCJIT.h"
#include "llvm/ADT/TypeSwitch.h"
#include <optional>

namespace mlir::gccjit {
void GCCJITTranslation::convertTypes(
Expand Down Expand Up @@ -86,14 +87,13 @@ convertRecordType(GCCJITTranslation &translation,
for (Attribute fieldOpaqueAttr : type.getRecordFields()) {
auto fieldAttr = cast<FieldAttr>(fieldOpaqueAttr);

int fieldBitWidth = fieldAttr.getBitWidth();
int fieldBitWidth = fieldAttr.getBitWidth().value_or(0);
std::string fieldName = fieldAttr.getName().str();
gcc_jit_type *fieldType = translation.convertType(fieldAttr.getType());

SourceLocAttr fieldLoc = fieldAttr.getLoc();
gcc_jit_location *loc = nullptr;
if (fieldLoc)
loc = translation.convertLocation(fieldLoc);
if (auto fieldLoc = fieldAttr.getLoc())
loc = translation.convertLocation(*fieldLoc);

gcc_jit_field *field =
fieldAttr.getBitWidth()
Expand Down
22 changes: 22 additions & 0 deletions test/syntax/record.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: %gccjit-opt -o %t.mlir %s
// RUN: %filecheck --input-file=%t.mlir %s

module @test {
gccjit.func imported @gemm (
!gccjit.struct<"__memref_188510220862752" {
#gccjit.field<!gccjit.ptr<!gccjit.fp<float>> "base">,
#gccjit.field<!gccjit.ptr<!gccjit.fp<float>> "aligned">,
#gccjit.field<!gccjit.int<size_t> "offset">,
#gccjit.field<!gccjit.array<!gccjit.int<size_t>, 2> "sizes">,
#gccjit.field<!gccjit.array<!gccjit.int<size_t>, 2> "strides">
}>
)
// CHECK: @gemm
// CHECK-SAME: !gccjit.struct<"__memref_188510220862752" {
// CHECK-SAME: #gccjit.field<!gccjit.ptr<!gccjit.fp<float>> "base">
// CHECK-SAME: #gccjit.field<!gccjit.ptr<!gccjit.fp<float>> "aligned">
// CHECK-SAME: #gccjit.field<!gccjit.int<size_t> "offset">
// CHECK-SAME: #gccjit.field<!gccjit.array<!gccjit.int<size_t>, 2> "sizes">
// CHECK-SAME: #gccjit.field<!gccjit.array<!gccjit.int<size_t>, 2> "strides">
// CHECK-SAME: }
}

0 comments on commit 8babb40

Please sign in to comment.