Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#15572: Rewrite of Reshape OP from scratch #15572

Merged
merged 37 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5d343d4
#15269: reshape fully on device now
jvegaTT Nov 28, 2024
e2bbda5
#15558 edited comment to mention this issue
jvegaTT Nov 28, 2024
43b0bd2
#0: move tt_memmove to common library and ensure tilize/untilize is o…
jvegaTT Nov 29, 2024
8a94c1b
#0: added corrector for implied shape dimensions
jvegaTT Nov 29, 2024
4fd8c45
#13889: Added test to prove this issue is resolved
jvegaTT Nov 29, 2024
285d1fa
#12153: Adding test to verify that issue is resolved
jvegaTT Nov 29, 2024
7d40412
#15048: being more careful about bandaid for issues #15137 and #13338
jvegaTT Nov 29, 2024
7335c11
#14676: Adding test to verify that this issue is resolved
jvegaTT Nov 29, 2024
6110d47
#0: adding libraries for memmove to common
jvegaTT Nov 29, 2024
cbc2830
#14513: Adding test to prove issue is resolved
jvegaTT Nov 29, 2024
98c4b01
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT Nov 29, 2024
055dc0f
#15269: added multi-core support
jvegaTT Nov 29, 2024
1d058c7
Merge branch 'jvega/reshape_rm_on_device' of github.com:tenstorrent/t…
jvegaTT Nov 29, 2024
fac56d8
#0: addressing PR comments
jvegaTT Nov 29, 2024
90b6255
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT Nov 29, 2024
7ff95d2
#0: small oops
jvegaTT Nov 29, 2024
2edceb5
#15269: Host code optimizations
jvegaTT Dec 2, 2024
ad28b55
#15269: Move compute buffers to compile time
jvegaTT Dec 2, 2024
5501324
#15269: adding override_runtime_args_callback
jvegaTT Dec 2, 2024
109615f
#15269: improve the tt_memmove to use read or write noc as per user n…
jvegaTT Dec 2, 2024
631b1c5
#15269: improve the tt_memmove to use read or write datamover
jvegaTT Dec 2, 2024
10de9a2
#15269: add broken multi risk code
jvegaTT Dec 2, 2024
b709d9c
#15269: employing pow2 optimization
jvegaTT Dec 2, 2024
4fa2bf3
#15269: Added optimization to do one less copy on aligned only transfers
jvegaTT Dec 3, 2024
72f2d5c
#15269: added packet form of noc_async_read and write
jvegaTT Dec 3, 2024
09482f3
#15269: further small optimizations
jvegaTT Dec 3, 2024
3e30253
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT Dec 3, 2024
323ace7
#15702: Added a skip for grayskull due to issue 15702
jvegaTT Dec 4, 2024
81c282d
#15269: updating mnist device perf targets
jvegaTT Dec 4, 2024
82c0eb7
#15269: updating other perf targets
jvegaTT Dec 4, 2024
d835144
#15269: removing broken unused kernel code
jvegaTT Dec 4, 2024
16dd8ad
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT Dec 4, 2024
dad9a4f
#0: pre commit formatting change
jvegaTT Dec 4, 2024
0ea118b
#0 addressing artem PR review changes on 15572 PR
jvegaTT Dec 4, 2024
66a28c1
#15269: updating vgg device targets
jvegaTT Dec 4, 2024
0e411b0
#0: addressing austin PR review changes on 15572 PR
jvegaTT Dec 4, 2024
b8ce2b0
Merge branch 'main' into jvega/reshape_rm_on_device
jvegaTT Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/demos/bert_tiny/tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_perf_device_bare_metal(batch_size, expected_perf):
margin = 0.03

if is_wormhole_b0():
expected_perf = 4155.25
expected_perf = 3990.0
Copy link
Collaborator

@davorchap davorchap Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a perf drop or increase (lower or higher is better)?

Copy link
Contributor Author

@jvegaTT jvegaTT Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perf increased but the nature of the test makes it report a decrease. The device test currently uses tracy to record how long it runs on the device during the duration of the model. However, we are currently making the really expensive on-host reshapes (which waste time in moving data in and out of the device) run on device instead which has a significant end-to-end perf improvement. However this also means device run-time goes up since the device is doing more and the host is doing less. The very expensive memory copy is not being reflected. That is why we also ran full model experiments and saw that we significantly improved perf when taken in a wholistic approach. We do however need to update these targets to reflect the new device runtime though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

freaking excellent, great job @jvegaTT !

else:
expected_perf = 3476.55

Expand Down
4 changes: 2 additions & 2 deletions models/demos/mnist/tests/test_perf_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def test_perf_device_bare_metal(batch_size, reset_seeds):
num_iterations = 1
margin = 0.03
if is_grayskull():
expected_perf = 653017.5
expected_perf = 390000.0
elif is_wormhole_b0():
expected_perf = 1383185.64944
expected_perf = 900000.0

command = f"pytest tests/ttnn/integration_tests/mnist/test_mnist.py::test_mnist"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]
Expand Down
4 changes: 2 additions & 2 deletions models/demos/vgg/tests/test_perf_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def test_perf_device_bare_metal_vgg(batch_size, model_name):
margin = 0.03

if model_name == "ttnn_vgg11":
expected_perf = 183 if is_grayskull() else 356
expected_perf = 36 if is_grayskull() else 104
command = f"pytest tests/ttnn/integration_tests/vgg/test_ttnn_vgg11.py"
else:
expected_perf = 165 if is_grayskull() else 276
expected_perf = 34 if is_grayskull() else 90
command = f"pytest tests/ttnn/integration_tests/vgg/test_ttnn_vgg16.py"

cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]
Expand Down
26 changes: 26 additions & 0 deletions tests/ttnn/unit_tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_grayskull


@pytest.mark.parametrize("n", [16])
Expand Down Expand Up @@ -293,6 +294,10 @@ def test_reshape_tile_layout_only_change_shape(device):
((1, 1445, 192), (1445, 192)),
((1, 256), (1, 1, 256)),
((16, 1, 32), (16, 1, 32)),
((1, 32, 4608), (1, 32, 16, 3, 96)), # issue 13889
((2888, 49, 96), (8, 19, 19, 7, 7, 96)), # issue 12153
((128, 1, 1, 128), (128, 128)), # issue 14676
((5, 4, 208, 156), (3, 13, 8, 2080)), # issue 14513
],
)
@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
Expand All @@ -304,6 +309,26 @@ def test_reshape_tile_with_padding(input_shape, output_shape, layout, device):
ttnn_output = ttnn.reshape(input_tensor, output_shape)
assert layout == ttnn_output.layout
output = ttnn.to_torch(ttnn_output)
assert_with_pcc(torch_result, output, 0.9999)


# issue 15048
def test_previously_failing_test(device):
src_shape = (1, 56, 56, 64)
target_shape = (1, 1, 56 * 56, 64)
torch_input_tensor = torch.randn(src_shape, dtype=torch.bfloat16)
torch_result = torch_input_tensor.reshape(target_shape)

input_tensor = ttnn.from_torch(
torch_input_tensor,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)

ttnn_output = ttnn.reshape(input_tensor, target_shape)
output = ttnn.to_torch(ttnn_output)

assert_with_pcc(torch_result, output, 0.9999)

Expand All @@ -330,6 +355,7 @@ def test_reshape_host(input_shape, output_shape, device):


# required for Embedding
@skip_for_grayskull("avoid this test while issue 15702 is resolved")
@pytest.mark.parametrize(
"input_shape, output_shape",
[
Expand Down
2 changes: 2 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/device/reshape_rm_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/reshape_view/device/host/reshape_rm_host_prep.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,81 @@
// It's best to copy and paste the functions in rather than include the header as code size will likely explode
// Best to separate in to cpp/hpp at some point to avoid the code size explosion but need to figure out the linking
// issues
#include <stdio.h>
#include <cstring>
#define MASK_64 0xFFFFFFFFFFFFFFC0
#define OFFSET_64 0x000000000000003F
#define MASK_16 0xFFFFFFFFFFFFFFF0
#define OFFSET_16 0x000000000000000F

namespace tt::data_movement::common {

template <uint32_t max_transfer_size, bool only_reads>
FORCE_INLINE void enhanced_noc_async_read(
const uint64_t src_noc_addr, const uint32_t dst_l1_addr, const uint32_t bytes) {
// If you do not know the max_transfer_size at compile time write 0 to it.
// only reads is true if we ONLY use noc_async_read and all calls to tt_memmove have use_read_datamover as True
if constexpr (((max_transfer_size < NOC_MAX_BURST_SIZE) && (max_transfer_size != 0)) || only_reads) {
noc_async_read_one_packet(src_noc_addr, dst_l1_addr, bytes);
} else {
noc_async_read(src_noc_addr, dst_l1_addr, bytes);
}
}

template <uint32_t max_transfer_size, bool only_writes>
FORCE_INLINE void enhanced_noc_async_write(
const uint32_t src_l1_addr, const uint64_t dst_noc_addr, const uint32_t bytes) {
// If you do not know the max_transfer_size at compile time write 0 to it.
// only writes is true if we ONLY use noc_async_read and all calls to tt_memmove have use_read_datamover as False
if constexpr (((max_transfer_size < NOC_MAX_BURST_SIZE) && (max_transfer_size != 0)) || only_writes) {
noc_async_write_one_packet(src_l1_addr, dst_noc_addr, bytes);
} else {
noc_async_write(src_l1_addr, dst_noc_addr, bytes);
}
}

template <bool guaranteed_16B_aligned, bool copy_async, bool use_read_datamover, uint32_t max_transfer_size>
FORCE_INLINE void tt_memmove(const uint32_t dst_l1_addr, const uint32_t src_l1_addr, const uint32_t bytes) {
// Function performs a memory copy between two l1 addresses in the local core
// Uses noc_async_read when possible to copy the data over
// Set guaranteed 16B aligned to true if the source and destination are externally guaranteed to be 16B aligned
// (dangerous) Set copy_async to true if you wish to perform the operation asynchronously, in this case you can add
// a noc_async_read_barrier to synchronize later
if constexpr (use_read_datamover) {
if constexpr (guaranteed_16B_aligned) {
enhanced_noc_async_read<max_transfer_size, false>(get_noc_addr(src_l1_addr), dst_l1_addr, bytes);
if constexpr (!copy_async) {
noc_async_read_barrier();
}
} else {
if ((dst_l1_addr & OFFSET_16) == (src_l1_addr & OFFSET_16)) {
enhanced_noc_async_read<max_transfer_size, false>(get_noc_addr(src_l1_addr), dst_l1_addr, bytes);
if constexpr (!copy_async) {
noc_async_read_barrier();
}
} else {
memmove((void*)(dst_l1_addr), (void*)(src_l1_addr), (size_t)(bytes));
}
}
} else {
if constexpr (guaranteed_16B_aligned) {
enhanced_noc_async_write<max_transfer_size, false>(src_l1_addr, get_noc_addr(dst_l1_addr), bytes);
if constexpr (!copy_async) {
noc_async_write_barrier();
}
} else {
if ((dst_l1_addr & OFFSET_16) == (src_l1_addr & OFFSET_16)) {
enhanced_noc_async_write<max_transfer_size, false>(src_l1_addr, get_noc_addr(dst_l1_addr), bytes);
if constexpr (!copy_async) {
noc_async_write_barrier();
}
} else {
memmove((void*)(dst_l1_addr), (void*)(src_l1_addr), (size_t)(bytes));
}
}
}
}

// this function is useful for converting bfloat16 values to float32
FORCE_INLINE float bfloat16_to_float32(uint16_t bfloat16_data) {
uint32_t bits = static_cast<uint32_t>(bfloat16_data) << 16;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0


/*
Function reads from RM and writes to RM

Assumptions:

Compile arguments
0. src0_is_dram: 1 if source is dram else 0
1. read_size_is_pow2: 1 if read size is power of 2 else 0
2. log_base_2_of_page_size: log base 2 of page size
3. write_size_is_pow2: 1 if write size is power of 2 else 0
4. log_base_2_of_page_size: log base 2 of page size
5. needs_read_allignment: 1 if read needs allignment else 0
//Needed if BRAM and page size is not multiple of 64 bytes

Runtime arguments
0. src_addr: source address
1. dst_addr: destination address
2. source_page_size_bytes: source page size in bytes
3. dest_page_size_bytes: destination page size in bytes
4. source_read_size_bytes: source read size in bytes
5. read_start_page: read start page
6. read_end_page: read end page
7. write_start_page: write start page
*/
#include <stdint.h>
#include "dataflow_api.h"
#include "debug/dprint.h" // required in all kernels using DPRINT
#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp"

void kernel_main() {
//We are guranteed to be in 2D going to 2D

const uint32_t src_addr = get_arg_val<uint32_t>(0);
const uint32_t dst_addr = get_arg_val<uint32_t>(1);
//If DDR this is source_page_size_bytes + 64 (rounded up to next 64B), if L1 this is source_page_size_bytes + 16 (rounded up to next 16B)
const uint32_t source_read_size_bytes = get_arg_val<uint32_t>(2);
const uint32_t read_start_page = get_arg_val<uint32_t>(3);
const uint32_t read_end_page = get_arg_val<uint32_t>(4);
const uint32_t write_start_page = get_arg_val<uint32_t>(5);
const uint32_t write_start_offset = get_arg_val<uint32_t>(6);
const uint32_t nop = get_arg_val<uint32_t>(9);

constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1;
#define src_aligned_to_64 get_compile_time_arg_val(1) == 1
#define src_aligned_to_16 get_compile_time_arg_val(2) == 1
constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(3);
constexpr uint32_t cb_id_in1 = get_compile_time_arg_val(4);
constexpr uint32_t source_page_size_bytes = get_compile_time_arg_val(5);
constexpr uint32_t dest_page_size_bytes = get_compile_time_arg_val(6);
#define source_page_is_pow_2 get_compile_time_arg_val(7) == 1
constexpr uint32_t source_page_pow_2 = get_compile_time_arg_val(8);
#define dest_page_is_pow_2 get_compile_time_arg_val(9) == 1
constexpr uint32_t dest_page_pow_2 = get_compile_time_arg_val(10);

//Since we need to operate on a grid of cores but sometimes pages don't split properly, if nop then don't use this core
if (nop == 1)
{
return;
}
#if source_page_is_pow_2
const InterleavedPow2AddrGen<tensor_is_dram> s = {
.bank_base_address = src_addr, .log_base_2_of_page_size = source_page_pow_2};
#else
const InterleavedAddrGen<tensor_is_dram> s = {
.bank_base_address = src_addr,
.page_size = source_page_size_bytes
};
#endif
#if dest_page_is_pow_2
const InterleavedPow2AddrGen<tensor_is_dram> d = {
.bank_base_address = dst_addr, .log_base_2_of_page_size = dest_page_pow_2};
#else
const InterleavedAddrGen<tensor_is_dram> d = {
.bank_base_address = dst_addr,
.page_size = dest_page_size_bytes
};
#endif

uint32_t read_offset = 0;
uint32_t write_page = write_start_page;
uint32_t readable = 0;
uint32_t end_to_write = 0;
uint32_t transaction = 0;
uint32_t writable = dest_page_size_bytes - write_start_offset;
//cb_id_in0 is a CB source_read_size_bytes page size, 1 page
//cb_id_in1 is a CB dest_page_size_bytes + allignment_to_64 page size, 1 page
cb_reserve_back(cb_id_in0, 1);
cb_reserve_back(cb_id_in1, 1);
const uint32_t source_buffer = get_write_ptr(cb_id_in0);
const uint32_t dest_buffer = get_write_ptr(cb_id_in1);
cb_push_back(cb_id_in1, 1);
cb_push_back(cb_id_in0, 1);

uint64_t dst_noc_addr = get_noc_addr(write_page, d);
uint64_t write_offset = (dst_noc_addr & OFFSET_16) + write_start_offset;
uint64_t begin_write_offset = write_offset;
constexpr bool can_be_clean = ((source_page_size_bytes % 16) == 0 && (dest_page_size_bytes % 16) == 0);
uint64_t dst_noc_addr_offset = 0;
for (uint32_t i = read_start_page; i < read_end_page; i++) {
//Read from source
uint64_t src_noc_addr = s.get_noc_addr(i,0);

#if (src_aligned_to_64 || ((!tensor_is_dram) && src_aligned_to_16))
//Aligned to 64 bytes or 16 bytes but L1
tt::data_movement::common::enhanced_noc_async_read<source_page_size_bytes, false>(
src_noc_addr, source_buffer, source_page_size_bytes);
read_offset = 0;
#elif (tensor_is_dram)
//DDR but not alligned to 64 (potentially also not alligned to 16)
tt::data_movement::common::enhanced_noc_async_read<(source_page_size_bytes + 128), false>(
src_noc_addr & MASK_64, source_buffer, source_read_size_bytes);
read_offset = src_noc_addr&OFFSET_64;
#else
//L1 but not alligned to 16
tt::data_movement::common::enhanced_noc_async_read<(source_page_size_bytes + 128), false>(
src_noc_addr & MASK_16, source_buffer, source_read_size_bytes);
read_offset = src_noc_addr&OFFSET_16;
#endif
readable = source_page_size_bytes;
noc_async_read_barrier();

//Write to dest
while (readable > 0)
{
noc_async_write_barrier();
if (readable < writable)
{
if constexpr (can_be_clean) {
tt::data_movement::common::enhanced_noc_async_write<dest_page_size_bytes, false>(
source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, readable);
dst_noc_addr_offset = dst_noc_addr_offset + readable;
} else {
tt::data_movement::common::tt_memmove<false, true, false, dest_page_size_bytes>(
dest_buffer + write_offset, source_buffer + read_offset, readable);
if (i == read_end_page - 1) {
tt::data_movement::common::enhanced_noc_async_write<dest_page_size_bytes, false>(
dest_buffer + begin_write_offset, dst_noc_addr, end_to_write);
return;
}
write_offset = write_offset + readable;
end_to_write = end_to_write + readable;
}
writable = writable - readable;
readable = 0;

}
else if (readable == writable)
{
if constexpr (can_be_clean) {
tt::data_movement::common::enhanced_noc_async_write<dest_page_size_bytes, false>(
source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, readable);
} else {
tt::data_movement::common::tt_memmove<false, false, false, dest_page_size_bytes>(
dest_buffer + write_offset, source_buffer + read_offset, readable);
tt::data_movement::common::enhanced_noc_async_write<dest_page_size_bytes, false>(
dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes);
}
dst_noc_addr_offset = 0;

writable = dest_page_size_bytes;
readable = 0;
if (i == read_end_page - 1) {
return;
}
write_page++;
dst_noc_addr = get_noc_addr(write_page, d);
if constexpr (!can_be_clean) {
end_to_write = 0;
write_offset = dst_noc_addr & OFFSET_16;
begin_write_offset = write_offset;
}
}
else
{
if constexpr (can_be_clean) {
tt::data_movement::common::enhanced_noc_async_write<dest_page_size_bytes, false>(
source_buffer + read_offset, dst_noc_addr + dst_noc_addr_offset, writable);
} else {
tt::data_movement::common::tt_memmove<false, false, false, dest_page_size_bytes>(
dest_buffer + write_offset, source_buffer + read_offset, writable);
tt::data_movement::common::enhanced_noc_async_write<dest_page_size_bytes, false>(
dest_buffer + begin_write_offset, dst_noc_addr, dest_page_size_bytes);
}
// writable < readable
readable = readable - writable;
read_offset = read_offset + writable;
write_page++;
dst_noc_addr_offset = 0;
dst_noc_addr = get_noc_addr(write_page, d);
if constexpr (!can_be_clean) {
end_to_write = 0;
write_offset = dst_noc_addr & OFFSET_16;
begin_write_offset = write_offset;
}
writable = dest_page_size_bytes;
}
}
}
return;
}
Loading
Loading