Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ttnn.reallocate when unaligned RM tensors are used #16192

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions tests/ttnn/unit_tests/operations/test_reallocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
from models.utility_functions import is_wormhole_b0, is_blackhole


@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="#7733: fix for sharding on whb0")
@pytest.mark.parametrize(
"mem_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG, ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG]
)
@pytest.mark.parametrize("mem_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG])
@pytest.mark.parametrize("num_allocs", [1, 2, 3, 4])
def test_ttnn_reallocate(device, mem_config, num_allocs):
def test_reallocate_interleaved(device, mem_config, num_allocs):
width = 1024
height = 128
depth = 2
Expand All @@ -35,17 +32,6 @@ def test_ttnn_reallocate(device, mem_config, num_allocs):
}
)

# If sharded, creat actual memory config
if mem_config == ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG:
shard_spec = ttnn.ShardSpec(
shard_grid, [batch * height * depth // 8, width], ttnn.ShardOrientation.ROW_MAJOR, False
)
mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
shard_spec,
)

torch_tensors = []
tensors = []
for i in range(num_allocs):
Expand All @@ -72,3 +58,38 @@ def test_ttnn_reallocate(device, mem_config, num_allocs):
assert new_address >= initial_address

assert_with_pcc(torch_tensors[-1], ttnn.to_torch(tensors[-1]), 0.9999)


@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("strategy", [ttnn.ShardStrategy.BLOCK, ttnn.ShardStrategy.HEIGHT])
@pytest.mark.parametrize(
"input_shape, core_grid",
(
([1, 1, 32, 32], ttnn.CoreGrid(x=1, y=1)),
([1, 1, 256, 256], ttnn.CoreGrid(x=2, y=2)),
([1, 1, 4, 34], ttnn.CoreGrid(x=1, y=1)), # Checks unaligned RM shard
([2, 2, 128, 1024], ttnn.CoreGrid(x=4, y=4)),
),
)
def test_reallocate_sharded(device, input_shape, core_grid, strategy, layout):
if (input_shape[-1] % ttnn.TILE_SIZE != 0 or input_shape[-2] % ttnn.TILE_SIZE != 0) and layout == ttnn.TILE_LAYOUT:
pytest.skip("Shards must be aligned with tile layout")

input_memory_config = ttnn.create_sharded_memory_config(
input_shape, core_grid, strategy, ttnn.ShardOrientation.ROW_MAJOR
)

torch_input_tensor = torch.rand(input_shape).to(dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=layout)

dummy_tensor = torch.rand([1, 1, 512, 512])
dummy_tensor = ttnn.from_torch(dummy_tensor, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT)
dummy_tensor = ttnn.to_device(dummy_tensor, device, ttnn.L1_MEMORY_CONFIG)

input_tensor = ttnn.to_device(input_tensor, device, input_memory_config)

ttnn.deallocate(dummy_tensor) # make L1 space for reallocation
output_tensor = ttnn.reallocate(input_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_input_tensor, output_tensor, 1.0)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void kernel_main() {
constexpr uint32_t src_cb_id = get_compile_time_arg_val(0);
constexpr uint32_t dst_cb_id = get_compile_time_arg_val(1);

uint32_t src_cb_base_addr = get_write_ptr(src_cb_id); // TODO change to read
uint32_t src_cb_base_addr = get_read_ptr(src_cb_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love seeing TODOs being removed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seemed like a low hanging one 🥲

uint32_t dst_cb_base_addr = get_write_ptr(dst_cb_id);

// Copy from top of src cb to top of dst cb (backwards)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,19 +215,10 @@ operation::ProgramWithCallbacks move_multi_core_sharded(const Tensor& input, Ten
"Error");
const uint32_t src_cb_sharded = tt::CBIndex::c_0;
const uint32_t dst_cb_sharded = tt::CBIndex::c_1;
uint32_t tile_size_bytes = tile_size(cb_data_format);
uint32_t shard_shape_num_tiles = tt::div_up(shard_shape[0] * shard_shape[1], TILE_HEIGHT * TILE_WIDTH);
uint32_t total_size_bytes = 0;
uint32_t page_size_bytes = 0;
if ((shard_shape[0] * shard_shape[1]) % (TILE_HEIGHT * TILE_WIDTH) == 0) {
uint32_t tile_size_bytes = tile_size(cb_data_format);
total_size_bytes = shard_shape_num_tiles * tile_size_bytes;
page_size_bytes = tile_size_bytes;
} else {
uint32_t datum_size_bytes = datum_size(cb_data_format);
total_size_bytes = shard_shape[0] * shard_shape[1] * datum_size_bytes;
page_size_bytes = shard_shape[1] * datum_size_bytes;
}

uint32_t total_size_bytes = input.buffer()->aligned_size_per_bank();
uint32_t page_size_bytes = input.buffer()->aligned_page_size();

CircularBufferConfig src_cb_sharded_config =
CircularBufferConfig(total_size_bytes, {{src_cb_sharded, cb_data_format}})
.set_page_size(src_cb_sharded, page_size_bytes);
Expand All @@ -242,6 +233,7 @@ operation::ProgramWithCallbacks move_multi_core_sharded(const Tensor& input, Ten

auto input_buffer_address = input.buffer()->address();
auto output_buffer_address = output.buffer()->address();

TT_FATAL(
output_buffer_address > input_buffer_address,
"Expected output buffer to be allocated at a higher address than input buffer");
Expand Down
Loading