Skip to content

Commit

Permalink
Fix memory leak
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Nov 9, 2024
1 parent b80b459 commit 6e647ca
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -795,13 +795,20 @@ class AllocTracker {
/// Returns true if the ptr is released internally.
bool isReleasedInternally(uintptr_t ptr) const;

/// Set the pointer is allocated by TensorRT.
void setTensorRTAllocated(uintptr_t ptr);

/// Get that pointer is allocated by TensorRT.
bool getTensorRTAllocated(uintptr_t ptr);

private:
struct Metadata {
std::atomic<int32_t> externalReferenceCount = {0};
// whether we free'd/released this buffer internally.
// if this is true then it should be truelly released and untracked
// when decrementExternalCount causes count to go to zero
bool releasedInternally{false};
bool tensorrtAllocated{false};
PointerInfo info;
};

Expand Down
3 changes: 3 additions & 0 deletions mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ MTRT_Status mtrtMemRefCreateExternal(

MTRT_Status mtrtMemRefValueDestroyAsync(MTRT_MemRefValue buffer,
MTRT_Stream stream) {

MemRefValue *memref = unwrap(buffer);
llvm::dbgs() << "[MLIR-TRT] Deallocating memref pointer " << memref->getMemory() << "\n";
Status s = memref->getClient()->deallocate(
std::unique_ptr<MemRefValue>(memref),
mtrtStreamIsNull(stream) ? std::nullopt
Expand All @@ -336,6 +338,7 @@ MTRT_Status mtrtMemRefValueDestroyAsync(MTRT_MemRefValue buffer,

MTRT_Status mtrtMemRefValueDestroy(MTRT_MemRefValue buffer) {
MemRefValue *memref = unwrap(buffer);
llvm::dbgs() << "[MLIR-TRT] Deallocating memref pointer " << memref->getMemory() << "\n";
Status s =
memref->getClient()->deallocate(std::unique_ptr<MemRefValue>(memref));
if (!s.isOk())
Expand Down
16 changes: 16 additions & 0 deletions mlir-tensorrt/executor/lib/Runtime/API/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,20 @@ AllocTracker::~AllocTracker() {
MTRT_DBGF("freed %zu bytes of unfreed memory", totalSize);
}

void AllocTracker::setTensorRTAllocated(uintptr_t ptr) {
assert(llvm::is_contained(map, ptr) &&
llvm::formatv("Untracked pointer {0}", ptr).str().c_str());
std::unique_ptr<Metadata> const &metadata = map.at(ptr);
metadata->tensorrtAllocated = true;
}

bool AllocTracker::getTensorRTAllocated(uintptr_t ptr) {
assert(llvm::is_contained(map, ptr) &&
llvm::formatv("Untracked pointer {0}", ptr).str().c_str());
std::unique_ptr<Metadata> const &metadata = map.at(ptr);
return metadata->tensorrtAllocated;
}

void AllocTracker::markReleasedInternally(uintptr_t ptr) {
assert(llvm::is_contained(map, ptr) &&
llvm::formatv("Untracked pointer {0}", ptr).str().c_str());
Expand Down Expand Up @@ -473,6 +487,7 @@ void AllocTracker::track(PointerInfo info) {
value->externalReferenceCount.store(0);
value->releasedInternally = false;
value->info = info;
value->tensorrtAllocated = false;
if (!contains(info.ptr)) {
map.insert(std::make_pair(info.ptr, std::move(value)));
return;
Expand Down Expand Up @@ -669,6 +684,7 @@ ResourceTracker::~ResourceTracker() {

void ResourceTracker::track(uintptr_t ptr, Deleter deleter) {
assert(ptr && deleter && "expected valid ptr and deleter");
MTRT_DBGF("tracking resource at 0x%lx", ptr);
tracker.insert(std::make_pair(ptr, deleter));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ static void registerLuaRuntimeMethodsCommon(
}

void mlirtrt::runtime::registerLuaRuntimeMethods(
lua_State *state, const RuntimeSessionOptions &options,
lua_State *state, const RuntimeSessionOptions &options,
PinnedMemoryAllocator *pinnedMemoryAllocator, AllocTracker *allocTracker,
ResourceTracker *resourceTracker) {
registerLuaRuntimeMethodsCommon(state, pinnedMemoryAllocator, allocTracker,
Expand Down Expand Up @@ -650,9 +650,6 @@ parseResults(const sol::protected_function_result &pfr,
if (!memref.isOk())
return memref.getStatus();

// Increment external reference count since we are returning a memref
allocator.incrementExternalCount(info.ptr);

results.push_back(std::move(*memref));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,13 @@ registerCudaMemoryManagementOps(sol::state_view &lua,
cudaMemcpyDeviceToHost,
stream),
state);
if (allocTracker->getTensorRTAllocated(
reinterpret_cast<uintptr_t>(srcPtr))) {
// Free tensorrt allocate source pointer, since there it won't be
// released by external memref.
SET_LUA_ERROR_IF_CUDART_ERROR(cudaFreeAsync(srcPtr, stream), state);
allocTracker->untrack(reinterpret_cast<uintptr_t>(srcPtr));
}
};

lua["__cuda_memcpy_host_pinned2device"] =
Expand Down Expand Up @@ -480,6 +487,13 @@ registerCudaMemoryManagementOps(sol::state_view &lua,
cudaMemcpyDeviceToHost,
stream),
state);
if (allocTracker->getTensorRTAllocated(
reinterpret_cast<uintptr_t>(srcPtr))) {
// Free tensorrt allocate source pointer, since there it won't be
// released by external memref.
SET_LUA_ERROR_IF_CUDART_ERROR(cudaFreeAsync(srcPtr, stream), state);
allocTracker->untrack(reinterpret_cast<uintptr_t>(srcPtr));
}
};
lua["__cuda_memcpy_device2device"] = [allocTracker](
sol::this_state state,
Expand All @@ -504,6 +518,13 @@ registerCudaMemoryManagementOps(sol::state_view &lua,
cudaMemcpyDeviceToDevice,
stream),
state);
if (allocTracker->getTensorRTAllocated(
reinterpret_cast<uintptr_t>(srcPtr))) {
// Free tensorrt allocate source pointer, since there it won't be
// released by external memref.
SET_LUA_ERROR_IF_CUDART_ERROR(cudaFreeAsync(srcPtr, stream), state);
allocTracker->untrack(reinterpret_cast<uintptr_t>(srcPtr));
}
return;
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class OutputAllocatorImpl : public nvinfer1::IOutputAllocator {
if (memory.isOk()) {
mOutputPtr = (*memory).ptr;
mOutputSize = memory->size;
mTracker->setTensorRTAllocated(memory->ptr);
MTRT_DBGF(
"tensorrt module output allocator allocating %lu bytes at 0x%lx",
mOutputSize, mOutputPtr);
Expand Down
3 changes: 3 additions & 0 deletions mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ static std::unique_ptr<PyMemRefValue> createMemRef(
static std::unique_ptr<PyMemRefValue>
createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule,
std::optional<bool> assertCanonicalStrides) {

llvm::dbgs() << "Creating a memref view from DL pack tensors\n";

DLManagedTensor *managedTensor = static_cast<DLManagedTensor *>(
PyCapsule_GetPointer(capsule.ptr(), "dltensor"));

Expand Down

0 comments on commit 6e647ca

Please sign in to comment.