From 4c7b5b8be8a7dd56b6624f1bd574a78d0d8bf4b9 Mon Sep 17 00:00:00 2001 From: Jhalak Patel Date: Mon, 4 Nov 2024 11:38:22 -0800 Subject: [PATCH 1/2] Tripy changes for non-DPS --- tripy/tests/backend/api/test_executable.py | 2 +- tripy/tests/frontend/test_tensor.py | 3 +- tripy/tests/integration/test_iota.py | 15 +-- tripy/tests/integration/test_quantize.py | 3 +- tripy/tripy/backend/api/compile.py | 1 - tripy/tripy/backend/api/executable.py | 50 +++++---- tripy/tripy/backend/mlir/compiler.py | 1 + tripy/tripy/backend/mlir/executor.py | 121 ++------------------- tripy/tripy/flat_ir/ops/copy.py | 9 ++ tripy/tripy/frontend/tensor.py | 6 +- 10 files changed, 61 insertions(+), 150 deletions(-) diff --git a/tripy/tests/backend/api/test_executable.py b/tripy/tests/backend/api/test_executable.py index 3b588b463..40ac8d01a 100644 --- a/tripy/tests/backend/api/test_executable.py +++ b/tripy/tests/backend/api/test_executable.py @@ -87,7 +87,7 @@ def test_signature(self, single_return_executable): assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD assert param.annotation == tp.Tensor - assert signature.return_annotation == tp.Tensor + assert signature.return_annotation == Sequence[tp.Tensor] def test_signature_multiple_return_values(self, multiple_return_executable): signature = inspect.signature(multiple_return_executable) diff --git a/tripy/tests/frontend/test_tensor.py b/tripy/tests/frontend/test_tensor.py index f0a21ded2..837ec8af0 100644 --- a/tripy/tests/frontend/test_tensor.py +++ b/tripy/tests/frontend/test_tensor.py @@ -226,8 +226,7 @@ def test_no_explicit_cast(self): "devices", [ ("cpu", "gpu"), - # TODO(#155) - # ("gpu", "cpu"), + ("gpu", "cpu"), ], ) def test_explicit_copy(self, devices): diff --git a/tripy/tests/integration/test_iota.py b/tripy/tests/integration/test_iota.py index 39df35787..48094c882 100644 --- a/tripy/tests/integration/test_iota.py +++ b/tripy/tests/integration/test_iota.py @@ -82,16 +82,17 @@ def test_iota_like(self, dtype, shape, dim): @pytest.mark.parametrize("dtype", DATA_TYPES.values()) def test_negative_no_casting(self, dtype): - from tripy.frontend.trace.ops.iota import Iota + with tp.logger.use_verbosity("ir"): + from tripy.frontend.trace.ops.iota import Iota - if dtype in [tp.float32, tp.int32, tp.int64]: - pytest.skip("tp.iota() supports float32, int32, and int64 without cast") + if dtype in [tp.float32, tp.int32, tp.int64]: + pytest.skip("tp.iota() supports float32, int32, and int64 without cast") - # TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint - a = tp.ones((2, 2)) - out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype) + # TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint + a = tp.ones((2, 2)) + out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype) - exception_str = "error: 'tensorrt.linspace' op result #0 must be 0D/1D/2D/3D/4D/5D/6D/7D/8D tensor of 32-bit float or 32-bit signless integer values" + exception_str = "InternalError: failed to run compilation on module with symbol name." if dtype == tp.bool: exception_str = "InternalError: failed to run compilation" with helper.raises( diff --git a/tripy/tests/integration/test_quantize.py b/tripy/tests/integration/test_quantize.py index b50293869..9bf96be98 100644 --- a/tripy/tests/integration/test_quantize.py +++ b/tripy/tests/integration/test_quantize.py @@ -117,5 +117,6 @@ def test_non_constant_scale(self): input = tp.ones((4, 4)) scale = tp.ones((4,)) quantized = tp.quantize(input, scale, tp.int8, dim=0) + quantized_int32 = tp.cast(quantized, tp.int32) - assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8))) + assert bool(tp.all(quantized_int32 == tp.ones((4, 4), dtype=tp.int32))) diff --git a/tripy/tripy/backend/api/compile.py b/tripy/tripy/backend/api/compile.py index 1f6315805..0491f8618 100644 --- a/tripy/tripy/backend/api/compile.py +++ b/tripy/tripy/backend/api/compile.py @@ -196,5 +196,4 @@ def process_arg(name, arg): return Executable( executable, compiled_arg_names, - output_devices=[out.device for out in trace.outputs], ) diff --git a/tripy/tripy/backend/api/executable.py b/tripy/tripy/backend/api/executable.py index 33347b314..8c9771531 100644 --- a/tripy/tripy/backend/api/executable.py +++ b/tripy/tripy/backend/api/executable.py @@ -14,7 +14,7 @@ # limitations under the License. import base64 import inspect -from typing import Sequence, Union +from typing import Sequence, Union, Tuple, Callable import mlir_tensorrt.runtime.api as runtime @@ -37,13 +37,11 @@ class Executable: """ # The constructor is intentionally undocumented because it is not meant to be called by users. - # TODO(#155): output_devices is not needed after they can be queried from executable - def __init__(self, executable, arg_names, output_devices): + def __init__(self, executable, arg_names): self._executable = executable self._executor = Executor(self._executable) self._arg_names = arg_names self._num_expected_args = len(arg_names) - self._output_devices = output_devices self._executable_signature = self._executable.get_signature("main") # Build a signature so the executable works with `inspect.signature` @@ -128,7 +126,7 @@ def add(a, b): tensor.eval() try: - executor_outputs = self._executor.execute(self._output_devices, input_tensors) + executor_outputs = self._executor.execute(input_tensors) except runtime.MTRTException as err: # TODO: Evaluate whether this should be moved into the executor if "function expects a memref type with element type" in str(err): @@ -170,15 +168,22 @@ def add(a, b): output_tensors = output_tensors[0] return output_tensors - def _get_arg_info(self, idx): - arg = self._executable_signature.get_arg(idx) - arg = runtime.MemRefType(arg) - arg_bound = self._executable_signature.get_arg_bound(idx) - shape_bounds = tuple(zip(arg_bound.min(), arg_bound.max())) - if len(shape_bounds) == 0: - # For static shape arguments, get_arg_bound returns an empty list and we fallback to arg.shape - shape_bounds = tuple((x, x) for x in arg.shape) - return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype)) + def _get_info(self, idx: int, get_item: Callable, get_bound: Callable) -> ArgInfo: + item = runtime.MemRefType(get_item(idx)) + bound = get_bound(idx) + shape_bounds = tuple(zip(bound.min(), bound.max())) + + if not shape_bounds: + # For static shape, fallback to item.shape + shape_bounds = tuple((x, x) for x in item.shape) + + return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(item.dtype)) + + def _get_arg_info(self, idx: int) -> ArgInfo: + return self._get_info(idx, self._executable_signature.get_arg, self._executable_signature.get_arg_bound) + + def _get_result_info(self, idx: int) -> ArgInfo: + return self._get_info(idx, self._executable_signature.get_result, self._executable_signature.get_res_bound) def get_input_info(self) -> Sequence[ArgInfo]: """ @@ -221,11 +226,16 @@ def add(a, b): compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)]) print(compiled_add.get_output_info()) """ - output_info = [] - offset = self._executable_signature.get_num_input_args() - for idx in range(self._executable_signature.get_num_output_args()): - output_info.append(self._get_arg_info(idx + offset)) - return output_info + num_input_args = self._executable_signature.get_num_input_args() + num_output_args = self._executable_signature.get_num_output_args() + num_results = self._executable_signature.get_num_results() + + assert not (num_output_args and num_results), "Cannot have both output arguments and results" + + if num_output_args: + return [self._get_arg_info(idx + num_input_args) for idx in range(num_output_args)] + else: + return [self._get_result_info(idx) for idx in range(num_results)] def save(self, path: str) -> None: """ @@ -289,7 +299,6 @@ def add(a, b): def encode_executable(executable): return { "arg_names": executable._arg_names, - "output_devices": executable._output_devices, "executable": base64.b64encode(executable._executable.serialize()).decode(), } @@ -300,5 +309,4 @@ def decode_executable(executable_dict): return Executable( runtime.Executable(executable_bytes), executable_dict["arg_names"], - executable_dict["output_devices"], ) diff --git a/tripy/tripy/backend/mlir/compiler.py b/tripy/tripy/backend/mlir/compiler.py index 1874e8938..517b978d5 100644 --- a/tripy/tripy/backend/mlir/compiler.py +++ b/tripy/tripy/backend/mlir/compiler.py @@ -58,6 +58,7 @@ def _make_mlir_opts(self, trt_builder_opt_level): f"--tensorrt-timing-cache-path={G_TIMING_CACHE_FILE}", f"--tensorrt-builder-opt-level={trt_builder_opt_level}", "--tensorrt-strongly-typed=True", + "--enable-non-dps-returns", ] if config.enable_mlir_debug or config.enable_tensorrt_debug: opts.append("--debug=true") diff --git a/tripy/tripy/backend/mlir/executor.py b/tripy/tripy/backend/mlir/executor.py index b03c507f3..447ad2cfa 100644 --- a/tripy/tripy/backend/mlir/executor.py +++ b/tripy/tripy/backend/mlir/executor.py @@ -31,89 +31,17 @@ class Executor: def __init__(self, executable: runtime.Executable) -> None: - + runtime.GlobalDebug.flag = True + debug_types = ["allocator", "runtime"] + runtime.GlobalDebug.set_types(debug_types) self.runtime_client = MLIRRuntimeClient() session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) self.session = runtime.RuntimeSession(session_options, executable) self.device = self.runtime_client.get_devices()[0] # Assume a single device is available. self.signature = executable.get_signature("main") self.stream = default_stream() - self.num_input_args = self.signature.get_num_input_args() - self.num_output_args = self.signature.get_num_output_args() - self.output_args = [ - self.signature.get_arg(index + self.num_input_args) for index in range(self.num_output_args) - ] - self.output_memrefs = [runtime.MemRefType(out) for out in self.output_args] - - def _create_shape_memref(self, shape): - shape = make_tuple(shape) - if len(shape) == 0: - return create_memref( - shape=(0,), - dtype=datatype.int64, - device=device("cpu"), - ) - return create_memref( - array=convert_list_to_array(shape, datatype.int64), - shape=(len(shape),), - dtype=datatype.int64, - device=device("cpu"), - ) - - def _get_outputs_shape(self): - outputs_shape = [] - all_outputs_known = True - for memref in self.output_memrefs: - outputs_shape.append(memref.shape) - all_outputs_known &= all(dim >= 0 for dim in memref.shape) - return outputs_shape, all_outputs_known - - def _get_inputs_runtime_shape(self, inputs): - inputs_shape = [] - for input in inputs: - inputs_shape.append(input.trace_tensor.producer.data.shape) - return inputs_shape - - def _execute_shape_inference(self, inputs_shape, outputs_shape): - inputs_shape_memref = [self._create_shape_memref(inp_shape) for inp_shape in inputs_shape] - outputs_shape_memref = [self._create_shape_memref(out_shape) for out_shape in outputs_shape] - self.session.execute_function( - name=self.signature.get_shape_func_name(), in_args=inputs_shape_memref, out_args=outputs_shape_memref - ) - - outputs_runtime_shape = [memoryview(s).tolist() for s in outputs_shape_memref] - return outputs_runtime_shape - - def _get_output_tensor_info(self, outputs_runtime_shape, output_devices): - outputs_tensor_info = [] - for index in range(self.num_output_args): - memref = self.output_memrefs[index] - dtype = convert_runtime_dtype_to_tripy_dtype(memref.dtype) - - output_device = output_devices[index] - if not output_device: - output_device = device(("gpu" if memref.address_space == runtime.PointerType.device else "cpu", 0)) - - runtime_shape = [rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[index])] - outputs_tensor_info.append( - TensorInfo( - len(runtime_shape), - tuple(runtime_shape), - dtype, - output_device, - ) - ) - return outputs_tensor_info - - def get_output_tensor_runtime_info(self, inputs, output_devices=List[device]): - outputs_shape, all_outputs_known = self._get_outputs_shape() - if not all_outputs_known: - inputs_shape = self._get_inputs_runtime_shape(inputs) - outputs_shape = self._execute_shape_inference(inputs_shape, outputs_shape) - output_tensor_info = self._get_output_tensor_info(outputs_shape, output_devices) - return output_tensor_info - def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]: + def execute(self, inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]: in_args = [] for inp in inputs: memref = inp.trace_tensor.producer.data @@ -131,45 +59,10 @@ def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> ) in_args.append(memref) - # HACK (#155): Remove `get_devices` once executable output tensor location matches Trace IR. - out_tensor_info = self.get_output_tensor_runtime_info(inputs, output_devices) - - # Allocate output memory and store buffer pointers. - outputs = [ - create_memref( - shape=info.shape, dtype=info.dtype, device=info.device, stream=self.stream._active_cuda_stream - ) - for info in out_tensor_info - ] - - out_args = [] - for out in outputs: - memref = out - # HACK (#155): MLIR-TensorRT requires inputs to be on device. - # Remove explicit copy to device once #155 is addressed. - if memref.address_space != runtime.PointerType.device: - memref = self.runtime_client.copy_to_device( - host_memref=memref, - device=self.runtime_client.get_devices()[0], - stream=self.stream._active_cuda_stream, - ) - if not memref: - raise_error("Could not allocate output memref", details=memref.error_details) - out_args.append(memref) - # Execute and populate device pointers. - self.session.execute_function( - "main", in_args=in_args, out_args=out_args, stream=self.stream._active_cuda_stream + outputs = self.session.execute_function( + "main", in_args=in_args, stream=self.stream._active_cuda_stream, client=self.runtime_client ) - # For outputs that were on the host, do the copy back - # TODO(#155): MLIR-TensorRT should allow output tensor placements on host. - for idx, out_info in enumerate(out_tensor_info): - if out_info.device.kind != "gpu": - self.runtime_client.copy_to_host( - device_memref=out_args[idx], - existing_host_memref=outputs[idx], - stream=self.stream._active_cuda_stream, - ) - + # For now return results on GPU. return outputs diff --git a/tripy/tripy/flat_ir/ops/copy.py b/tripy/tripy/flat_ir/ops/copy.py index 48598b2c5..092ff9373 100644 --- a/tripy/tripy/flat_ir/ops/copy.py +++ b/tripy/tripy/flat_ir/ops/copy.py @@ -29,6 +29,12 @@ class CopyOp(BaseFlatIROp): target: tripy.common.device + def set_memory_space_attr(self, tensor, mem_space_attr): + current_type = tensor.type + # Set the encoding attribute on the operation's result + new_type = ir.RankedTensorType.get(current_type.shape, current_type.element_type, encoding=mem_space_attr) + tensor.set_type(new_type) + def to_mlir(self, operands): from mlir_tensorrt.compiler.dialects import bufferization, tensor, arith @@ -46,7 +52,10 @@ def to_mlir(self, operands): sliced_dims.append(dim) alloc_tensor = bufferization.alloc_tensor(inp_type, sliced_dims, memory_space=mem_space_attr) + self.set_memory_space_attr(alloc_tensor, mem_space_attr) result_tensor = bufferization.materialize_in_destination(inp_type, operands[0], alloc_tensor) + self.set_memory_space_attr(result_tensor, mem_space_attr) cast_tensor = tensor.cast(self.outputs[0].to_mlir(), result_tensor) + self.set_memory_space_attr(cast_tensor, mem_space_attr) return [cast_tensor] diff --git a/tripy/tripy/frontend/tensor.py b/tripy/tripy/frontend/tensor.py index 344567ba8..5c5e2983c 100644 --- a/tripy/tripy/frontend/tensor.py +++ b/tripy/tripy/frontend/tensor.py @@ -185,11 +185,11 @@ def eval(self) -> runtime.MemRefValue: compiler = Compiler(trt_builder_opt_level=0) executable = compiler.compile(mlir, flat_ir=flat_ir) - executor = Executor(executable) + self.executor = Executor(executable) # Upon computing the value of this tensor, we switch it to have a `Storage` # parameter so that it does not need to be computed again. - data = executor.execute([out.device for out in flat_ir.outputs]) - executor.stream.synchronize() + data = self.executor.execute() + self.executor.stream.synchronize() assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor" data = data[0] From 5ad00f29246544bb9c44b6feb6d796f944dbb652 Mon Sep 17 00:00:00 2001 From: Jhalak Patel Date: Mon, 30 Sep 2024 09:56:10 -0700 Subject: [PATCH 2/2] Add initial IR for alloc enqueue --- .../lib/Compiler/StableHloToExecutable.cpp | 15 +++- .../TensorRTToTensorRTRuntime.cpp | 2 + .../Plan/Transforms/EliminateShapeOps.cpp | 41 ++++++--- .../Plan/Transforms/OutlineClusters.cpp | 78 ++++++++++++++++- .../include/mlir-executor-c/Runtime/Runtime.h | 30 +++++-- .../Runtime/Backend/Lua/LuaRuntime.h | 7 -- .../executor/lib/CAPI/Runtime/Runtime.cpp | 40 +++++++-- .../lib/Conversion/MemRefToExecutor.cpp | 20 +++++ .../executor/lib/Runtime/API/API.cpp | 38 +++++--- .../lib/Runtime/Backend/Lua/LuaRuntime.cpp | 72 ++++++++++------ .../Lua/Modules/TensorRT/TensorRTModule.cpp | 6 +- .../executor/lib/Support/Allocators.cpp | 2 +- .../test/lib/BufferizationTestPass.cpp | 33 ++++--- .../python/bindings/Runtime/RuntimePyBind.cpp | 48 +++++++++-- .../TRT10/test_stablehlo_add.py | 6 +- .../IntegrationTests/test_call_validation.py | 6 +- .../test_executable_serialize.py | 8 +- .../IntegrationTests/test_stablehlo_add.py | 79 +++++++++++------ .../test_stablehlo_dynamic.py | 86 ++++++++++++++----- .../test_runtime_debug_dump.py | 4 +- 20 files changed, 477 insertions(+), 144 deletions(-) diff --git a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp index c2bb73bcd..57fd29bb7 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp @@ -222,6 +222,10 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions( disallowHostTensorsInTensorRTClusters, llvm::cl::init(false), llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor " "calculations (but they can still be inputs)")); + addOption( + "enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false), + llvm::cl::desc( + "allow tensorrt based output allocations using output allocator")); addOption("executor-index-bitwidth", executorIndexBitwidth, llvm::cl::init(64)); addOption("device-compute-capability", deviceComputeCapability, @@ -306,6 +310,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline( plan::StablehloClusteringPassOptions clusteringOpts{}; clusteringOpts.disallowHostTensorsInTensorRTClusters = opts.disallowHostTensorsInTensorRTClusters; + clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns; clusteringOpts.entrypoint = opts.entrypoint; plan::buildPlanSegmentationPipeline(pm, clusteringOpts); @@ -339,7 +344,9 @@ void StableHloToExecutableTask::buildPostClusteringPipeline( // Perform bufferization. pm.addPass(createMemRefCastEliminationPass()); - pm.addPass(plan::createPlanAllocTensorsPass()); + plan::PlanAllocTensorsPassOptions allocTensorsOpts{}; + allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns; + pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts)); pm.addPass(plan::createPlanBufferizePass()); pm.addPass(createMemRefCastEliminationPass()); pm.addPass(createCanonicalizerPass()); @@ -525,6 +532,11 @@ struct ClusteringPipelineCliOpts *this, "device-compute-capability", llvm::cl::desc("target device compute capability (SM version)"), llvm::cl::init(60)}; + Option enableNonDPSReturns{ + *this, "enable-non-dps-returns", + llvm::cl::desc( + "allow tensorrt based output allocations using output allocator"), + llvm::cl::init(false)}; Option deviceMaxSharedMemoryPerBlockKb{ *this, "device-max-smem-per-block", llvm::cl::desc("max shared memory per block (in kilobytes)"), @@ -552,6 +564,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts( opts.deviceComputeCapability = cliOpts.deviceComputeCapability; opts.deviceMaxSharedMemoryPerBlockKb = cliOpts.deviceMaxSharedMemoryPerBlockKb; + opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns; opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost; opts.entrypoint = cliOpts.entrypoint; return opts; diff --git a/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp b/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp index f8382596f..7552c8ce4 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp @@ -93,6 +93,8 @@ convertCallOp(Operation *op, IRRewriter &rewriter, SmallVector hostTensorArgs; for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) { const TensorKindLattice *kind = solver.lookupState(arg); + if (!isa(arg.getType())) + continue; RankedTensorType rtt = cast(arg.getType()); // To be conservative, we only do this if type is i32 and num elements // <= 8. diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp index eb29494c9..393802270 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp @@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern { } // namespace /// Get a map from `tensorrt.func` functions to associated `tensorrt.call` -/// operations. -static llvm::DenseMap> +/// and `tensorrt.call_alloc` operations. +static llvm::DenseMap> getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) { - llvm::DenseMap> map; - op->walk([&](tensorrt::CallOp callOp) { - func::FuncOp func = callOp.getFuncCallee(collection); - if (map.contains(func)) { - map[func].push_back(callOp); + llvm::DenseMap> map; + op->walk([&](Operation *callOp) { + if (!isa(callOp)) return; - } - map.insert(std::make_pair(func, SmallVector{callOp})); + + func::FuncOp func; + if (auto call = dyn_cast(callOp)) + func = call.getFuncCallee(collection); + else if (auto callAlloc = dyn_cast(callOp)) + func = callAlloc.getFuncCallee(collection); + else + return; + + if (map.count(func)) + map[func].push_back(callOp); + else + map.insert({func, SmallVector{callOp}}); }); return map; } @@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) { /// `tensorrt.call` operations. static LogicalResult removeUnusedArgs(SymbolTableCollection &collection, ModuleOp op, func::FuncOp funcOp, - ArrayRef callOps) { + ArrayRef callOps) { llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0); for (BlockArgument arg : funcOp.getArguments()) { if (arg.use_empty()) @@ -99,8 +108,16 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection, funcOp.eraseArgument(i); // Update the call ops. - for (tensorrt::CallOp callOp : callOps) - callOp.getInputsMutable().erase(i); + for (Operation *callOp : callOps) { + if (auto call = dyn_cast(callOp)) { + call.getInputsMutable().erase(i); + } else if (auto callAlloc = dyn_cast(callOp)) { + callAlloc.getInputsMutable().erase(i); + } else { + llvm::errs() << "Unexpected operation type in callOps\n"; + callOp->dump(); + } + } } return success(); diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp index d593393f6..0eae7f594 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp @@ -268,8 +268,82 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter, plan::InlineClosedAllocGroupOp op) { - return op.emitError("outlinining inline closed alloc group ops to tensorrt " - "dialect is not yet implemented"); + tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op); + auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs())); + FailureOr func = createOutlinedFunc( + rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster", + "cluster.tensorrt", TypeRange(op.getInputs()), + op.getYield()->getOperandTypes()); + if (failed(func)) + return failure(); + assert(func->getFunctionBody().getBlocks().size() == 1 && + "expected body with one block"); + func->setPublic(); + + rewriter.setInsertionPoint(op); + + auto callOp = rewriter.create( + op.getLoc(), op.getResultTypes(), op.getInputs(), + SymbolRefAttr::get(trtModuleOp.getNameAttr(), + {FlatSymbolRefAttr::get(*func)})); + + // Populate the function arguments attributes. + for (unsigned i = 0; i < (*func).getNumArguments(); i++) { + BoundsAttr srcAttr = cast(op.getInputAttrs()[i]); + // We may have scalar (index|signless int)-typed values since we haven't + // eliminated `plan.(with_shape|with_values)` ops yet. + if (!op.argHasTensorType(i) || srcAttr.isNone()) + continue; + FailureOr boundAttr = + getTensorRTShapeProfile(srcAttr, op.getInputs()[i]); + if (failed(boundAttr)) + return op->emitOpError("failed to create TensorRT shape profile " + "attribute from Plan BoundsAttr for argument #") + << i << " (" << srcAttr << ")"; + if (srcAttr.isShapeBound()) { + func->setArgAttr(i, + tensorrt::TensorRTDialect::getShapeProfileArgAttrName(), + *boundAttr); + continue; + } + assert(srcAttr.isValueBound() && "expected value bound or shape bound"); + func->setArgAttr( + i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(), + *boundAttr); + func->setArgAttr(i, mlir::getHostTensorArgAttrName(), + rewriter.getUnitAttr()); + } + + // Populate the function entry block. + rewriter.eraseBlock(&func->getFunctionBody().front()); + + // Move private decomposition funcs associated with all `stablehlo.composite` + // ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op + // has its own symbol table. + SymbolTableCollection symbolTable; + for (auto compositeOp : op.getBody().getOps()) { + auto decompositionFunc = dyn_cast_if_present( + symbolTable.lookupSymbolIn(op->getParentOfType(), + compositeOp.getDecompositionAttr())); + if (!decompositionFunc) + return emitError(compositeOp.getLoc()) + << "failed to lookup stablehlo.composite decomposition " + "function: " + << compositeOp.getDecompositionAttr(); + rewriter.moveOpAfter(decompositionFunc, func->getOperation()); + } + + // Move region op operations to the func body. + Operation *regionYieldOp = op.getYield(); + rewriter.inlineRegionBefore(op.getRegion(), func->getFunctionBody(), + func->getFunctionBody().end()); + rewriter.setInsertionPoint(regionYieldOp); + rewriter.replaceOpWithNewOp(regionYieldOp, + regionYieldOp->getOperands()); + + // replace the original region results. + rewriter.replaceOp(op, callOp); + return success(); } /// Create outlined functions for each `scf.execute_region` operation within diff --git a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h index 807d3dc23..6bdf72ad6 100644 --- a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h +++ b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h @@ -215,6 +215,11 @@ static inline bool mtrtRuntimeClientIsNull(MTRT_RuntimeClient client) { return !client.ptr; } +/// Returns null client. +static inline MTRT_RuntimeClient mtrtRuntimeClientGetNull() { + return MTRT_RuntimeClient{nullptr}; +} + /// Creates a `MTRT_RuntimeClient`. Client must be alive for the lifetime of the /// program execution. /// The `stream` passed to the client is used by all underlying CUDA methods @@ -308,6 +313,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) { return !value.ptr; } +// Returns whether the RuntimeValue is MemRef. +MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value); + +// Returns whether the RuntimeValue is Scalar. +MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value); + /// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue. MLIR_CAPI_EXPORTED MTRT_RuntimeValue mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref); @@ -391,16 +402,25 @@ static inline bool mtrtRuntimeSessionIsNull(MTRT_RuntimeSession session) { return !session.ptr; } -/// Using `session`, execute the pubic function with the specified name. -/// The `inArgs` and `outArgs` are arrays for input arguments and destination -/// arguments, respectively. Input arguments may be MemRefs or scalars, but -/// destination arguments must be MemRefs. +/// Using `session`, execute the public function with the specified name. +/// The `inArgs`, `outArgs`, and `results` are arrays for input arguments, +/// output arguments, and return values, respectively. Arguments and results +/// can be MemRefs, scalars, or other supported types. Both `outArgs` and +/// `results` can be used simultaneously, allowing for functions that both +/// modify arguments and return values. /// A stream may optionally be specified, otherwise pass the result of /// `mtrtStreamGetNull()`. +/// +/// The `results` array must point to an array with at least the number of +/// elements returned by mtrtRuntimeSessionGetNumResults for the given function. MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction( MTRT_RuntimeSession session, MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs, - const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream); + const MTRT_RuntimeValue *outArgs, size_t numOutArgs, + MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client); + +MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNumResults( + MTRT_RuntimeSession session, MTRT_StringView name, int64_t *numResults); //===----------------------------------------------------------------------===// // DLPack diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h index 88c616bc7..a58b1d022 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h @@ -104,13 +104,6 @@ executeFunctionWithLuaBackend(LuaRuntimeSession &session, std::string_view name, std::optional stream = {}, std::optional client = {}); -// Parses the results of a function call, handling both scalar and MemRef return -// types -StatusOr>> -parseResults(const sol::protected_function_result &pfr, - const FunctionSignatureView &sig, - std::optional client); - } // namespace mlirtrt::runtime #endif // MLIR_TENSORRT_RUNTIME_BACKEND_LUA_LUARUNTIME_H diff --git a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp index f6320dce4..ab325a3ef 100644 --- a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp +++ b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp @@ -675,6 +675,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) { return wrap(static_cast(x)); } +bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) { + RuntimeValue *x = unwrap(value); + return x->getKind() == RuntimeValue::Kind::MemRef; +} + +bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) { + RuntimeValue *x = unwrap(value); + return x->getKind() == RuntimeValue::Kind::Scalar; +} + //===----------------------------------------------------------------------===// // MTRT_RuntimeSessionOptions //===----------------------------------------------------------------------===// @@ -721,7 +731,8 @@ MTRT_Status mtrtRuntimeSessionDestroy(MTRT_RuntimeSession session) { MTRT_Status mtrtRuntimeSessionExecuteFunction( MTRT_RuntimeSession session, MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs, - const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream) { + const MTRT_RuntimeValue *outArgs, size_t numOutArgs, + MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client) { LuaRuntimeSession *cppSession = static_cast(unwrap(session)); @@ -731,19 +742,36 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction( llvm::SmallVector outArgValues = llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs), [](MTRT_RuntimeValue arg) { return unwrap(arg); }); - - StatusOr>> result = + StatusOr>> resultValues = executeFunctionWithLuaBackend( *cppSession, std::string_view(name.data, name.length), inArgValues, outArgValues, !mtrtStreamIsNull(stream) ? std::optional(unwrap(stream)->getRawStream()) - : std::nullopt); - if (!result.isOk()) - return wrap(result.getStatus()); + : std::nullopt, + !mtrtRuntimeClientIsNull(client) ? std::optional(unwrap(client)) + : std::nullopt); + if (!resultValues.isOk()) + return wrap(resultValues.getStatus()); + + for (size_t i = 0; i < resultValues->size(); ++i) + results[i] = wrap((*resultValues)[i].release()); return mtrtStatusGetOk(); } + +MTRT_Status mtrtRuntimeSessionGetNumResults(MTRT_RuntimeSession session, + MTRT_StringView name, + int64_t *numResults) { + LuaRuntimeSession *cppSession = + static_cast(unwrap(session)); + *numResults = cppSession->getExecutable() + .getFunction(std::string_view(name.data, name.length)) + .getSignature() + .getNumResults(); + return mtrtStatusGetOk(); +} + //===----------------------------------------------------------------------===// // MTRT_RuntimeClient //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp b/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp index 90e4c9681..c595e5414 100644 --- a/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp +++ b/mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp @@ -24,6 +24,7 @@ #include "mlir-executor/Conversion/ConvertToExecutorCommon.h" #include "mlir-executor/Conversion/Passes.h" #include "mlir-executor/Executor/IR/Executor.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" @@ -548,6 +549,21 @@ void executor::populateMemRefToExecutorPatterns( } namespace { + +class RemoveNoOpClonePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(bufferization::CloneOp op, + PatternRewriter &rewriter) const override { + if (op.getInput().getType() == op.getOutput().getType()) { + rewriter.replaceOp(op, op.getInput()); + return success(); + } + return failure(); + } +}; + /// Pass to convert `memref` to `executor` dialect operrations. class ConvertMemRefToExecutorPass : public mlir::executor::impl::ConvertMemRefToExecutorPassBase< @@ -579,6 +595,10 @@ class ConvertMemRefToExecutorPass RewritePatternSet patterns(ctx); executor::populateMemRefToExecutorPatterns( patterns, typeConverter, allowUncheckedMemrefCastConversion); + + // Remove unrealized cast and redundant clone operations. + patterns.add(ctx); + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp index 5ff748b5f..abf079655 100644 --- a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp @@ -367,13 +367,14 @@ RuntimeSession::RuntimeSession(RuntimeSessionOptions options, //===----------------------------------------------------------------------===// AllocTracker::~AllocTracker() { + MTRT_DBGF("Destroying alloc tracker %p", static_cast(this)); MTRT_DBGF("checking %u allocations", map.size()); llvm::SmallVector ptrsToFree; ptrsToFree.reserve(map.size()); for (const auto &[ptrVal, metadata] : map) { if (metadata->info.isInternallyManaged() && metadata->externalReferenceCount.load() == 0) { - MTRT_DBGF("still live: 0x%lx type %d size %lu", ptrVal, + MTRT_DBGF("still live: 0x%lx type %d size %lx", ptrVal, static_cast(metadata->info.type), metadata->info.size); ptrsToFree.push_back(metadata->info); } @@ -410,7 +411,7 @@ void AllocTracker::incrementExternalCount(uintptr_t ptr) { llvm::formatv("Untracked pointer {0}", ptr).str().c_str()); std::unique_ptr const &metadata = map.at(ptr); int32_t ref = ++metadata->externalReferenceCount; - MTRT_DBG("Incremented external reference for pointer %d to %d", ptr, ref); + MTRT_DBGF("Incremented external reference for 0x%lx to %d", ptr, ref); } void AllocTracker::decrementExternalCount(uintptr_t ptr) { @@ -422,11 +423,12 @@ void AllocTracker::decrementExternalCount(uintptr_t ptr) { llvm::formatv("External reference count cannot be negative: {0}", ref) .str() .c_str()); - MTRT_DBG("Decremented external reference for pointer %d to %d", ptr, ref); + MTRT_DBGF("Decremented external reference for pointer 0x%lx to %d", ptr, ref); if (ref == 0 && metadata->releasedInternally) { - MTRT_DBG("External reference to an internally released pointer %d is 0, " - "try deallocating pointer memory of size %lu", - ptr, ref, metadata->info.size); + MTRT_DBGF( + "External reference to an internally released pointer 0x%lx is 0, " + "try deallocating pointer memory of size %lx", + ptr, metadata->info.size); Status s = safeDeallocate(*this, metadata->info.ptr); if (!s.isOk()) MTRT_DBGF("error while deallocating dangling memory: %s", @@ -452,9 +454,11 @@ void AllocTracker::track(PointerInfo info) { assert((!contains(info.ptr) || get(info.ptr).isExternallyManaged()) && "an internally managed pointer should not already be tracked"); } - MTRT_DBGF("AllocTracker is now tracking 0x%lx size=%lu space=%s ownership=%s", - info.ptr, info.size, runtime::impl::EnumNamePointerType(info.type), - runtime::impl::EnumNamePointerOwner(info.owner)); + MTRT_DBGF( + "AllocTracker %p is now tracking 0x%lx size=%lx space=%s ownership=%s", + static_cast(this), info.ptr, info.size, + runtime::impl::EnumNamePointerType(info.type), + runtime::impl::EnumNamePointerOwner(info.owner)); auto value = std::make_unique(); value->externalReferenceCount.store(0); value->releasedInternally = false; @@ -468,6 +472,8 @@ void AllocTracker::track(PointerInfo info) { } void AllocTracker::untrack(uintptr_t ptr) { + MTRT_DBGF( + "AllocTracker %p is now untracking 0x%lx)", static_cast(this), ptr); assert(llvm::is_contained(map, ptr) && llvm::formatv("Untracked pointer {0}", ptr).str().c_str()); map.erase(map.find(ptr)); @@ -574,7 +580,7 @@ mlirtrt::Status runtime::safeDeallocate(AllocTracker &tracker, uintptr_t ptr, PointerInfo obj = tracker.get(ptr); if (obj.owner == PointerOwner::external) { - MTRT_DBGF("Untracking externally managed pointer 0x%lx", ptr); + MTRT_DBGF("Untracking externally managed 0x%lx", ptr); tracker.untrack(obj.ptr); return mlirtrt::Status::getOk(); } @@ -725,9 +731,15 @@ StatusOr> MemRefValue::create( if (!::getFootprintInBytes(shape, strides, bitsPerElement).isOk()) return getInvalidArgStatus( "only memrefs with non-negative strides are allowed"); - if (!ptr) - return getInvalidArgStatus( - "MemRef objects must be created with a valid pointer"); + + auto is_empty_tensor = [](const llvm::ArrayRef &shape) -> bool { + return std::any_of(shape.begin(), shape.end(), + [](int64_t s) { return s == 0; }); + }; + + if (!ptr && !is_empty_tensor(shape)) return getInvalidArgStatus( + "MemRef objects must be created with a valid pointer for a non-empty " + "tensor"); if (isDeviceVisible(addressSpace) && (!device || !*device)) return getInvalidArgStatus("a specific device must be provided for MemRefs " diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp index aea894610..7536fca31 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp @@ -565,37 +565,43 @@ getScalarValue(const sol::protected_function_result &pfr, int index, } } -// Parses the results of a function call, handling both scalar and MemRef return -// types -StatusOr>> -runtime::parseResults(const sol::protected_function_result &pfr, - const FunctionSignatureView &sig, - std::optional client) { +/// Parses the results of a function call, handling both scalar and MemRef return types. +/// +/// @param pfr The protected function result to parse. +/// @param sig The function signature view. +/// @param sessionAllocTracker The allocation tracker for the current session. +/// @param client Optional runtime client pointer. +/// @return A vector of unique pointers to RuntimeValue, or an error status. +static StatusOr>> parseResults( + const sol::protected_function_result &pfr, const FunctionSignatureView &sig, + AllocTracker &sessionAllocTracker, std::optional client) { llvm::SmallVector> results; + results.reserve(sig.getNumResults()); + for (unsigned i = 0; i < sig.getNumResults(); ++i) { + const auto &resultType = sig.getResult(i); - if (sig.getResult(i).isa()) { - auto scalar = getScalarValue(pfr, i, sig); - if (!scalar.isOk()) - return scalar.getStatus(); - results.push_back(std::move(*scalar)); + if (resultType.isa()) { + auto scalarValue = getScalarValue(pfr, i, sig); + if (!scalarValue.isOk()) + return scalarValue.getStatus(); + results.push_back(std::move(*scalarValue)); continue; } - MemRefTableReader reader(pfr, i); - - if (!sig.getResult(i).isa()) + if (!resultType.isa()) return getInvalidArgStatus("Result can only be a memref or scalar"); // Handle MemRef return values - const auto &resultView = sig.getResult(i).get(); - unsigned rank = resultView.getRank(); + const auto &memRefView = resultType.get(); + MemRefTableReader reader(pfr, i); // Extract MemRef metadata uintptr_t allocPtr = reader.getNextValue(); [[maybe_unused]] uintptr_t alignedPtr = reader.getNextValue(); int64_t offset = reader.getNextValue(); + unsigned rank = memRefView.getRank(); llvm::SmallVector shape(rank); llvm::SmallVector strides(rank); @@ -608,16 +614,34 @@ runtime::parseResults(const sol::protected_function_result &pfr, if (!client) return getInvalidArgStatus("Runtime client cannot be nullptr"); - // Create MemRefValue from extracted data - auto memref = (*client)->createExternalMemRef( - resultView.getAddressSpace(), resultView.getElementType().getBitWidth(), + // Track the allocation in the session allocation tracker. This increment + // ensures the session maintains a reference to the allocation, preventing + // premature deallocation, and creates an external reference count for the + // pointer. The external count is crucial for proper memory management, + // allowing the allocation to persist even if internal references are + // released, thus ensuring safe access from external systems. + sessionAllocTracker.incrementExternalCount(allocPtr); + + // Create an external MemRef and track it in both session and client allocation trackers + MTRT_DBGF( + "Creating external MemRef for ptr 0x%lx: " + "Session tracker: %p, Client: %p, Client tracker: %p. " + "This ptr is registered with the session and will now be tracked by the client as well.", + allocPtr, + static_cast(&sessionAllocTracker), + static_cast(*client), + static_cast(&(*client)->getAllocTracker())); + + auto memRef = (*client)->createExternalMemRef( + memRefView.getAddressSpace(), memRefView.getElementType().getBitWidth(), allocPtr, offset, shape, strides, (*client)->getDevices()[0].get(), - resultView.getElementType()); + memRefView.getElementType()); - if (!memref.isOk()) - return memref.getStatus(); + if (!memRef.isOk()) + return memRef.getStatus(); - results.push_back(std::move(*memref)); + (*client)->getAllocTracker().incrementExternalCount((*memRef)->getMemory()); + results.push_back(std::move(*memRef)); } return results; @@ -715,5 +739,5 @@ runtime::executeFunctionWithLuaBackend( "\": ", err.what()); } - return parseResults(pfr, sig, client); + return parseResults(pfr, sig, tracker, client); } diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp index b24355eb3..38a2db092 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp @@ -141,9 +141,11 @@ class OutputAllocatorImpl : public nvinfer1::IOutputAllocator { size = std::max(size, static_cast(1)); if (size > mOutputSize) { size = roundUp(size, alignment); - if (mOutputPtr) + if (mOutputPtr) { + MTRT_DBGF("tensorrt module output allocator deallocating 0x%lx", mOutputPtr); mlirtrt::runtime::safeDeallocate(*mTracker, mOutputPtr, CudaStreamPtr(stream)); + } mOutputPtr = 0; mOutputSize = 0; StatusOr memory = @@ -152,6 +154,8 @@ class OutputAllocatorImpl : public nvinfer1::IOutputAllocator { if (memory.isOk()) { mOutputPtr = (*memory).ptr; mOutputSize = memory->size; + MTRT_DBGF("tensorrt module output allocator allocating %lu bytes at 0x%lx", mOutputSize, + mOutputPtr); } return reinterpret_cast(mOutputPtr); } diff --git a/mlir-tensorrt/executor/lib/Support/Allocators.cpp b/mlir-tensorrt/executor/lib/Support/Allocators.cpp index ce7310ad7..f65103972 100644 --- a/mlir-tensorrt/executor/lib/Support/Allocators.cpp +++ b/mlir-tensorrt/executor/lib/Support/Allocators.cpp @@ -296,4 +296,4 @@ Status PinnedMemoryAllocator::freeAsync(uintptr_t ptr, CudaStream stream) { return getInternalErrorStatus( "MLIR-Executor was not built with CUDA enabled"); #endif -} \ No newline at end of file +} diff --git a/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp b/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp index 9e8a9e50a..0c20d7da4 100644 --- a/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp +++ b/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp @@ -53,22 +53,31 @@ class ExecutorBufferizationTestPass } } }; + +struct PlanBufferizationPipelineCliOpts + : public PassPipelineOptions { + Option enableNonDPSReturns{ + *this, "enable-non-dps-returns", + llvm::cl::desc("allow backend clusters to directly allocate outputs"), + llvm::cl::init(false)}; +}; + } // namespace namespace mlir::executor { void registerTestExecutorBufferizePass() { PassRegistration(); - - PassPipelineRegistration<> executorBufferizationPipeline( - "test-executor-bufferization-pipeline", - "Run one-shot-bufferization and buffer deallocation pipelines", - [](OpPassManager &pm) { - pm.addPass(std::make_unique()); - pm.addPass(bufferization::createDropEquivalentBufferResultsPass()); - bufferization::BufferDeallocationPipelineOptions deallocOptions{}; - bufferization::buildBufferDeallocationPipeline(pm, deallocOptions); - pm.addPass(createCSEPass()); - pm.addPass(createCanonicalizerPass()); - }); + PassPipelineRegistration + executorBufferizationPipeline( + "test-executor-bufferization-pipeline", + "Run one-shot-bufferization and buffer deallocation pipelines", + [](OpPassManager &pm, const PlanBufferizationPipelineCliOpts &opts) { + pm.addPass(std::make_unique()); + pm.addPass(bufferization::createDropEquivalentBufferResultsPass()); + bufferization::BufferDeallocationPipelineOptions deallocOptions{}; + bufferization::buildBufferDeallocationPipeline(pm, deallocOptions); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + }); } } // namespace mlir::executor diff --git a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp index ddb0c74e2..1654549de 100644 --- a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp +++ b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp @@ -600,6 +600,15 @@ static MTRT_RuntimeValue convertArgType(py::object obj) { throw std::runtime_error("argument must be MemRef or scalar"); } +/// Convert Runtime value to PyMemRefValue or PyScalarValue object. +static py::object convertGenericArgToPyObject(MTRT_RuntimeValue value) { + if (mtrtRuntimeValueIsMemRef(value)) + return py::cast(mtrtRuntimeValueDynCastToMemRef(value)); + if (mtrtRuntimeValueIsScalar(value)) + return py::cast(mtrtRuntimeValueDynCastToScalar(value)); + return py::none(); +} + //===----------------------------------------------------------------------===// // Declare the bindings. //===----------------------------------------------------------------------===// @@ -950,22 +959,43 @@ PYBIND11_MODULE(_api, m) { .def( "execute_function", [](PyRuntimeSession &self, std::string name, - std::vector inArgs, std::vector outArgs, - std::optional stream) { + std::vector inArgs, + std::optional> outArgs, + std::optional stream, PyRuntimeClient &client) { MTRT_StringView nameRef{name.data(), name.size()}; + int64_t numResults; + MTRT_Status s = + mtrtRuntimeSessionGetNumResults(self, nameRef, &numResults); + THROW_IF_MTRT_ERROR(s); + auto inArgsGeneric = llvm::map_to_vector(inArgs, convertArgType); - auto outArgsGeneric = llvm::map_to_vector(outArgs, convertArgType); + auto outArgsGeneric = + outArgs ? llvm::map_to_vector(*outArgs, convertArgType) + : llvm::SmallVector{}; + + std::vector resultsGeneric(numResults); - MTRT_Status s = mtrtRuntimeSessionExecuteFunction( + s = mtrtRuntimeSessionExecuteFunction( self, nameRef, inArgsGeneric.data(), inArgsGeneric.size(), outArgsGeneric.data(), outArgsGeneric.size(), - stream ? *stream : mtrtStreamGetNull()); + resultsGeneric.data(), stream ? *stream : mtrtStreamGetNull(), + client); THROW_IF_MTRT_ERROR(s); - }, - py::arg("name"), py::arg("in_args"), py::arg("out_args"), - py::arg("stream") = py::none()); + std::vector resultPyObject; + if (numResults > 0) { + for (const auto &arg : resultsGeneric) + resultPyObject.push_back(convertGenericArgToPyObject(arg)); + } + + return resultPyObject; + }, + py::arg("name"), py::arg("in_args"), py::arg("out_args") = py::none(), + py::arg("stream") = py::none(), py::arg("client"), + "Execute a function given input and optional output arguments. " + "Return optional results as a Python object if output arguments are " + "not present."); py::class_(m, "GlobalDebug", py::module_local()) .def_property_static("flag", &PyGlobalDebugFlag::get, &PyGlobalDebugFlag::set, "LLVM-wide debug flag") @@ -977,4 +1007,4 @@ PYBIND11_MODULE(_api, m) { py::overload_cast &>( &PyGlobalDebugFlag::set_types), "Sets specific debug types to be produced by LLVM"); -} \ No newline at end of file +} diff --git a/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py b/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py index 480ce74d4..9535aef79 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py +++ b/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py @@ -36,7 +36,11 @@ def test_stablehlo_add( session = runtime.RuntimeSession(session_options, exe) session.execute_function( - "main", in_args=test.in_args, out_args=test.out_args, stream=stream + "main", + in_args=test.in_args, + out_args=test.out_args, + stream=stream, + client=runtime_client, ) output = [ ( diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py b/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py index 415767d70..ee3c784f8 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py @@ -73,7 +73,11 @@ def execute(self, arg: runtime.RuntimeValue): session = runtime.RuntimeSession(self.session_options, self.exe) try: session.execute_function( - "main", in_args=[arg], out_args=[arg], stream=self.stream + "main", + in_args=[arg], + out_args=[arg], + stream=self.stream, + client=self.client, ) print("Test passed succesfully") except runtime.MTRTException as e: diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_executable_serialize.py b/mlir-tensorrt/test/python/IntegrationTests/test_executable_serialize.py index e4bb3cba5..c841a334e 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_executable_serialize.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_executable_serialize.py @@ -47,7 +47,9 @@ def test_serialize(ASM): device=devices[0], stream=stream, ) - session0.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) + session0.execute_function( + "main", in_args=[arg0], out_args=[arg1], stream=stream, client=client + ) output0 = np.asarray(client.copy_to_host(arg1, stream=stream)) stream.sync() @@ -57,7 +59,9 @@ def test_serialize(ASM): exe_reconstructed = compiler.Executable(serialized_exe) session1 = runtime.RuntimeSession(session_options, exe_reconstructed) - session1.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) + session1.execute_function( + "main", in_args=[arg0], out_args=[arg1], stream=stream, client=client + ) output1 = np.asarray(client.copy_to_host(arg1, stream=stream)) stream.sync() diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py index 2c95a3081..4e9b9ceff 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py @@ -14,17 +14,23 @@ """ -def stablehlo_add(): +def stablehlo_add(use_non_dps=False, debug=False): # Build/parse the main function. with ir.Context() as context: m = ir.Module.parse(ASM) # Use the compiler API to compile to executable. client = compiler.CompilerClient(context) - opts = compiler.StableHLOToExecutableOptions( - client, - ["--tensorrt-builder-opt-level=3", "--tensorrt-strongly-typed=false"], - ) + c_opts = [ + "--tensorrt-builder-opt-level=3", + "--tensorrt-strongly-typed=false", + ] + if use_non_dps: + c_opts.append("--enable-non-dps-returns") + if debug: + c_opts.append("--debug=true") + c_opts.append(f"--mlir-print-ir-tree-dir=mlir-dumps-add-no-clone") + opts = compiler.StableHLOToExecutableOptions(client, c_opts) exe = compiler.compiler_stablehlo_to_executable(client, m.operation, opts) # The RuntimeClient can and should persist across multiple Executables, RuntimeSessions, etc. @@ -44,36 +50,53 @@ def stablehlo_add(): device=devices[0], stream=stream, ) - arg1 = client.create_memref( - np.zeros(shape=(2, 3, 4), dtype=np.float32).data, - device=devices[0], - stream=stream, - ) - session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) - data = np.asarray(client.copy_to_host(arg1, stream=stream)) + result = None + if use_non_dps: + results = session.execute_function( + "main", in_args=[arg0], stream=stream, client=client + ) + result = results[0] + else: + result = client.create_memref( + np.zeros(shape=(2, 3, 4), dtype=np.float32).data, + device=devices[0], + stream=stream, + ) + session.execute_function( + "main", in_args=[arg0], out_args=[result], stream=stream, client=client + ) + + data = np.asarray(client.copy_to_host(result, stream=stream)) stream.sync() print(data) - # Run execution a bunch more times asynchronously so that it calculates - # `x * 2**num_iter`. - num_iter = 5 - start_time = time.time() - for _ in range(0, num_iter): - session.execute_function("main", in_args=[arg0], out_args=[arg0], stream=stream) - data = np.asarray(client.copy_to_host(arg1, stream=stream)) - stream.sync() - end_time = time.time() - elapsed = end_time - start_time + if not use_non_dps: + # Run execution a bunch more times asynchronously so that it calculates + # `x * 2**num_iter`. + num_iter = 5 + start_time = time.time() + for _ in range(0, num_iter): + session.execute_function( + "main", in_args=[arg0], out_args=[arg0], stream=stream, client=client + ) + data = np.asarray(client.copy_to_host(arg0, stream=stream)) + stream.sync() + end_time = time.time() + elapsed = end_time - start_time - print(np.asarray(client.copy_to_host(arg0))) - print(f"1000 iterations avg { (elapsed/num_iter)/1000.0} msec per iteration") + print(np.asarray(client.copy_to_host(arg0))) + print(f"1000 iterations avg { (elapsed/num_iter)/1000.0} msec per iteration") if __name__ == "__main__": + print("DPS style execution:") stablehlo_add() + print("Non DPS style execution:") + stablehlo_add(use_non_dps=True) +# CHECK-LABEL: DPS style execution: # CHECK: [ 0. 2. 4. 6.] # CHECK-NEXT: [ 8. 10. 12. 14.] # CHECK-NEXT: [16. 18. 20. 22.]] @@ -88,3 +111,11 @@ def stablehlo_add(): # CHECK-NEXT: [384. 416. 448. 480.] # CHECK-NEXT: [512. 544. 576. 608.] # CHECK-NEXT: [640. 672. 704. 736.] +# CHECK-LABEL: DPS style execution: +# CHECK: [ 0. 2. 4. 6.] +# CHECK-NEXT: [ 8. 10. 12. 14.] +# CHECK-NEXT: [16. 18. 20. 22.]] +# CHECK-NEXT: +# CHECK-NEXT: [24. 26. 28. 30.] +# CHECK-NEXT: [32. 34. 36. 38.] +# CHECK-NEXT: [40. 42. 44. 46.]]] diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py index 35515e054..364156667 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py @@ -77,7 +77,10 @@ def infer_output_shape(client, session, exe, input_shape): outs = [client.create_memref(out_0, shape=shape, dtype=runtime.ScalarTypeCode.i64)] session.execute_function( - exe.get_signature("main").get_shape_func_name(), in_args=ins, out_args=outs + exe.get_signature("main").get_shape_func_name(), + in_args=ins, + out_args=outs, + client=client, ) # Copy output shape from device to host. Also, convert to int32 type since shape calculation uses int64 type. @@ -86,21 +89,23 @@ def infer_output_shape(client, session, exe, input_shape): return output_shape -def test_program(program: str, input_shape: Iterable[int], debug: bool = True): +def test_program( + program: str, input_shape: Iterable[int], use_non_dps=False, debug: bool = False +): # Build/parse the main function. with ir.Context() as context: m = ir.Module.parse(program) # Use the compiler API to compile to executable. client = compiler.CompilerClient(context) - opts = compiler.StableHLOToExecutableOptions( - client, - [ - "--tensorrt-builder-opt-level=3", - "--tensorrt-strongly-typed=false", - "--entrypoint=main", - ], - ) + c_opts = [ + "--tensorrt-builder-opt-level=3", + "--tensorrt-strongly-typed=false", + "--entrypoint=main", + ] + if use_non_dps: + c_opts.append("--enable-non-dps-returns") + opts = compiler.StableHLOToExecutableOptions(client, c_opts) if debug: opts.set_debug_options(False, [], "tmp") exe = compiler.compiler_stablehlo_to_executable(client, m.operation, opts) @@ -126,29 +131,66 @@ def test_program(program: str, input_shape: Iterable[int], debug: bool = True): np.ones(input_shape, dtype=np.float32).data, device=devices[0], stream=stream ) - output_shape = infer_output_shape(client, session, exe, input_shape) - - arg2 = client.create_memref( - np.zeros(output_shape, dtype=np.float32).data, - device=devices[0], - stream=stream, - ) - - session.execute_function( - "main", in_args=[arg0, arg1], out_args=[arg2], stream=stream - ) - data = np.asarray(client.copy_to_host(arg2, stream=stream)) + result = None + if use_non_dps: + results = session.execute_function( + "main", in_args=[arg0, arg1], stream=stream, client=client + ) + result = results[0] + else: + output_shape = infer_output_shape(client, session, exe, input_shape) + result = client.create_memref( + np.zeros(output_shape, dtype=np.float32).data, + device=devices[0], + stream=stream, + ) + session.execute_function( + "main", + in_args=[arg0, arg1], + out_args=[result], + stream=stream, + client=client, + ) + data = np.asarray(client.copy_to_host(result, stream=stream)) stream.sync() print(data) if __name__ == "__main__": + print("DPS style execution:") print("Test (3, ?, 2)") test_program(program1, (3, 4, 2)) print("Test (?, 2)") test_program(program2, (4, 2)) + print("Non DPS style execution:") + print("Test (3, ?, 2)") + test_program(program1, (3, 4, 2), use_non_dps=True) + print("Test (?, 2)") + test_program(program2, (4, 2), use_non_dps=True) + +# CHECK-LABEL: DPS style execution: +# CHECK-LABEL: Test (3, ?, 2) +# CHECK: [{{\[}}[2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.]] +# CHECK: {{\[}}[2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.]] +# CHECK: {{\[}}[2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.]]] + +# CHECK-LABEL: Test (?, 2) +# CHECK: {{\[}}[2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.] +# CHECK: [2. 2.]] +# CHECK-LABEL: Non DPS style execution: # CHECK-LABEL: Test (3, ?, 2) # CHECK: [{{\[}}[2. 2.] # CHECK: [2. 2.] diff --git a/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py b/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py index 68ca684e8..4d379c3e6 100644 --- a/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py +++ b/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py @@ -49,7 +49,9 @@ def stablehlo_add(): device=devices[0], stream=stream, ) - session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) + session.execute_function( + "main", in_args=[arg0], out_args=[arg1], stream=stream, client=client + ) data = np.asarray(client.copy_to_host(arg1, stream=stream)) stream.sync()