Skip to content

Commit

Permalink
[gccjit][conversion] add primitive conversions for func and cf (#17)
Browse files Browse the repository at this point in the history
* [gccjit][conversion] add primitive conversion patterns

* wip

* wip

* [gccjit] translate func/br/cond_br

* fix build

* more fix

* fix

* suppress GCC false positive

* suppress GCC false positive

* suppress GCC false positive

* fine, fine, I give up. can't you just shut the fucking up GCC

* suppress GCC false positive
  • Loading branch information
SchrodingerZhu authored Nov 4, 2024
1 parent ce1d726 commit c53909a
Show file tree
Hide file tree
Showing 12 changed files with 556 additions and 107 deletions.
29 changes: 29 additions & 0 deletions include/mlir-gccjit/Conversion/Conversions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2024 Schrodinger ZHU Yifan <[email protected]>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MLIR_GCCJIT_CONVERSION_CONVERTIONS_H
#define MLIR_GCCJIT_CONVERSION_CONVERTIONS_H

#include <mlir/IR/MLIRContext.h>

#include "mlir-gccjit/Conversion/TypeConverter.h"

namespace mlir::gccjit {
void populateFuncToGCCJITPatterns(MLIRContext *context,
GCCJITTypeConverter &typeConverter,
RewritePatternSet &patterns,
SymbolTable &symbolTable);
} // namespace mlir::gccjit

#endif
28 changes: 14 additions & 14 deletions include/mlir-gccjit/Conversion/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,38 @@

namespace mlir::gccjit {
class GCCJITTypeConverter : public TypeConverter {
llvm::DenseMap<mlir::Type, gccjit::StructType> packedTypes;

public:
GCCJITTypeConverter();
~GCCJITTypeConverter();
// integral types
gccjit::IntType convertIndexType(mlir::IndexType type);
gccjit::IntType convertIntegerType(mlir::IntegerType type);
gccjit::IntAttr convertIntegerAttr(mlir::IntegerAttr attr);
gccjit::IntType convertIndexType(mlir::IndexType type) const;
gccjit::IntType convertIntegerType(mlir::IntegerType type) const;
gccjit::IntAttr convertIntegerAttr(mlir::IntegerAttr attr) const;

// floating point types
gccjit::FloatType convertFloatType(mlir::FloatType type);
gccjit::FloatAttr convertFloatAttr(mlir::FloatAttr attr);
gccjit::FloatType convertFloatType(mlir::FloatType type) const;
gccjit::FloatAttr convertFloatAttr(mlir::FloatAttr attr) const;

// special composite types
gccjit::ComplexType convertComplexType(mlir::ComplexType type);
gccjit::VectorType convertVectorType(mlir::VectorType type);
gccjit::ComplexType convertComplexType(mlir::ComplexType type) const;
gccjit::VectorType convertVectorType(mlir::VectorType type) const;

// function prototype
gccjit::FuncType convertFunctionType(mlir::FunctionType type, bool isVarArg);
gccjit::FuncType convertFunctionType(mlir::FunctionType type,
bool isVarArg) const;

// function type to function pointer
gccjit::PointerType convertFunctionTypeAsPtr(mlir::FunctionType type,
bool isVarArg);
bool isVarArg) const;

// memref type
gccjit::StructType getMemrefDescriptorType(mlir::MemRefType type);
gccjit::StructType getMemrefDescriptorType(mlir::MemRefType type) const;
gccjit::StructType
getUnrankedMemrefDescriptorType(mlir::UnrankedMemRefType type);
getUnrankedMemrefDescriptorType(mlir::UnrankedMemRefType type) const;

private:
Type convertAndPackTypesIfNonSingleton(TypeRange types, FunctionType name);
Type convertAndPackTypesIfNonSingleton(TypeRange types,
FunctionType name) const;
};
} // namespace mlir::gccjit
#endif // MLIR_GCCJIT_CONVERSION_TYPECONVERTER_H
4 changes: 1 addition & 3 deletions include/mlir-gccjit/IR/GCCJITOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def FuncOp : GCCJIT_Op<"func", [IsolatedFromAbove]> {
SymbolNameAttr:$sym_name,
FnKind:$fn_kind,
TypeAttrOf<GCCJIT_FuncType>:$function_type,
ArrayAttr:$gccjit_fn_attrs,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
ArrayAttr:$gccjit_fn_attrs
);
let regions = (region AnyRegion:$body);
let hasVerifier = 1;
Expand Down
2 changes: 1 addition & 1 deletion include/mlir-gccjit/IR/GCCJITTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class GCCJIT_RecordType<string name, string typeMnemonic>
let builders = [
TypeBuilder<(ins "mlir::StringAttr":$name,
"mlir::ArrayAttr":$fields), [{
return get($_ctxt, name, fields, mlir::gccjit::SourceLocAttr{});
return get($_ctxt, name, fields, std::nullopt);
}]>,
TypeBuilder<(ins "llvm::StringRef":$name,
"llvm::ArrayRef<mlir::gccjit::FieldAttr>":$fields), [{
Expand Down
16 changes: 10 additions & 6 deletions include/mlir-gccjit/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@
#ifndef MLIR_GCCJIT_PASSES_H
#define MLIR_GCCJIT_PASSES_H

#include <mlir/Dialect/ControlFlow/IR/ControlFlow.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/IR/BuiltinDialect.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/Pass/Pass.h>

#include "mlir-gccjit/IR/GCCJITDialect.h"
#include "mlir-gccjit/IR/GCCJITOps.h"

namespace mlir::gccjit {

// TODO: Add pass declarations here.
// #define GEN_PASS_CLASSES
// #define GEN_PASS_REGISTRATION
// #define GEN_PASS_DECL
// #include "mlir-gccjit/Passes.h.inc"
std::unique_ptr<Pass> createConvertCFToGCCJITPass();
std::unique_ptr<Pass> createConvertFuncToGCCJITPass();

#define GEN_PASS_CLASSES
#define GEN_PASS_REGISTRATION
#define GEN_PASS_DECL
#include "mlir-gccjit/Passes.h.inc"

} // namespace mlir::gccjit

Expand Down
9 changes: 8 additions & 1 deletion include/mlir-gccjit/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@

include "mlir/Pass/PassBase.td"

// TODO: Add passes here.
def ConvertFuncToGCCJIT : Pass<"convert-func-to-gccjit", "::mlir::ModuleOp"> {
let summary = "Convert Functions and control flows to GCCJIT Dialect";
let description = [{
This pass converts function operations and control flow operations to GCCJIT dialect.
}];
let constructor = "::mlir::gccjit::createConvertFuncToGCCJITPass()";
let dependentDialects = ["::mlir::gccjit::GCCJITDialect", "::mlir::func::FuncDialect", "::mlir::BuiltinDialect", "::mlir::cf::ControlFlowDialect"];
}

#endif // MLIR_GCCJIT_PASSES
2 changes: 2 additions & 0 deletions src/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
add_mlir_dialect_library(MLIRGCCJITConversion
TypeConverter.cpp
ConvertFuncToGCCJIT.cpp

DEPENDS
MLIRGCCJIT
MLIRGCCJITPassIncGen
${dialect_libs}
${conversion_libs}

Expand Down
Loading

0 comments on commit c53909a

Please sign in to comment.