diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h b/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h index 64cf76258..ea85651d0 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h @@ -51,6 +51,28 @@ static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient options) { return !options.ptr; } +//===----------------------------------------------------------------------===// +// MTRT_OptionsContext +//===----------------------------------------------------------------------===// + +typedef struct MTRT_OptionsContext { + void *ptr; +} MTRT_OptionsContext; + +MLIR_CAPI_EXPORTED MTRT_Status mtrtOptionsContextCreateFromArgs( + MTRT_CompilerClient client, MTRT_OptionsContext *options, + const char *optionsType, const MlirStringRef *argv, unsigned argc); + +MLIR_CAPI_EXPORTED MTRT_Status +mtrtOptionsContextPrint(MTRT_OptionsContext options); + +MLIR_CAPI_EXPORTED MTRT_Status +mtrtOptionsContextDestroy(MTRT_OptionsContext options); + +static inline bool mtrtOptionsConextIsNull(MTRT_OptionsContext options) { + return !options.ptr; +} + //===----------------------------------------------------------------------===// // MTRT_StableHLOToExecutableOptions //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h index b4a92b25c..e75f0c551 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h @@ -28,14 +28,20 @@ #define MLIR_TENSORRT_COMPILER_CLIENT #include "mlir-executor/Support/Status.h" +#include "mlir-tensorrt-dialect/Utils/Options.h" #include "mlir-tensorrt/Compiler/Options.h" +#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/TypeID.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ManagedStatic.h" +#include #include +#include namespace mlirtrt::compiler { @@ -95,6 +101,16 @@ class CompilationTask : public CompilationTaskBase { /// builder caches are being persisted to disk. class CompilerClient { public: + using OptionsConstructorFuncT = + std::function>( + const CompilerClient &client, const std::vector &)>; + + static bool registerOption(std::string optionsType, + OptionsConstructorFuncT func) { + (*registry)[optionsType] = std::move(func); + return true; + } + static StatusOr> create(mlir::MLIRContext *context); @@ -131,6 +147,17 @@ class CompilerClient { static void setupPassManagerLogging(mlir::PassManager &pm, const DebugOptions &options); + StatusOr> + createOptions(const std::string &optionsType, + const std::vector &args) { + if (!registry->contains(optionsType)) + return getInvalidArgStatus( + "{0} is not a valid option type. Valid options were: {1}", + optionsType, llvm::join(registry->keys(), ",")); + + return (*registry)[optionsType](*this, args); + } + protected: CompilerClient(mlir::MLIRContext *context); @@ -145,8 +172,43 @@ class CompilerClient { /// used to create the PM. llvm::DenseMap> cachedPassManagers; + + static llvm::ManagedStatic> registry; }; +/// Helper to register option types with the client +template +StatusOr> +optionsCreateFromArgs(const CompilerClient &client, + const std::vector &args) { + // Load available extensions. + mlir::MLIRContext *context = client.getContext(); + mlir::plan::PlanDialect *planDialect = + context->getLoadedDialect(); + compiler::TaskExtensionRegistry extensions = + planDialect->extensionConstructors.getExtensionRegistryForTask(); + + auto result = std::make_unique(std::move(extensions)); + + std::string err; + if (failed(result->parse(args, err))) { + std::string line = llvm::join(args, " "); + return getInternalErrorStatus( + "failed to parse options string {0} due to error: {1}", line, err); + } + + // TODO: Figure out whether to add a method in the base class like + // "finalizeOptions" or a callback here, or something else if + // `inferDeviceOptionsFromHost` is unique to StableHLO. + // + // Populate device options from host information. + Status inferStatus = result->inferDeviceOptionsFromHost(); + if (!inferStatus.isOk()) + return inferStatus; + + return std::unique_ptr(result.release()); +} + } // namespace mlirtrt::compiler #endif // MLIR_TENSORRT_COMPILER_CLIENT diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h index 0e3b8e3ab..4134e3cbb 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h @@ -189,6 +189,13 @@ class StableHloToExecutableTask const StableHLOToExecutableOptions &options); }; +// TODO: Figure out if we want to register in a different way. +static const bool registeredStableHloToExecutable = + CompilerClient::registerOption( + "stable-hlo-to-executable", + optionsCreateFromArgs); + //===----------------------------------------------------------------------===// // Pipeline Registrations //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp index 185824226..5d817bdc7 100644 --- a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp +++ b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp @@ -26,6 +26,7 @@ #include "mlir-c/Support.h" #include "mlir-executor-c/Support/Status.h" #include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h" +#include "mlir-tensorrt-dialect/Utils/Options.h" #include "mlir-tensorrt/Compiler/Extension.h" #include "mlir-tensorrt/Compiler/StableHloToExecutable.h" #include "mlir-tensorrt/Compiler/TensorRTExtension/TensorRTExtension.h" @@ -44,6 +45,7 @@ using namespace mlir; DEFINE_C_API_PTR_METHODS(MTRT_CompilerClient, CompilerClient) DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions, StableHLOToExecutableOptions) +DEFINE_C_API_PTR_METHODS(MTRT_OptionsContext, OptionsContext) #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop #endif @@ -99,6 +101,37 @@ MTRT_Status mtrtCompilerClientDestroy(MTRT_CompilerClient client) { return mtrtStatusGetOk(); } +//===----------------------------------------------------------------------===// +// MTRT_OptionsContext +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MTRT_Status mtrtOptionsContextCreateFromArgs( + MTRT_CompilerClient client, MTRT_OptionsContext *options, + const char *optionsType, const MlirStringRef *argv, unsigned argc) { + std::vector argvStrRef(argc); + for (unsigned i = 0; i < argc; i++) + argvStrRef[i] = llvm::StringRef(argv[i].data, argv[i].length); + + auto result = unwrap(client)->createOptions(optionsType, argvStrRef); + if (!result.isOk()) + return wrap(result.getStatus()); + + *options = wrap(result->release()); + return mtrtStatusGetOk(); +} + +MLIR_CAPI_EXPORTED MTRT_Status +mtrtOptionsContextPrint(MTRT_OptionsContext options) { + unwrap(options)->print(llvm::outs()); + return mtrtStatusGetOk(); +} + +MLIR_CAPI_EXPORTED MTRT_Status +mtrtOptionsContextDestroy(MTRT_OptionsContext options) { + delete unwrap(options); + return mtrtStatusGetOk(); +} + //===----------------------------------------------------------------------===// // MTRT_StableHLOToExecutableOptions //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/lib/Compiler/Client.cpp b/mlir-tensorrt/compiler/lib/Compiler/Client.cpp index dca44e340..baec0bb2e 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/Client.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/Client.cpp @@ -46,6 +46,8 @@ CompilationTaskBase::~CompilationTaskBase() {} // CompilerClient //===----------------------------------------------------------------------===// +decltype(CompilerClient::registry) CompilerClient::registry = {}; + StatusOr> CompilerClient::create(MLIRContext *context) { context->disableMultithreading(); diff --git a/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp b/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp index 912df1fcf..b06d6dfb1 100644 --- a/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp +++ b/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp @@ -52,6 +52,17 @@ class PyCompilerClient mtrtCompilerClientIsNull, mtrtCompilerClientDestroy}; }; +/// Python object type wrapper for `MTRT_OptionsContext`. +class PyOptionsContext + : public PyMTRTWrapper { +public: + using PyMTRTWrapper::PyMTRTWrapper; + DECLARE_WRAPPER_CONSTRUCTORS(PyOptionsContext); + + static constexpr auto kMethodTable = CAPITable{ + mtrtOptionsConextIsNull, mtrtOptionsContextDestroy}; +}; + /// Python object type wrapper for `MTRT_StableHLOToExecutableOptions`. class PyStableHLOToExecutableOptions : public PyMTRTWrapper(m, "OptionsContext", py::module_local()) + .def(py::init<>([](PyCompilerClient &client, + const std::string &optionsType, + const std::vector &args) { + std::vector refs(args.size()); + for (unsigned i = 0; i < args.size(); i++) + refs[i] = mlirStringRefCreate(args[i].data(), args[i].size()); + + MTRT_OptionsContext options; + MTRT_Status s = mtrtOptionsContextCreateFromArgs( + client, &options, optionsType.c_str(), refs.data(), + refs.size()); + THROW_IF_MTRT_ERROR(s); + return new PyOptionsContext(options); + }), + py::arg("client"), py::arg("options_type"), py::arg("args")) + + .def("print", [](PyOptionsContext &self) { + // TODO: Should this be exposed via __str__ instead? + auto s = mtrtOptionsContextPrint(self); + THROW_IF_MTRT_ERROR(s); + }); + py::class_(m, "StableHLOToExecutableOptions", py::module_local()) .def(py::init<>([](PyCompilerClient &client, diff --git a/mlir-tensorrt/python/bindings/Utils.h b/mlir-tensorrt/python/bindings/Utils.h index b12830389..fdc2dee19 100644 --- a/mlir-tensorrt/python/bindings/Utils.h +++ b/mlir-tensorrt/python/bindings/Utils.h @@ -163,7 +163,7 @@ class PyMTRTWrapper { static py::object createFromCapsule(py::object capsule) { if constexpr (cFuncTable.capsuleToCApi == nullptr) { - throw py::value_error("boject cannot be converted from opaque capsule"); + throw py::value_error("object cannot be converted from opaque capsule"); } else { MTRT_StableHLOToExecutableOptions cObj = cFuncTable.capsuleToCApi(capsule.ptr()); diff --git a/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_options_context.py b/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_options_context.py new file mode 100644 index 000000000..bc0ee2916 --- /dev/null +++ b/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_options_context.py @@ -0,0 +1,29 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import mlir_tensorrt.compiler.api as api +from mlir_tensorrt.compiler.ir import * + + +with Context() as context: + client = api.CompilerClient(context) + # Try to create a non-existent option type + try: + opts = api.OptionsContext(client, "non-existent-options-type", []) + except Exception as err: + print(err) + + opts = api.OptionsContext( + client, + "stable-hlo-to-executable", + [ + "--tensorrt-builder-opt-level=3", + "--tensorrt-strongly-typed=false", + "--tensorrt-workspace-memory-pool-limit=1gb", + ], + ) + + opts.print() + + +# CHECK: InvalidArgument: InvalidArgument: non-existent-options-type is not a valid option type. Valid options were: stable-hlo-to-executable +# CHECK: --tensorrt-timing-cache-path= --device-infer-from-host=true --debug-only= --executor-index-bitwidth=64 --entrypoint=main --plan-clustering-disallow-host-tensors-in-tensorrt-clusters=false --tensorrt-workspace-memory-pool-limit=1073741824 --device-max-registers-per-block=65536 --tensorrt-strongly-typed=false --tensorrt-layer-info-dir= --device-compute-capability=86 --debug=false --mlir-print-ir-tree-dir= --disable-tensorrt-extension=false --tensorrt-builder-opt-level=3 --tensorrt-engines-dir=