Skip to content

Commit

Permalink
Fix ttnn.reallocate when unaligned RM tensors are used
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Dec 20, 2024
1 parent eacc150 commit 936cac3
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 30 deletions.
53 changes: 37 additions & 16 deletions tests/ttnn/unit_tests/operations/test_reallocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
from models.utility_functions import is_wormhole_b0, is_blackhole


@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="#7733: fix for sharding on whb0")
@pytest.mark.parametrize(
"mem_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG, ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG]
)
@pytest.mark.parametrize("mem_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG])
@pytest.mark.parametrize("num_allocs", [1, 2, 3, 4])
def test_ttnn_reallocate(device, mem_config, num_allocs):
def test_reallocate_interleaved(device, mem_config, num_allocs):
width = 1024
height = 128
depth = 2
Expand All @@ -35,17 +32,6 @@ def test_ttnn_reallocate(device, mem_config, num_allocs):
}
)

# If sharded, creat actual memory config
if mem_config == ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG:
shard_spec = ttnn.ShardSpec(
shard_grid, [batch * height * depth // 8, width], ttnn.ShardOrientation.ROW_MAJOR, False
)
mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
shard_spec,
)

torch_tensors = []
tensors = []
for i in range(num_allocs):
Expand All @@ -72,3 +58,38 @@ def test_ttnn_reallocate(device, mem_config, num_allocs):
assert new_address >= initial_address

assert_with_pcc(torch_tensors[-1], ttnn.to_torch(tensors[-1]), 0.9999)


@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("strategy", [ttnn.ShardStrategy.BLOCK, ttnn.ShardStrategy.HEIGHT])
@pytest.mark.parametrize(
"input_shape, core_grid",
(
([1, 1, 32, 32], ttnn.CoreGrid(x=1, y=1)),
([1, 1, 256, 256], ttnn.CoreGrid(x=2, y=2)),
([1, 1, 4, 34], ttnn.CoreGrid(x=1, y=1)), # Checks unaligned RM shard
([2, 2, 128, 1024], ttnn.CoreGrid(x=4, y=4)),
),
)
def test_reallocate_sharded(device, input_shape, core_grid, strategy, layout):
if (input_shape[-1] % ttnn.TILE_SIZE != 0 or input_shape[-2] % ttnn.TILE_SIZE != 0) and layout == ttnn.TILE_LAYOUT:
pytest.skip("Shards must be aligned with tile layout")

input_memory_config = ttnn.create_sharded_memory_config(
input_shape, core_grid, strategy, ttnn.ShardOrientation.ROW_MAJOR
)

torch_input_tensor = torch.rand(input_shape).to(dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=layout)

dummy_tensor = torch.rand([1, 1, 512, 512])
dummy_tensor = ttnn.from_torch(dummy_tensor, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT)
dummy_tensor = ttnn.to_device(dummy_tensor, device, ttnn.L1_MEMORY_CONFIG)

input_tensor = ttnn.to_device(input_tensor, device, input_memory_config)

ttnn.deallocate(dummy_tensor) # make L1 space for reallocation
output_tensor = ttnn.reallocate(input_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_input_tensor, output_tensor, 1.0)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void kernel_main() {
constexpr uint32_t src_cb_id = get_compile_time_arg_val(0);
constexpr uint32_t dst_cb_id = get_compile_time_arg_val(1);

uint32_t src_cb_base_addr = get_write_ptr(src_cb_id); // TODO change to read
uint32_t src_cb_base_addr = get_read_ptr(src_cb_id);
uint32_t dst_cb_base_addr = get_write_ptr(dst_cb_id);

// Copy from top of src cb to top of dst cb (backwards)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,19 +215,10 @@ operation::ProgramWithCallbacks move_multi_core_sharded(const Tensor& input, Ten
"Error");
const uint32_t src_cb_sharded = tt::CBIndex::c_0;
const uint32_t dst_cb_sharded = tt::CBIndex::c_1;
uint32_t tile_size_bytes = tile_size(cb_data_format);
uint32_t shard_shape_num_tiles = tt::div_up(shard_shape[0] * shard_shape[1], TILE_HEIGHT * TILE_WIDTH);
uint32_t total_size_bytes = 0;
uint32_t page_size_bytes = 0;
if ((shard_shape[0] * shard_shape[1]) % (TILE_HEIGHT * TILE_WIDTH) == 0) {
uint32_t tile_size_bytes = tile_size(cb_data_format);
total_size_bytes = shard_shape_num_tiles * tile_size_bytes;
page_size_bytes = tile_size_bytes;
} else {
uint32_t datum_size_bytes = datum_size(cb_data_format);
total_size_bytes = shard_shape[0] * shard_shape[1] * datum_size_bytes;
page_size_bytes = shard_shape[1] * datum_size_bytes;
}

uint32_t total_size_bytes = input.buffer()->aligned_size_per_bank();
uint32_t page_size_bytes = input.buffer()->aligned_page_size();

CircularBufferConfig src_cb_sharded_config =
CircularBufferConfig(total_size_bytes, {{src_cb_sharded, cb_data_format}})
.set_page_size(src_cb_sharded, page_size_bytes);
Expand All @@ -242,6 +233,7 @@ operation::ProgramWithCallbacks move_multi_core_sharded(const Tensor& input, Ten

auto input_buffer_address = input.buffer()->address();
auto output_buffer_address = output.buffer()->address();

TT_FATAL(
output_buffer_address > input_buffer_address,
"Expected output buffer to be allocated at a higher address than input buffer");
Expand Down

0 comments on commit 936cac3

Please sign in to comment.