Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#13690: Allow buffers to be allocated when a trace is live on device #13696

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,13 @@ TEST_F(SingleDeviceTraceFixture, EnqueueProgramTraceCapture) {
uint32_t tid = BeginTraceCapture(this->device_, command_queue.id());
EnqueueProgram(command_queue, simple_program, false);
EndTraceCapture(this->device_, command_queue.id(), tid);

// Create and Enqueue a Program with a live trace to ensure that a warning is generated
Buffer input_temp(this->device_, 2048, 2048, BufferType::DRAM);
Buffer output_temp(this->device_, 2048, 2048, BufferType::DRAM);
Program simple_program_temp = create_simple_unary_program(input_temp, output_temp);
EnqueueProgram(command_queue, simple_program_temp, true);
// Run trace that can clobber the temporary buffers created above
EnqueueProgram(command_queue, simple_program, false);
EnqueueTrace(command_queue, tid, true);
EnqueueReadBuffer(command_queue, output, trace_output_data.data(), true);
EXPECT_TRUE(eager_output_data == trace_output_data);
Expand Down
20 changes: 16 additions & 4 deletions tt_metal/impl/allocator/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,21 @@ DeviceAddr base_alloc(
return bank_manager.allocate_buffer(size, page_size, bottom_up, config.compute_grid_size, num_shards);
}

void disable_allocs(Allocator &allocator) { allocator.disabled_allocs = true; }

void enable_allocs(Allocator &allocator) { allocator.disabled_allocs = false; }
void mark_allocations_unsafe(Allocator &allocator) { allocator.allocations_unsafe = true; }

void mark_allocations_safe(Allocator &allocator) { allocator.allocations_unsafe = false; }

void verify_safe_allocation(Allocator& allocator) {
// Inform the user that its unsafe to allocate buffers when a trace is live on device.
// If the user does this, they are meant to ensure that buffers allocated when a trace is active,
// have a lifetime that ends before the trace is executed.
// Print the warning once per device, to ensure that user output is not clobbered.
thread_local static bool warning_generated = false;
if (allocator.allocations_unsafe and not warning_generated) {
log_warning("Allocating device buffers is unsafe due to the existence of an active trace. These buffers may be corrupted once a trace is executed.");
warning_generated = true;
}
}

uint64_t allocate_buffer(
Allocator &allocator,
Expand All @@ -374,7 +386,7 @@ uint64_t allocate_buffer(
bool bottom_up,
std::optional<uint32_t> num_shards) {
uint64_t address = 0;
TT_FATAL(!allocator.disabled_allocs, "Allocation of new buffers has been disabled");
verify_safe_allocation(allocator);
switch (buffer_type) {
case BufferType::DRAM:
return allocator.descriptor.dram.alloc(
Expand Down
9 changes: 5 additions & 4 deletions tt_metal/impl/allocator/allocator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ DeviceAddr base_alloc(const AllocatorConfig & config, BankManager &bank_manager,

DeviceAddr allocate_buffer(Allocator &allocator, DeviceAddr size, DeviceAddr page_size, const BufferType &buffer_type, bool bottom_up, std::optional<uint32_t> num_shards = std::nullopt);

void disable_allocs(Allocator &allocator);
void mark_allocations_unsafe(Allocator &allocator);

void enable_allocs(Allocator &allocator);
void mark_allocations_safe(Allocator &allocator);

void deallocate_buffer(Allocator &allocator, DeviceAddr address, const BufferType &buffer_type);
void deallocate_buffers(Allocator &allocator);
Expand All @@ -114,8 +114,9 @@ void clear(Allocator &allocatator);

struct Allocator {
Allocator(const AllocatorConfig &alloc_config, const allocator::AllocDescriptor &alloc_descriptor);

bool disabled_allocs = false;
// Set to true if allocating a buffer is unsafe. This happens when a live trace on device can corrupt
// memory allocated by the user (memory used by trace is not tracked in the allocator once the trace is captured).
bool allocations_unsafe = false;
allocator::BankManager dram_manager;
allocator::BankManager l1_manager;
allocator::BankManager l1_small_manager;
Expand Down
16 changes: 8 additions & 8 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2864,7 +2864,7 @@ bool Device::close() {
tt_metal::detail::DumpDeviceProfileResults(this, true);

this->trace_buffer_pool_.clear();
this->EnableAllocs();
this->MarkAllocationsSafe();

this->deallocate_buffers();

Expand Down Expand Up @@ -3276,7 +3276,7 @@ bool Device::using_slow_dispatch() const {
void Device::begin_trace(const uint8_t cq_id, const uint32_t tid) {
TT_FATAL(this->trace_buffer_pool_.count(tid) == 0, "Trace already exists for tid {} on device", tid);
TT_FATAL(!this->hw_command_queues_[cq_id]->tid.has_value(), "CQ {} is already being used for tracing tid {}", (uint32_t)cq_id, tid);
this->EnableAllocs();
this->MarkAllocationsSafe();
// Create an empty trace buffer here. This will get initialized in end_trace
this->trace_buffer_pool_.insert({tid, Trace::create_empty_trace_buffer()});
this->hw_command_queues_[cq_id]->record_begin(tid, this->trace_buffer_pool_[tid]->desc);
Expand All @@ -3295,7 +3295,7 @@ void Device::end_trace(const uint8_t cq_id, const uint32_t tid) {
trace_data.push_back(((uint32_t*)command_sequence.data())[i]);
}
Trace::initialize_buffer(this->command_queue(cq_id), this->trace_buffer_pool_[tid]);
this->DisableAllocs();
this->MarkAllocationsUnsafe();
}

void Device::replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking) {
Expand All @@ -3315,7 +3315,7 @@ void Device::release_trace(const uint32_t tid) {
uint32_t erased = this->trace_buffer_pool_.erase(tid);
// Only enable allocations once all captured traces are released
if (this->trace_buffer_pool_.empty()) {
this->EnableAllocs();
this->MarkAllocationsSafe();
}
}

Expand All @@ -3327,12 +3327,12 @@ std::shared_ptr<TraceBuffer> Device::get_trace(const uint32_t tid) {
}
}

void Device::DisableAllocs() {
tt::tt_metal::allocator::disable_allocs(*(this->allocator_));
void Device::MarkAllocationsUnsafe() {
tt::tt_metal::allocator::mark_allocations_unsafe(*(this->allocator_));
}

void Device::EnableAllocs() {
tt::tt_metal::allocator::enable_allocs(*(this->allocator_));
void Device::MarkAllocationsSafe() {
tt::tt_metal::allocator::mark_allocations_safe(*(this->allocator_));
}

void Device::generate_device_headers(const std::string &path) const
Expand Down
4 changes: 2 additions & 2 deletions tt_metal/impl/device/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ class Device {
bool distributed_dispatcher() const;

private:
void DisableAllocs();
void EnableAllocs();
void MarkAllocationsUnsafe();
void MarkAllocationsSafe();
std::unordered_map<uint32_t, std::shared_ptr<TraceBuffer>> trace_buffer_pool_;
};

Expand Down