From ba839d4e4f753cededbf7cfe141fe5c7d01da033 Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 20 Nov 2024 10:19:48 -0800 Subject: [PATCH] Exposes base `OptionsContext` class in the API In order to support more than just the StableHloToExecutable pipeline, we need to be able to create different option types from the API. This commit exposes the base `OptionsContext` class in the Python API and includes a mechanism for child classes to register themselves with the client, allowing them to be created through a common API. --- .../mlir-tensorrt-c/Compiler/Compiler.h | 22 +++++++ .../include/mlir-tensorrt/Compiler/Client.h | 62 +++++++++++++++++++ .../Compiler/StableHloToExecutable.h | 7 +++ .../compiler/lib/CAPI/Compiler/Compiler.cpp | 33 ++++++++++ .../compiler/lib/Compiler/Client.cpp | 2 + .../bindings/Compiler/CompilerPyBind.cpp | 34 ++++++++++ mlir-tensorrt/python/bindings/Utils.h | 2 +- .../compiler_api/test_options_context.py | 29 +++++++++ 8 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_options_context.py diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h b/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h index 64cf76258..ea85651d0 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h @@ -51,6 +51,28 @@ static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient options) { return !options.ptr; } +//===----------------------------------------------------------------------===// +// MTRT_OptionsContext +//===----------------------------------------------------------------------===// + +typedef struct MTRT_OptionsContext { + void *ptr; +} MTRT_OptionsContext; + +MLIR_CAPI_EXPORTED MTRT_Status mtrtOptionsContextCreateFromArgs( + MTRT_CompilerClient client, MTRT_OptionsContext *options, + const char *optionsType, const MlirStringRef *argv, unsigned argc); + +MLIR_CAPI_EXPORTED MTRT_Status +mtrtOptionsContextPrint(MTRT_OptionsContext options); + +MLIR_CAPI_EXPORTED MTRT_Status +mtrtOptionsContextDestroy(MTRT_OptionsContext options); + +static inline bool mtrtOptionsConextIsNull(MTRT_OptionsContext options) { + return !options.ptr; +} + //===----------------------------------------------------------------------===// // MTRT_StableHLOToExecutableOptions //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h index b4a92b25c..e75f0c551 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h @@ -28,14 +28,20 @@ #define MLIR_TENSORRT_COMPILER_CLIENT #include "mlir-executor/Support/Status.h" +#include "mlir-tensorrt-dialect/Utils/Options.h" #include "mlir-tensorrt/Compiler/Options.h" +#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/TypeID.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ManagedStatic.h" +#include #include +#include namespace mlirtrt::compiler { @@ -95,6 +101,16 @@ class CompilationTask : public CompilationTaskBase { /// builder caches are being persisted to disk. class CompilerClient { public: + using OptionsConstructorFuncT = + std::function>( + const CompilerClient &client, const std::vector &)>; + + static bool registerOption(std::string optionsType, + OptionsConstructorFuncT func) { + (*registry)[optionsType] = std::move(func); + return true; + } + static StatusOr> create(mlir::MLIRContext *context); @@ -131,6 +147,17 @@ class CompilerClient { static void setupPassManagerLogging(mlir::PassManager &pm, const DebugOptions &options); + StatusOr> + createOptions(const std::string &optionsType, + const std::vector &args) { + if (!registry->contains(optionsType)) + return getInvalidArgStatus( + "{0} is not a valid option type. Valid options were: {1}", + optionsType, llvm::join(registry->keys(), ",")); + + return (*registry)[optionsType](*this, args); + } + protected: CompilerClient(mlir::MLIRContext *context); @@ -145,8 +172,43 @@ class CompilerClient { /// used to create the PM. llvm::DenseMap> cachedPassManagers; + + static llvm::ManagedStatic> registry; }; +/// Helper to register option types with the client +template +StatusOr> +optionsCreateFromArgs(const CompilerClient &client, + const std::vector &args) { + // Load available extensions. + mlir::MLIRContext *context = client.getContext(); + mlir::plan::PlanDialect *planDialect = + context->getLoadedDialect(); + compiler::TaskExtensionRegistry extensions = + planDialect->extensionConstructors.getExtensionRegistryForTask(); + + auto result = std::make_unique(std::move(extensions)); + + std::string err; + if (failed(result->parse(args, err))) { + std::string line = llvm::join(args, " "); + return getInternalErrorStatus( + "failed to parse options string {0} due to error: {1}", line, err); + } + + // TODO: Figure out whether to add a method in the base class like + // "finalizeOptions" or a callback here, or something else if + // `inferDeviceOptionsFromHost` is unique to StableHLO. + // + // Populate device options from host information. + Status inferStatus = result->inferDeviceOptionsFromHost(); + if (!inferStatus.isOk()) + return inferStatus; + + return std::unique_ptr(result.release()); +} + } // namespace mlirtrt::compiler #endif // MLIR_TENSORRT_COMPILER_CLIENT diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h index 0e3b8e3ab..4134e3cbb 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h @@ -189,6 +189,13 @@ class StableHloToExecutableTask const StableHLOToExecutableOptions &options); }; +// TODO: Figure out if we want to register in a different way. +static const bool registeredStableHloToExecutable = + CompilerClient::registerOption( + "stable-hlo-to-executable", + optionsCreateFromArgs); + //===----------------------------------------------------------------------===// // Pipeline Registrations //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp index 185824226..5d817bdc7 100644 --- a/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp +++ b/mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp @@ -26,6 +26,7 @@ #include "mlir-c/Support.h" #include "mlir-executor-c/Support/Status.h" #include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h" +#include "mlir-tensorrt-dialect/Utils/Options.h" #include "mlir-tensorrt/Compiler/Extension.h" #include "mlir-tensorrt/Compiler/StableHloToExecutable.h" #include "mlir-tensorrt/Compiler/TensorRTExtension/TensorRTExtension.h" @@ -44,6 +45,7 @@ using namespace mlir; DEFINE_C_API_PTR_METHODS(MTRT_CompilerClient, CompilerClient) DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions, StableHLOToExecutableOptions) +DEFINE_C_API_PTR_METHODS(MTRT_OptionsContext, OptionsContext) #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop #endif @@ -99,6 +101,37 @@ MTRT_Status mtrtCompilerClientDestroy(MTRT_CompilerClient client) { return mtrtStatusGetOk(); } +//===----------------------------------------------------------------------===// +// MTRT_OptionsContext +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MTRT_Status mtrtOptionsContextCreateFromArgs( + MTRT_CompilerClient client, MTRT_OptionsContext *options, + const char *optionsType, const MlirStringRef *argv, unsigned argc) { + std::vector argvStrRef(argc); + for (unsigned i = 0; i < argc; i++) + argvStrRef[i] = llvm::StringRef(argv[i].data, argv[i].length); + + auto result = unwrap(client)->createOptions(optionsType, argvStrRef); + if (!result.isOk()) + return wrap(result.getStatus()); + + *options = wrap(result->release()); + return mtrtStatusGetOk(); +} + +MLIR_CAPI_EXPORTED MTRT_Status +mtrtOptionsContextPrint(MTRT_OptionsContext options) { + unwrap(options)->print(llvm::outs()); + return mtrtStatusGetOk(); +} + +MLIR_CAPI_EXPORTED MTRT_Status +mtrtOptionsContextDestroy(MTRT_OptionsContext options) { + delete unwrap(options); + return mtrtStatusGetOk(); +} + //===----------------------------------------------------------------------===// // MTRT_StableHLOToExecutableOptions //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/compiler/lib/Compiler/Client.cpp b/mlir-tensorrt/compiler/lib/Compiler/Client.cpp index dca44e340..baec0bb2e 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/Client.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/Client.cpp @@ -46,6 +46,8 @@ CompilationTaskBase::~CompilationTaskBase() {} // CompilerClient //===----------------------------------------------------------------------===// +decltype(CompilerClient::registry) CompilerClient::registry = {}; + StatusOr> CompilerClient::create(MLIRContext *context) { context->disableMultithreading(); diff --git a/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp b/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp index 912df1fcf..b06d6dfb1 100644 --- a/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp +++ b/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp @@ -52,6 +52,17 @@ class PyCompilerClient mtrtCompilerClientIsNull, mtrtCompilerClientDestroy}; }; +/// Python object type wrapper for `MTRT_OptionsContext`. +class PyOptionsContext + : public PyMTRTWrapper { +public: + using PyMTRTWrapper::PyMTRTWrapper; + DECLARE_WRAPPER_CONSTRUCTORS(PyOptionsContext); + + static constexpr auto kMethodTable = CAPITable{ + mtrtOptionsConextIsNull, mtrtOptionsContextDestroy}; +}; + /// Python object type wrapper for `MTRT_StableHLOToExecutableOptions`. class PyStableHLOToExecutableOptions : public PyMTRTWrapper(m, "OptionsContext", py::module_local()) + .def(py::init<>([](PyCompilerClient &client, + const std::string &optionsType, + const std::vector &args) { + std::vector refs(args.size()); + for (unsigned i = 0; i < args.size(); i++) + refs[i] = mlirStringRefCreate(args[i].data(), args[i].size()); + + MTRT_OptionsContext options; + MTRT_Status s = mtrtOptionsContextCreateFromArgs( + client, &options, optionsType.c_str(), refs.data(), + refs.size()); + THROW_IF_MTRT_ERROR(s); + return new PyOptionsContext(options); + }), + py::arg("client"), py::arg("options_type"), py::arg("args")) + + .def("print", [](PyOptionsContext &self) { + // TODO: Should this be exposed via __str__ instead? + auto s = mtrtOptionsContextPrint(self); + THROW_IF_MTRT_ERROR(s); + }); + py::class_(m, "StableHLOToExecutableOptions", py::module_local()) .def(py::init<>([](PyCompilerClient &client, diff --git a/mlir-tensorrt/python/bindings/Utils.h b/mlir-tensorrt/python/bindings/Utils.h index b12830389..fdc2dee19 100644 --- a/mlir-tensorrt/python/bindings/Utils.h +++ b/mlir-tensorrt/python/bindings/Utils.h @@ -163,7 +163,7 @@ class PyMTRTWrapper { static py::object createFromCapsule(py::object capsule) { if constexpr (cFuncTable.capsuleToCApi == nullptr) { - throw py::value_error("boject cannot be converted from opaque capsule"); + throw py::value_error("object cannot be converted from opaque capsule"); } else { MTRT_StableHLOToExecutableOptions cObj = cFuncTable.capsuleToCApi(capsule.ptr()); diff --git a/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_options_context.py b/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_options_context.py new file mode 100644 index 000000000..bc0ee2916 --- /dev/null +++ b/mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_options_context.py @@ -0,0 +1,29 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import mlir_tensorrt.compiler.api as api +from mlir_tensorrt.compiler.ir import * + + +with Context() as context: + client = api.CompilerClient(context) + # Try to create a non-existent option type + try: + opts = api.OptionsContext(client, "non-existent-options-type", []) + except Exception as err: + print(err) + + opts = api.OptionsContext( + client, + "stable-hlo-to-executable", + [ + "--tensorrt-builder-opt-level=3", + "--tensorrt-strongly-typed=false", + "--tensorrt-workspace-memory-pool-limit=1gb", + ], + ) + + opts.print() + + +# CHECK: InvalidArgument: InvalidArgument: non-existent-options-type is not a valid option type. Valid options were: stable-hlo-to-executable +# CHECK: --tensorrt-timing-cache-path= --device-infer-from-host=true --debug-only= --executor-index-bitwidth=64 --entrypoint=main --plan-clustering-disallow-host-tensors-in-tensorrt-clusters=false --tensorrt-workspace-memory-pool-limit=1073741824 --device-max-registers-per-block=65536 --tensorrt-strongly-typed=false --tensorrt-layer-info-dir= --device-compute-capability=86 --debug=false --mlir-print-ir-tree-dir= --disable-tensorrt-extension=false --tensorrt-builder-opt-level=3 --tensorrt-engines-dir=