From 9cc021618f0eadf417ef010c67eeeb1662db44e3 Mon Sep 17 00:00:00 2001 From: Colman Glagovich <114512306+cglagovichTT@users.noreply.github.com> Date: Wed, 29 Jan 2025 09:36:40 -0500 Subject: [PATCH] #16557: Implement JointAttention (#17079) --- .../misc/test_scaled_dot_product_attention.py | 152 ++++ ttnn/CMakeLists.txt | 2 + .../transformer/sdpa/device/joint_sdpa_op.cpp | 204 ++++++ .../transformer/sdpa/device/joint_sdpa_op.hpp | 34 + .../device/joint_sdpa_program_factory.cpp | 651 ++++++++++++++++++ .../device/joint_sdpa_program_factory.hpp | 28 + .../device/kernels/compute/joint_sdpa.cpp | 197 ++++++ .../kernels/dataflow/dataflow_common.hpp | 137 +++- .../device/kernels/dataflow/joint_reader.cpp | 100 +++ .../device/kernels/dataflow/joint_writer.cpp | 92 +++ .../ttnn/operations/transformer/sdpa/sdpa.cpp | 61 +- .../ttnn/operations/transformer/sdpa/sdpa.hpp | 33 +- .../transformer/sdpa/sdpa_pybind.cpp | 80 ++- 13 files changed, 1766 insertions(+), 5 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_op.cpp create mode 100644 ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_op.hpp create mode 100644 ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_program_factory.cpp create mode 100644 ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_program_factory.hpp create mode 100644 ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/joint_sdpa.cpp create mode 100644 ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/joint_reader.cpp create mode 100644 ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/joint_writer.cpp diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py index add1b78e339..9bc75655c85 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py @@ -564,3 +564,155 @@ def test_sdpa_chunked_iterate_batch( assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format( device.num_program_cache_entries() ) + + +def run_test_joint_sdpa( + device, + b, + nh, + seq_len, + joint_seq_len, + d, + q_chunk_size, + k_chunk_size, + dtype, + use_high_precision_compute=False, + grid_size=None, +): + torch.manual_seed(1234) + + program_config = ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=grid_size or device.compute_with_storage_grid_size(), + q_chunk_size=q_chunk_size, + k_chunk_size=k_chunk_size, + exp_approx_mode=False, + ) + + if use_high_precision_compute: + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) + else: + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + + Q = fa_rand(b, nh, seq_len, d) + K = fa_rand(b, nh, seq_len, d) + V = fa_rand(b, nh, seq_len, d) + + joint_Q = fa_rand(b, nh, joint_seq_len, d) + joint_K = fa_rand(b, nh, joint_seq_len, d) + joint_V = fa_rand(b, nh, joint_seq_len, d) + + # Print shapes of all inputs along with input names + logger.debug(f"Q: {Q.shape}") + logger.debug(f"K: {K.shape}") + logger.debug(f"V: {V.shape}") + + tt_Q = ttnn.from_torch(Q, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) + tt_K = ttnn.from_torch(K, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) + tt_V = ttnn.from_torch(V, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) + tt_joint_Q = ttnn.from_torch(joint_Q, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) + tt_joint_K = ttnn.from_torch(joint_K, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) + tt_joint_V = ttnn.from_torch(joint_V, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) + tt_out, tt_joint_out = ttnn.transformer.joint_scaled_dot_product_attention( + tt_Q, + tt_K, + tt_V, + tt_joint_Q, + tt_joint_K, + tt_joint_V, + joint_strategy="rear", + program_config=program_config, + compute_kernel_config=compute_kernel_config, + ) + tt_out = ttnn.to_torch(tt_out) + tt_joint_out = ttnn.to_torch(tt_joint_out) + # Slice out any tile-padding + tt_out = tt_out[:, :, :seq_len, :] + tt_joint_out = tt_joint_out[:, :, :joint_seq_len, :] + logger.debug(f"tt_out: {tt_out.shape}") + logger.debug(f"tt_joint_out: {tt_joint_out.shape}") + + pt_Q = torch.cat([Q, joint_Q], dim=2) + pt_K = torch.cat([K, joint_K], dim=2) + pt_V = torch.cat([V, joint_V], dim=2) + gt = torch.nn.functional.scaled_dot_product_attention(pt_Q, pt_K, pt_V, is_causal=False) + gt_out = gt[:, :, :seq_len, :] + gt_joint_out = gt[:, :, seq_len:, :] + + for out, gt in [(tt_out, gt_out), (tt_joint_out, gt_joint_out)]: + out_pass, out_pcc = comp_pcc(gt, out, 0.994) + logger.debug(f"python vs pytorch: {out_pcc}") + logger.debug(f"mse: {((gt - out) ** 2).mean()}") + assert out_pass + + +@skip_for_blackhole("Mismatching on BH, see #12349") +@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled") +@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") +@pytest.mark.parametrize("dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["bfp8", "bf16"]) +@pytest.mark.parametrize("q_chunk_size", [128, 512], ids=["q128", "q512"]) +@pytest.mark.parametrize("k_chunk_size", [128, 512], ids=["k128", "k512"]) +@pytest.mark.parametrize("b", [1, 2], ids=["b1", "b2"]) +@pytest.mark.parametrize("nh", [1, 3], ids=["nh1", "nh3"]) +@pytest.mark.parametrize( + "seq_len, joint_seq_len", + [ + (15, 19), + (2048, 256), + (3000, 100), + (20 * 1024 + 1, 118), + ], +) +@pytest.mark.parametrize( + "d", + [128], + ids=[ + "d128", + ], +) +def test_joint_sdpa(device, b, nh, seq_len, joint_seq_len, d, q_chunk_size, k_chunk_size, dtype): + if q_chunk_size == 512 and k_chunk_size == 512: + pytest.skip("OOM config.") + ttnn.device.DisablePersistentKernelCache() + run_test_joint_sdpa(device, b, nh, seq_len, joint_seq_len, d, q_chunk_size, k_chunk_size, dtype) + + +@skip_for_blackhole("Mismatching on BH, see #12349") +@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled") +@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") +@pytest.mark.parametrize("dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["bfp8", "bf16"]) +@pytest.mark.parametrize("q_chunk_size", [128], ids=["q128"]) +@pytest.mark.parametrize("k_chunk_size", [128], ids=["k128"]) +@pytest.mark.parametrize("b", [1], ids=["b1"]) +@pytest.mark.parametrize("nh", [1], ids=["nh1"]) +@pytest.mark.parametrize( + "seq_len, joint_seq_len", + [ + (3000, 100), + ], +) +@pytest.mark.parametrize( + "d", + [128], + ids=[ + "d128", + ], +) +def test_joint_sdpa_program_cache( + device, b, nh, seq_len, joint_seq_len, d, q_chunk_size, k_chunk_size, dtype, use_program_cache +): + dummy_tensors = [] + for _ in range(3): + dummy_tensors.append( + ttnn.from_torch(fa_rand(b, nh, seq_len, d), dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) + ) + run_test_joint_sdpa(device, b, nh, seq_len, joint_seq_len, d, q_chunk_size, k_chunk_size, dtype, dummy_tensors) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 22b6f4375ad..11b6ae2eda2 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -377,6 +377,8 @@ set(TTNN_OP_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_op.cpp new file mode 100644 index 00000000000..d2ffc6d364d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_op.cpp @@ -0,0 +1,204 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "joint_sdpa_op.hpp" + +#include "joint_sdpa_program_factory.hpp" +#include "ttnn/run_operation.hpp" +#include + +using namespace tt::tt_metal; + +namespace ttnn::operations::transformer { + +void JointScaledDotProductAttention::validate(const std::vector& input_tensors) const { + tt::log_info("Validating Joint SDPA inputs"); + TT_FATAL(input_tensors.size() == 6, "Must have 6 input tensors (Q, K, V, joint_Q, joint_K, joint_V)"); + + const auto& input_tensor_q = input_tensors.at(0); + const auto& input_tensor_k = input_tensors.at(1); + const auto& input_tensor_v = input_tensors.at(2); + const auto& joint_tensor_q = input_tensors.at(3); + const auto& joint_tensor_k = input_tensors.at(4); + const auto& joint_tensor_v = input_tensors.at(5); + + // Validate joint strategy is 'rear' + TT_FATAL(this->joint_strategy == "rear", "Joint strategy must be 'rear'. Got: {}", this->joint_strategy); + + // Validate all tensors have the same dtype + const auto dtype = input_tensor_q.get_dtype(); + for (const auto& tensor : input_tensors) { + TT_FATAL( + tensor.get_dtype() == dtype, + "All tensors must have the same dtype. Expected {}, got {}", + dtype, + tensor.get_dtype()); + } + + // Get shapes + const auto q_shape = input_tensor_q.get_logical_shape(); + const auto k_shape = input_tensor_k.get_logical_shape(); + const auto v_shape = input_tensor_v.get_logical_shape(); + const auto joint_q_shape = joint_tensor_q.get_logical_shape(); + const auto joint_k_shape = joint_tensor_k.get_logical_shape(); + const auto joint_v_shape = joint_tensor_v.get_logical_shape(); + + // Validate storage types and buffers + for (auto& tensor : input_tensors) { + TT_FATAL(tensor.storage_type() == StorageType::DEVICE, "Operands to Joint SDPA need to be on device"); + TT_FATAL(tensor.buffer() != nullptr, "Operands to Joint SDPA need to be allocated in buffers on device"); + TT_FATAL(tensor.get_layout() == Layout::TILE, "Inputs to Joint SDPA must be tilized"); + TT_FATAL( + tensor.get_dtype() == DataType::BFLOAT16 || tensor.get_dtype() == DataType::BFLOAT8_B, + "Inputs to Joint SDPA must be BF16 or BF8"); + TT_FATAL( + tensor.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM, + "Operands to Joint SDPA need to be in DRAM"); + } + + // Validate input shapes match + TT_FATAL( + k_shape[0] == q_shape[0] && v_shape[0] == q_shape[0], + "Batch sizes must match. Got Q: {}, K: {}, V: {}", + q_shape[0], + k_shape[0], + v_shape[0]); + + // Validate joint input shapes match + TT_FATAL( + joint_k_shape[0] == joint_q_shape[0] && joint_v_shape[0] == joint_q_shape[0], + "Joint batch sizes must match. Got Q: {}, K: {}, V: {}", + joint_q_shape[0], + joint_k_shape[0], + joint_v_shape[0]); + + // Validate Q and joint Q have same batch size and num heads + TT_FATAL( + q_shape[0] == joint_q_shape[0], + "Q and joint Q must have same batch size. Got Q: {}, joint Q: {}", + q_shape[0], + joint_q_shape[0]); + + // Validate head dimensions match + TT_FATAL( + k_shape[3] == q_shape[3] && v_shape[3] == q_shape[3], + "Head dimensions must match. Got Q: {}, K: {}, V: {}", + q_shape[3], + k_shape[3], + v_shape[3]); + + TT_FATAL( + joint_k_shape[3] == joint_q_shape[3] && joint_v_shape[3] == joint_q_shape[3], + "Joint head dimensions must match. Got Q: {}, K: {}, V: {}", + joint_q_shape[3], + joint_k_shape[3], + joint_v_shape[3]); + + TT_FATAL( + q_shape[3] == joint_q_shape[3], + "Q and joint Q must have same head dimension. Got Q: {}, joint Q: {}", + q_shape[3], + joint_q_shape[3]); + + // Validate num_heads relationship + const auto nqh = q_shape[1]; + const auto nkv = k_shape[1]; + const auto joint_nqh = joint_q_shape[1]; + const auto joint_nkv = joint_k_shape[1]; + + TT_FATAL(nqh == nkv, "Q num_heads must be equal to K num_heads. Got Q: {}, K: {}", nqh, nkv); + + TT_FATAL( + joint_nqh == joint_nkv, + "Joint Q num_heads must be equal to Joint K num_heads. Got Q: {}, K: {}", + joint_nqh, + joint_nkv); + TT_FATAL( + joint_nkv == nkv, "Joint K num_heads must be equal to K num_heads. Got Joint K: {}, K: {}", joint_nkv, nkv); + + // Validate chunk sizes if program config is provided + auto q_chunk_size = this->get_q_chunk_size(); + auto k_chunk_size = this->get_k_chunk_size(); + + TT_FATAL( + q_chunk_size % tt::constants::TILE_WIDTH == 0, + "q_chunk_size must be divisible by TILE_SIZE. Got q_chunk_size: {}, TILE_SIZE: {}", + q_chunk_size, + tt::constants::TILE_WIDTH); + TT_FATAL( + k_chunk_size % tt::constants::TILE_WIDTH == 0, + "k_chunk_size must be divisible by TILE_SIZE. Got k_chunk_size: {}, TILE_SIZE: {}", + k_chunk_size, + tt::constants::TILE_WIDTH); + + // Validate padding: Only the sequence dimension may be padded + auto validate_padding = [](const Tensor& tensor) { + auto logical_shape = tensor.get_logical_shape(); + auto padded_shape = tensor.get_padded_shape(); + TT_FATAL(logical_shape[0] == padded_shape[0], "Padding is not supported on the batch dimension"); + TT_FATAL(logical_shape[1] == padded_shape[1], "Padding is not supported on the num_heads dimension"); + TT_FATAL(logical_shape[3] == padded_shape[3], "Padding is not supported on the head_dim dimension"); + }; + + for (const auto& tensor : input_tensors) { + validate_padding(tensor); + } +} + +std::uint32_t JointScaledDotProductAttention::get_q_chunk_size() const { + return this->program_config ? this->program_config->q_chunk_size : 32; +} + +std::uint32_t JointScaledDotProductAttention::get_k_chunk_size() const { + return this->program_config ? this->program_config->k_chunk_size : 32; +} + +std::vector JointScaledDotProductAttention::compute_output_specs( + const std::vector& input_tensors) const { + auto& input = input_tensors.at(0); + auto& joint_input = input_tensors.at(3); + return { + TensorSpec( + input.get_logical_shape(), TensorLayout(input.get_dtype(), PageConfig(Layout::TILE), output_mem_config)), + TensorSpec( + joint_input.get_logical_shape(), + TensorLayout(joint_input.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; +} + +operation::ProgramWithCallbacks JointScaledDotProductAttention::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { + auto& input_tensor_q = input_tensors.at(0); + auto& input_tensor_k = input_tensors.at(1); + auto& input_tensor_v = input_tensors.at(2); + auto& joint_tensor_q = input_tensors.at(3); + auto& joint_tensor_k = input_tensors.at(4); + auto& joint_tensor_v = input_tensors.at(5); + auto& output_tensor = output_tensors.at(0); + auto& joint_output_tensor = output_tensors.at(1); + + auto scale = this->scale; + if (not scale.has_value()) { + scale = 1.0f / std::sqrt(static_cast(input_tensor_q.get_logical_shape()[-1])); + } + + std::size_t q_chunk_size = this->get_q_chunk_size(); + std::size_t k_chunk_size = this->get_k_chunk_size(); + + return detail::joint_sdpa( + input_tensor_q, + input_tensor_k, + input_tensor_v, + joint_tensor_q, + joint_tensor_k, + joint_tensor_v, + output_tensor, + joint_output_tensor, + scale, + q_chunk_size, + k_chunk_size, + this->compute_kernel_config, + this->program_config); +} + +} // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_op.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_op.hpp new file mode 100644 index 00000000000..6f67131e55b --- /dev/null +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_op.hpp @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" +#include "ttnn/operation.hpp" +#include "ttnn/operations/transformer/sdpa_config.hpp" +#include "ttnn/tensor/tensor.hpp" + +namespace ttnn::operations::transformer { + +struct JointScaledDotProductAttention { + const std::string joint_strategy; + const std::optional scale; + const MemoryConfig output_mem_config; + const std::optional program_config; + const DeviceComputeKernelConfig compute_kernel_config; + + void validate(const std::vector& input_tensors) const; + + std::vector compute_output_specs(const std::vector& input_tensors) const; + + operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, std::vector& output_tensors) const; + + std::uint32_t get_q_chunk_size() const; + std::uint32_t get_k_chunk_size() const; +}; + +} // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_program_factory.cpp new file mode 100644 index 00000000000..b886d408ea7 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_program_factory.cpp @@ -0,0 +1,651 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "joint_sdpa_program_factory.hpp" +#include "joint_sdpa_op.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include "ttnn/operations/math.hpp" +#include "ttnn/operation.hpp" + +using namespace tt::constants; +using namespace tt::tt_metal; + +namespace ttnn::operations::transformer::detail { + +// implementation of softmax with optional scale/mask (see the header for input_tensor more detailed description) +operation::ProgramWithCallbacks joint_sdpa( + const Tensor& input_tensor_q, + const Tensor& input_tensor_k, + const Tensor& input_tensor_v, + const Tensor& joint_tensor_q, + const Tensor& joint_tensor_k, + const Tensor& joint_tensor_v, + const Tensor& output_tensor, + const Tensor& joint_output_tensor, + std::optional scale, + std::size_t q_chunk_size, + std::size_t k_chunk_size, + DeviceComputeKernelConfig compute_kernel_config, + std::optional program_config) { + /* + Q: B x NH x N x DH + K: B x NH x N x DH + V: B x NH x N x DH + + Q_joint: B x NH x L x DH + K_joint: B x NH x L x DH + V_joint: B x NH x L x DH + */ + + const auto q_shape = input_tensor_q.get_logical_shape(); + const auto joint_q_shape = joint_tensor_q.get_logical_shape(); + const uint32_t B = q_shape[0], NH = q_shape[1], N = q_shape[2], DH = q_shape[3]; + const uint32_t L = joint_q_shape[2]; + + // Calculate padded sequence length + const uint32_t padded_Nq = tt::round_up(N, q_chunk_size); + const uint32_t padded_Nk = tt::round_up(N, k_chunk_size); + const uint32_t padded_Lq = tt::round_up(L, q_chunk_size); + const uint32_t padded_Lk = tt::round_up(L, k_chunk_size); + + const uint32_t padded_Nqt = padded_Nq / TILE_HEIGHT; + const uint32_t padded_Nkt = padded_Nk / TILE_HEIGHT; + const uint32_t padded_Lqt = padded_Lq / TILE_HEIGHT; + const uint32_t padded_Lkt = padded_Lk / TILE_HEIGHT; + + // Find unpadded sequence lengths in tiles + const uint32_t valid_Nt = tt::div_up(N, TILE_HEIGHT); + const uint32_t valid_Lt = tt::div_up(L, TILE_HEIGHT); + + // Compute kernel operates on concatenated Q and K + const uint32_t cat_Sq = padded_Nq + padded_Lq; + const uint32_t cat_Sk = padded_Nk + padded_Lk; + + const uint32_t cat_Sqt = cat_Sq / TILE_HEIGHT; + const uint32_t cat_Skt = cat_Sk / TILE_HEIGHT; + const uint32_t DHt = DH / TILE_WIDTH; + + // Kernel will need to know the tile-based shapes of both sets of tensors + // to create a representation of the concatenated tensors. + + // const std::vector q_tile_shape = {B, NH, padded_Nqt, DHt}; + // const std::vector k_tile_shape = {B, NH, padded_Nkt, DHt}; + // const std::vector joint_q_tile_shape = {B, NH, padded_Lqt, DHt}; + // const std::vector joint_k_tile_shape = {B, NH, padded_Lkt, DHt}; + + /* + For non-causal case we must provide a padded mask if the K sequence length has been padded + Note that we dont have this issue in non-causal case if Q is padded, since those pad tokens + don't affect attention of unpadded tokens. + In causal case, the causal mask takes care of masking K pad tokens. + */ + const bool use_joint_mask = (padded_Nk != N) || (padded_Lk != L); + + const uint32_t Sq_chunk_t = q_chunk_size / TILE_HEIGHT; + const uint32_t Sk_chunk_t = k_chunk_size / TILE_HEIGHT; + const uint32_t q_num_chunks = cat_Sq / q_chunk_size; + const uint32_t k_num_chunks = cat_Sk / k_chunk_size; + + tt::log_debug("B: {}", B); + tt::log_debug("NH: {}", NH); + tt::log_debug("N: {}", N); + tt::log_debug("L: {}", L); + tt::log_debug("DH: {}", DH); + + // Log padded dimensions + tt::log_debug("padded_Nq: {}", padded_Nq); + tt::log_debug("padded_Nk: {}", padded_Nk); + tt::log_debug("padded_Lq: {}", padded_Lq); + tt::log_debug("padded_Lk: {}", padded_Lk); + tt::log_debug("padded_Nqt: {}", padded_Nqt); + tt::log_debug("padded_Nkt: {}", padded_Nkt); + tt::log_debug("padded_Lqt: {}", padded_Lqt); + tt::log_debug("padded_Lkt: {}", padded_Lkt); + + // Log tile dimensions + tt::log_debug("DHt: {}", DHt); + tt::log_debug("valid_Nt: {}", valid_Nt); + tt::log_debug("valid_Lt: {}", valid_Lt); + + // Log chunking parameters + tt::log_debug("Sq_chunk_t: {}", Sq_chunk_t); + tt::log_debug("Sk_chunk_t: {}", Sk_chunk_t); + tt::log_debug("q_chunk_size: {}", q_chunk_size); + tt::log_debug("k_chunk_size: {}", k_chunk_size); + tt::log_debug("q_num_chunks: {}", q_num_chunks); + tt::log_debug("k_num_chunks: {}", k_num_chunks); + + // Log concatenated dimensions + tt::log_debug("cat_Sq: {}", cat_Sq); + tt::log_debug("cat_Sk: {}", cat_Sk); + tt::log_debug("cat_Sqt: {}", cat_Sqt); + tt::log_debug("cat_Skt: {}", cat_Skt); + + tt::log_debug("use_joint_mask: {}", use_joint_mask); + + Program program = CreateProgram(); + + IDevice* device = input_tensor_q.device(); + + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = + get_compute_kernel_config_args(device->arch(), compute_kernel_config); + + CoreCoord grid_size = program_config.has_value() ? program_config->compute_with_storage_grid_size + : device->compute_with_storage_grid_size(); + bool exp_approx_mode = + program_config.has_value() + ? (program_config->exp_approx_mode.has_value() ? program_config->exp_approx_mode.value() : true) + : true; + + auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); + uint32_t num_cores = grid_size.x * grid_size.y; + + TT_FATAL( + num_cores <= device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y, + "Provided grid must not contain more cores than the device. Got {} cores, expected at most {} cores.", + num_cores, + device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y); + + // Parallelization scheme + // We will choose parallelization factors for batch, num_heads, and q_seq_len in that order + uint32_t batch_parallel_factor = std::min(B, num_cores); + uint32_t nh_parallel_factor = std::min(num_cores / batch_parallel_factor, NH); + uint32_t q_parallel_factor = std::min(num_cores / (batch_parallel_factor * nh_parallel_factor), q_num_chunks); + + TT_FATAL( + batch_parallel_factor * nh_parallel_factor * q_parallel_factor <= num_cores, + "Parallelism must not exceed number of cores. Got {}, expected at most {}.", + batch_parallel_factor * nh_parallel_factor * q_parallel_factor, + num_cores); + + tt::log_debug("Parallelization scheme:"); + tt::log_debug("batch_parallel_factor: {}", batch_parallel_factor); + tt::log_debug("nh_parallel_factor: {}", nh_parallel_factor); + tt::log_debug("q_parallel_factor: {}", q_parallel_factor); + + // Ceiling divide to allow for non-perfect divisions + const uint32_t batch_per_core = tt::div_up(B, batch_parallel_factor); + const uint32_t nh_per_core = tt::div_up(NH, nh_parallel_factor); + const uint32_t q_per_core = tt::div_up(q_num_chunks, q_parallel_factor); + + const uint32_t q_buffer_factor = (q_per_core > 1) ? 2 : 1; + + tt::log_debug("q_per_core: {}", q_per_core); + + // These tile capacity counts for CBs need to match the number of tiles expected by the kernel (softmax.cpp) + uint32_t q_tiles = Sq_chunk_t * DHt * q_buffer_factor; + uint32_t k_tiles = Sk_chunk_t * DHt * 2; // double buffer + uint32_t v_tiles = Sk_chunk_t * DHt * 2; // double buffer + uint32_t mask_tiles = Sq_chunk_t * Sk_chunk_t; + uint32_t qk_tiles = Sq_chunk_t * Sk_chunk_t; + uint32_t out_im_tiles = Sq_chunk_t * DHt; + uint32_t out0_t = Sq_chunk_t * DHt; + uint32_t scale_tiles = 1; + uint32_t statistics_tiles = Sq_chunk_t; // Single column of values in each iteration + + // log all values + tt::log_debug("q_tiles: {}", q_tiles); + tt::log_debug("k_tiles: {}", k_tiles); + tt::log_debug("v_tiles: {}", v_tiles); + tt::log_debug("mask_tiles: {}", mask_tiles); + tt::log_debug("qk_tiles: {}", qk_tiles); + tt::log_debug("out0_t: {}", out0_t); + tt::log_debug("scale_tiles: {}", scale_tiles); + tt::log_debug("statistics_tiles: {}", statistics_tiles); + + // Host code is responsible for determining matmul configuration + const uint32_t dst_size = fp32_dest_acc_en ? 4 : 8; + const uint32_t qk_in0_block_w = DHt; + // max of Sk_chunk_t and dst_size + const uint32_t qk_out_subblock_w = std::min(Sk_chunk_t, dst_size); + // If qk_out_subblock_w is full row of output, scale subblock_h so volume = dst_size. Otherwise it's 1 to maintain + // row-major intermediate buffer. + const uint32_t qk_out_subblock_h = + (qk_out_subblock_w == Sk_chunk_t) ? (std::min(Sq_chunk_t, dst_size / qk_out_subblock_w)) : 1; + + const uint32_t qk_in0_num_subblocks = Sq_chunk_t / qk_out_subblock_h; + const uint32_t qk_in1_num_subblocks = Sk_chunk_t / qk_out_subblock_w; + const uint32_t qk_num_blocks = DHt / qk_in0_block_w; + + // now for out0 + const uint32_t out_in0_block_w = Sk_chunk_t; + const uint32_t out_out_subblock_w = std::min(DHt, dst_size); + const uint32_t out_out_subblock_h = + (out_out_subblock_w == DHt) ? (std::min(Sq_chunk_t, dst_size / out_out_subblock_w)) : 1; + + const uint32_t out_in0_num_subblocks = Sq_chunk_t / out_out_subblock_h; + const uint32_t out_in1_num_subblocks = DHt / out_out_subblock_w; + const uint32_t out_num_blocks = Sk_chunk_t / out_in0_block_w; + + // log all values + tt::log_debug("dst_size: {}", dst_size); + tt::log_debug("qk_in0_block_w: {}", qk_in0_block_w); + tt::log_debug("qk_out_subblock_w: {}", qk_out_subblock_w); + tt::log_debug("qk_out_subblock_h: {}", qk_out_subblock_h); + tt::log_debug("qk_in0_num_subblocks: {}", qk_in0_num_subblocks); + tt::log_debug("qk_in1_num_subblocks: {}", qk_in1_num_subblocks); + tt::log_debug("qk_num_blocks: {}", qk_num_blocks); + tt::log_debug("out_in0_block_w: {}", out_in0_block_w); + tt::log_debug("out_out_subblock_w: {}", out_out_subblock_w); + tt::log_debug("out_out_subblock_h: {}", out_out_subblock_h); + tt::log_debug("out_in0_num_subblocks: {}", out_in0_num_subblocks); + tt::log_debug("out_in1_num_subblocks: {}", out_in1_num_subblocks); + tt::log_debug("out_num_blocks: {}", out_num_blocks); + + // Determine granularity for statistics computation + const uint32_t stats_granularity = std::min(Sq_chunk_t, dst_size); + // Find log2 of stats_granularity using std + const uint32_t log2_stats_granularity = std::log2(stats_granularity); + // Assert that this is a power of 2 + TT_FATAL( + stats_granularity == (1 << log2_stats_granularity), + "stats_granularity must be a power of 2. Got {}.", + stats_granularity); + + const uint32_t sub_exp_granularity = std::min(Sk_chunk_t, dst_size); + const uint32_t log2_sub_exp_granularity = std::log2(sub_exp_granularity); + TT_FATAL( + sub_exp_granularity == (1 << log2_sub_exp_granularity), + "sub_exp_granularity must be a power of 2. Got {}.", + sub_exp_granularity); + + const uint32_t mul_bcast_granularity = std::min(Sq_chunk_t * Sk_chunk_t, dst_size); + const uint32_t log2_mul_bcast_granularity = std::log2(mul_bcast_granularity); + TT_FATAL( + mul_bcast_granularity == (1 << log2_mul_bcast_granularity), + "mul_bcast_granularity must be a power of 2. Got {}.", + mul_bcast_granularity); + + uint32_t dht_granularity = std::min(DHt, dst_size); + uint32_t log2_dht_granularity = std::log2(dht_granularity); + // Sometimes DHt is not a power of 2, so granularity should be 1 + if (dht_granularity != (1 << log2_dht_granularity)) { + dht_granularity = 1; + log2_dht_granularity = 0; + } + TT_FATAL( + dht_granularity == (1 << log2_dht_granularity), + "dht_granularity must be a power of 2. Got {}.", + dht_granularity); + + // Log these + tt::log_debug("stats_granularity: {}", stats_granularity); + tt::log_debug("log2_stats_granularity: {}", log2_stats_granularity); + tt::log_debug("sub_exp_granularity: {}", sub_exp_granularity); + tt::log_debug("log2_sub_exp_granularity: {}", log2_sub_exp_granularity); + tt::log_debug("mul_bcast_granularity: {}", mul_bcast_granularity); + tt::log_debug("log2_mul_bcast_granularity: {}", log2_mul_bcast_granularity); + tt::log_debug("dht_granularity: {}", dht_granularity); + tt::log_debug("log2_dht_granularity: {}", log2_dht_granularity); + + // Reduce ops need to multiply by a scalar. We always want to multiply by 1.0f + class bfloat16 bfloat_identity_scalar(1.0f); + uint32_t packed_identity_scalar = pack_two_bfloat16_into_uint32({bfloat_identity_scalar, bfloat_identity_scalar}); + + union { + float f; + uint32_t u; + } scale_union; + scale_union.f = scale.value_or(1.0f); + + // log scale + tt::log_debug("scale: {}", scale_union.f); + + std::vector reader_compile_time_args = { + B, + NH, + DHt, + Sq_chunk_t, + Sk_chunk_t, + k_num_chunks, + valid_Nt, + valid_Lt, + padded_Nqt, + padded_Nkt, + padded_Lqt, + padded_Lkt, + num_cores, + }; + + // Calculate which K chunks contain the mask boundaries + // If a tensor does not require masking, set to MAX_UINT32. This avoids a + // bug in the mask generation code, which would mask a full, valid chunk + // with -inf. + const uint32_t mask_chunk_0 = + (padded_Nk != N) ? (padded_Nkt / Sk_chunk_t) - 1 : (uint32_t)(-1); // idx of last chunk in first sequence + const uint32_t mask_chunk_1 = + (padded_Lk != L) ? (cat_Skt / Sk_chunk_t) - 1 : (uint32_t)(-1); // idx of last chunk in second sequence + + std::vector writer_compile_time_args = { + B, + NH, + DHt, + Sq_chunk_t, + Sk_chunk_t, + k_num_chunks, + valid_Nt, + valid_Lt, + padded_Nqt, + padded_Nkt, + padded_Lqt, + padded_Lkt, + N, + L, + num_cores, + packed_identity_scalar, + scale_union.u, + (uint32_t)use_joint_mask, + mask_chunk_0, + mask_chunk_1, + }; + + std::vector compute_compile_time_args = { + B, + NH, + cat_Skt, + DHt, + Sq_chunk_t, + Sk_chunk_t, + k_num_chunks, + qk_in0_block_w, + qk_out_subblock_w, + qk_out_subblock_h, + qk_in0_num_subblocks, + qk_in1_num_subblocks, + qk_num_blocks, + out_in0_block_w, + out_out_subblock_w, + out_out_subblock_h, + out_in0_num_subblocks, + out_in1_num_subblocks, + out_num_blocks, + (uint32_t)use_joint_mask, + mask_chunk_0, + mask_chunk_1, + }; + + std::map defines; + defines["STATS_GRANULARITY"] = std::to_string(stats_granularity); + defines["LOG2_STATS_GRANULARITY"] = std::to_string(log2_stats_granularity); + defines["SUB_EXP_GRANULARITY"] = std::to_string(sub_exp_granularity); + defines["LOG2_SUB_EXP_GRANULARITY"] = std::to_string(log2_sub_exp_granularity); + defines["MUL_BCAST_GRANULARITY"] = std::to_string(mul_bcast_granularity); + defines["LOG2_MUL_BCAST_GRANULARITY"] = std::to_string(log2_mul_bcast_granularity); + defines["DHT_GRANULARITY"] = std::to_string(dht_granularity); + defines["LOG2_DHT_GRANULARITY"] = std::to_string(log2_dht_granularity); + defines["EXP_APPROX_MODE"] = std::to_string(exp_approx_mode); + + auto reader_kernels_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/joint_reader.cpp", + core_grid, + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args, defines)); + + auto writer_kernels_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/joint_writer.cpp", + core_grid, + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args, defines)); + + auto compute_kernels_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/joint_sdpa.cpp", + core_grid, + tt::tt_metal::ComputeConfig{ + .math_fidelity = math_fidelity, + .fp32_dest_acc_en = fp32_dest_acc_en, + .math_approx_mode = math_approx_mode, + .compile_args = compute_compile_time_args, + .defines = defines}); + + // Create circular buffers + + tt::DataFormat q_df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor_q.get_dtype()); + tt::DataFormat k_df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor_k.get_dtype()); + tt::DataFormat v_df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor_v.get_dtype()); + tt::DataFormat mask_df = tt::DataFormat::Float16_b; + tt::DataFormat out_df = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype()); + tt::DataFormat scalar_df = tt::DataFormat::Float16_b; + tt::DataFormat im_df = tt::DataFormat::Float16_b; // need to disable fp32 cbs (Issue #13364) fp32_dest_acc_en ? + // tt::DataFormat::Float32 : tt::DataFormat::Float16_b; + tt::DataFormat stats_df = im_df; + + uint32_t q_tile_size = tt::tt_metal::detail::TileSize(q_df); + uint32_t k_tile_size = tt::tt_metal::detail::TileSize(k_df); + uint32_t v_tile_size = tt::tt_metal::detail::TileSize(v_df); + uint32_t mask_tile_size = tt::tt_metal::detail::TileSize(mask_df); + uint32_t out_tile_size = tt::tt_metal::detail::TileSize(out_df); + uint32_t scalar_tile_size = tt::tt_metal::detail::TileSize(scalar_df); + uint32_t im_tile_size = tt::tt_metal::detail::TileSize(im_df); + uint32_t stats_tile_size = tt::tt_metal::detail::TileSize(stats_df); + + log_debug("q_data_format: {}", q_df); + log_debug("k_data_format: {}", k_df); + log_debug("v_data_format: {}", v_df); + log_debug("mask_data_format: {}", mask_df); + log_debug("out_data_format: {}", out_df); + log_debug("scalar_data_format: {}", scalar_df); + log_debug("intermediate_data_format: {}", im_df); + log_debug("statistics_data_format: {}", stats_df); + + // Q input + auto c_in0_config = CircularBufferConfig(q_tiles * q_tile_size, {{tt::CBIndex::c_0, q_df}}) + .set_page_size(tt::CBIndex::c_0, q_tile_size); + + auto cb_in0_id = CreateCircularBuffer(program, core_grid, c_in0_config); + // K input + auto c_in1_config = CircularBufferConfig(k_tiles * k_tile_size, {{tt::CBIndex::c_1, k_df}}) + .set_page_size(tt::CBIndex::c_1, k_tile_size); + auto cb_in1_id = CreateCircularBuffer(program, core_grid, c_in1_config); + // V input + auto c_in2_config = CircularBufferConfig(v_tiles * v_tile_size, {{tt::CBIndex::c_2, v_df}}) + .set_page_size(tt::CBIndex::c_2, v_tile_size); + auto cb_in2_id = CreateCircularBuffer(program, core_grid, c_in2_config); + + // Only create mask buffer if it's going to be used + if (use_joint_mask) { + // attn_mask input + auto c_in3_config = CircularBufferConfig(mask_tiles * mask_tile_size, {{tt::CB::c_in3, mask_df}}) + .set_page_size(tt::CB::c_in3, mask_tile_size); + auto cb_in3_id = CreateCircularBuffer(program, core_grid, c_in3_config); + } + + // scale input + auto c_in4_config = CircularBufferConfig(scale_tiles * scalar_tile_size, {{tt::CBIndex::c_4, scalar_df}}) + .set_page_size(tt::CBIndex::c_4, scalar_tile_size); + auto cb_in4_id = CreateCircularBuffer(program, core_grid, c_in4_config); + + // identity scale input + auto c_in5_config = CircularBufferConfig(scale_tiles * scalar_tile_size, {{tt::CBIndex::c_5, scalar_df}}) + .set_page_size(tt::CBIndex::c_5, scalar_tile_size); + auto cb_in5_id = CreateCircularBuffer(program, core_grid, c_in5_config); + + // cb_qk_im + auto c_intermed0_config = CircularBufferConfig(qk_tiles * im_tile_size, {{tt::CBIndex::c_24, im_df}}) + .set_page_size(tt::CBIndex::c_24, im_tile_size); + auto cb_intermed0_id = CreateCircularBuffer(program, core_grid, c_intermed0_config); + + // cb_out_im + auto c_intermed1_config = CircularBufferConfig(out_im_tiles * im_tile_size, {{tt::CBIndex::c_25, im_df}}) + .set_page_size(tt::CBIndex::c_25, im_tile_size); + auto cb_intermed1_id = CreateCircularBuffer(program, core_grid, c_intermed1_config); + + // cb_out_accumulate_im + auto c_intermed2_config = CircularBufferConfig(out_im_tiles * im_tile_size, {{tt::CBIndex::c_26, im_df}}) + .set_page_size(tt::CBIndex::c_26, im_tile_size); + auto cb_intermed2_id = CreateCircularBuffer(program, core_grid, c_intermed2_config); + + // cb_cur_max + auto c_intermed3_config = CircularBufferConfig(statistics_tiles * stats_tile_size, {{tt::CBIndex::c_27, stats_df}}) + .set_page_size(tt::CBIndex::c_27, stats_tile_size); + auto cb_intermed3_id = CreateCircularBuffer(program, core_grid, c_intermed3_config); + + // cb_prev_max + auto c_intermed4_config = CircularBufferConfig(statistics_tiles * stats_tile_size, {{tt::CBIndex::c_28, stats_df}}) + .set_page_size(tt::CBIndex::c_28, stats_tile_size); + auto cb_intermed4_id = CreateCircularBuffer(program, core_grid, c_intermed4_config); + + // cb_cur_sum + auto c_intermed5_config = CircularBufferConfig(statistics_tiles * stats_tile_size, {{tt::CBIndex::c_29, stats_df}}) + .set_page_size(tt::CBIndex::c_29, stats_tile_size); + auto cb_intermed5_id = CreateCircularBuffer(program, core_grid, c_intermed5_config); + + // cb_prev_sum + auto c_intermed6_config = CircularBufferConfig(statistics_tiles * stats_tile_size, {{tt::CBIndex::c_30, stats_df}}) + .set_page_size(tt::CBIndex::c_30, stats_tile_size); + auto cb_intermed6_id = CreateCircularBuffer(program, core_grid, c_intermed6_config); + + // cb_exp_max_diff + auto c_intermed7_config = CircularBufferConfig(statistics_tiles * stats_tile_size, {{tt::CBIndex::c_31, stats_df}}) + .set_page_size(tt::CBIndex::c_31, stats_tile_size); + auto cb_intermed7_id = CreateCircularBuffer(program, core_grid, c_intermed7_config); + + // Output + auto c_out0_config = CircularBufferConfig(out0_t * out_tile_size, {{tt::CBIndex::c_16, out_df}}) + .set_page_size(tt::CBIndex::c_16, out_tile_size); + auto cb_out0_id = CreateCircularBuffer(program, core_grid, c_out0_config); + + uint32_t q_addr = input_tensor_q.buffer()->address(); + uint32_t k_addr = input_tensor_k.buffer()->address(); + uint32_t v_addr = input_tensor_v.buffer()->address(); + uint32_t joint_q_addr = joint_tensor_q.buffer()->address(); + uint32_t joint_k_addr = joint_tensor_k.buffer()->address(); + uint32_t joint_v_addr = joint_tensor_v.buffer()->address(); + uint32_t out_addr = output_tensor.buffer()->address(); + uint32_t joint_out_addr = joint_output_tensor.buffer()->address(); + + // Set reader rt args + for (uint32_t i = 0; i < num_cores; ++i) { + CoreCoord core = {i % grid_size.x, i / grid_size.x}; + + uint32_t local_batch_start = (i / (nh_parallel_factor * q_parallel_factor)) * batch_per_core; + uint32_t local_batch_end = local_batch_start + batch_per_core; + uint32_t local_nh_start = ((i / q_parallel_factor) % nh_parallel_factor) * nh_per_core; + uint32_t local_nh_end = local_nh_start + nh_per_core; + uint32_t local_q_start = (i % q_parallel_factor) * q_per_core; + uint32_t local_q_end = local_q_start + q_per_core; + + // clamp all to max values for non-even partitioning + local_batch_start = std::min(local_batch_start, B); + local_batch_end = std::min(local_batch_end, B); + local_nh_start = std::min(local_nh_start, NH); + local_nh_end = std::min(local_nh_end, NH); + local_q_start = std::min(local_q_start, q_num_chunks); + local_q_end = std::min(local_q_end, q_num_chunks); + + // log the above + tt::log_debug("core: {}", i); + tt::log_debug("x={},y={}", core.x, core.y); + tt::log_debug("local_batch_start: {}", local_batch_start); + tt::log_debug("local_batch_end: {}", local_batch_end); + tt::log_debug("local_nh_start: {}", local_nh_start); + tt::log_debug("local_nh_end: {}", local_nh_end); + tt::log_debug("local_q_start: {}", local_q_start); + tt::log_debug("local_q_end: {}", local_q_end); + + SetRuntimeArgs( + program, + reader_kernels_id, + core, + {q_addr, + k_addr, + v_addr, + joint_q_addr, + joint_k_addr, + joint_v_addr, + local_batch_start, + local_batch_end, + local_nh_start, + local_nh_end, + local_q_start, + local_q_end}); + + // Writer args + SetRuntimeArgs( + program, + writer_kernels_id, + core, + {out_addr, + joint_out_addr, + local_batch_start, + local_batch_end, + local_nh_start, + local_nh_end, + local_q_start, + local_q_end}); + + // Compute args + SetRuntimeArgs( + program, + compute_kernels_id, + core, + {local_batch_start, local_batch_end, local_nh_start, local_nh_end, local_q_start, local_q_end}); + } + + auto override_runtime_arguments_callback = + [num_cores, grid_size, reader_kernels_id, writer_kernels_id, compute_kernels_id]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + // Get addresses for regular tensors + auto q_buffer = input_tensors.at(0).buffer(); + auto k_buffer = input_tensors.at(1).buffer(); + auto v_buffer = input_tensors.at(2).buffer(); + auto joint_q_buffer = input_tensors.at(3).buffer(); + auto joint_k_buffer = input_tensors.at(4).buffer(); + auto joint_v_buffer = input_tensors.at(5).buffer(); + + // Get addresses for output tensors + auto out_buffer = output_tensors.at(0).buffer(); + auto joint_out_buffer = output_tensors.at(1).buffer(); + + uint32_t q_addr = q_buffer->address(); + uint32_t k_addr = k_buffer->address(); + uint32_t v_addr = v_buffer->address(); + uint32_t joint_q_addr = joint_q_buffer->address(); + uint32_t joint_k_addr = joint_k_buffer->address(); + uint32_t joint_v_addr = joint_v_buffer->address(); + uint32_t out_addr = out_buffer->address(); + uint32_t joint_out_addr = joint_out_buffer->address(); + + auto& reader_args_by_core = GetRuntimeArgs(program, reader_kernels_id); + auto& writer_args_by_core = GetRuntimeArgs(program, writer_kernels_id); + auto& compute_args_by_core = GetRuntimeArgs(program, compute_kernels_id); + + for (uint32_t i = 0; i < num_cores; ++i) { + CoreCoord core = {i % grid_size.x, i / grid_size.x}; + + auto& reader_args = reader_args_by_core[core.x][core.y]; + auto& writer_args = writer_args_by_core[core.x][core.y]; + auto& compute_args = compute_args_by_core[core.x][core.y]; + + // Update reader args + reader_args[0] = q_addr; + reader_args[1] = k_addr; + reader_args[2] = v_addr; + reader_args[3] = joint_q_addr; + reader_args[4] = joint_k_addr; + reader_args[5] = joint_v_addr; + + // Update writer args + writer_args[0] = out_addr; + writer_args[1] = joint_out_addr; + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +} // namespace ttnn::operations::transformer::detail diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_program_factory.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_program_factory.hpp new file mode 100644 index 00000000000..285be3dfb2c --- /dev/null +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/joint_sdpa_program_factory.hpp @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" +#include "ttnn/operation.hpp" +#include "ttnn/operations/transformer/sdpa_config.hpp" + +namespace ttnn::operations::transformer::detail { + +operation::ProgramWithCallbacks joint_sdpa( + const Tensor& input_tensor_q, + const Tensor& input_tensor_k, + const Tensor& input_tensor_v, + const Tensor& joint_tensor_q, + const Tensor& joint_tensor_k, + const Tensor& joint_tensor_v, + const Tensor& output_tensor, + const Tensor& joint_output_tensor, + std::optional scale, + std::size_t q_chunk_size, + std::size_t k_chunk_size, + DeviceComputeKernelConfig compute_kernel_config, + std::optional program_config); + +} // namespace ttnn::operations::transformer::detail diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/joint_sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/joint_sdpa.cpp new file mode 100644 index 00000000000..f7178265a14 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/joint_sdpa.cpp @@ -0,0 +1,197 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#define REDUCE_OP (PoolType::MAX) +#define REDUCE_DIM (ReduceDim::REDUCE_ROW) + +#include "compute_kernel_api.h" +#include "compute_common.hpp" + +namespace NAMESPACE { +void MAIN { + constexpr uint32_t B = get_compile_time_arg_val(0); + constexpr uint32_t NH = get_compile_time_arg_val(1); + constexpr uint32_t Skt = get_compile_time_arg_val(2); + constexpr uint32_t DHt = get_compile_time_arg_val(3); + constexpr uint32_t Sq_chunk_t = get_compile_time_arg_val(4); + constexpr uint32_t Sk_chunk_t = get_compile_time_arg_val(5); + constexpr uint32_t k_num_chunks = get_compile_time_arg_val(6); + + constexpr uint32_t qk_in0_block_w = get_compile_time_arg_val(7); + constexpr uint32_t qk_subblock_w = get_compile_time_arg_val(8); + constexpr uint32_t qk_subblock_h = get_compile_time_arg_val(9); + constexpr uint32_t qk_in0_num_subblocks = get_compile_time_arg_val(10); + constexpr uint32_t qk_in1_num_subblocks = get_compile_time_arg_val(11); + constexpr uint32_t qk_num_blocks = get_compile_time_arg_val(12); + constexpr uint32_t out_in0_block_w = get_compile_time_arg_val(13); + constexpr uint32_t out_subblock_w = get_compile_time_arg_val(14); + constexpr uint32_t out_subblock_h = get_compile_time_arg_val(15); + constexpr uint32_t out_in0_num_subblocks = get_compile_time_arg_val(16); + constexpr uint32_t out_in1_num_subblocks = get_compile_time_arg_val(17); + constexpr uint32_t out_num_blocks = get_compile_time_arg_val(18); + + constexpr bool use_joint_mask = get_compile_time_arg_val(19) == 1; + constexpr uint32_t mask_chunk_0 = get_compile_time_arg_val(20); + constexpr uint32_t mask_chunk_1 = get_compile_time_arg_val(21); + + uint32_t argidx = 0; + const uint32_t local_batch_start = get_arg_val(argidx++); + const uint32_t local_batch_end = get_arg_val(argidx++); + const uint32_t local_nh_start = get_arg_val(argidx++); + const uint32_t local_nh_end = get_arg_val(argidx++); + const uint32_t local_q_start = get_arg_val(argidx++); + const uint32_t local_q_end = get_arg_val(argidx++); + + constexpr uint32_t q_chunk_tiles = Sq_chunk_t * DHt; + constexpr uint32_t k_chunk_tiles = Sk_chunk_t * DHt; + constexpr uint32_t qk_chunk_tiles = Sq_chunk_t * Sk_chunk_t; + constexpr uint32_t out_chunk_tiles = Sq_chunk_t * DHt; + + constexpr uint32_t cb_q_in = tt::CBIndex::c_0; + constexpr uint32_t cb_k_in = tt::CBIndex::c_1; + constexpr uint32_t cb_v_in = tt::CBIndex::c_2; + constexpr uint32_t cb_mask_in = tt::CBIndex::c_3; + constexpr uint32_t cb_scale_in = tt::CBIndex::c_4; + constexpr uint32_t cb_identity_scale_in = tt::CBIndex::c_5; + + constexpr uint32_t cb_qk_im = tt::CBIndex::c_24; + constexpr uint32_t cb_out_im_A = tt::CBIndex::c_25; + constexpr uint32_t cb_out_im_B = tt::CBIndex::c_26; + constexpr uint32_t cb_max_A = tt::CBIndex::c_27; + constexpr uint32_t cb_max_B = tt::CBIndex::c_28; + constexpr uint32_t cb_sum_A = tt::CBIndex::c_29; + constexpr uint32_t cb_sum_B = tt::CBIndex::c_30; + constexpr uint32_t cb_exp_max_diff = tt::CBIndex::c_31; + + constexpr uint32_t cb_out = tt::CBIndex::c_16; + + mm_init(); + + for (uint32_t nb = local_batch_start; nb < local_batch_end; ++nb) { + for (uint32_t nq = local_nh_start; nq < local_nh_end; ++nq) { + for (uint32_t q_chunk = local_q_start; q_chunk < local_q_end; ++q_chunk) { + // Set up ping pong buffers + uint32_t alias_prev_sum = cb_sum_A; + uint32_t alias_cur_sum = cb_sum_B; + uint32_t alias_prev_max = cb_max_A; + uint32_t alias_cur_max = cb_max_B; + uint32_t alias_mm2_prev_out = cb_out_im_A; + uint32_t alias_mm2_cur_out = cb_out_im_B; + + cb_wait_front(cb_q_in, q_chunk_tiles); + + for (uint32_t k_chunk = 0; k_chunk < k_num_chunks; ++k_chunk) { + /* QK = Q_CHUNK @ K_CHUNK */ + pack_reconfig_data_format(cb_qk_im); + matmul_blocks( + cb_q_in, + cb_k_in, + cb_qk_im, + Sq_chunk_t, + Sk_chunk_t, + DHt, + qk_num_blocks, + qk_in0_num_subblocks, + qk_in1_num_subblocks, + qk_in0_block_w, + qk_subblock_h, + qk_subblock_w, + true /*transpose*/); + + /* QK *= SCALE */ + mul_block_bcast_scalar_inplace(); + + if constexpr (use_joint_mask) { + if ((k_chunk == mask_chunk_0) || (k_chunk == mask_chunk_1)) { + /* QK += MASK */ + reconfig_data_format(cb_qk_im, cb_mask_in); + add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles); + } + } + + /* Compute max and sum for softmax */ + reconfig_data_format(cb_qk_im, cb_identity_scale_in); + reduce_c< + PoolType::MAX, + ReduceDim::REDUCE_ROW, + cb_qk_im, + cb_identity_scale_in, + Sq_chunk_t, + Sk_chunk_t>(alias_cur_max); + + if (k_chunk > 0) { + max_block_inplace(alias_cur_max, alias_prev_max); + } + + /* QK -= cb_cur_max */ + /* QK = exp(QK)*/ + sub_exp_block_bcast_cols_inplace(alias_cur_max); + + /* cb_cur_sum = sum(cb_qk_im, dim=-1) */ + reduce_c< + PoolType::SUM, + ReduceDim::REDUCE_ROW, + cb_qk_im, + cb_identity_scale_in, + Sq_chunk_t, + Sk_chunk_t>(alias_cur_sum); + + /* OUT_IM = QK @ V_CHUNK */ + matmul_blocks( + cb_qk_im, + cb_v_in, + alias_mm2_cur_out, + Sq_chunk_t, + DHt, + Sk_chunk_t, + out_num_blocks, + out_in0_num_subblocks, + out_in1_num_subblocks, + out_in0_block_w, + out_subblock_h, + out_subblock_w, + false /*transpose*/); + + cb_pop_front(cb_qk_im, qk_chunk_tiles); + reconfig_data_format(alias_prev_max, alias_cur_max); + + /* OUT_ACC += OUT_IM */ + if (k_chunk > 0) { + /* cb_exp_max_diff = torch.exp(cb_prev_max - cb_cur_max) */ + sub_exp_block(alias_prev_max, alias_cur_max, cb_exp_max_diff, Sq_chunk_t); + cb_pop_front(alias_prev_max, Sq_chunk_t); + + /* cb_prev_sum *= cb_exp_max_diff */ + mul_block_inplace(alias_prev_sum, cb_exp_max_diff, Sq_chunk_t); + /* cb_cur_sum += cb_prev_sum */ + add_block_inplace(alias_cur_sum, alias_prev_sum, Sq_chunk_t); + + /* cb_out_accumulate_im *= cb_exp_max_diff */ + mul_block_bcast_cols_inplace(alias_mm2_prev_out, cb_exp_max_diff); + add_block_inplace(alias_mm2_cur_out, alias_mm2_prev_out, out_chunk_tiles); + } + + // Swap ping-pong buffers + std::swap(alias_prev_sum, alias_cur_sum); + std::swap(alias_mm2_prev_out, alias_mm2_cur_out); + std::swap(alias_prev_max, alias_cur_max); + } + + /* cb_cur_sum = 1.0 / cb_cur_sum */ + recip_block_inplace(alias_prev_sum, Sq_chunk_t); + + /* cb_out_accumulate_im *= cb_cur_sum */ + mul_block_bcast_cols_inplace(alias_mm2_prev_out, alias_prev_sum); + pack_reconfig_data_format(cb_out); + copy_block(alias_mm2_prev_out, cb_out, out_chunk_tiles); + + cb_pop_front(cb_q_in, q_chunk_tiles); + cb_pop_front(alias_prev_max, Sq_chunk_t); + } + } + } +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/dataflow_common.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/dataflow_common.hpp index 26c2e6a7112..ef36c6c17ed 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/dataflow_common.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/dataflow_common.hpp @@ -28,10 +28,9 @@ uint32_t virtual_seq_tile_id_to_physical_tile_id( } class TensorTileShape { +public: uint32_t shape[4]; uint32_t strides[4]; - -public: // Constructor to initialize with 4D shape TensorTileShape(uint32_t d0, uint32_t d1, uint32_t d2, uint32_t d3) { shape[0] = d0; @@ -366,3 +365,137 @@ void generate_noncausal_padded_mask(uint32_t Sq_chunk_t, uint32_t Sk_chunk_t, ui noc_async_read_barrier(); cb_push_back(cb_mask_in, mask_size_tiles); } + +struct CatAddrGenerator { + InterleavedAddrGenFast first_reader; + InterleavedAddrGenFast second_reader; + TensorTileShape first_shape; + TensorTileShape second_shape; + uint32_t first_seq_padded; + uint32_t second_seq_padded; + + CatAddrGenerator( + const InterleavedAddrGenFast& first_reader, + TensorTileShape first_logical_shape, + uint32_t first_seq_padded, + const InterleavedAddrGenFast& second_reader, + TensorTileShape second_logical_shape, + uint32_t second_seq_padded) : + first_reader(first_reader), + second_reader(second_reader), + first_shape(first_logical_shape), + second_shape(second_logical_shape), + first_seq_padded(first_seq_padded), + second_seq_padded(second_seq_padded) {} + + uint32_t maybe_read_tile(uint32_t d0, uint32_t d1, uint32_t d2, uint32_t d3, uint32_t dst_addr) const { + if (d2 < first_shape.shape[2]) { + uint32_t tile_id = first_shape.id_of(d0, d1, d2, d3); + noc_async_read_tile(tile_id, first_reader, dst_addr); + return 1; + } else if (d2 >= first_seq_padded && (d2 - first_seq_padded) < second_shape.shape[2]) { + uint32_t adjusted_seq = d2 - first_seq_padded; + uint32_t tile_id = second_shape.id_of(d0, d1, adjusted_seq, d3); + noc_async_read_tile(tile_id, second_reader, dst_addr); + return 1; + } + return 0; + } + + uint32_t maybe_write_tile(uint32_t d0, uint32_t d1, uint32_t d2, uint32_t d3, uint32_t src_addr) const { + if (d2 < first_shape.shape[2]) { + uint32_t tile_id = first_shape.id_of(d0, d1, d2, d3); + noc_async_write_tile(tile_id, first_reader, src_addr); + return 1; + } else if (d2 >= first_seq_padded && (d2 - first_seq_padded) < second_shape.shape[2]) { + uint32_t adjusted_seq = d2 - first_seq_padded; + uint32_t tile_id = second_shape.id_of(d0, d1, adjusted_seq, d3); + noc_async_write_tile(tile_id, second_reader, src_addr); + return 1; + } + return 0; + } +}; + +struct Slice { + uint32_t d0; // batch dimension + uint32_t d1; // head dimension + uint32_t d2_start; // sequence start + uint32_t d2_end; // sequence end + uint32_t d3_start; // feature start + uint32_t d3_end; // feature end + + Slice(uint32_t d0, uint32_t d1, uint32_t d2_start, uint32_t d2_end, uint32_t d3_start, uint32_t d3_end) : + d0(d0), d1(d1), d2_start(d2_start), d2_end(d2_end), d3_start(d3_start), d3_end(d3_end) {} + + uint32_t get_d2_size() const { return d2_end - d2_start; } + uint32_t get_d3_size() const { return d3_end - d3_start; } +}; + +void read_block( + const CatAddrGenerator& cat_addr_generator, + const Slice& src_slice, + const uint32_t cb_id, + const uint32_t tile_bytes, + const uint32_t barrier_threshold, + const bool transpose) { + const uint32_t src_rows = src_slice.get_d2_size(); + const uint32_t src_cols = src_slice.get_d3_size(); + const uint32_t num_tiles = src_rows * src_cols; + cb_reserve_back(cb_id, num_tiles); + const uint32_t base_write_ptr = get_write_ptr(cb_id); + uint32_t outer_ptr_stride = transpose ? tile_bytes : src_cols * tile_bytes; + uint32_t inner_ptr_stride = transpose ? tile_bytes * src_rows : tile_bytes; + + uint32_t barrier_count = 0; + for (uint32_t row = 0; row < src_rows; ++row) { + uint32_t write_ptr = base_write_ptr + row * outer_ptr_stride; + for (uint32_t col = 0; col < src_cols; ++col) { + uint32_t did_read = cat_addr_generator.maybe_read_tile( + src_slice.d0, src_slice.d1, src_slice.d2_start + row, src_slice.d3_start + col, write_ptr); + + write_ptr += inner_ptr_stride; + barrier_count += did_read; + if (barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } + } + } + noc_async_read_barrier(); + cb_push_back(cb_id, num_tiles); +} + +void write_block( + const CatAddrGenerator& cat_addr_generator, + const Slice& dst_slice, + const uint32_t cb_id, + const uint32_t tile_bytes, + const uint32_t barrier_threshold) { + const uint32_t dst_rows = dst_slice.get_d2_size(); + const uint32_t dst_cols = dst_slice.get_d3_size(); + const uint32_t num_tiles = dst_rows * dst_cols; + const uint32_t base_read_ptr = get_read_ptr(cb_id); + uint32_t outer_ptr_stride = dst_cols * tile_bytes; + uint32_t inner_ptr_stride = tile_bytes; + + uint32_t barrier_count = 0; + + cb_wait_front(cb_id, num_tiles); + for (uint32_t row = 0; row < dst_rows; ++row) { + uint32_t read_ptr = base_read_ptr + row * outer_ptr_stride; + for (uint32_t col = 0; col < dst_cols; ++col) { + uint32_t did_write = cat_addr_generator.maybe_write_tile( + dst_slice.d0, dst_slice.d1, dst_slice.d2_start + row, dst_slice.d3_start + col, read_ptr); + read_ptr += inner_ptr_stride; + + barrier_count += did_write; + if (barrier_count == barrier_threshold) { + noc_async_writes_flushed(); + barrier_count = 0; + } + } + } + noc_async_write_barrier(); + cb_pop_front(cb_id, num_tiles); +} diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/joint_reader.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/joint_reader.cpp new file mode 100644 index 00000000000..60c5b4ce8d1 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/joint_reader.cpp @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "dataflow_common.hpp" + +void kernel_main() { + constexpr uint32_t B = get_compile_time_arg_val(0); + constexpr uint32_t NH = get_compile_time_arg_val(1); + constexpr uint32_t DHt = get_compile_time_arg_val(2); + constexpr uint32_t Sq_chunk_t = get_compile_time_arg_val(3); + constexpr uint32_t Sk_chunk_t = get_compile_time_arg_val(4); + constexpr uint32_t k_num_chunks = get_compile_time_arg_val(5); + constexpr uint32_t valid_Nt = get_compile_time_arg_val(6); + constexpr uint32_t valid_Lt = get_compile_time_arg_val(7); + constexpr uint32_t padded_Nqt = get_compile_time_arg_val(8); + constexpr uint32_t padded_Nkt = get_compile_time_arg_val(9); + constexpr uint32_t padded_Lqt = get_compile_time_arg_val(10); + constexpr uint32_t padded_Lkt = get_compile_time_arg_val(11); + constexpr uint32_t num_cores = get_compile_time_arg_val(12); + + uint32_t argidx = 0; + const uint32_t q_addr = get_arg_val(argidx++); + const uint32_t k_addr = get_arg_val(argidx++); + const uint32_t v_addr = get_arg_val(argidx++); + const uint32_t joint_q_addr = get_arg_val(argidx++); + const uint32_t joint_k_addr = get_arg_val(argidx++); + const uint32_t joint_v_addr = get_arg_val(argidx++); + const uint32_t local_batch_start = get_arg_val(argidx++); + const uint32_t local_batch_end = get_arg_val(argidx++); + const uint32_t local_nh_start = get_arg_val(argidx++); + const uint32_t local_nh_end = get_arg_val(argidx++); + const uint32_t local_q_start = get_arg_val(argidx++); + const uint32_t local_q_end = get_arg_val(argidx++); + + constexpr bool is_dram = true; + + constexpr uint32_t cb_q_in = tt::CBIndex::c_0; + constexpr uint32_t cb_k_in = tt::CBIndex::c_1; + constexpr uint32_t cb_v_in = tt::CBIndex::c_2; + + constexpr uint32_t q_tile_bytes = get_tile_size(cb_q_in); + constexpr DataFormat q_data_format = get_dataformat(cb_q_in); + constexpr uint32_t k_tile_bytes = get_tile_size(cb_k_in); + constexpr DataFormat k_data_format = get_dataformat(cb_k_in); + constexpr uint32_t v_tile_bytes = get_tile_size(cb_v_in); + constexpr DataFormat v_data_format = get_dataformat(cb_v_in); + + constexpr uint32_t barrier_threshold = get_barrier_read_threshold(); + + const InterleavedAddrGenFast q_reader = { + .bank_base_address = q_addr, .page_size = q_tile_bytes, .data_format = q_data_format}; + const InterleavedAddrGenFast k_reader = { + .bank_base_address = k_addr, .page_size = k_tile_bytes, .data_format = k_data_format}; + const InterleavedAddrGenFast v_reader = { + .bank_base_address = v_addr, .page_size = v_tile_bytes, .data_format = v_data_format}; + const InterleavedAddrGenFast joint_q_reader = { + .bank_base_address = joint_q_addr, .page_size = q_tile_bytes, .data_format = q_data_format}; + const InterleavedAddrGenFast joint_k_reader = { + .bank_base_address = joint_k_addr, .page_size = k_tile_bytes, .data_format = k_data_format}; + const InterleavedAddrGenFast joint_v_reader = { + .bank_base_address = joint_v_addr, .page_size = v_tile_bytes, .data_format = v_data_format}; + + const auto input_tile_logical = TensorTileShape(B, NH, valid_Nt, DHt); + const auto joint_tile_logical = TensorTileShape(B, NH, valid_Lt, DHt); + const auto cat_q_generator = + CatAddrGenerator(q_reader, input_tile_logical, padded_Nqt, joint_q_reader, joint_tile_logical, padded_Lqt); + const auto cat_k_generator = + CatAddrGenerator(k_reader, input_tile_logical, padded_Nkt, joint_k_reader, joint_tile_logical, padded_Lkt); + const auto cat_v_generator = + CatAddrGenerator(v_reader, input_tile_logical, padded_Nkt, joint_v_reader, joint_tile_logical, padded_Lkt); + + for (uint32_t nb = local_batch_start; nb < local_batch_end; ++nb) { + for (uint32_t nq = local_nh_start; nq < local_nh_end; ++nq) { + for (uint32_t q_chunk = local_q_start; q_chunk < local_q_end; ++q_chunk) { + const auto q_row_start_tile = q_chunk * Sq_chunk_t; + const auto q_slice = Slice(nb, nq, q_row_start_tile, q_row_start_tile + Sq_chunk_t, 0, DHt); + + read_block( + cat_q_generator, q_slice, cb_q_in, q_tile_bytes, barrier_threshold, false /*transpose*/ + ); + + for (uint32_t k_chunk = 0; k_chunk < k_num_chunks; ++k_chunk) { + const auto kv_row_start_tile = k_chunk * Sk_chunk_t; + const auto kv_slice = Slice(nb, nq, kv_row_start_tile, kv_row_start_tile + Sk_chunk_t, 0, DHt); + + read_block( + cat_k_generator, kv_slice, cb_k_in, k_tile_bytes, barrier_threshold, true /*transpose*/ + ); + + read_block( + cat_v_generator, kv_slice, cb_v_in, v_tile_bytes, barrier_threshold, false /*transpose*/ + ); + } + } + } + } +} diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/joint_writer.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/joint_writer.cpp new file mode 100644 index 00000000000..6379f5f3c8a --- /dev/null +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/joint_writer.cpp @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_bcast_scalar.hpp" +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp" +#include "dataflow_common.hpp" + +void kernel_main() { + constexpr uint32_t B = get_compile_time_arg_val(0); + constexpr uint32_t NH = get_compile_time_arg_val(1); + constexpr uint32_t DHt = get_compile_time_arg_val(2); + constexpr uint32_t Sq_chunk_t = get_compile_time_arg_val(3); + constexpr uint32_t Sk_chunk_t = get_compile_time_arg_val(4); + constexpr uint32_t k_num_chunks = get_compile_time_arg_val(5); + constexpr uint32_t valid_Nt = get_compile_time_arg_val(6); + constexpr uint32_t valid_Lt = get_compile_time_arg_val(7); + constexpr uint32_t padded_Nqt = get_compile_time_arg_val(8); + constexpr uint32_t padded_Nkt = get_compile_time_arg_val(9); + constexpr uint32_t padded_Lqt = get_compile_time_arg_val(10); + constexpr uint32_t padded_Lkt = get_compile_time_arg_val(11); + constexpr uint32_t unpadded_N = get_compile_time_arg_val(12); + constexpr uint32_t unpadded_L = get_compile_time_arg_val(13); + constexpr uint32_t num_cores = get_compile_time_arg_val(14); + constexpr uint32_t identity_scalar_packed = get_compile_time_arg_val(15); + constexpr uint32_t scale_val = get_compile_time_arg_val(16); + constexpr bool use_joint_mask = get_compile_time_arg_val(17) == 1; + constexpr uint32_t mask_chunk_0 = get_compile_time_arg_val(18); + constexpr uint32_t mask_chunk_1 = get_compile_time_arg_val(19); + + uint32_t argidx = 0; + const uint32_t out_addr = get_arg_val(argidx++); + const uint32_t joint_out_addr = get_arg_val(argidx++); + const uint32_t local_batch_start = get_arg_val(argidx++); + const uint32_t local_batch_end = get_arg_val(argidx++); + const uint32_t local_nh_start = get_arg_val(argidx++); + const uint32_t local_nh_end = get_arg_val(argidx++); + const uint32_t local_q_start = get_arg_val(argidx++); + const uint32_t local_q_end = get_arg_val(argidx++); + + constexpr bool is_dram = true; + constexpr uint32_t cb_out = tt::CBIndex::c_16; + constexpr uint32_t cb_mask_in = tt::CBIndex::c_3; + constexpr uint32_t tile_bytes = get_tile_size(cb_out); + constexpr DataFormat data_format = get_dataformat(cb_out); + + const InterleavedAddrGenFast out_writer = { + .bank_base_address = out_addr, .page_size = tile_bytes, .data_format = data_format}; + const InterleavedAddrGenFast joint_out_writer = { + .bank_base_address = joint_out_addr, .page_size = tile_bytes, .data_format = data_format}; + + const auto output_tile_logical = TensorTileShape(B, NH, valid_Nt, DHt); + const auto joint_tile_logical = TensorTileShape(B, NH, valid_Lt, DHt); + const auto cat_out_generator = + CatAddrGenerator(out_writer, output_tile_logical, padded_Nqt, joint_out_writer, joint_tile_logical, padded_Lqt); + + constexpr uint32_t barrier_threshold = get_barrier_read_threshold(); + uint32_t barrier_count = 0; + + constexpr uint32_t cb_scale_in = tt::CBIndex::c_4; + constexpr uint32_t cb_identity_scale_in = tt::CBIndex::c_5; + + generate_bcast_unary_scalar(cb_scale_in, scale_val); + generate_reduce_scaler(cb_identity_scale_in, identity_scalar_packed); + + for (uint32_t nb = local_batch_start; nb < local_batch_end; ++nb) { + for (uint32_t nq = local_nh_start; nq < local_nh_end; ++nq) { + for (uint32_t q_chunk = local_q_start; q_chunk < local_q_end; ++q_chunk) { + if constexpr (use_joint_mask) { + /* + If `use_joint_mask`, then one or both of input tensors are padded. + We already know that input tensors are padded up to Sk_chunk_t. + Therefore, for the last K chunk of the first tensor and the last K chunk of the joint tensor, + we should generate the vertical mask. + */ + if (mask_chunk_0 != (uint32_t)(-1)) { + generate_noncausal_padded_mask(Sq_chunk_t, Sk_chunk_t, unpadded_N); + } + if (mask_chunk_1 != (uint32_t)(-1)) { + generate_noncausal_padded_mask(Sq_chunk_t, Sk_chunk_t, unpadded_L); + } + } + + const uint32_t out_row_start_tile = q_chunk * Sq_chunk_t; + const auto dst_slice = Slice(nb, nq, out_row_start_tile, out_row_start_tile + Sq_chunk_t, 0, DHt); + + write_block(cat_out_generator, dst_slice, cb_out, tile_bytes, barrier_threshold); + } + } + } +} diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.cpp index 32c28ea0ff2..ee52f3c299a 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 @@ -7,6 +7,7 @@ #include #include "device/sdpa_op.hpp" +#include "device/joint_sdpa_op.hpp" #include "ttnn/common/constants.hpp" #include "ttnn/run_operation.hpp" @@ -124,4 +125,62 @@ ttnn::Tensor ExecuteChunkedScaledDotProductAttention::invoke( compute_kernel_config); } +std::tuple ExecuteJointAttention::invoke( + uint8_t queue_id, + const ttnn::Tensor& input_tensor_q, + const ttnn::Tensor& input_tensor_k, + const ttnn::Tensor& input_tensor_v, + const ttnn::Tensor& joint_tensor_q, + const ttnn::Tensor& joint_tensor_k, + const ttnn::Tensor& joint_tensor_v, + const std::string& joint_strategy, + SDPAProgramConfig program_config, + std::optional scale, + std::optional compute_kernel_config) { + auto arch = input_tensor_q.storage_type() == StorageType::DEVICE + ? input_tensor_q.device()->arch() + : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); + auto kernel_config_val = init_device_compute_kernel_config( + input_tensor_q.device()->arch(), compute_kernel_config, MathFidelity::HiFi2, true, false, false); + + auto results = operation::run( + JointScaledDotProductAttention{ + .joint_strategy = joint_strategy, + .scale = scale, + .output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + .program_config = std::move(program_config), + .compute_kernel_config = kernel_config_val}, + {input_tensor_q, input_tensor_k, input_tensor_v, joint_tensor_q, joint_tensor_k, joint_tensor_v}, + {}, + {}, + queue_id); + + return {results.at(0), results.at(1)}; +} + +std::tuple ExecuteJointAttention::invoke( + const ttnn::Tensor& input_tensor_q, + const ttnn::Tensor& input_tensor_k, + const ttnn::Tensor& input_tensor_v, + const ttnn::Tensor& joint_tensor_q, + const ttnn::Tensor& joint_tensor_k, + const ttnn::Tensor& joint_tensor_v, + const std::string& joint_strategy, + SDPAProgramConfig program_config, + std::optional scale, + std::optional compute_kernel_config) { + return invoke( + DefaultQueueId, + input_tensor_q, + input_tensor_k, + input_tensor_v, + joint_tensor_q, + joint_tensor_k, + joint_tensor_v, + joint_strategy, + std::move(program_config), + scale, + compute_kernel_config); +} + } // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.hpp index 1f94b192158..abea10a2d59 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 @@ -61,6 +61,33 @@ struct ExecuteChunkedScaledDotProductAttention { std::optional compute_kernel_config = std::nullopt); }; +struct ExecuteJointAttention { + static std::tuple invoke( + uint8_t queue_id, + const ttnn::Tensor& input_tensor_q, + const ttnn::Tensor& input_tensor_k, + const ttnn::Tensor& input_tensor_v, + const ttnn::Tensor& joint_tensor_q, + const ttnn::Tensor& joint_tensor_k, + const ttnn::Tensor& joint_tensor_v, + const std::string& joint_strategy, + SDPAProgramConfig program_config, + std::optional scale = std::nullopt, + std::optional compute_kernel_config = std::nullopt); + + static std::tuple invoke( + const ttnn::Tensor& input_tensor_q, + const ttnn::Tensor& input_tensor_k, + const ttnn::Tensor& input_tensor_v, + const ttnn::Tensor& joint_tensor_q, + const ttnn::Tensor& joint_tensor_k, + const ttnn::Tensor& joint_tensor_v, + const std::string& joint_strategy, + SDPAProgramConfig program_config, + std::optional scale = std::nullopt, + std::optional compute_kernel_config = std::nullopt); +}; + } // namespace operations::transformer namespace transformer { @@ -73,6 +100,10 @@ constexpr auto chunked_scaled_dot_product_attention = ttnn::register_operation_w "ttnn::transformer::chunked_scaled_dot_product_attention", ttnn::operations::transformer::ExecuteChunkedScaledDotProductAttention>(); +constexpr auto joint_scaled_dot_product_attention = ttnn::register_operation_with_auto_launch_op< + "ttnn::transformer::joint_scaled_dot_product_attention", + ttnn::operations::transformer::ExecuteJointAttention>(); + } // namespace transformer } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp index 8440728d920..9bde1cb8d49 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 @@ -149,5 +149,83 @@ void py_bind_sdpa(py::module& module) { py::arg("compute_kernel_config").noconvert() = std::nullopt, py::arg("queue_id") = 0, }); + + auto joint_doc = R"doc( + JointAttention operation that efficiently performs non-causal attention over two + sets of query, key, and value tensors. Internally, these are concatenated in the sequence + dimension (joint_strategy = "rear"), then attention is computed once. The + output is split ("sliced") into two parts: one for the original Q/K/V chunk, + and one for the joint Q/K/V chunk. + + This op handles optional padding via an attention mask to omit padded tokens from + both the "original" and "joint" sequences. + + Args: + input_tensor_q (ttnn.Tensor): Original queries [b x nh x N x dh]. + input_tensor_k (ttnn.Tensor): Original keys [b x nh x N x dh]. + input_tensor_v (ttnn.Tensor): Original values [b x nh x N x dh]. + + joint_tensor_q (ttnn.Tensor): Joint queries [b x nh x L x dh]. + joint_tensor_k (ttnn.Tensor): Joint keys [b x nh x L x dh]. + joint_tensor_v (ttnn.Tensor): Joint values [b x nh x L x dh]. + + Keyword args: + joint_strategy (str): Strategy for joint attention. Must be "rear". + program_config (ttnn.SDPAProgramConfig) + scale (float, optional): Scale factor for QK^T. Defaults to None. + compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional):Defaults to None. + queue_id (int, optional): Command queue ID. Defaults to 0. + + Returns: + (ttnn.Tensor, ttnn.Tensor): + - The attention output for the original Q/K/V shape [b x nh x N x dh]. + - The attention output for the joint Q/K/V shape [b x nh x L x dh]. + )doc"; + + using JointOperationType = decltype(ttnn::transformer::joint_scaled_dot_product_attention); + + ttnn::bind_registered_operation( + module, + ttnn::transformer::joint_scaled_dot_product_attention, + joint_doc, + ttnn::pybind_overload_t{ + [](const JointOperationType& self, + const ttnn::Tensor& input_tensor_q, + const ttnn::Tensor& input_tensor_k, + const ttnn::Tensor& input_tensor_v, + const ttnn::Tensor& joint_tensor_q, + const ttnn::Tensor& joint_tensor_k, + const ttnn::Tensor& joint_tensor_v, + const std::string& joint_strategy, + SDPAProgramConfig program_config, + std::optional scale, + std::optional compute_kernel_config, + uint8_t queue_id) { + auto outputs = self( + queue_id, + input_tensor_q, + input_tensor_k, + input_tensor_v, + joint_tensor_q, + joint_tensor_k, + joint_tensor_v, + joint_strategy, + program_config, + scale, + compute_kernel_config); + return outputs; + }, + py::arg("input_tensor_q").noconvert(), + py::arg("input_tensor_k").noconvert(), + py::arg("input_tensor_v").noconvert(), + py::arg("joint_tensor_q").noconvert(), + py::arg("joint_tensor_k").noconvert(), + py::arg("joint_tensor_v").noconvert(), + py::kw_only(), + py::arg("joint_strategy"), + py::arg("program_config").noconvert(), + py::arg("scale").noconvert() = std::nullopt, + py::arg("compute_kernel_config").noconvert() = std::nullopt, + py::arg("queue_id") = 0}); } } // namespace ttnn::operations::transformer