Skip to content

Commit

Permalink
Add support for page size > max prefetch cmd size for interleaved buf…
Browse files Browse the repository at this point in the history
…fers (#17677)

#16861 

This PR adds support for interleaved buffers to have page sizes which
are greater than the max prefetch command size.

### Checklist
- [x] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml)
CI passes
(https://github.com/tenstorrent/tt-metal/actions/runs/13524331831)
- [ ] [Blackhole Post
commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml)
CI passes (if applicable)
- [ ] [Model
regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml)
CI passes (if applicable)
- [ ] [Device performance
regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml)
CI passes (if applicable)
- [ ] **(For models and ops writers)** Full [new models
tests](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
CI passes (if applicable)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sagarwalTT authored Mar 3, 2025
1 parent e3d9950 commit 07cee7c
Show file tree
Hide file tree
Showing 15 changed files with 592 additions and 191 deletions.
1 change: 1 addition & 0 deletions tests/tt_metal/tt_metal/common/command_queue_fixture.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,4 @@ class CommandQueueMultiDeviceFixture : public DispatchFixture {
};

class CommandQueueMultiDeviceProgramFixture : public CommandQueueMultiDeviceFixture {};
class CommandQueueMultiDeviceBufferFixture : public CommandQueueMultiDeviceFixture {};
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,7 @@ TEST_F(CommandQueueSingleCardBufferFixture, TestPageLargerThanAndUnalignedToTran
}
}

TEST_F(CommandQueueSingleCardBufferFixture, TestPageLargerThanMaxPrefetchCommandSize) {
constexpr uint32_t num_round_robins = 1;
TEST_F(CommandQueueSingleCardBufferFixture, TestSinglePageLargerThanMaxPrefetchCommandSize) {
for (IDevice* device : devices_) {
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id());
const uint32_t max_prefetch_command_size = DispatchMemMap::get(dispatch_core_type).max_prefetch_command_size();
Expand All @@ -529,8 +528,38 @@ TEST_F(CommandQueueSingleCardBufferFixture, TestPageLargerThanMaxPrefetchCommand
}
}

TEST_F(CommandQueueSingleCardBufferFixture, TestUnalignedPageLargerThanMaxPrefetchCommandSize) {
constexpr uint32_t num_round_robins = 1;
TEST_F(CommandQueueSingleCardBufferFixture, TestMultiplePagesLargerThanMaxPrefetchCommandSize) {
for (IDevice* device : devices_) {
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id());
const uint32_t max_prefetch_command_size = DispatchMemMap::get(dispatch_core_type).max_prefetch_command_size();
TestBufferConfig config = {
.num_pages = 1024, .page_size = max_prefetch_command_size + 2048, .buftype = BufferType::DRAM};
local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer(device, device->command_queue(), config);
}
}

TEST_F(CommandQueueSingleCardBufferFixture, TestMultiplePagesLargerThanMaxPrefetchCommandSizeSubBuffer) {
for (IDevice* device : devices_) {
tt::log_info("Running On Device {}", device->id());
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id());

const uint32_t max_prefetch_command_size = DispatchMemMap::get(dispatch_core_type).max_prefetch_command_size();
const uint32_t page_size = max_prefetch_command_size + 2048;
const uint32_t buffer_size = 40 * page_size;
const uint32_t region_size = 5 * page_size;
const uint32_t region_offset = 30 * page_size;

const BufferRegion region(region_offset, region_size);
auto buffer = Buffer::create(device, buffer_size, page_size, BufferType::DRAM);
auto src = local_test_functions::generate_arange_vector(region.size);
EnqueueWriteSubBuffer(device->command_queue(), *buffer, src, region, false);
vector<uint32_t> result;
EnqueueReadSubBuffer(device->command_queue(), *buffer, result, region, true);
EXPECT_EQ(src, result);
}
}

TEST_F(CommandQueueSingleCardBufferFixture, TestSingleUnalignedPageLargerThanMaxPrefetchCommandSize) {
for (IDevice* device : devices_) {
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id());
const uint32_t max_prefetch_command_size = DispatchMemMap::get(dispatch_core_type).max_prefetch_command_size();
Expand All @@ -540,6 +569,37 @@ TEST_F(CommandQueueSingleCardBufferFixture, TestUnalignedPageLargerThanMaxPrefet
}
}

TEST_F(CommandQueueSingleCardBufferFixture, TestMultipleUnalignedPagesLargerThanMaxPrefetchCommandSize) {
for (IDevice* device : devices_) {
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id());
const uint32_t max_prefetch_command_size = DispatchMemMap::get(dispatch_core_type).max_prefetch_command_size();
TestBufferConfig config = {
.num_pages = 1024, .page_size = max_prefetch_command_size + 4, .buftype = BufferType::DRAM};
local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer(device, device->command_queue(), config);
}
}

TEST_F(CommandQueueSingleCardBufferFixture, TestMultipleUnalignedPagesLargerThanMaxPrefetchCommandSizeSubBuffer) {
for (IDevice* device : devices_) {
tt::log_info("Running On Device {}", device->id());
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id());

const uint32_t max_prefetch_command_size = DispatchMemMap::get(dispatch_core_type).max_prefetch_command_size();
const uint32_t page_size = max_prefetch_command_size + 4;
const uint32_t buffer_size = 40 * page_size;
const uint32_t region_size = 5 * page_size;
const uint32_t region_offset = 30 * page_size;

const BufferRegion region(region_offset, region_size);
auto buffer = Buffer::create(device, buffer_size, page_size, BufferType::DRAM);
auto src = local_test_functions::generate_arange_vector(region.size);
EnqueueWriteSubBuffer(device->command_queue(), *buffer, src, region, false);
vector<uint32_t> result;
EnqueueReadSubBuffer(device->command_queue(), *buffer, result, region, true);
EXPECT_EQ(src, result);
}
}

TEST_F(CommandQueueSingleCardBufferFixture, TestNon32BAlignedPageSizeForDram) {
TestBufferConfig config = {.num_pages = 1250, .page_size = 200, .buftype = BufferType::DRAM};

Expand All @@ -557,16 +617,6 @@ TEST_F(CommandQueueSingleCardBufferFixture, TestNon32BAlignedPageSizeForDram2) {
}
}

TEST_F(CommandQueueSingleCardBufferFixture, TestPageSizeTooLarge) {
// Should throw a host error due to the page size not fitting in the consumer CB
TestBufferConfig config = {.num_pages = 1024, .page_size = 250880 * 2, .buftype = BufferType::DRAM};

for (IDevice* device : devices_) {
EXPECT_ANY_THROW((local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer(
device, device->command_queue(), config)));
}
}

// Requires enqueue write buffer
TEST_F(CommandQueueSingleCardBufferFixture, TestWrapHostHugepageOnEnqueueReadBuffer) {
for (IDevice* device : this->devices_) {
Expand Down Expand Up @@ -981,20 +1031,6 @@ TEST_F(MultiCommandQueueSingleDeviceBufferFixture, TestNon32BAlignedPageSizeForD
local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer_multi_queue(this->device_, cqs, config));
}

TEST_F(MultiCommandQueueSingleDeviceBufferFixture, TestPageSizeTooLarge) {
if (this->arch_ == tt::ARCH::WORMHOLE_B0) {
GTEST_SKIP(); // This test hanging on wormhole b0
}
// Should throw a host error due to the page size not fitting in the consumer CB
TestBufferConfig config = {.num_pages = 1024, .page_size = 250880 * 2, .buftype = BufferType::DRAM};

CommandQueue& a = this->device_->command_queue(0);
CommandQueue& b = this->device_->command_queue(1);
vector<std::reference_wrapper<CommandQueue>> cqs = {a, b};
EXPECT_ANY_THROW(
local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer_multi_queue(this->device_, cqs, config));
}

TEST_F(MultiCommandQueueSingleDeviceBufferFixture, TestIssueMultipleReadWriteCommandsForOneBuffer) {
uint32_t page_size = 2048;
uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(this->device_->id());
Expand All @@ -1010,6 +1046,21 @@ TEST_F(MultiCommandQueueSingleDeviceBufferFixture, TestIssueMultipleReadWriteCom
local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer_multi_queue(this->device_, cqs, config));
}

TEST_F(CommandQueueMultiDeviceBufferFixture, TestMultipleUnalignedPagesLargerThanMaxPrefetchCommandSize) {
for (IDevice* device : devices_) {
tt::log_info("Running On Device {}", device->id());
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id());
const uint32_t max_prefetch_command_size = DispatchMemMap::get(dispatch_core_type).max_prefetch_command_size();
TestBufferConfig config = {
.num_pages = 50, .page_size = max_prefetch_command_size + 4, .buftype = BufferType::DRAM};

CommandQueue& a = device->command_queue(0);
vector<std::reference_wrapper<CommandQueue>> cqs = {a};
EXPECT_TRUE(
local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer_multi_queue(device, cqs, config));
}
}

} // end namespace dram_tests

namespace l1_tests {
Expand Down Expand Up @@ -1112,6 +1163,36 @@ TEST_F(CommandQueueSingleCardBufferFixture, TestMultipleNonOverlappingWritesShar
}
}

TEST_F(CommandQueueSingleCardBufferFixture, TestMultiplePagesLargerThanMaxPrefetchCommandSizeForL1) {
for (IDevice* device : devices_) {
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id());
const uint32_t max_prefetch_command_size = DispatchMemMap::get(dispatch_core_type).max_prefetch_command_size();
TestBufferConfig config = {
.num_pages = 30, .page_size = max_prefetch_command_size + 2048, .buftype = BufferType::L1};
local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer(device, device->command_queue(), config);
}
}

TEST_F(CommandQueueSingleCardBufferFixture, TestSingleUnalignedPageLargerThanMaxPrefetchCommandSizeForL1) {
for (IDevice* device : devices_) {
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id());
const uint32_t max_prefetch_command_size = DispatchMemMap::get(dispatch_core_type).max_prefetch_command_size();
TestBufferConfig config = {
.num_pages = 1, .page_size = max_prefetch_command_size + 4, .buftype = BufferType::L1};
local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer(device, device->command_queue(), config);
}
}

TEST_F(CommandQueueSingleCardBufferFixture, TestMultipleUnalignedPagesLargerThanMaxPrefetchCommandSizeForL1) {
for (IDevice* device : devices_) {
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id());
const uint32_t max_prefetch_command_size = DispatchMemMap::get(dispatch_core_type).max_prefetch_command_size();
TestBufferConfig config = {
.num_pages = 30, .page_size = max_prefetch_command_size + 4, .buftype = BufferType::L1};
local_test_functions::test_EnqueueWriteBuffer_and_EnqueueReadBuffer(device, device->command_queue(), config);
}
}

TEST_F(CommandQueueSingleCardBufferFixture, TestMultipleNonOverlappingReadsShardedSubBufferForL1) {
const uint32_t page_size = 64;
const uint32_t buffer_size = 16 * page_size;
Expand Down
1 change: 1 addition & 0 deletions tt_metal/api/tt-metalium/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ class Buffer final {
uint32_t num_dev_pages() const;

BufferType buffer_type() const { return buffer_type_; }
HalMemType memory_type() const;
CoreType core_type() const;

bool is_l1() const;
Expand Down
2 changes: 2 additions & 0 deletions tt_metal/api/tt-metalium/cq_commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ struct CQDispatchWriteHostCmd {
uint32_t length;
} __attribute__((packed));

constexpr uint16_t CQ_DISPATCH_CMD_PAGED_WRITE_MAX_PAGE_INDEX = 0xFFFF;

struct CQDispatchWritePagedCmd {
uint8_t is_dram; // one flag, false=l1
uint16_t start_page;
Expand Down
7 changes: 4 additions & 3 deletions tt_metal/api/tt-metalium/dispatch_settings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ class DispatchSettings {

static constexpr uint32_t EVENT_PADDED_SIZE = 16;

// When page size of buffer to write/read exceeds MAX_PREFETCH_COMMAND_SIZE, the PCIe aligned page size is broken
// down into equal sized partial pages BASE_PARTIAL_PAGE_SIZE denotes the initial partial page size to use, it is
// incremented by PCIe alignment until page size can be evenly split
// When page size of buffer to write/read exceeds the max prefetch command size, the PCIe-aligned page size is
// broken down into equal sized partial pages. BASE_PARTIAL_PAGE_SIZE is incremented until the partial page size
// is PCIE-aligned. If the resulting partial page size doesn't evenly divide the full page size, the last partial
// page size is padded appropriately.
static constexpr uint32_t BASE_PARTIAL_PAGE_SIZE = 4096;

static_assert(
Expand Down
9 changes: 9 additions & 0 deletions tt_metal/api/tt-metalium/hal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class Hal {
std::vector<DeviceAddr> dram_bases_;
std::vector<uint32_t> dram_sizes_;
std::vector<uint32_t> mem_alignments_;
std::vector<uint32_t> mem_alignments_with_pcie_;
uint32_t num_nocs_;
uint32_t noc_addr_node_id_bits_;
uint32_t noc_coord_reg_offset_;
Expand Down Expand Up @@ -249,6 +250,8 @@ class Hal {
uint32_t get_dev_size(HalDramMemAddrType addr_type) const;

uint32_t get_alignment(HalMemType memory_type) const;
// Returns an alignment that is aligned with PCIE and the given memory type
uint32_t get_common_alignment_with_pcie(HalMemType memory_type) const;

bool get_supports_cbs(uint32_t programmable_core_type_index) const;

Expand Down Expand Up @@ -346,6 +349,12 @@ inline uint32_t Hal::get_alignment(HalMemType memory_type) const {
return this->mem_alignments_[index];
}

inline uint32_t Hal::get_common_alignment_with_pcie(HalMemType memory_type) const {
uint32_t index = utils::underlying_type<HalMemType>(memory_type);
TT_ASSERT(index < this->mem_alignments_.size());
return this->mem_alignments_with_pcie_[index];
}

inline bool Hal::get_supports_cbs(uint32_t programmable_core_type_index) const {
return this->core_info_[programmable_core_type_index].supports_cbs_;
}
Expand Down
14 changes: 10 additions & 4 deletions tt_metal/distributed/mesh_command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,17 @@ void MeshCommandQueue::read_shard_from_device(
}
}
} else {
auto dispatch_params = buffer_dispatch::initialize_interleaved_buf_read_dispatch_params(
*shard_view, id_, expected_num_workers_completed_, region);
buffer_dispatch::BufferReadDispatchParamsVariant dispatch_params_variant =
buffer_dispatch::initialize_interleaved_buf_read_dispatch_params(
*shard_view, id_, expected_num_workers_completed_, region);

buffer_dispatch::BufferReadDispatchParams* dispatch_params = std::visit(
[](auto& val) { return static_cast<buffer_dispatch::BufferReadDispatchParams*>(&val); },
dispatch_params_variant);

buffer_dispatch::copy_interleaved_buffer_to_completion_queue(
dispatch_params, *shard_view, sub_device_ids, this->dispatch_core_type());
if (dispatch_params.pages_per_txn > 0) {
*dispatch_params, *shard_view, sub_device_ids, this->dispatch_core_type());
if (dispatch_params->pages_per_txn > 0) {
auto read_descriptor = std::get<tt::tt_metal::ReadBufferDescriptor>(
*buffer_dispatch::generate_interleaved_buffer_read_descriptor(dst, dispatch_params, *shard_view));
buffer_dispatch::copy_completion_queue_data_into_user_space(
Expand Down
11 changes: 11 additions & 0 deletions tt_metal/impl/buffers/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <mutex>
#include <utility>
#include <buffer_constants.hpp>
#include "hal.hpp"
#include "umd/device/tt_soc_descriptor.h"
#include "fmt/base.h"
#include <reflection.hpp>
Expand Down Expand Up @@ -489,6 +490,16 @@ uint32_t Buffer::num_dev_pages() const {
return this->shard_spec().num_pages() * this->num_cores().value();
}

HalMemType Buffer::memory_type() const {
if (this->is_dram()) {
return HalMemType::DRAM;
} else if (this->is_l1()) {
return HalMemType::L1;
} else {
TT_THROW("Unknown HAL memory type for {} buffer type", this->buffer_type());
}
}

CoreType Buffer::core_type() const {
switch (this->buffer_type_) {
case BufferType::DRAM:
Expand Down
Loading

0 comments on commit 07cee7c

Please sign in to comment.