diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py index f1eebbaa7cd..4ebd1c5e542 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py @@ -15,6 +15,8 @@ from models.demos.t3000.llama2_70b.tt.llama_common import precompute_freqs, freqs_to_rotation_matrix, gather_rotary_emb +MAX_SEQ_LEN = 128 * 1024 + def get_rotation_mat(dhead, end, start_pos, seqlen, batch): cos, sin = precompute_freqs(dhead, end) @@ -28,18 +30,48 @@ class TtLlamaRotary(torch.nn.Module): def __init__( self, device, + batch, head_dim: int, + mode: str, datatype=ttnn.bfloat16, ): super().__init__() + + self.batch = batch self.head_dim = head_dim self.device = device + self.mode = mode - tile_width = 32 + self.core_grid = device.compute_with_storage_grid_size() + num_cores = self.core_grid.x * self.core_grid.y - self.transformation_mat = ttnn.from_torch( - get_rot_transformation_mat(dhead=tile_width), device=device, layout=ttnn.TILE_LAYOUT, dtype=datatype - ) + if mode == "decode": + # Generate the cos/sin matrices needed for ttnn.embedding op + cos_matrix, sin_matrix = compute_gather_cos_sin( + dhead=head_dim, end=MAX_SEQ_LEN * 2, position_ids=torch.arange(MAX_SEQ_LEN) + ) + + self.cos_matrix = ttnn.from_torch(cos_matrix, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=datatype) + self.sin_matrix = ttnn.from_torch(sin_matrix, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=datatype) + + # Generate the transformation matrix + trans_mat = get_rot_transformation_mat(dhead=ttnn.TILE_SIZE).repeat( + 1, 1, num_cores, 1 + ) # Repeat across all cores on device + trans_mat_mem_config = ttnn.create_sharded_memory_config( + shape=(1, 1, ttnn.TILE_SIZE * num_cores, ttnn.TILE_SIZE), + core_grid=ttnn.CoreGrid(y=self.core_grid.y, x=self.core_grid.x), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + self.transformation_mat = ttnn.from_torch( + trans_mat, device=device, layout=ttnn.TILE_LAYOUT, dtype=datatype, memory_config=trans_mat_mem_config + ) + + else: + self.transformation_mat = ttnn.from_torch( + get_rot_transformation_mat(dhead=ttnn.TILE_SIZE), device=device, layout=ttnn.TILE_LAYOUT, dtype=datatype + ) def apply_rotary(self, x, cos, sin): # n_head = 8 for Q @@ -54,11 +86,50 @@ def apply_rotary(self, x, cos, sin): ) rotary_output = ttnn.experimental.rotary_embedding_llama( - x, cos, sin, self.transformation_mat, compute_kernel_config=compute_kernel_config + x, + cos, + sin, + self.transformation_mat, + is_decode_mode=self.mode == "decode", + compute_kernel_config=compute_kernel_config, ) return rotary_output + def prepare_decode_cos_sin(self, position_ids): + assert isinstance(position_ids, torch.Tensor), "Position ids must be a torch tensor" + + position_ids = position_ids.unsqueeze(-1) # [batch, 1] + position_ids = ttnn.from_torch( + position_ids, device=self.device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32 + ) + + cos = ttnn.embedding(position_ids, self.cos_matrix) # [batch, head_dim, head_dim] + sin = ttnn.embedding(position_ids, self.sin_matrix) # [batch, head_dim, head_dim] + + cos = ttnn.reshape(cos, [1, position_ids.shape[0], 1, self.head_dim]) # [1, batch, 1, head_dim] + sin = ttnn.reshape(sin, [1, position_ids.shape[0], 1, self.head_dim]) # [1, batch, 1, head_dim] + + cos = ttnn.to_layout(cos, ttnn.TILE_LAYOUT) + sin = ttnn.to_layout(sin, ttnn.TILE_LAYOUT) + + grid = ( + ttnn.CoreRangeSet(ttnn.num_cores_to_corerange_set(self.batch, self.core_grid, row_wise=True)) + .bounding_box() + .grid_size() + ) + mem_config = ttnn.create_sharded_memory_config( + shape=(1, self.batch, ttnn.TILE_SIZE, self.head_dim), + core_grid=ttnn.CoreGrid(y=grid.y, x=grid.x), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + + cos = ttnn.interleaved_to_sharded(cos, mem_config) # [1, 1 (= batch / shard_num_cores), 1[32], self.head_dim] + sin = ttnn.interleaved_to_sharded(sin, mem_config) # [1, 1 (= batch / shard_num_cores), 1[32], self.head_dim] + + return cos, sin + def forward(self, xq, xk, cos, sin): xq = self.apply_rotary(xq, cos, sin) xk = self.apply_rotary(xk, cos, sin) @@ -113,19 +184,31 @@ def run_test_rotary_embedding_llama( ): # Prepare input torch.manual_seed(0) + mode = "decode" if seq_len == 1 else "prefill" + inp = [ (torch.rand(batch, n_heads, seq_len, head_dim) * 2) - 1, (torch.rand(batch, n_kv_heads, seq_len, head_dim) * 2) - 1, ] + + if mode == "decode": # For decode, torch expects [1, n_heads, batch, head_dim] + inp = [x.permute(2, 1, 0, 3) for x in inp] + freqs_cis = precompute_freqs_cis( # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning. head_dim, - max_seq_len * 2, + MAX_SEQ_LEN * 2 if mode == "decode" else max_seq_len * 2, # In decode, precompute for all positions ) # torch.Size([8192, 64]) start_pos = 0 # Must pick non-zero start pos to get non-zero freqs_cis - freqs_cis = freqs_cis[start_pos : start_pos + seq_len] + + if mode == "decode": # In decode, each user has a different position + position_ids = torch.arange(batch) # TODO: Update to check other indices as well + else: + position_ids = slice(start_pos, start_pos + seq_len) + + freqs_cis = freqs_cis[position_ids] # PyTorch Ground Truth output -------------------------------------------------------------------- torch_xq = inp[0].transpose(1, 2) @@ -139,17 +222,47 @@ def run_test_rotary_embedding_llama( pytorch_out = (torch_xq, torch_xk) # TT hardware / Modified PyTorch execution ------------------------------------------------------------- - tt_model = TtLlamaRotary(device, head_dim, datatype) + tt_model = TtLlamaRotary(device, batch, head_dim, mode, datatype) - cos, sin = compute_gather_cos_sin( - dhead=head_dim, end=max_seq_len * 2, position_ids=torch.arange(start_pos, start_pos + seq_len) - ) - tt_inp = [inp[0], inp[1], cos, sin] - tt_inp = [ttnn.from_torch(i, device=device, dtype=datatype, layout=ttnn.TILE_LAYOUT) for i in tt_inp] + if mode == "decode": + cos, sin = tt_model.prepare_decode_cos_sin(position_ids) + + # For decode, TTNN expects inputs to be [1, batch, nh, dhead] + inp = [x.transpose(1, 2) for x in inp] + + grid = ( + ttnn.CoreRangeSet(ttnn.num_cores_to_corerange_set(batch, tt_model.core_grid, row_wise=True)) + .bounding_box() + .grid_size() + ) + input_mem_config = ttnn.create_sharded_memory_config( + shape=(1, batch, ttnn.TILE_SIZE, head_dim), + core_grid=ttnn.CoreGrid(y=grid.y, x=grid.x), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + + tt_inp = [ + ttnn.from_torch(i, device=device, dtype=datatype, memory_config=input_mem_config, layout=ttnn.TILE_LAYOUT) + for i in inp + ] + tt_inp += [cos, sin] # Append cos and sin to the input list + else: + cos, sin = compute_gather_cos_sin( + dhead=head_dim, + end=max_seq_len * 2, + position_ids=torch.arange(start_pos, start_pos + seq_len), + ) + + tt_inp = [inp[0], inp[1], cos, sin] + tt_inp = [ttnn.from_torch(i, device=device, dtype=datatype, layout=ttnn.TILE_LAYOUT) for i in tt_inp] tt_out = tt_model(*tt_inp) tt_out = [ttnn.to_torch(tt_out_tensor) for tt_out_tensor in tt_out] + if mode == "decode": # Swap back the n_head and batch dimensions to compare with torch output + tt_out = [x.transpose(1, 2) for x in tt_out] + # check outputs ---------------------------------------------------------------------- assert len(pytorch_out) == len(tt_out), "Lengths of pytorch and tt outputs do not match!" does_pass = True @@ -191,6 +304,11 @@ def run_test_rotary_embedding_llama( (1, 8192), (1, 16384), (1, 128 * 1024), + (64, 1), + (32, 1), + (16, 1), + (8, 1), + (1, 1), ), ids=( "prefill_32", @@ -203,6 +321,11 @@ def run_test_rotary_embedding_llama( "prefill_8k", "prefill_16k", "prefill_128k", + "decode_64", + "decode_32", + "decode_16", + "decode_8", + "decode_1", ), ) @pytest.mark.parametrize( @@ -235,12 +358,15 @@ def test_rotary_embedding_llama( if seq_len == 128 * 1024 and (n_heads, n_kv_heads, head_dim) != (8, 1, 128): pytest.skip("Only testing for (8, 1, 128) due to time constraints") + if seq_len == 1 and (n_heads > ttnn.TILE_SIZE or n_kv_heads > ttnn.TILE_SIZE): + pytest.skip("n_heads or n_kv_heads cannot be greater than ttnn.TILE_SIZE for decode mode") + max_seq_len = max(4096, seq_len) run_test_rotary_embedding_llama(device, batch, seq_len, pcc, n_heads, n_kv_heads, head_dim, max_seq_len, datatype) # shift input/output tensor by creating very small tensor between loop - inp = torch.rand(1, 1, 32, 32) + inp = torch.randn(1, 1, 32, 32) test_tensor = ( ttnn.Tensor( inp.reshape(-1).tolist(), @@ -261,11 +387,21 @@ def test_rotary_embedding_llama( (1, 2048), (1, 4096), (1, 8192), + (64, 1), + (32, 1), + (16, 1), + (8, 1), + (1, 1), ), ids=( "prefill_2k", "prefill_4k", "prefill_8k", + "decode_64", + "decode_32", + "decode_16", + "decode_8", + "decode_1", ), ) @pytest.mark.parametrize( @@ -291,6 +427,8 @@ def test_rotary_embedding_llama_with_program_cache( max_seq_len = max(4096, seq_len) + mode = "decode" if seq_len == 1 else "prefill" + cache_tensors = [] for _ in range(3): run_test_rotary_embedding_llama( @@ -298,7 +436,7 @@ def test_rotary_embedding_llama_with_program_cache( ) # shift input/output tensor by creating very small tensor between loop - inp = torch.rand(1, 1, 32, 32) + inp = torch.randn(1, 1, 32, 32) test_tensor = ( ttnn.Tensor( inp.reshape(-1).tolist(), @@ -312,4 +450,7 @@ def test_rotary_embedding_llama_with_program_cache( cache_tensors.append(test_tensor) - assert device.num_program_cache_entries() == 2 + if mode == "decode": + assert device.num_program_cache_entries() == 5 # 2 * Rope + embedding + reshape + to_layout + else: + assert device.num_program_cache_entries() == 2 # 2 * Rope diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/kernels/compute/rotary_embedding_llama_sharded.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/kernels/compute/rotary_embedding_llama_sharded.cpp new file mode 100644 index 00000000000..f6ccfba7910 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/kernels/compute/rotary_embedding_llama_sharded.cpp @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/common.h" +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/bcast.h" +#include "compute_kernel_api/matmul.h" + +ALWI void ACQ() { acquire_dst(); } +ALWI void REL() { release_dst(); } + +namespace NAMESPACE { +void MAIN { + + constexpr uint32_t onetile = 1; + constexpr uint32_t in_cb = get_compile_time_arg_val(0); + constexpr uint32_t cos_cb = get_compile_time_arg_val(1); + constexpr uint32_t sin_cb = get_compile_time_arg_val(2); + constexpr uint32_t trans_mat_cb = get_compile_time_arg_val(3); + + constexpr uint32_t rotated_in_interm_cb = get_compile_time_arg_val(4); + constexpr uint32_t cos_interm_cb = get_compile_time_arg_val(5); + constexpr uint32_t sin_interm_cb = get_compile_time_arg_val(6); + constexpr uint32_t out_cb = get_compile_time_arg_val(7); + constexpr uint32_t Wt = get_compile_time_arg_val(8); + constexpr uint32_t Ht = get_compile_time_arg_val(9); // How many rows (tiles) in n_heads dimension + + mm_init(); + binary_op_init_common(rotated_in_interm_cb, cos_cb); // General Init for all binary ops + + // Get the trans_mat + cb_reserve_back(trans_mat_cb, onetile); + cb_push_back(trans_mat_cb, onetile); + cb_wait_front(trans_mat_cb, onetile); + + + // Get the sin/cos matrices + // TODO: To parallelize across multiple batch, this should be in a batch loop + cb_reserve_back(sin_cb, Wt); + cb_reserve_back(cos_cb, Wt); + + cb_push_back(sin_cb, Wt); + cb_push_back(cos_cb, Wt); + + + for (uint32_t ht = 0; ht < Ht; ht++) { // Over n_heads_t dimension + cb_reserve_back(rotated_in_interm_cb, Wt); + cb_reserve_back(sin_interm_cb, Wt); + cb_reserve_back(cos_interm_cb, Wt); + cb_reserve_back(out_cb, Wt); + + // Get the input + cb_reserve_back(in_cb, Wt); + cb_push_back(in_cb, Wt); + cb_wait_front(in_cb, Wt); + + // Do the computation + + // rotated = x @ trans_mat + mm_init_short(in_cb, trans_mat_cb); + ACQ(); + for (uint32_t j = 0; j < Wt; ++j) { + matmul_tiles(in_cb, trans_mat_cb, j, 0, j, false); + pack_tile(j, rotated_in_interm_cb, j); + } + REL(); + cb_push_back(rotated_in_interm_cb, Wt); + cb_wait_front(rotated_in_interm_cb, Wt); + + mul_bcast_rows_init_short(); + ACQ(); + for (uint32_t j = 0; j < Wt; ++j) { + // sin_interim = rotated * sin + mul_tiles_bcast(rotated_in_interm_cb, sin_cb, j, j, j); + pack_tile(j, sin_interm_cb, j); + } + REL(); + cb_push_back(sin_interm_cb, Wt); + cb_pop_front(rotated_in_interm_cb, Wt); + + ACQ(); + for (uint32_t j = 0; j < Wt; ++j) { + // cos_interim = x * cos + mul_tiles_bcast(in_cb, cos_cb, j, j, j); + pack_tile(j, cos_interm_cb, j); + } + REL(); + cb_push_back(cos_interm_cb, Wt); + cb_pop_front(in_cb, Wt); // Done with input + + + cb_wait_front(sin_interm_cb, Wt); + cb_wait_front(cos_interm_cb, Wt); + add_tiles_init(); + ACQ(); + for (uint32_t j = 0; j < Wt; ++j) { + // out = cos_interim + sin_interim + add_tiles(cos_interm_cb, sin_interm_cb, j, j, j); + pack_tile(j, out_cb, j); + } + REL(); + cb_push_back(out_cb, Wt); + cb_pop_front(sin_interm_cb, Wt); + cb_pop_front(cos_interm_cb, Wt); + + } + + // Done with the sin/cos matrices, so remove from CB + cb_pop_front(sin_cb, Wt); + cb_pop_front(cos_cb, Wt); + + // Done with the transformation matrix, so remove from CB + cb_pop_front(trans_mat_cb, onetile); +} +} // NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_device_operation.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_device_operation.cpp index de5400db665..b9ba4ff4240 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_device_operation.cpp @@ -19,6 +19,7 @@ void RotaryEmbeddingLlama::validate(const std::vector& input_tensors) co const auto& sin = input_tensors.at(2); const auto& trans_mat = input_tensors.at(3); TT_FATAL(input_tensors.size() == 4, "Error"); + auto ref_device = input_tensor.device(); for (const auto& input : input_tensors) { TT_FATAL(input.storage_type() == StorageType::DEVICE, "Operands to rotary embedding need to be on device!"); @@ -27,42 +28,65 @@ void RotaryEmbeddingLlama::validate(const std::vector& input_tensors) co TT_FATAL((input.get_layout() == Layout::TILE), "Inputs to rotary embedding must be tilized"); } - TT_FATAL(input_tensor.get_padded_shape()[-1] % TILE_WIDTH == 0, "Input X dim must be divisible into tiles"); - uint32_t seq_len = input_tensor.get_padded_shape()[-2]; - uint32_t B = input_tensor.get_padded_shape()[0]; - uint32_t head_dim = input_tensor.get_padded_shape()[-1]; - + uint32_t head_dim = input_tensor.get_logical_shape()[-1]; TT_FATAL(head_dim <= 128 || std::get(this->compute_kernel_config).fp32_dest_acc_en == false, "If head_dim is > 128, fp32_dest_acc_en must be False"); // Check that head_dim is less than 256 TT_FATAL(head_dim <= 256, "Head dim must be less than 256"); // Check that head_dim is a multiple of 32 - TT_FATAL(head_dim % 32 == 0, "Head dim must be a multiple of 32"); - // Check datatypes + TT_FATAL(head_dim % TILE_WIDTH == 0, "Head dim must be a multiple of TILE_WIDTH"); + TT_FATAL(input_tensor.get_dtype() == cos.get_dtype() && cos.get_dtype() == sin.get_dtype() && sin.get_dtype() == trans_mat.get_dtype() && trans_mat.get_dtype() == DataType::BFLOAT16, "All input tensors must have dtype = bfloat16"); - TT_FATAL(cos.get_dtype() == sin.get_dtype(), "Cos and Sin dtypes must match"); - TT_FATAL(cos.get_padded_shape() == sin.get_padded_shape(), "Cos and Sin dims must match"); - TT_FATAL(cos.get_padded_shape()[0] == 1 && cos.get_padded_shape()[1] == 1 && cos.get_padded_shape()[-1] == head_dim, "Cos dims must match input dims"); - - TT_FATAL(trans_mat.get_padded_shape()[0] == 1 && trans_mat.get_padded_shape()[1] == 1, "Transformation matrix must have 1st & 2nd dim equal to 1"); - TT_FATAL(trans_mat.get_padded_shape()[-2] == TILE_HEIGHT, "Transformation matrix must have 3rd dim equal to TILE_HEIGHT"); - TT_FATAL(trans_mat.get_padded_shape()[-1] == TILE_WIDTH, "Transformation matrix must have 4rd dim equal to TILE_WIDTH"); - - - TT_FATAL(input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Error"); - TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Error"); + TT_FATAL(input_tensor.memory_config().memory_layout == this->output_mem_config.memory_layout, "Input tensor and output tensor must have same memory layout"); + + // Check that cos and sin have same dims + TT_FATAL(cos.get_logical_shape() == sin.get_logical_shape(), "Cos and Sin dims must match"); + + if (this->is_decode_mode) { // Decode mode validation + uint32_t seq_len = input_tensor.get_logical_shape()[0]; + TT_FATAL(seq_len == 1, "rotary_embedding_llama currently only supports sharded inputs in decode mode, and therefore, seq_len (in dim 0) must be 1."); + + for (const auto& input : input_tensors) { + TT_FATAL((input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED), "Sharded inputs for RoPE must be HEIGHT_SHARDED."); + } + + uint32_t num_cores = input_tensor.device()->compute_with_storage_grid_size().x * input_tensor.device()->compute_with_storage_grid_size().y; + uint32_t batch = input_tensor.get_logical_shape()[1]; + TT_FATAL(batch <= num_cores, "In decode mode, RoPE is parallelized over batch dimension, and therefore, batch size must be less than or equal to the number of cores"); + + // Checks for cos and sin + TT_FATAL(batch == cos.get_logical_shape()[1], "Cos and Sin must have the same batch size as the input"); + TT_FATAL(cos.shard_spec()->shape[0] == TILE_HEIGHT, "In decode mode, RoPE only supports n_heads (shard_shape[0]) less than equal to TILE_HEIGHT"); // TODO: might be supported by kernel currently, but need to check with pytest + + // Checks for transformation matrix + TT_FATAL(trans_mat.get_logical_shape()[0] == 1 && trans_mat.get_logical_shape()[1] == 1, "Transformation matrix must have 1st & 2nd dim equal to 1"); + TT_FATAL(trans_mat.shard_spec()->shape[0] == TILE_HEIGHT, "Transformation matrix must have 3rd dim equal to TILE_HEIGHT"); + TT_FATAL(trans_mat.shard_spec()->shape[1] == TILE_WIDTH, "Transformation matrix must have 4rd dim equal to TILE_WIDTH"); + } else { // Prefill mode validation + TT_FATAL(input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Input tensor must be interleaved in prefill mode"); + + // Checks for cos and sin + TT_FATAL(cos.get_logical_shape()[0] == 1 && cos.get_logical_shape()[1] == 1 && cos.get_logical_shape()[-1] == head_dim, "Cos dims must match input dims"); + TT_FATAL(input_tensor.memory_config().memory_layout == sin.memory_config().memory_layout, "Input tensor and sin tensor must have same memory layout"); + TT_FATAL(input_tensor.memory_config().memory_layout == cos.memory_config().memory_layout, "Input tensor and cos tensor must have same memory layout"); + + // Checks for transformation matrix + TT_FATAL(trans_mat.get_logical_shape()[0] == 1 && trans_mat.get_logical_shape()[1] == 1, "Transformation matrix must have 1st & 2nd dim equal to 1"); + TT_FATAL(trans_mat.get_logical_shape()[-2] == TILE_HEIGHT, "Transformation matrix must have 3rd dim equal to TILE_HEIGHT"); + TT_FATAL(trans_mat.get_logical_shape()[-1] == TILE_WIDTH, "Transformation matrix must have 4rd dim equal to TILE_WIDTH"); + } } -std::vector RotaryEmbeddingLlama::compute_output_shapes(const std::vector& input_tensors) const { +std::vector RotaryEmbeddingLlama::compute_output_shapes(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); - auto shape = input_tensor.get_legacy_shape(); + auto shape = input_tensor.get_logical_shape(); return {shape}; } std::vector RotaryEmbeddingLlama::create_output_tensors(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); - auto output_shape = this->compute_output_shapes(input_tensors)[0].logical_shape(); + auto output_shape = this->compute_output_shapes(input_tensors)[0]; return {create_device_tensor( output_shape, input_tensor.get_dtype(), input_tensor.get_layout(), input_tensor.device(), this->output_mem_config)}; } @@ -76,7 +100,11 @@ operation::ProgramWithCallbacks RotaryEmbeddingLlama::create_program( auto& output_tensor = output_tensors.at(0); // Works on single core as well - return rotary_embedding_llama_multi_core(input_tensor, cos, sin, trans_mat, output_tensor, this->compute_kernel_config); + if (this->is_decode_mode) { + return rotary_embedding_llama_multi_core_sharded(input_tensor, cos, sin, trans_mat, output_tensor, this->compute_kernel_config); + } else { + return rotary_embedding_llama_multi_core(input_tensor, cos, sin, trans_mat, output_tensor, this->compute_kernel_config); + } } } // namespace tt_metal diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_device_operation.hpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_device_operation.hpp index f228a10d0e4..0d1338934df 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_device_operation.hpp @@ -15,12 +15,12 @@ namespace tt { namespace tt_metal { struct RotaryEmbeddingLlama { - const uint32_t seq_len; + const bool is_decode_mode; const MemoryConfig output_mem_config; const ttnn::DeviceComputeKernelConfig compute_kernel_config; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_program_factory.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_program_factory.cpp index 6c075f6b7c8..28bafefaf60 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_program_factory.cpp @@ -341,6 +341,182 @@ operation::ProgramWithCallbacks rotary_embedding_llama_multi_core( return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } + +operation::ProgramWithCallbacks rotary_embedding_llama_multi_core_sharded( + const Tensor &input, + const Tensor &cos, + const Tensor &sin, + const Tensor &trans_mat, + Tensor &output, + ttnn::DeviceComputeKernelConfig compute_kernel_config +) { + Program program{}; + + const tt::DataFormat input_cb_data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); + const uint32_t input_single_tile_size = tt_metal::detail::TileSize(input_cb_data_format); + + const tt::DataFormat cos_cb_data_format = tt_metal::datatype_to_dataformat_converter(cos.get_dtype()); + const uint32_t cos_single_tile_size = tt_metal::detail::TileSize(cos_cb_data_format); + + const tt::DataFormat sin_cb_data_format = tt_metal::datatype_to_dataformat_converter(sin.get_dtype()); + const uint32_t sin_single_tile_size = tt_metal::detail::TileSize(sin_cb_data_format); + + const tt::DataFormat trans_mat_cb_data_format = tt_metal::datatype_to_dataformat_converter(trans_mat.get_dtype()); + const uint32_t trans_mat_single_tile_size = tt_metal::detail::TileSize(trans_mat_cb_data_format); + + const tt::DataFormat output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + const uint32_t output_single_tile_size = tt_metal::detail::TileSize(output_cb_data_format); + + bool in_sharded = input.shard_spec().has_value(); + bool out_sharded = output.shard_spec().has_value(); + std::optional shard_spec = in_sharded ? input.shard_spec() : output.shard_spec(); + + const uint32_t batch = input.get_padded_shape()[1]; + const uint32_t n_heads_t = shard_spec->shape[0] / constants::TILE_HEIGHT; + const uint32_t head_dim_t = shard_spec->shape[1] / constants::TILE_WIDTH; + + tt_metal::Device *device = input.device(); + + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = + get_compute_kernel_config_args(device->arch(), compute_kernel_config); + + + CoreRange all_cores = shard_spec->grid.bounding_box(); + uint32_t num_cores_x = all_cores.grid_size().x; + uint32_t num_cores_y = all_cores.grid_size().y; + + const uint32_t num_input_tiles = n_heads_t * head_dim_t; + const uint32_t num_output_tiles = num_input_tiles; + + + // Parallelization + const uint32_t num_cores = num_cores_x * num_cores_y; + const uint32_t batch_parallel_factor = std::min(batch, num_cores); + const uint32_t batch_per_core = (batch + batch_parallel_factor - 1) / batch_parallel_factor; // TODO: To make general, add support for batch_per_core > 1 + + const uint32_t num_sin_cos_rows_per_core = batch_per_core; + uint32_t num_cos_sin_tiles = head_dim_t * num_sin_cos_rows_per_core; + + + // Set up the CBs + auto src_buffer = input.buffer(); + auto cos_buffer = cos.buffer(); + auto sin_buffer = sin.buffer(); + auto trans_mat_buffer = trans_mat.buffer(); + auto dst_buffer = output.buffer(); + + uint32_t input_cb_index = CB::c_in0; + tt_metal::CircularBufferConfig cb_input_config = + tt_metal::CircularBufferConfig( + num_input_tiles * input_single_tile_size, {{input_cb_index, input_cb_data_format}}) + .set_page_size(input_cb_index, input_single_tile_size) + .set_globally_allocated_address(*src_buffer); + auto cb_input = tt_metal::CreateCircularBuffer(program, all_cores, cb_input_config); + + uint32_t cos_cb_index = CB::c_in1; + tt_metal::CircularBufferConfig cb_cos_config = + tt_metal::CircularBufferConfig(num_cos_sin_tiles * cos_single_tile_size, {{cos_cb_index, cos_cb_data_format}}) + .set_page_size(cos_cb_index, cos_single_tile_size) + .set_globally_allocated_address(*cos_buffer); + auto cb_cos = tt_metal::CreateCircularBuffer(program, all_cores, cb_cos_config); + + uint32_t sin_cb_index = CB::c_in2; + tt_metal::CircularBufferConfig cb_sin_config = + tt_metal::CircularBufferConfig(num_cos_sin_tiles * sin_single_tile_size, {{sin_cb_index, sin_cb_data_format}}) + .set_page_size(sin_cb_index, sin_single_tile_size) + .set_globally_allocated_address(*sin_buffer); + auto cb_sin = tt_metal::CreateCircularBuffer(program, all_cores, cb_sin_config); + + uint32_t trans_mat_cb_index = CB::c_in3; + // We only take one tile of trans_mat + uint32_t num_trans_mat_tiles = 1; + tt_metal::CircularBufferConfig cb_trans_mat_config = + tt_metal::CircularBufferConfig(num_trans_mat_tiles * trans_mat_single_tile_size, {{trans_mat_cb_index, trans_mat_cb_data_format}}) + .set_page_size(trans_mat_cb_index, trans_mat_single_tile_size). + set_globally_allocated_address(*trans_mat_buffer); + auto cb_trans_mat = tt_metal::CreateCircularBuffer(program, all_cores, cb_trans_mat_config); + + uint32_t num_interm_tiles = head_dim_t; + uint32_t rotated_input_interm_cb_index = CB::c_intermed0; + tt_metal::CircularBufferConfig cb_rotated_input_interm_config = + tt_metal::CircularBufferConfig( + num_interm_tiles * input_single_tile_size, {{rotated_input_interm_cb_index, input_cb_data_format}}) + .set_page_size(rotated_input_interm_cb_index, input_single_tile_size); + auto cb_rotated_input_interm = tt_metal::CreateCircularBuffer(program, all_cores, cb_rotated_input_interm_config); + + uint32_t cos_interm_cb_index = CB::c_intermed1; + tt_metal::CircularBufferConfig cb_cos_interm_config = + tt_metal::CircularBufferConfig( + num_interm_tiles * input_single_tile_size, {{cos_interm_cb_index, cos_cb_data_format}}) + .set_page_size(cos_interm_cb_index, cos_single_tile_size); + auto cb_cos_interm = tt_metal::CreateCircularBuffer(program, all_cores, cb_cos_interm_config); + + uint32_t sin_interm_cb_index = CB::c_intermed2; + tt_metal::CircularBufferConfig cb_sin_interm_config = + tt_metal::CircularBufferConfig( + num_interm_tiles * input_single_tile_size, {{sin_interm_cb_index, sin_cb_data_format}}) + .set_page_size(sin_interm_cb_index, sin_single_tile_size); + auto cb_sin_interm = tt_metal::CreateCircularBuffer(program, all_cores, cb_sin_interm_config); + + uint32_t output_cb_index = CB::c_out0; // output operands start at index 16 + tt_metal::CircularBufferConfig cb_output_config = + tt_metal::CircularBufferConfig( + num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) + .set_page_size(output_cb_index, output_single_tile_size) + .set_globally_allocated_address(*dst_buffer); + auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config); + + + // Set up the kernel + std::vector compute_kernel_args = { + (std::uint32_t)input_cb_index, + (std::uint32_t)cos_cb_index, + (std::uint32_t)sin_cb_index, + (std::uint32_t)trans_mat_cb_index, + (std::uint32_t)rotated_input_interm_cb_index, + (std::uint32_t)cos_interm_cb_index, + (std::uint32_t)sin_interm_cb_index, + (std::uint32_t)output_cb_index, + (std::uint32_t)head_dim_t, + (std::uint32_t)n_heads_t, + }; + + auto rotary_embedding_kernel_id = tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/kernels/compute/rotary_embedding_llama_sharded.cpp", + all_cores, + tt_metal::ComputeConfig{.math_fidelity=math_fidelity, .fp32_dest_acc_en=fp32_dest_acc_en, .compile_args = compute_kernel_args}); + + auto override_runtime_arguments_callback = [ + cb_input, + cb_cos, + cb_sin, + cb_trans_mat, + cb_output + ]( const void *operation, + Program &program, + const std::vector& input_tensors, + const std::vector> &, + const std::vector &output_tensors) { + + auto src_buffer = input_tensors.at(0).buffer(); + auto cos_buffer = input_tensors.at(1).buffer(); + auto sin_buffer = input_tensors.at(2).buffer(); + auto trans_mat_buffer = input_tensors.at(3).buffer(); + auto dst_buffer = output_tensors.at(0).buffer(); + + // Update the CB globally allocated addresses here + UpdateDynamicCircularBufferAddress(program, cb_input, *src_buffer); + UpdateDynamicCircularBufferAddress(program, cb_cos, *cos_buffer); + UpdateDynamicCircularBufferAddress(program, cb_sin, *sin_buffer); + UpdateDynamicCircularBufferAddress(program, cb_trans_mat, *trans_mat_buffer); + UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); + + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + } // namespace tt_metal } // namespace tt diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_program_factory.hpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_program_factory.hpp index bfdcafcf1fa..76cbe1faeb8 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/rotary_embedding_llama_program_factory.hpp @@ -16,5 +16,8 @@ namespace tt_metal { operation::ProgramWithCallbacks rotary_embedding_llama_multi_core( const Tensor &input, const Tensor &cos, const Tensor &sin, const Tensor &trans_mat, Tensor &output, ttnn::DeviceComputeKernelConfig compute_kernel_config); +operation::ProgramWithCallbacks rotary_embedding_llama_multi_core_sharded( + const Tensor &input, const Tensor &cos, const Tensor &sin, const Tensor &trans_mat, Tensor &output, ttnn::DeviceComputeKernelConfig compute_kernel_config); + } // namespace tt_metal } // namespace tt diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama.cpp index 39a0b6f027c..0ac240064c5 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama.cpp @@ -13,14 +13,14 @@ Tensor RotaryEmbeddingLlamaOperation::invoke( const Tensor &cos_cache, const Tensor &sin_cache, const Tensor& trans_mat, + const bool is_decode_mode, const std::optional& memory_config, std::optional compute_kernel_config) { std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor, cos_cache, sin_cache, trans_mat}))}; operation::launch_op( - [memory_config, compute_kernel_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + [is_decode_mode, memory_config, compute_kernel_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { auto& input_tensor = input_tensors.at(0); - uint32_t seq_len = input_tensor.get_legacy_shape()[-2]; auto arch = input_tensor.storage_type() == StorageType::DEVICE ? input_tensor.device()->arch() : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::HiFi4, true, false, false); @@ -31,7 +31,7 @@ Tensor RotaryEmbeddingLlamaOperation::invoke( } return operation::run( - RotaryEmbeddingLlama{seq_len, memory_config.value_or(default_memory_config), kernel_config_val}, input_tensors); + RotaryEmbeddingLlama{is_decode_mode, memory_config.value_or(default_memory_config), kernel_config_val}, input_tensors); }, {input_tensor, cos_cache, sin_cache, trans_mat}, output_tensors); return output_tensors.at(0); } diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama.hpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama.hpp index d698dc4d45c..458a67f5e7c 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama.hpp @@ -17,6 +17,7 @@ namespace operations::experimental::transformer { const Tensor& cos_cache, const Tensor& sin_cache, const Tensor& trans_mat, + const bool is_decode_mode = false, const std::optional& memory_config = std::nullopt, const std::optional compute_kernel_config = std::nullopt); }; diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama_pybind.cpp index 1e4bed73d21..d0270c79c84 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama_pybind.cpp @@ -19,7 +19,7 @@ void py_bind_rotary_embedding_llama(pybind11::module& module) { ttnn::bind_registered_operation( module, ttnn::experimental::rotary_embedding_llama, - R"doc(rotary_embedding_llama(input_tensor: ttnn.Tensor, cos_cache: ttnn.Tensor, sin_cache: ttnn.Tensor, trans_mat: ttnn.Tensor, memory_config: MemoryConfig, compute_kernel_config: Optional[DeviceComputeKernelConfig]) -> ttnn.Tensor + R"doc(rotary_embedding_llama(input_tensor: ttnn.Tensor, cos_cache: ttnn.Tensor, sin_cache: ttnn.Tensor, trans_mat: ttnn.Tensor, is_decode_mode: bool, memory_config: MemoryConfig, compute_kernel_config: Optional[DeviceComputeKernelConfig]) -> ttnn.Tensor Applies the rotary embedding to the input_tensor tensor using the cos_cache and sin_cache tensors. @@ -30,6 +30,7 @@ void py_bind_rotary_embedding_llama(pybind11::module& module) { * :attr:`cos_cache`: Cosine Cache Tensor * :attr:`sin_cache`: Sine Cache Tensor * :attr:`trans_mat`: Transformation Matrix Tensor + * :attr:`is_decode_mode`: Specify mode of operation * :attr:`memory_config`: Memory Config of the output tensor = DEFAULT_OUTPUT_MEMORY_CONFIG * :attr:`compute_kernel_config`: Optional[DeviceComputeKernelConfig] = None )doc", @@ -39,6 +40,7 @@ void py_bind_rotary_embedding_llama(pybind11::module& module) { py::arg("sin_cache"), py::arg("trans_mat"), py::kw_only(), + py::arg("is_decode_mode") = false, py::arg("memory_config") = std::nullopt, py::arg("compute_kernel_config") = std::nullopt}); }