Skip to content

Commit

Permalink
#13690: Allow buffers to be allocated when a trace is live on device
Browse files Browse the repository at this point in the history
  - Instead of asserting out, print a warning informing the user that
    this is can lead to data corruptions
  - Allows interleaving of traced and untraced workloads (this is safe
    as long as untraced outputs/intermediates are fully consumed before
    a trace is run
  • Loading branch information
tt-asaigal committed Oct 10, 2024
1 parent d17cc67 commit c83a7c2
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 18 deletions.
20 changes: 16 additions & 4 deletions tt_metal/impl/allocator/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,21 @@ DeviceAddr base_alloc(
return bank_manager.allocate_buffer(size, page_size, bottom_up, config.compute_grid_size, num_shards);
}

void disable_allocs(Allocator &allocator) { allocator.disabled_allocs = true; }

void enable_allocs(Allocator &allocator) { allocator.disabled_allocs = false; }
void mark_allocations_unsafe(Allocator &allocator) { allocator.allocations_unsafe = true; }

void mark_allocations_safe(Allocator &allocator) { allocator.allocations_unsafe = false; }

void verify_safe_allocation(Allocator& allocator) {
// Inform the user that its unsafe to allocate buffers when a trace is live on device.
// If the user does this, they are meant to ensure that buffers allocated when a trace is active,
// have a lifetime that ends before the trace is executed.
// Print the warning once per device, to ensure that user output is not clobbered.
thread_local static bool warning_generated = false;
if (allocator.allocations_unsafe and not warning_generated) {
log_warning("Allocating device buffers is unsafe due to the existence of an active trace. These buffers may be corrupted once a trace is executed.");
warning_generated = true;
}
}

uint64_t allocate_buffer(
Allocator &allocator,
Expand All @@ -374,7 +386,7 @@ uint64_t allocate_buffer(
bool bottom_up,
std::optional<uint32_t> num_shards) {
uint64_t address = 0;
TT_FATAL(!allocator.disabled_allocs, "Allocation of new buffers has been disabled");
verify_safe_allocation(allocator);
switch (buffer_type) {
case BufferType::DRAM:
return allocator.descriptor.dram.alloc(
Expand Down
9 changes: 5 additions & 4 deletions tt_metal/impl/allocator/allocator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ DeviceAddr base_alloc(const AllocatorConfig & config, BankManager &bank_manager,

DeviceAddr allocate_buffer(Allocator &allocator, DeviceAddr size, DeviceAddr page_size, const BufferType &buffer_type, bool bottom_up, std::optional<uint32_t> num_shards = std::nullopt);

void disable_allocs(Allocator &allocator);
void mark_allocations_unsafe(Allocator &allocator);

void enable_allocs(Allocator &allocator);
void mark_allocations_safe(Allocator &allocator);

void deallocate_buffer(Allocator &allocator, DeviceAddr address, const BufferType &buffer_type);
void deallocate_buffers(Allocator &allocator);
Expand All @@ -114,8 +114,9 @@ void clear(Allocator &allocatator);

struct Allocator {
Allocator(const AllocatorConfig &alloc_config, const allocator::AllocDescriptor &alloc_descriptor);

bool disabled_allocs = false;
// Set to true if allocating a buffer is unsafe. This happens when a live trace on device can corrupt
// memory allocated by the user (memory used by trace is not tracked in the allocator once the trace is captured).
bool allocations_unsafe = false;
allocator::BankManager dram_manager;
allocator::BankManager l1_manager;
allocator::BankManager l1_small_manager;
Expand Down
16 changes: 8 additions & 8 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2861,7 +2861,7 @@ bool Device::close() {
tt_metal::detail::DumpDeviceProfileResults(this, true);

this->trace_buffer_pool_.clear();
this->EnableAllocs();
this->MarkAllocationsSafe();

this->deallocate_buffers();

Expand Down Expand Up @@ -3273,7 +3273,7 @@ bool Device::using_slow_dispatch() const {
void Device::begin_trace(const uint8_t cq_id, const uint32_t tid) {
TT_FATAL(this->trace_buffer_pool_.count(tid) == 0, "Trace already exists for tid {} on device", tid);
TT_FATAL(!this->hw_command_queues_[cq_id]->tid.has_value(), "CQ {} is already being used for tracing tid {}", (uint32_t)cq_id, tid);
this->EnableAllocs();
this->MarkAllocationsSafe();
// Create an empty trace buffer here. This will get initialized in end_trace
this->trace_buffer_pool_.insert({tid, Trace::create_empty_trace_buffer()});
this->hw_command_queues_[cq_id]->record_begin(tid, this->trace_buffer_pool_[tid]->desc);
Expand All @@ -3292,7 +3292,7 @@ void Device::end_trace(const uint8_t cq_id, const uint32_t tid) {
trace_data.push_back(((uint32_t*)command_sequence.data())[i]);
}
Trace::initialize_buffer(this->command_queue(cq_id), this->trace_buffer_pool_[tid]);
this->DisableAllocs();
this->MarkAllocationsUnsafe();
}

void Device::replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking) {
Expand All @@ -3312,7 +3312,7 @@ void Device::release_trace(const uint32_t tid) {
uint32_t erased = this->trace_buffer_pool_.erase(tid);
// Only enable allocations once all captured traces are released
if (this->trace_buffer_pool_.empty()) {
this->EnableAllocs();
this->MarkAllocationsSafe();
}
}

Expand All @@ -3324,12 +3324,12 @@ std::shared_ptr<TraceBuffer> Device::get_trace(const uint32_t tid) {
}
}

void Device::DisableAllocs() {
tt::tt_metal::allocator::disable_allocs(*(this->allocator_));
void Device::MarkAllocationsUnsafe() {
tt::tt_metal::allocator::mark_allocations_unsafe(*(this->allocator_));
}

void Device::EnableAllocs() {
tt::tt_metal::allocator::enable_allocs(*(this->allocator_));
void Device::MarkAllocationsSafe() {
tt::tt_metal::allocator::mark_allocations_safe(*(this->allocator_));
}

void Device::generate_device_headers(const std::string &path) const
Expand Down
4 changes: 2 additions & 2 deletions tt_metal/impl/device/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ class Device {
bool distributed_dispatcher() const;

private:
void DisableAllocs();
void EnableAllocs();
void MarkAllocationsUnsafe();
void MarkAllocationsSafe();
std::unordered_map<uint32_t, std::shared_ptr<TraceBuffer>> trace_buffer_pool_;
};

Expand Down

0 comments on commit c83a7c2

Please sign in to comment.