Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir-tensorrt] Add support for non-DPS calling convention #258

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions(
disallowHostTensorsInTensorRTClusters, llvm::cl::init(false),
llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor "
"calculations (but they can still be inputs)"));
addOption(
"enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false),
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"));
addOption("executor-index-bitwidth", executorIndexBitwidth,
llvm::cl::init(64));
addOption("device-compute-capability", deviceComputeCapability,
Expand Down Expand Up @@ -306,6 +310,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline(
plan::StablehloClusteringPassOptions clusteringOpts{};
clusteringOpts.disallowHostTensorsInTensorRTClusters =
opts.disallowHostTensorsInTensorRTClusters;
clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
clusteringOpts.entrypoint = opts.entrypoint;
plan::buildPlanSegmentationPipeline(pm, clusteringOpts);

Expand Down Expand Up @@ -339,7 +344,9 @@ void StableHloToExecutableTask::buildPostClusteringPipeline(

// Perform bufferization.
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(plan::createPlanAllocTensorsPass());
plan::PlanAllocTensorsPassOptions allocTensorsOpts{};
allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts));
pm.addPass(plan::createPlanBufferizePass());
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -525,6 +532,11 @@ struct ClusteringPipelineCliOpts
*this, "device-compute-capability",
llvm::cl::desc("target device compute capability (SM version)"),
llvm::cl::init(60)};
Option<bool> enableNonDPSReturns{
*this, "enable-non-dps-returns",
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"),
llvm::cl::init(false)};
Option<int64_t> deviceMaxSharedMemoryPerBlockKb{
*this, "device-max-smem-per-block",
llvm::cl::desc("max shared memory per block (in kilobytes)"),
Expand Down Expand Up @@ -552,6 +564,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts(
opts.deviceComputeCapability = cliOpts.deviceComputeCapability;
opts.deviceMaxSharedMemoryPerBlockKb =
cliOpts.deviceMaxSharedMemoryPerBlockKb;
opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns;
opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost;
opts.entrypoint = cliOpts.entrypoint;
return opts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ convertCallOp(Operation *op, IRRewriter &rewriter,
SmallVector<int64_t> hostTensorArgs;
for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) {
const TensorKindLattice *kind = solver.lookupState<TensorKindLattice>(arg);
if (!isa<RankedTensorType>(arg.getType()))
continue;
RankedTensorType rtt = cast<RankedTensorType>(arg.getType());
// To be conservative, we only do this if type is i32 and num elements
// <= 8.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern<plan::WithValuesOp> {
} // namespace

/// Get a map from `tensorrt.func` functions to associated `tensorrt.call`
/// operations.
static llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>>
/// and `tensorrt.call_alloc` operations.
static llvm::DenseMap<func::FuncOp, SmallVector<Operation *>>
getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>> map;
op->walk([&](tensorrt::CallOp callOp) {
func::FuncOp func = callOp.getFuncCallee(collection);
if (map.contains(func)) {
map[func].push_back(callOp);
llvm::DenseMap<func::FuncOp, SmallVector<Operation *>> map;
op->walk([&](Operation *callOp) {
if (!isa<tensorrt::CallOp, tensorrt::CallAllocOp>(callOp))
return;
}
map.insert(std::make_pair(func, SmallVector<tensorrt::CallOp>{callOp}));

func::FuncOp func;
if (auto call = dyn_cast<tensorrt::CallOp>(callOp))
func = call.getFuncCallee(collection);
else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp))
func = callAlloc.getFuncCallee(collection);
else
return;

if (map.count(func))
map[func].push_back(callOp);
else
map.insert({func, SmallVector<Operation *>{callOp}});
});
return map;
}
Expand All @@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
/// `tensorrt.call` operations.
static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
ModuleOp op, func::FuncOp funcOp,
ArrayRef<tensorrt::CallOp> callOps) {
ArrayRef<Operation *> callOps) {
llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0);
for (BlockArgument arg : funcOp.getArguments()) {
if (arg.use_empty())
Expand All @@ -99,8 +108,16 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
funcOp.eraseArgument(i);

// Update the call ops.
for (tensorrt::CallOp callOp : callOps)
callOp.getInputsMutable().erase(i);
for (Operation *callOp : callOps) {
if (auto call = dyn_cast<tensorrt::CallOp>(callOp)) {
call.getInputsMutable().erase(i);
} else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp)) {
callAlloc.getInputsMutable().erase(i);
} else {
llvm::errs() << "Unexpected operation type in callOps\n";
callOp->dump();
}
}
}

return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,82 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,

static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
plan::InlineClosedAllocGroupOp op) {
return op.emitError("outlinining inline closed alloc group ops to tensorrt "
"dialect is not yet implemented");
tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op);
auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs()));
FailureOr<FunctionOpInterface> func = createOutlinedFunc(
rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster",
"cluster.tensorrt", TypeRange(op.getInputs()),
op.getYield()->getOperandTypes());
if (failed(func))
return failure();
assert(func->getFunctionBody().getBlocks().size() == 1 &&
"expected body with one block");
func->setPublic();

rewriter.setInsertionPoint(op);

auto callOp = rewriter.create<tensorrt::CallAllocOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
SymbolRefAttr::get(trtModuleOp.getNameAttr(),
{FlatSymbolRefAttr::get(*func)}));

// Populate the function arguments attributes.
for (unsigned i = 0; i < (*func).getNumArguments(); i++) {
BoundsAttr srcAttr = cast<BoundsAttr>(op.getInputAttrs()[i]);
// We may have scalar (index|signless int)-typed values since we haven't
// eliminated `plan.(with_shape|with_values)` ops yet.
if (!op.argHasTensorType(i) || srcAttr.isNone())
continue;
FailureOr<tensorrt::ShapeProfileAttr> boundAttr =
getTensorRTShapeProfile(srcAttr, op.getInputs()[i]);
if (failed(boundAttr))
return op->emitOpError("failed to create TensorRT shape profile "
"attribute from Plan BoundsAttr for argument #")
<< i << " (" << srcAttr << ")";
if (srcAttr.isShapeBound()) {
func->setArgAttr(i,
tensorrt::TensorRTDialect::getShapeProfileArgAttrName(),
*boundAttr);
continue;
}
assert(srcAttr.isValueBound() && "expected value bound or shape bound");
func->setArgAttr(
i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(),
*boundAttr);
func->setArgAttr(i, mlir::getHostTensorArgAttrName(),
rewriter.getUnitAttr());
}

// Populate the function entry block.
rewriter.eraseBlock(&func->getFunctionBody().front());

// Move private decomposition funcs associated with all `stablehlo.composite`
// ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op
// has its own symbol table.
SymbolTableCollection symbolTable;
for (auto compositeOp : op.getBody().getOps<stablehlo::CompositeOp>()) {
auto decompositionFunc = dyn_cast_if_present<func::FuncOp>(
symbolTable.lookupSymbolIn(op->getParentOfType<ModuleOp>(),
compositeOp.getDecompositionAttr()));
if (!decompositionFunc)
return emitError(compositeOp.getLoc())
<< "failed to lookup stablehlo.composite decomposition "
"function: "
<< compositeOp.getDecompositionAttr();
rewriter.moveOpAfter(decompositionFunc, func->getOperation());
}

// Move region op operations to the func body.
Operation *regionYieldOp = op.getYield();
rewriter.inlineRegionBefore(op.getRegion(), func->getFunctionBody(),
func->getFunctionBody().end());
rewriter.setInsertionPoint(regionYieldOp);
rewriter.replaceOpWithNewOp<func::ReturnOp>(regionYieldOp,
regionYieldOp->getOperands());

// replace the original region results.
rewriter.replaceOp(op, callOp);
return success();
}

/// Create outlined functions for each `scf.execute_region` operation within
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ static inline bool mtrtRuntimeClientIsNull(MTRT_RuntimeClient client) {
return !client.ptr;
}

/// Returns null client.
static inline MTRT_RuntimeClient mtrtRuntimeClientGetNull() {
return MTRT_RuntimeClient{nullptr};
}

/// Creates a `MTRT_RuntimeClient`. Client must be alive for the lifetime of the
/// program execution.
/// The `stream` passed to the client is used by all underlying CUDA methods
Expand Down Expand Up @@ -308,6 +313,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) {
return !value.ptr;
}

// Returns whether the RuntimeValue is MemRef.
jhalakpatel marked this conversation as resolved.
Show resolved Hide resolved
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value);

// Returns whether the RuntimeValue is Scalar.
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value);

/// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue.
MLIR_CAPI_EXPORTED MTRT_RuntimeValue
mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref);
Expand Down Expand Up @@ -391,16 +402,25 @@ static inline bool mtrtRuntimeSessionIsNull(MTRT_RuntimeSession session) {
return !session.ptr;
}

/// Using `session`, execute the pubic function with the specified name.
/// The `inArgs` and `outArgs` are arrays for input arguments and destination
/// arguments, respectively. Input arguments may be MemRefs or scalars, but
/// destination arguments must be MemRefs.
/// Using `session`, execute the public function with the specified name.
/// The `inArgs`, `outArgs`, and `results` are arrays for input arguments,
/// output arguments, and return values, respectively. Arguments and results
/// can be MemRefs, scalars, or other supported types. Both `outArgs` and
/// `results` can be used simultaneously, allowing for functions that both
/// modify arguments and return values.
/// A stream may optionally be specified, otherwise pass the result of
/// `mtrtStreamGetNull()`.
///
/// The `results` array must point to an array with at least the number of
/// elements returned by mtrtRuntimeSessionGetNumResults for the given function.
MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction(
MTRT_RuntimeSession session, MTRT_StringView name,
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream);
const MTRT_RuntimeValue *outArgs, size_t numOutArgs,
MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to expose an API entirely for the function signature.

MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNumResults(
MTRT_RuntimeSession session, MTRT_StringView name, int64_t *numResults);

//===----------------------------------------------------------------------===//
// DLPack
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,6 @@ executeFunctionWithLuaBackend(LuaRuntimeSession &session, std::string_view name,
std::optional<CudaStream> stream = {},
std::optional<RuntimeClient *> client = {});

// Parses the results of a function call, handling both scalar and MemRef return
// types
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>>
parseResults(const sol::protected_function_result &pfr,
const FunctionSignatureView &sig,
std::optional<RuntimeClient *> client);

} // namespace mlirtrt::runtime

#endif // MLIR_TENSORRT_RUNTIME_BACKEND_LUA_LUARUNTIME_H
40 changes: 34 additions & 6 deletions mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) {
return wrap(static_cast<ScalarValue *>(x));
}

bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) {
RuntimeValue *x = unwrap(value);
return x->getKind() == RuntimeValue::Kind::MemRef;
}

bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) {
RuntimeValue *x = unwrap(value);
return x->getKind() == RuntimeValue::Kind::Scalar;
}

//===----------------------------------------------------------------------===//
// MTRT_RuntimeSessionOptions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -721,7 +731,8 @@ MTRT_Status mtrtRuntimeSessionDestroy(MTRT_RuntimeSession session) {
MTRT_Status mtrtRuntimeSessionExecuteFunction(
MTRT_RuntimeSession session, MTRT_StringView name,
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream) {
const MTRT_RuntimeValue *outArgs, size_t numOutArgs,
MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client) {
LuaRuntimeSession *cppSession =
static_cast<LuaRuntimeSession *>(unwrap(session));

Expand All @@ -731,19 +742,36 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction(
llvm::SmallVector<RuntimeValue *> outArgValues =
llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs),
[](MTRT_RuntimeValue arg) { return unwrap(arg); });

StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> result =
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> resultValues =
executeFunctionWithLuaBackend(
*cppSession, std::string_view(name.data, name.length), inArgValues,
outArgValues,
!mtrtStreamIsNull(stream)
? std::optional(unwrap(stream)->getRawStream())
: std::nullopt);
if (!result.isOk())
return wrap(result.getStatus());
: std::nullopt,
!mtrtRuntimeClientIsNull(client) ? std::optional(unwrap(client))
: std::nullopt);
if (!resultValues.isOk())
return wrap(resultValues.getStatus());

for (size_t i = 0; i < resultValues->size(); ++i)
results[i] = wrap((*resultValues)[i].release());

return mtrtStatusGetOk();
}

MTRT_Status mtrtRuntimeSessionGetNumResults(MTRT_RuntimeSession session,
MTRT_StringView name,
int64_t *numResults) {
LuaRuntimeSession *cppSession =
static_cast<LuaRuntimeSession *>(unwrap(session));
*numResults = cppSession->getExecutable()
.getFunction(std::string_view(name.data, name.length))
.getSignature()
.getNumResults();
return mtrtStatusGetOk();
}

//===----------------------------------------------------------------------===//
// MTRT_RuntimeClient
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 20 additions & 0 deletions mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir-executor/Conversion/ConvertToExecutorCommon.h"
#include "mlir-executor/Conversion/Passes.h"
#include "mlir-executor/Executor/IR/Executor.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -548,6 +549,21 @@ void executor::populateMemRefToExecutorPatterns(
}

namespace {

class RemoveNoOpClonePattern : public OpRewritePattern<bufferization::CloneOp> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is illegal. Clone indicates creating a copy, which you are not doing here. This indicates something else is missing in the pipeline before this point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@christopherbate Let me know if you any suggestion on how to fix this.

public:
using OpRewritePattern<bufferization::CloneOp>::OpRewritePattern;

LogicalResult matchAndRewrite(bufferization::CloneOp op,
PatternRewriter &rewriter) const override {
if (op.getInput().getType() == op.getOutput().getType()) {
rewriter.replaceOp(op, op.getInput());
return success();
}
return failure();
}
};

/// Pass to convert `memref` to `executor` dialect operrations.
class ConvertMemRefToExecutorPass
: public mlir::executor::impl::ConvertMemRefToExecutorPassBase<
Expand Down Expand Up @@ -579,6 +595,10 @@ class ConvertMemRefToExecutorPass
RewritePatternSet patterns(ctx);
executor::populateMemRefToExecutorPatterns(
patterns, typeConverter, allowUncheckedMemrefCastConversion);

// Remove unrealized cast and redundant clone operations.
patterns.add<RemoveNoOpClonePattern>(ctx);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
Expand Down
Loading