diff --git a/tt_metal/impl/allocator/allocator.cpp b/tt_metal/impl/allocator/allocator.cpp index 6800114e4aef..44ddcac06789 100644 --- a/tt_metal/impl/allocator/allocator.cpp +++ b/tt_metal/impl/allocator/allocator.cpp @@ -362,9 +362,21 @@ DeviceAddr base_alloc( return bank_manager.allocate_buffer(size, page_size, bottom_up, config.compute_grid_size, num_shards); } -void disable_allocs(Allocator &allocator) { allocator.disabled_allocs = true; } - -void enable_allocs(Allocator &allocator) { allocator.disabled_allocs = false; } +void mark_allocations_unsafe(Allocator &allocator) { allocator.allocations_unsafe = true; } + +void mark_allocations_safe(Allocator &allocator) { allocator.allocations_unsafe = false; } + +void verify_safe_allocation(Allocator& allocator) { + // Inform the user that its unsafe to allocate buffers when a trace is live on device. + // If the user does this, they are meant to ensure that buffers allocated when a trace is active, + // have a lifetime that ends before the trace is executed. + // Print the warning once per device, to ensure that user output is not clobbered. + thread_local static bool warning_generated = false; + if (allocator.allocations_unsafe and not warning_generated) { + log_warning("Allocating device buffers is unsafe due to the existence of an active trace. These buffers may be corrupted once a trace is executed."); + warning_generated = true; + } +} uint64_t allocate_buffer( Allocator &allocator, @@ -374,7 +386,7 @@ uint64_t allocate_buffer( bool bottom_up, std::optional num_shards) { uint64_t address = 0; - TT_FATAL(!allocator.disabled_allocs, "Allocation of new buffers has been disabled"); + verify_safe_allocation(allocator); switch (buffer_type) { case BufferType::DRAM: return allocator.descriptor.dram.alloc( diff --git a/tt_metal/impl/allocator/allocator.hpp b/tt_metal/impl/allocator/allocator.hpp index 90a1b3ec9bad..7d27f5a4ee73 100644 --- a/tt_metal/impl/allocator/allocator.hpp +++ b/tt_metal/impl/allocator/allocator.hpp @@ -101,9 +101,9 @@ DeviceAddr base_alloc(const AllocatorConfig & config, BankManager &bank_manager, DeviceAddr allocate_buffer(Allocator &allocator, DeviceAddr size, DeviceAddr page_size, const BufferType &buffer_type, bool bottom_up, std::optional num_shards = std::nullopt); -void disable_allocs(Allocator &allocator); +void mark_allocations_unsafe(Allocator &allocator); -void enable_allocs(Allocator &allocator); +void mark_allocations_safe(Allocator &allocator); void deallocate_buffer(Allocator &allocator, DeviceAddr address, const BufferType &buffer_type); void deallocate_buffers(Allocator &allocator); @@ -114,8 +114,9 @@ void clear(Allocator &allocatator); struct Allocator { Allocator(const AllocatorConfig &alloc_config, const allocator::AllocDescriptor &alloc_descriptor); - - bool disabled_allocs = false; + // Set to true if allocating a buffer is unsafe. This happens when a live trace on device can corrupt + // memory allocated by the user (memory used by trace is not tracked in the allocator once the trace is captured). + bool allocations_unsafe = false; allocator::BankManager dram_manager; allocator::BankManager l1_manager; allocator::BankManager l1_small_manager; diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 09c5e0759d87..7046c9767595 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -2861,7 +2861,7 @@ bool Device::close() { tt_metal::detail::DumpDeviceProfileResults(this, true); this->trace_buffer_pool_.clear(); - this->EnableAllocs(); + this->MarkAllocationsSafe(); this->deallocate_buffers(); @@ -3273,7 +3273,7 @@ bool Device::using_slow_dispatch() const { void Device::begin_trace(const uint8_t cq_id, const uint32_t tid) { TT_FATAL(this->trace_buffer_pool_.count(tid) == 0, "Trace already exists for tid {} on device", tid); TT_FATAL(!this->hw_command_queues_[cq_id]->tid.has_value(), "CQ {} is already being used for tracing tid {}", (uint32_t)cq_id, tid); - this->EnableAllocs(); + this->MarkAllocationsSafe(); // Create an empty trace buffer here. This will get initialized in end_trace this->trace_buffer_pool_.insert({tid, Trace::create_empty_trace_buffer()}); this->hw_command_queues_[cq_id]->record_begin(tid, this->trace_buffer_pool_[tid]->desc); @@ -3292,7 +3292,7 @@ void Device::end_trace(const uint8_t cq_id, const uint32_t tid) { trace_data.push_back(((uint32_t*)command_sequence.data())[i]); } Trace::initialize_buffer(this->command_queue(cq_id), this->trace_buffer_pool_[tid]); - this->DisableAllocs(); + this->MarkAllocationsUnsafe(); } void Device::replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking) { @@ -3312,7 +3312,7 @@ void Device::release_trace(const uint32_t tid) { uint32_t erased = this->trace_buffer_pool_.erase(tid); // Only enable allocations once all captured traces are released if (this->trace_buffer_pool_.empty()) { - this->EnableAllocs(); + this->MarkAllocationsSafe(); } } @@ -3324,12 +3324,12 @@ std::shared_ptr Device::get_trace(const uint32_t tid) { } } -void Device::DisableAllocs() { - tt::tt_metal::allocator::disable_allocs(*(this->allocator_)); +void Device::MarkAllocationsUnsafe() { + tt::tt_metal::allocator::mark_allocations_unsafe(*(this->allocator_)); } -void Device::EnableAllocs() { - tt::tt_metal::allocator::enable_allocs(*(this->allocator_)); +void Device::MarkAllocationsSafe() { + tt::tt_metal::allocator::mark_allocations_safe(*(this->allocator_)); } void Device::generate_device_headers(const std::string &path) const diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index 3ea349d5618a..5d6296c1c88c 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -336,8 +336,8 @@ class Device { bool distributed_dispatcher() const; private: - void DisableAllocs(); - void EnableAllocs(); + void MarkAllocationsUnsafe(); + void MarkAllocationsSafe(); std::unordered_map> trace_buffer_pool_; };