Skip to content

Commit

Permalink
Fix crash if MeshDevice is deallocated before MeshBuffer (#18181)
Browse files Browse the repository at this point in the history
### Ticket

### Problem description
Currently there is a crash if MeshDevice is deallocated or closed before
MeshBuffer
There are two semi-independent issues:
1. Lifetime issue if MeshDevice is deallocated
2. Destruction order is inconsistent in MeshDevice destructor and close
method, because `sub_device_manager_tracker_ ` may perform buffer
deallocation and this would call back to MeshDevice, so member
destruction order actually meters here.

### What's changed
Added a test to reproduce the issue
Stored MeshDevice as weak_ptr inside of MeshBuffer to be able to detect
this case
Added special handling for this case, skipping buffer deallocation call
Change reset order in MeshDevice close

### Checklist
- [x] [All post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/13496304356)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sminakov-tt authored Feb 24, 2025
1 parent aa09c9f commit 4a0562c
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 8 deletions.
38 changes: 38 additions & 0 deletions tests/tt_metal/distributed/test_mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,44 @@ TEST_F(MeshBufferTestT3000, Deallocation) {
EXPECT_FALSE(buffer_view->is_allocated());
}

TEST(MeshBufferTest, DeallocationWithoutMeshDevice) {
for (int i = 0; i < 100; i++) {
auto config =
MeshDeviceConfig{.mesh_shape = SimpleMeshShape(1, 1), .offset = std::nullopt, .physical_device_ids = {}};
auto mesh_device =
MeshDevice::create(config, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, DispatchCoreType::WORKER);

const DeviceLocalBufferConfig device_local_config{
.page_size = 2048,
.buffer_type = BufferType::DRAM,
.buffer_layout = TensorMemoryLayout::INTERLEAVED,
.bottom_up = false};
const ReplicatedBufferConfig buffer_config{.size = 2048};
auto buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device.get());

mesh_device.reset();
}
}

TEST(MeshBufferTest, DeallocationWithMeshDeviceClosed) {
for (int i = 0; i < 100; i++) {
auto config =
MeshDeviceConfig{.mesh_shape = SimpleMeshShape(1, 1), .offset = std::nullopt, .physical_device_ids = {}};
auto mesh_device =
MeshDevice::create(config, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, DispatchCoreType::WORKER);

const DeviceLocalBufferConfig device_local_config{
.page_size = 2048,
.buffer_type = BufferType::DRAM,
.buffer_layout = TensorMemoryLayout::INTERLEAVED,
.bottom_up = false};
const ReplicatedBufferConfig buffer_config{.size = 2048};
auto buffer = MeshBuffer::create(buffer_config, device_local_config, mesh_device.get());

mesh_device->close();
}
}

TEST_F(MeshBufferTestT3000, GetDeviceBuffer) {
const DeviceLocalBufferConfig device_local_config{
.page_size = 1024,
Expand Down
3 changes: 3 additions & 0 deletions tt_metal/api/tt-metalium/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ class Buffer final {

size_t unique_id() const { return unique_id_; }

// Mark the buffer as deallocated, without releasing underlying device memory
void mark_as_deallocated();

Buffer(
IDevice* device,
DeviceAddr size,
Expand Down
10 changes: 6 additions & 4 deletions tt_metal/api/tt-metalium/mesh_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class MeshBuffer {
const DeviceLocalBufferConfig& device_local_layout,
MeshDevice* mesh_device,
std::optional<DeviceAddr> address = std::nullopt);
~MeshBuffer();

// Returns true if the MeshBuffer is allocated. Note that MeshBuffer is created in the allocated state; either the
// destructor or the `deallocate` method deallocate the MeshBuffer.
Expand All @@ -85,7 +86,8 @@ class MeshBuffer {
// resources.
void deallocate();

MeshDevice* device() const { return mesh_device_; }
// Throws an exception if the corresponding MeshDevice is already deallocated
MeshDevice* device() const;
DeviceAddr size() const;
DeviceAddr device_local_size() const { return device_local_size_; }
DeviceAddr address() const { return address_; };
Expand Down Expand Up @@ -114,7 +116,7 @@ class MeshBuffer {
buffers_(SimpleMeshShape(mesh_device->shape()), nullptr),
config_(config),
device_local_config_(device_local_config),
mesh_device_(mesh_device),
mesh_device_(mesh_device->shared_from_this()),
address_(backing_buffer->address()),
device_local_size_(device_local_size),
state_(OwnedBufferState{std::move(backing_buffer)}) {}
Expand All @@ -129,15 +131,15 @@ class MeshBuffer {
buffers_(SimpleMeshShape(mesh_device->shape()), /*fill_value=*/nullptr),
config_(config),
device_local_config_(device_local_config),
mesh_device_(mesh_device),
mesh_device_(mesh_device->shared_from_this()),
address_(address),
device_local_size_(device_local_size),
state_(ExternallyOwnedState{}) {}

void initialize_device_buffers();
MeshBufferConfig config_;
DeviceLocalBufferConfig device_local_config_;
MeshDevice* mesh_device_ = nullptr;
std::weak_ptr<MeshDevice> mesh_device_;
DeviceAddr address_ = 0;
DeviceAddr device_local_size_ = 0;

Expand Down
25 changes: 23 additions & 2 deletions tt_metal/distributed/mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ std::shared_ptr<MeshBuffer> MeshBuffer::create(
void MeshBuffer::initialize_device_buffers() {
auto init_device_buffer_at_address = [this](const MeshCoordinate& coord) {
std::shared_ptr<Buffer> buffer = Buffer::create(
mesh_device_->get_device(coord),
device()->get_device(coord),
address_,
device_local_size_,
device_local_config_.page_size,
Expand All @@ -132,7 +132,28 @@ void MeshBuffer::initialize_device_buffers() {

bool MeshBuffer::is_allocated() const { return not std::holds_alternative<DeallocatedState>(state_); }

void MeshBuffer::deallocate() { state_ = DeallocatedState{}; }
MeshBuffer::~MeshBuffer() { deallocate(); }

void MeshBuffer::deallocate() {
auto mesh_device = mesh_device_.lock();
if (mesh_device) {
state_ = DeallocatedState{};
return;
}

// Special handling is required if MeshDevice is already deallocated
if (std::holds_alternative<OwnedBufferState>(state_)) {
auto& owned_state = std::get<OwnedBufferState>(state_);
owned_state.backing_buffer->mark_as_deallocated();
}
state_ = DeallocatedState{};
}

MeshDevice* MeshBuffer::device() const {
auto device = mesh_device_.lock();
TT_FATAL(device, "Can't get device from mesh buffer, already deallocated");
return device.get();
}

std::shared_ptr<Buffer> MeshBuffer::get_device_buffer(const MeshCoordinate& device_coord) const {
return buffers_.at(device_coord);
Expand Down
4 changes: 2 additions & 2 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ std::vector<std::shared_ptr<MeshDevice>> MeshDevice::create_submeshes(const Mesh
return submeshes;
}

MeshDevice::~MeshDevice() {}
MeshDevice::~MeshDevice() { close(); }

IDevice* MeshDevice::get_device(chip_id_t physical_device_id) const {
for (auto device : this->get_devices()) {
Expand Down Expand Up @@ -327,12 +327,12 @@ bool MeshDevice::close() {
submesh->close();
}
submeshes_.clear();
sub_device_manager_tracker_.reset();
if (scoped_devices_) {
scoped_devices_.reset();
}
parent_mesh_.reset();
view_.reset();
sub_device_manager_tracker_.reset();
return true;
}

Expand Down
8 changes: 8 additions & 0 deletions tt_metal/impl/buffers/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,15 @@ void Buffer::deallocate() {
});
}

void Buffer::mark_as_deallocated() {
allocation_status_.store(AllocationStatus::DEALLOCATED, std::memory_order::relaxed);
}

void Buffer::deleter(Buffer* buffer) {
if (buffer->allocation_status_.load(std::memory_order::relaxed) == AllocationStatus::DEALLOCATED) {
delete buffer;
return;
}
buffer->device_->push_work([buffer] {
std::unique_ptr<Buffer> unique_buffer = std::unique_ptr<Buffer>(buffer);
buffer->deallocate_impl();
Expand Down

0 comments on commit 4a0562c

Please sign in to comment.