Skip to content

Commit

Permalink
Exposes base OptionsContext class in the API
Browse files Browse the repository at this point in the history
In order to support more than just the StableHloToExecutable pipeline, we
need to be able to create different option types from the API. This commit
exposes the base `OptionsContext` class in the Python API and includes a mechanism
for child classes to register themselves with the client, allowing them to be
created through a common API.
  • Loading branch information
pranavm-nvidia committed Nov 20, 2024
1 parent 89e2090 commit ba839d4
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 1 deletion.
22 changes: 22 additions & 0 deletions mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient options) {
return !options.ptr;
}

//===----------------------------------------------------------------------===//
// MTRT_OptionsContext
//===----------------------------------------------------------------------===//

typedef struct MTRT_OptionsContext {
void *ptr;
} MTRT_OptionsContext;

MLIR_CAPI_EXPORTED MTRT_Status mtrtOptionsContextCreateFromArgs(
MTRT_CompilerClient client, MTRT_OptionsContext *options,
const char *optionsType, const MlirStringRef *argv, unsigned argc);

MLIR_CAPI_EXPORTED MTRT_Status
mtrtOptionsContextPrint(MTRT_OptionsContext options);

MLIR_CAPI_EXPORTED MTRT_Status
mtrtOptionsContextDestroy(MTRT_OptionsContext options);

static inline bool mtrtOptionsConextIsNull(MTRT_OptionsContext options) {
return !options.ptr;
}

//===----------------------------------------------------------------------===//
// MTRT_StableHLOToExecutableOptions
//===----------------------------------------------------------------------===//
Expand Down
62 changes: 62 additions & 0 deletions mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,20 @@
#define MLIR_TENSORRT_COMPILER_CLIENT

#include "mlir-executor/Support/Status.h"
#include "mlir-tensorrt-dialect/Utils/Options.h"
#include "mlir-tensorrt/Compiler/Options.h"
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/ManagedStatic.h"
#include <functional>
#include <memory>
#include <unordered_map>

namespace mlirtrt::compiler {

Expand Down Expand Up @@ -95,6 +101,16 @@ class CompilationTask : public CompilationTaskBase {
/// builder caches are being persisted to disk.
class CompilerClient {
public:
using OptionsConstructorFuncT =
std::function<StatusOr<std::unique_ptr<mlir::OptionsContext>>(
const CompilerClient &client, const std::vector<llvm::StringRef> &)>;

static bool registerOption(std::string optionsType,
OptionsConstructorFuncT func) {
(*registry)[optionsType] = std::move(func);
return true;
}

static StatusOr<std::unique_ptr<CompilerClient>>
create(mlir::MLIRContext *context);

Expand Down Expand Up @@ -131,6 +147,17 @@ class CompilerClient {
static void setupPassManagerLogging(mlir::PassManager &pm,
const DebugOptions &options);

StatusOr<std::unique_ptr<mlir::OptionsContext>>
createOptions(const std::string &optionsType,
const std::vector<llvm::StringRef> &args) {
if (!registry->contains(optionsType))
return getInvalidArgStatus(
"{0} is not a valid option type. Valid options were: {1}",
optionsType, llvm::join(registry->keys(), ","));

return (*registry)[optionsType](*this, args);
}

protected:
CompilerClient(mlir::MLIRContext *context);

Expand All @@ -145,8 +172,43 @@ class CompilerClient {
/// used to create the PM.
llvm::DenseMap<PassManagerKey, std::unique_ptr<CompilationTaskBase>>
cachedPassManagers;

static llvm::ManagedStatic<llvm::StringMap<OptionsConstructorFuncT>> registry;
};

/// Helper to register option types with the client
template <typename OptionsT, typename TaskT>
StatusOr<std::unique_ptr<mlir::OptionsContext>>
optionsCreateFromArgs(const CompilerClient &client,
const std::vector<llvm::StringRef> &args) {
// Load available extensions.
mlir::MLIRContext *context = client.getContext();
mlir::plan::PlanDialect *planDialect =
context->getLoadedDialect<mlir::plan::PlanDialect>();
compiler::TaskExtensionRegistry extensions =
planDialect->extensionConstructors.getExtensionRegistryForTask<TaskT>();

auto result = std::make_unique<OptionsT>(std::move(extensions));

std::string err;
if (failed(result->parse(args, err))) {
std::string line = llvm::join(args, " ");
return getInternalErrorStatus(
"failed to parse options string {0} due to error: {1}", line, err);
}

// TODO: Figure out whether to add a method in the base class like
// "finalizeOptions" or a callback here, or something else if
// `inferDeviceOptionsFromHost` is unique to StableHLO.
//
// Populate device options from host information.
Status inferStatus = result->inferDeviceOptionsFromHost();
if (!inferStatus.isOk())
return inferStatus;

return std::unique_ptr<mlir::OptionsContext>(result.release());
}

} // namespace mlirtrt::compiler

#endif // MLIR_TENSORRT_COMPILER_CLIENT
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ class StableHloToExecutableTask
const StableHLOToExecutableOptions &options);
};

// TODO: Figure out if we want to register in a different way.
static const bool registeredStableHloToExecutable =
CompilerClient::registerOption(
"stable-hlo-to-executable",
optionsCreateFromArgs<StableHLOToExecutableOptions,
StableHloToExecutableTask>);

//===----------------------------------------------------------------------===//
// Pipeline Registrations
//===----------------------------------------------------------------------===//
Expand Down
33 changes: 33 additions & 0 deletions mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "mlir-c/Support.h"
#include "mlir-executor-c/Support/Status.h"
#include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h"
#include "mlir-tensorrt-dialect/Utils/Options.h"
#include "mlir-tensorrt/Compiler/Extension.h"
#include "mlir-tensorrt/Compiler/StableHloToExecutable.h"
#include "mlir-tensorrt/Compiler/TensorRTExtension/TensorRTExtension.h"
Expand All @@ -44,6 +45,7 @@ using namespace mlir;
DEFINE_C_API_PTR_METHODS(MTRT_CompilerClient, CompilerClient)
DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions,
StableHLOToExecutableOptions)
DEFINE_C_API_PTR_METHODS(MTRT_OptionsContext, OptionsContext)
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
Expand Down Expand Up @@ -99,6 +101,37 @@ MTRT_Status mtrtCompilerClientDestroy(MTRT_CompilerClient client) {
return mtrtStatusGetOk();
}

//===----------------------------------------------------------------------===//
// MTRT_OptionsContext
//===----------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED MTRT_Status mtrtOptionsContextCreateFromArgs(
MTRT_CompilerClient client, MTRT_OptionsContext *options,
const char *optionsType, const MlirStringRef *argv, unsigned argc) {
std::vector<llvm::StringRef> argvStrRef(argc);
for (unsigned i = 0; i < argc; i++)
argvStrRef[i] = llvm::StringRef(argv[i].data, argv[i].length);

auto result = unwrap(client)->createOptions(optionsType, argvStrRef);
if (!result.isOk())
return wrap(result.getStatus());

*options = wrap(result->release());
return mtrtStatusGetOk();
}

MLIR_CAPI_EXPORTED MTRT_Status
mtrtOptionsContextPrint(MTRT_OptionsContext options) {
unwrap(options)->print(llvm::outs());
return mtrtStatusGetOk();
}

MLIR_CAPI_EXPORTED MTRT_Status
mtrtOptionsContextDestroy(MTRT_OptionsContext options) {
delete unwrap(options);
return mtrtStatusGetOk();
}

//===----------------------------------------------------------------------===//
// MTRT_StableHLOToExecutableOptions
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions mlir-tensorrt/compiler/lib/Compiler/Client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ CompilationTaskBase::~CompilationTaskBase() {}
// CompilerClient
//===----------------------------------------------------------------------===//

decltype(CompilerClient::registry) CompilerClient::registry = {};

StatusOr<std::unique_ptr<CompilerClient>>
CompilerClient::create(MLIRContext *context) {
context->disableMultithreading();
Expand Down
34 changes: 34 additions & 0 deletions mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ class PyCompilerClient
mtrtCompilerClientIsNull, mtrtCompilerClientDestroy};
};

/// Python object type wrapper for `MTRT_OptionsContext`.
class PyOptionsContext
: public PyMTRTWrapper<PyOptionsContext, MTRT_OptionsContext> {
public:
using PyMTRTWrapper::PyMTRTWrapper;
DECLARE_WRAPPER_CONSTRUCTORS(PyOptionsContext);

static constexpr auto kMethodTable = CAPITable<MTRT_OptionsContext>{
mtrtOptionsConextIsNull, mtrtOptionsContextDestroy};
};

/// Python object type wrapper for `MTRT_StableHLOToExecutableOptions`.
class PyStableHLOToExecutableOptions
: public PyMTRTWrapper<PyStableHLOToExecutableOptions,
Expand Down Expand Up @@ -240,6 +251,29 @@ PYBIND11_MODULE(_api, m) {
return new PyCompilerClient(client);
}));

py::class_<PyOptionsContext>(m, "OptionsContext", py::module_local())
.def(py::init<>([](PyCompilerClient &client,
const std::string &optionsType,
const std::vector<std::string> &args) {
std::vector<MlirStringRef> refs(args.size());
for (unsigned i = 0; i < args.size(); i++)
refs[i] = mlirStringRefCreate(args[i].data(), args[i].size());

MTRT_OptionsContext options;
MTRT_Status s = mtrtOptionsContextCreateFromArgs(
client, &options, optionsType.c_str(), refs.data(),
refs.size());
THROW_IF_MTRT_ERROR(s);
return new PyOptionsContext(options);
}),
py::arg("client"), py::arg("options_type"), py::arg("args"))

.def("print", [](PyOptionsContext &self) {
// TODO: Should this be exposed via __str__ instead?
auto s = mtrtOptionsContextPrint(self);
THROW_IF_MTRT_ERROR(s);
});

py::class_<PyStableHLOToExecutableOptions>(m, "StableHLOToExecutableOptions",
py::module_local())
.def(py::init<>([](PyCompilerClient &client,
Expand Down
2 changes: 1 addition & 1 deletion mlir-tensorrt/python/bindings/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class PyMTRTWrapper {

static py::object createFromCapsule(py::object capsule) {
if constexpr (cFuncTable.capsuleToCApi == nullptr) {
throw py::value_error("boject cannot be converted from opaque capsule");
throw py::value_error("object cannot be converted from opaque capsule");
} else {
MTRT_StableHLOToExecutableOptions cObj =
cFuncTable.capsuleToCApi(capsule.ptr());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import mlir_tensorrt.compiler.api as api
from mlir_tensorrt.compiler.ir import *


with Context() as context:
client = api.CompilerClient(context)
# Try to create a non-existent option type
try:
opts = api.OptionsContext(client, "non-existent-options-type", [])
except Exception as err:
print(err)

opts = api.OptionsContext(
client,
"stable-hlo-to-executable",
[
"--tensorrt-builder-opt-level=3",
"--tensorrt-strongly-typed=false",
"--tensorrt-workspace-memory-pool-limit=1gb",
],
)

opts.print()


# CHECK: InvalidArgument: InvalidArgument: non-existent-options-type is not a valid option type. Valid options were: stable-hlo-to-executable
# CHECK: --tensorrt-timing-cache-path= --device-infer-from-host=true --debug-only= --executor-index-bitwidth=64 --entrypoint=main --plan-clustering-disallow-host-tensors-in-tensorrt-clusters=false --tensorrt-workspace-memory-pool-limit=1073741824 --device-max-registers-per-block=65536 --tensorrt-strongly-typed=false --tensorrt-layer-info-dir= --device-compute-capability=86 --debug=false --mlir-print-ir-tree-dir= --disable-tensorrt-extension=false --tensorrt-builder-opt-level=3 --tensorrt-engines-dir=

0 comments on commit ba839d4

Please sign in to comment.