Skip to content

Commit

Permalink
[Python/Bindings] Allow execute_function API to return values
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Oct 18, 2024
1 parent 5e4f406 commit 406ef76
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 51 deletions.
30 changes: 25 additions & 5 deletions mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ static inline bool mtrtRuntimeClientIsNull(MTRT_RuntimeClient client) {
return !client.ptr;
}

/// Returns null client.
static inline MTRT_RuntimeClient mtrtRuntimeClientGetNull() {
return MTRT_RuntimeClient{nullptr};
}

/// Creates a `MTRT_RuntimeClient`. Client must be alive for the lifetime of the
/// program execution.
/// The `stream` passed to the client is used by all underlying CUDA methods
Expand Down Expand Up @@ -291,6 +296,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) {
return !value.ptr;
}

// Returns whether the RuntimeValue is MemRef.
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value);

// Returns whether the RuntimeValue is Scalar.
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value);

/// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue.
MLIR_CAPI_EXPORTED MTRT_RuntimeValue
mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref);
Expand Down Expand Up @@ -374,16 +385,25 @@ static inline bool mtrtRuntimeSessionIsNull(MTRT_RuntimeSession session) {
return !session.ptr;
}

/// Using `session`, execute the pubic function with the specified name.
/// The `inArgs` and `outArgs` are arrays for input arguments and destination
/// arguments, respectively. Input arguments may be MemRefs or scalars, but
/// destination arguments must be MemRefs.
/// Using `session`, execute the public function with the specified name.
/// The `inArgs`, `outArgs`, and `results` are arrays for input arguments,
/// output arguments, and return values, respectively. Arguments and results
/// can be MemRefs, scalars, or other supported types. Both `outArgs` and
/// `results` can be used simultaneously, allowing for functions that both
/// modify arguments and return values.
/// A stream may optionally be specified, otherwise pass the result of
/// `mtrtStreamGetNull()`.
///
/// The `results` array must point to an array with at least the number of
/// elements returned by mtrtRuntimeSessionGetNumResults for the given function.
MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction(
MTRT_RuntimeSession session, MTRT_StringView name,
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream);
const MTRT_RuntimeValue *outArgs, size_t numOutArgs,
MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client);

MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNumResults(
MTRT_RuntimeSession session, MTRT_StringView name, int64_t *numResults);

//===----------------------------------------------------------------------===//
// DLPack
Expand Down
40 changes: 34 additions & 6 deletions mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) {
return wrap(static_cast<ScalarValue *>(x));
}

bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) {
RuntimeValue *x = unwrap(value);
return x->getKind() == RuntimeValue::Kind::MemRef;
}

bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) {
RuntimeValue *x = unwrap(value);
return x->getKind() == RuntimeValue::Kind::Scalar;
}

//===----------------------------------------------------------------------===//
// MTRT_RuntimeSessionOptions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -691,7 +701,8 @@ MTRT_Status mtrtRuntimeSessionDestroy(MTRT_RuntimeSession session) {
MTRT_Status mtrtRuntimeSessionExecuteFunction(
MTRT_RuntimeSession session, MTRT_StringView name,
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream) {
const MTRT_RuntimeValue *outArgs, size_t numOutArgs,
MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client) {
LuaRuntimeSession *cppSession =
static_cast<LuaRuntimeSession *>(unwrap(session));

Expand All @@ -701,19 +712,36 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction(
llvm::SmallVector<RuntimeValue *> outArgValues =
llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs),
[](MTRT_RuntimeValue arg) { return unwrap(arg); });

StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> result =
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> resultValues =
executeFunctionWithLuaBackend(
*cppSession, std::string_view(name.data, name.length), inArgValues,
outArgValues,
!mtrtStreamIsNull(stream)
? std::optional(unwrap(stream)->getRawStream())
: std::nullopt);
if (!result.isOk())
return wrap(result.getStatus());
: std::nullopt,
!mtrtRuntimeClientIsNull(client) ? std::optional(unwrap(client))
: std::nullopt);
if (!resultValues.isOk())
return wrap(resultValues.getStatus());

for (size_t i = 0; i < resultValues->size(); ++i)
results[i] = wrap((*resultValues)[i].release());

return mtrtStatusGetOk();
}

MTRT_Status mtrtRuntimeSessionGetNumResults(MTRT_RuntimeSession session,
MTRT_StringView name,
int64_t *numResults) {
LuaRuntimeSession *cppSession =
static_cast<LuaRuntimeSession *>(unwrap(session));
*numResults = cppSession->getExecutable()
.getFunction(std::string_view(name.data, name.length))
.getSignature()
.getNumResults();
return mtrtStatusGetOk();
}

//===----------------------------------------------------------------------===//
// MTRT_RuntimeClient
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 6 additions & 3 deletions mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ PYBIND11_MODULE(_api, m) {
[](PyStableHLOToExecutableOptions &self, bool enabled,
std::vector<std::string> debugTypes,
std::optional<std::string> dumpIrTreeDir,
std::optional<std::string> dumpTensorRTDir) {
std::optional<std::string> dumpTensorRTDir,
std::optional<bool> dumpTextualPipeline) {
// The strings are copied by the CAPI call, so we just need to
// refence the C-strings temporarily.
std::vector<const char *> literals;
Expand All @@ -270,12 +271,14 @@ PYBIND11_MODULE(_api, m) {
THROW_IF_MTRT_ERROR(mtrtStableHloToExecutableOptionsSetDebugOptions(
self, enabled, literals.data(), literals.size(),
dumpIrTreeDir ? dumpIrTreeDir->c_str() : nullptr,
dumpTensorRTDir ? dumpTensorRTDir->c_str() : nullptr));
dumpTensorRTDir ? dumpTensorRTDir->c_str() : nullptr,
dumpTextualPipeline ? *dumpTextualPipeline : false));
},
py::arg("enabled"),
py::arg("debug_types") = std::vector<std::string>{},
py::arg("dump_ir_tree_dir") = py::none(),
py::arg("dump_tensorrt_dir") = py::none())
py::arg("dump_tensorrt_dir") = py::none(),
py::arg("dump_textual_pipeline") = py::none())

#ifdef MLIR_TRT_TARGET_TENSORRT
.def(
Expand Down
45 changes: 38 additions & 7 deletions mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,15 @@ static MTRT_RuntimeValue convertArgType(py::object obj) {
throw std::runtime_error("argument must be MemRef or scalar");
}

/// Convert Runtime value to PyMemRefValue or PyScalarValue object.
static py::object convertGenericArgToPyObject(MTRT_RuntimeValue value) {
if (mtrtRuntimeValueIsMemRef(value))
return py::cast<PyMemRefValue>(mtrtRuntimeValueDynCastToMemRef(value));
if (mtrtRuntimeValueIsScalar(value))
return py::cast<PyScalarValue>(mtrtRuntimeValueDynCastToScalar(value));
return py::none();
}

//===----------------------------------------------------------------------===//
// Declare the bindings.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -920,19 +929,41 @@ PYBIND11_MODULE(_api, m) {
.def(
"execute_function",
[](PyRuntimeSession &self, std::string name,
std::vector<py::object> inArgs, std::vector<py::object> outArgs,
std::optional<MTRT_Stream> stream) {
std::vector<py::object> inArgs,
std::optional<std::vector<py::object>> outArgs,
std::optional<MTRT_Stream> stream, PyRuntimeClient &client) {
MTRT_StringView nameRef{name.data(), name.size()};

int64_t numResults;
MTRT_Status s =
mtrtRuntimeSessionGetNumResults(self, nameRef, &numResults);
THROW_IF_MTRT_ERROR(s);

auto inArgsGeneric = llvm::map_to_vector(inArgs, convertArgType);
auto outArgsGeneric = llvm::map_to_vector(outArgs, convertArgType);
auto outArgsGeneric =
outArgs ? llvm::map_to_vector(*outArgs, convertArgType)
: llvm::SmallVector<MTRT_RuntimeValue>{};

MTRT_Status s = mtrtRuntimeSessionExecuteFunction(
std::vector<MTRT_RuntimeValue> resultsGeneric(numResults);

s = mtrtRuntimeSessionExecuteFunction(
self, nameRef, inArgsGeneric.data(), inArgsGeneric.size(),
outArgsGeneric.data(), outArgsGeneric.size(),
stream ? *stream : mtrtStreamGetNull());
resultsGeneric.data(), stream ? *stream : mtrtStreamGetNull(),
client);
THROW_IF_MTRT_ERROR(s);

std::vector<py::object> resultPyObject;
if (numResults > 0) {
for (const auto &arg : resultsGeneric)
resultPyObject.push_back(convertGenericArgToPyObject(arg));
}

return resultPyObject;
},
py::arg("name"), py::arg("in_args"), py::arg("out_args"),
py::arg("stream") = py::none());
py::arg("name"), py::arg("in_args"), py::arg("out_args") = py::none(),
py::arg("stream") = py::none(), py::arg("client"),
"Execute a function given input and optional output arguments. "
"Return optional results as a Python object if output arguments are "
"not present.");
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def test_stablehlo_add(
session = runtime.RuntimeSession(session_options, exe)

session.execute_function(
"main", in_args=test.in_args, out_args=test.out_args, stream=stream
"main",
in_args=test.in_args,
out_args=test.out_args,
stream=stream,
client=runtime_client,
)
output = [
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def execute(self, arg: runtime.RuntimeValue):
session = runtime.RuntimeSession(self.session_options, self.exe)
try:
session.execute_function(
"main", in_args=[arg], out_args=[arg], stream=self.stream
"main",
in_args=[arg],
out_args=[arg],
stream=self.stream,
client=self.client,
)
print("Test passed succesfully")
except runtime.MTRTException as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def test_serialize(ASM):
device=devices[0],
stream=stream,
)
session0.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream)
session0.execute_function(
"main", in_args=[arg0], out_args=[arg1], stream=stream, client=client
)
output0 = np.asarray(client.copy_to_host(arg1, stream=stream))
stream.sync()

Expand All @@ -57,7 +59,9 @@ def test_serialize(ASM):
exe_reconstructed = compiler.Executable(serialized_exe)

session1 = runtime.RuntimeSession(session_options, exe_reconstructed)
session1.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream)
session1.execute_function(
"main", in_args=[arg0], out_args=[arg1], stream=stream, client=client
)
output1 = np.asarray(client.copy_to_host(arg1, stream=stream))
stream.sync()

Expand Down
79 changes: 55 additions & 24 deletions mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,23 @@
"""


def stablehlo_add():
def stablehlo_add(use_non_dps=False, debug=False):
# Build/parse the main function.
with ir.Context() as context:
m = ir.Module.parse(ASM)

# Use the compiler API to compile to executable.
client = compiler.CompilerClient(context)
opts = compiler.StableHLOToExecutableOptions(
client,
["--tensorrt-builder-opt-level=3", "--tensorrt-strongly-typed=false"],
)
c_opts = [
"--tensorrt-builder-opt-level=3",
"--tensorrt-strongly-typed=false",
]
if use_non_dps:
c_opts.append("--use-non-dps-call-conv")
if debug:
c_opts.append("--debug=true")
c_opts.append(f"--mlir-print-ir-tree-dir=mlir-dumps-add-no-clone")
opts = compiler.StableHLOToExecutableOptions(client, c_opts)
exe = compiler.compiler_stablehlo_to_executable(client, m.operation, opts)

# The RuntimeClient can and should persist across multiple Executables, RuntimeSessions, etc.
Expand All @@ -44,36 +50,53 @@ def stablehlo_add():
device=devices[0],
stream=stream,
)
arg1 = client.create_memref(
np.zeros(shape=(2, 3, 4), dtype=np.float32).data,
device=devices[0],
stream=stream,
)
session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream)

data = np.asarray(client.copy_to_host(arg1, stream=stream))
result = None
if use_non_dps:
results = session.execute_function(
"main", in_args=[arg0], stream=stream, client=client
)
result = results[0]
else:
result = client.create_memref(
np.zeros(shape=(2, 3, 4), dtype=np.float32).data,
device=devices[0],
stream=stream,
)
session.execute_function(
"main", in_args=[arg0], out_args=[result], stream=stream, client=client
)

data = np.asarray(client.copy_to_host(result, stream=stream))
stream.sync()

print(data)

# Run execution a bunch more times asynchronously so that it calculates
# `x * 2**num_iter`.
num_iter = 5
start_time = time.time()
for _ in range(0, num_iter):
session.execute_function("main", in_args=[arg0], out_args=[arg0], stream=stream)
data = np.asarray(client.copy_to_host(arg1, stream=stream))
stream.sync()
end_time = time.time()
elapsed = end_time - start_time
if not use_non_dps:
# Run execution a bunch more times asynchronously so that it calculates
# `x * 2**num_iter`.
num_iter = 5
start_time = time.time()
for _ in range(0, num_iter):
session.execute_function(
"main", in_args=[arg0], out_args=[arg0], stream=stream, client=client
)
data = np.asarray(client.copy_to_host(arg0, stream=stream))
stream.sync()
end_time = time.time()
elapsed = end_time - start_time

print(np.asarray(client.copy_to_host(arg0)))
print(f"1000 iterations avg { (elapsed/num_iter)/1000.0} msec per iteration")
print(np.asarray(client.copy_to_host(arg0)))
print(f"1000 iterations avg { (elapsed/num_iter)/1000.0} msec per iteration")


if __name__ == "__main__":
print("DPS style execution:")
stablehlo_add()
print("Non DPS style execution:")
stablehlo_add(use_non_dps=True)

# CHECK-LABEL: DPS style execution:
# CHECK: [ 0. 2. 4. 6.]
# CHECK-NEXT: [ 8. 10. 12. 14.]
# CHECK-NEXT: [16. 18. 20. 22.]]
Expand All @@ -88,3 +111,11 @@ def stablehlo_add():
# CHECK-NEXT: [384. 416. 448. 480.]
# CHECK-NEXT: [512. 544. 576. 608.]
# CHECK-NEXT: [640. 672. 704. 736.]
# CHECK-LABEL: DPS style execution:
# CHECK: [ 0. 2. 4. 6.]
# CHECK-NEXT: [ 8. 10. 12. 14.]
# CHECK-NEXT: [16. 18. 20. 22.]]
# CHECK-NEXT:
# CHECK-NEXT: [24. 26. 28. 30.]
# CHECK-NEXT: [32. 34. 36. 38.]
# CHECK-NEXT: [40. 42. 44. 46.]]]
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def infer_output_shape(client, session, exe, input_shape):
outs = [client.create_memref(out_0, shape=shape, dtype=runtime.ScalarTypeCode.i64)]

session.execute_function(
exe.get_signature("main").get_shape_func_name(), in_args=ins, out_args=outs
exe.get_signature("main").get_shape_func_name(),
in_args=ins,
out_args=outs,
client=client,
)

# Copy output shape from device to host. Also, convert to int32 type since shape calculation uses int64 type.
Expand Down Expand Up @@ -135,7 +138,7 @@ def test_program(program: str, input_shape: Iterable[int], debug: bool = True):
)

session.execute_function(
"main", in_args=[arg0, arg1], out_args=[arg2], stream=stream
"main", in_args=[arg0, arg1], out_args=[arg2], stream=stream, client=client
)
data = np.asarray(client.copy_to_host(arg2, stream=stream))
stream.sync()
Expand Down

0 comments on commit 406ef76

Please sign in to comment.