diff --git a/tests/ttnn/unit_tests/operations/test_linear.py b/tests/ttnn/unit_tests/operations/test_linear.py index 97df1e7b8ef..29956039dc1 100644 --- a/tests/ttnn/unit_tests/operations/test_linear.py +++ b/tests/ttnn/unit_tests/operations/test_linear.py @@ -262,3 +262,51 @@ def test_bloom_ff2_linear(device): ) assert ttnn.pearson_correlation_coefficient(torch_output, output) >= 0.9992 + + +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("m_size", [32, 64]) +@pytest.mark.parametrize("k_size", [1024, 2048]) +@pytest.mark.parametrize("n_size", [1024, 2048]) +@pytest.mark.parametrize("activation", [None, "relu"]) +def test_linear_by_passing_in_1D_systolic_array_program_config_and_optional_outout_tensor( + device, batch_size, m_size, k_size, n_size, activation +): + torch.manual_seed(0) + + torch_input_tensor_a = torch.randn((batch_size, m_size, k_size), dtype=torch.bfloat16) + torch_input_tensor_b = torch.randn((k_size, n_size), dtype=torch.bfloat16) + torch_output_tensor = torch_input_tensor_a @ torch_input_tensor_b + if activation == "relu": + torch_output_tensor = torch.relu(torch_output_tensor) + + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + + torch_opt_output_tensor = torch.zeros_like(torch_output_tensor) + optional_output_tensor = ttnn.from_torch(torch_opt_output_tensor, layout=ttnn.TILE_LAYOUT, device=device) + + output_tensor = ttnn.linear( + input_tensor_a, + input_tensor_b, + activation=activation, + core_grid=device.core_grid, + ) + + output_tensor = ttnn.to_torch(output_tensor) + + ttnn.linear( + input_tensor_a, + input_tensor_b, + activation=activation, + optional_output_tensor=optional_output_tensor, + core_grid=device.core_grid, + ) + + optional_output_tensor = ttnn.to_torch(optional_output_tensor) + + assert len(output_tensor.shape) == len(torch_output_tensor.shape) == len(optional_output_tensor.shape) + assert output_tensor.shape == torch_output_tensor.shape == optional_output_tensor.shape + assert_with_pcc(torch_output_tensor, output_tensor, 0.997) + assert_with_pcc(torch_output_tensor, optional_output_tensor, 0.997) + assert_with_pcc(optional_output_tensor, output_tensor, 0.997) diff --git a/tests/ttnn/unit_tests/operations/test_matmul.py b/tests/ttnn/unit_tests/operations/test_matmul.py index c411ab46631..651709e0fad 100644 --- a/tests/ttnn/unit_tests/operations/test_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_matmul.py @@ -2079,3 +2079,34 @@ def test_interleaved_input_sharded_output_matmul(device): output3 = ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=out_mem_config) output_tensor = ttnn.to_torch(output3) assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) + + +@pytest.mark.parametrize( + "n_size, c, m, k, n", + [ + (1, 1, 1024, 64, 512), + ], +) +def test_optional_output_argument(device, n_size, c, m, k, n): + torch.manual_seed(0) + + torch_input_tensor_a = torch.rand((n_size, c, m, k), dtype=torch.bfloat16) + torch_input_tensor_b = torch.rand((n_size, c, k, n), dtype=torch.bfloat16) + torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b) + torch_opt_output_tensor = torch.zeros_like(torch_output_tensor) + + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + optional_output_tensor = ttnn.from_torch(torch_opt_output_tensor, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn.matmul(input_tensor_a, input_tensor_b) + output = ttnn.to_torch(output) + + ttnn.matmul(input_tensor_a, input_tensor_b, optional_output_tensor=optional_output_tensor) + optional_output_tensor = ttnn.to_torch(optional_output_tensor) + + assert len(output.shape) == len(torch_output_tensor.shape) == len(optional_output_tensor.shape) + assert output.shape == torch_output_tensor.shape == optional_output_tensor.shape + assert_with_pcc(torch_output_tensor, output, 0.999) + assert_with_pcc(torch_output_tensor, optional_output_tensor, 0.999) + assert_with_pcc(output, optional_output_tensor, 0.999) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp index 96b80a12e3b..9016ceb4835 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp @@ -20,7 +20,8 @@ namespace experimental { void AllGatherMatmul::validate( const std::vector& input_tensors, - const std::vector>& optional_input_tensors) const { + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) const { TT_ASSERT( input_tensors.size() == 4, "AllGatherMatmul requires 4 input tensors: [input, weight, all_gather_output, datacopy_output]"); @@ -33,7 +34,7 @@ void AllGatherMatmul::validate( this->all_gather_struct.validate({input_tensor}); // Matmul validate. - this->matmul_struct.validate({all_gather_output_tensor, weight_tensor}, optional_input_tensors); + this->matmul_struct.validate({all_gather_output_tensor, weight_tensor}, optional_input_tensors, {}); // All Gather Matmul validate TT_FATAL(this->all_gather_struct.dim == 3, "AllGatherMatmul requires dim=3 for the AllGather operaitons."); @@ -73,7 +74,7 @@ std::vector AllGatherMatmul::compute_output_specs(const std::v // Matmul shape ttnn::TensorSpec matmul_output_specs = - this->matmul_struct.compute_output_specs({input_tensors[1], input_tensors[2]})[0]; + this->matmul_struct.compute_output_specs({input_tensors[1], input_tensors[2]}, {})[0]; return {all_gather_output_shape, matmul_output_specs, datacopy_output_shape}; } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp index c8af6cc9dd7..8afef89a6aa 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp @@ -41,7 +41,8 @@ struct AllGatherMatmul { /* General */ void validate( const std::vector& input_tensors, - const std::vector>& optional_input_tensors) const; + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors = {std::nullopt}) const; std::vector compute_output_specs(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index f3d0b732e62..19535b430a8 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -1003,7 +1003,10 @@ namespace operations { namespace matmul { Matmul create_matmul_struct( - const Tensor& input_tensor_a, const Tensor& input_tensor_b, const struct Matmul& parameters) { + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const struct Matmul& parameters, + const std::vector>& optional_output_tensors) { auto arch = input_tensor_a.device()->arch(); const bool has_user_grid = parameters.user_core_coord.has_value(); const bool has_program_config = parameters.program_config.has_value(); @@ -1022,16 +1025,48 @@ Matmul create_matmul_struct( bool broadcast_batch = parameters.bcast_batch.value_or(get_broadcast_batch(input_tensor_a, input_tensor_b, parameters.program_config)); TT_FATAL(!(has_user_grid && has_program_config), "Cannot use both user core grid/coordinates and a program config"); + + const bool is_optional_output_tensor = + !optional_output_tensors.empty() && optional_output_tensors.at(0).has_value(); + std::optional output_dtype = parameters.output_dtype; + MemoryConfig output_mem_config = parameters.output_mem_config; + + if (is_optional_output_tensor) { + const auto& optional_output_tensor = optional_output_tensors.at(0); + if (output_mem_config == operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { + output_mem_config = optional_output_tensor->memory_config(); + } else { + TT_FATAL( + optional_output_tensor->memory_config() == output_mem_config, + "Memory config mismatch between optional output tensor {} & output tensor {}", + optional_output_tensor->memory_config(), + output_mem_config); + } + + if (output_dtype.has_value()) { + TT_FATAL( + optional_output_tensor->get_dtype() == output_dtype.value(), + "Type mismatch between optional output tensor {} & output tensor {}", + optional_output_tensor->get_dtype(), + output_dtype.value()); + } else { + output_dtype = optional_output_tensor->get_dtype(); + } + } else { + if (!output_dtype.has_value()) { + output_dtype = input_tensor_a.get_dtype(); + } + } + auto in0_tile = input_tensor_a.get_tensor_spec().tile(); auto in1_tile = input_tensor_b.get_tensor_spec().tile(); - tt::tt_metal::Tile output_tile = - get_output_tile(parameters.output_mem_config, in0_tile, in1_tile, parameters.output_tile); + tt::tt_metal::Tile output_tile = get_output_tile(output_mem_config, in0_tile, in1_tile, parameters.output_tile); return Matmul{ parameters.program_config, broadcast_batch, - parameters.output_mem_config, - parameters.output_dtype.value_or(input_tensor_a.get_dtype()), + output_mem_config, + output_dtype, kernel_config_val, parameters.untilize_out, parameters.user_core_coord, @@ -1047,9 +1082,11 @@ Tensor matmul( const Tensor& input_tensor_b, const std::optional& bias, const struct Matmul& parameters, - const uint8_t queue_id) { + const uint8_t queue_id, + const std::optional& optional_output_tensor) { std::vector> optional_input_tensors = {}; std::vector output_tensors; + if (bias.has_value()) { optional_input_tensors.push_back(bias.value()); output_tensors = { @@ -1068,21 +1105,23 @@ Tensor matmul( const auto& input_tensor_b = input_tensors.at(1); return operation::run( - create_matmul_struct(input_tensor_a, input_tensor_b, parameters), + create_matmul_struct(input_tensor_a, input_tensor_b, parameters, optional_output_tensors), {input_tensor_a, input_tensor_b}, optional_input_tensors, - {}, + optional_output_tensors, queue_id); }, {input_tensor_a, input_tensor_b}, output_tensors, - optional_input_tensors); + optional_input_tensors, + {optional_output_tensor}); return output_tensors.at(0); } void Matmul::validate( const std::vector& input_tensors, - const std::vector>& optional_input_tensors) const { + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) const { TT_FATAL(input_tensors.size() == 2, "Error"); const auto& input_tensor_a = input_tensors.at(0); const auto& input_tensor_b = input_tensors.at(1); @@ -1113,6 +1152,28 @@ void Matmul::validate( a_shape[-1], b_shape[-2]); + const bool is_optional_output_tensor = !optional_output_tensors.empty() && optional_output_tensors.at(0).has_value(); + if (is_optional_output_tensor) { + const auto& optional_output_tensor_c = optional_output_tensors.at(0); + const auto& optional_output_tensor_shape = optional_output_tensor_c->get_logical_shape(); + const auto output_tensor_spec = this->compute_output_specs(input_tensors, {}).at(0); + TT_FATAL( + optional_output_tensor_shape == output_tensor_spec.logical_shape(), + "Shape of Optional Output Tensor {} doesnt match Output Tensor {}", + optional_output_tensor_shape, + output_tensor_spec.logical_shape()); + TT_FATAL( + optional_output_tensor_c->get_dtype() == this->output_dtype.value(), + "Type mismatch between optional output tensor {} & output tensor {}", + optional_output_tensor_c->get_dtype(), + this->output_dtype.value()); + TT_FATAL( + optional_output_tensor_c->memory_config() == this->output_mem_config, + "Memory config mismatch between optional output tensor {} & output tensor {}", + optional_output_tensor_c->memory_config(), + this->output_mem_config); + } + TT_FATAL(this->bcast_batch.has_value(), "Error: bcast_batch field should have been automatically populated"); TT_FATAL(this->output_tile.has_value(), "Error: output_tile field should have been automatically populated"); if (this->bcast_batch.value()) { @@ -1562,7 +1623,18 @@ void Matmul::validate( chosen_program_config); } -std::vector Matmul::compute_output_specs(const std::vector& input_tensors) const { +std::vector Matmul::compute_output_specs( + const std::vector& input_tensors, const std::vector>& optional_output_tensors) const { + TT_FATAL( + optional_output_tensors.size() <= 1, + "None or One Optional output tensor can be passed when accessing it for computing Matmul's output specs"); + + const bool is_optional_output_tensor = !optional_output_tensors.empty() && optional_output_tensors.at(0).has_value(); + + if (is_optional_output_tensor) { + return {optional_output_tensors.at(0)->get_tensor_spec()}; + } + const auto& input_tensor_a = input_tensors.at(0); const auto& input_tensor_b = input_tensors.at(1); const ttnn::SimpleShape input_shape_a = input_tensor_a.get_logical_shape(); @@ -1587,6 +1659,7 @@ std::vector Matmul::compute_output_specs(const std::vectoroutput_tile.value(); auto tile_width_ratio = output_tile.get_tile_shape()[1] / in1_tile_shape[1]; auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE; + TT_FATAL(this->output_dtype.has_value(), "Error"); if (this->output_mem_config.is_sharded()) { MatmulProgramConfig chosen_program_config = get_program_config(input_tensor_a, input_tensor_b, this); @@ -1733,8 +1806,9 @@ std::vector Matmul::compute_output_specs(const std::vector Matmul::create_output_tensors(const std::vector& input_tensors) const { - return operation::default_create_output_tensors(*this, input_tensors, {}); +std::vector Matmul::create_output_tensors( + const std::vector& input_tensors, const std::vector>& optional_output_tensors) const { + return operation::default_create_output_tensors(*this, input_tensors, optional_output_tensors); } operation::ProgramWithCallbacks Matmul::create_program( diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp index 8815c6b5205..5324a0c2de2 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp @@ -177,10 +177,15 @@ struct Matmul { void validate( const std::vector& input_tensors, - const std::vector>& optional_input_tensors) const; - std::vector compute_output_specs(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; - tt::tt_metal::operation::ProgramWithCallbacks create_program( + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors = {std::nullopt}) const; + std::vector compute_output_specs( + const std::vector& input_tensors, + const std::vector>& optional_output_tensors = {std::nullopt}) const; + std::vector create_output_tensors( + const std::vector& input_tensors, + const std::vector>& optional_output_tensors = {std::nullopt}) const; + operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, const std::vector>& optional_input_tensors, std::vector& output_tensors) const; @@ -191,7 +196,10 @@ struct Matmul { }; Matmul create_matmul_struct( - const Tensor& input_tensor_a, const Tensor& input_tensor_b, const struct Matmul& parameters); + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const struct Matmul& parameters, + const std::vector>& optional_output_tensors = {std::nullopt}); operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_helper( tt::tt_metal::Program& program, @@ -221,7 +229,8 @@ Tensor matmul( const Tensor& input_tensor_b, const std::optional& bias = std::nullopt, const struct Matmul& parameters = Matmul{}, - const uint8_t queue_id = 0); + const uint8_t queue_id = 0, + const std::optional& optional_output_tensor = std::nullopt); } // namespace matmul diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul.cpp b/ttnn/cpp/ttnn/operations/matmul/matmul.cpp index c160725ae4d..db4e5c1e6b9 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul.cpp @@ -43,7 +43,8 @@ ttnn::Tensor bound_matmul( const ttnn::Tensor& input_tensor_b, const std::optional& bias, const struct Matmul& parameters, - const uint8_t& queue_id) { + const uint8_t& queue_id, + std::optional& optional_output_tensor) { const auto& input_tensor_a_adjusted = parameters.transpose_a ? ttnn::transpose(input_tensor_a, -1, -2, input_tensor_a.memory_config()) : input_tensor_a; @@ -76,8 +77,13 @@ ttnn::Tensor bound_matmul( } } - auto output_tensor = - matmul(input_tensor_a_adjusted, input_tensor_b_adjusted, post_process_bias ? std::nullopt : bias, parameters); + auto output_tensor = matmul( + input_tensor_a_adjusted, + input_tensor_b_adjusted, + post_process_bias ? std::nullopt : bias, + parameters, + 0, + optional_output_tensor = optional_output_tensor); if (post_process_bias) { output_tensor = ttnn::add(output_tensor, bias.value(), std::nullopt, parameters.output_mem_config); @@ -110,7 +116,8 @@ Tensor MatmulOperation::invoke( const std::optional& activation, const std::optional compute_kernel_config, const std::optional core_grid, - const std::optional& output_tile) { + const std::optional& output_tile, + std::optional optional_output_tensor) { std::optional user_core_coord; if (core_grid.has_value()) { user_core_coord = CoreCoord(core_grid->x, core_grid->y); @@ -133,7 +140,8 @@ Tensor MatmulOperation::invoke( transpose_a, transpose_b, output_tile}, - /*queue_id=*/0); + /*queue_id=*/0, + optional_output_tensor); } Tensor LinearOperation::invoke( @@ -148,7 +156,8 @@ Tensor LinearOperation::invoke( const std::optional& activation, const std::optional compute_kernel_config, const std::optional core_grid, - const std::optional& output_tile) { + const std::optional& output_tile, + std::optional optional_output_tensor) { std::optional user_core_coord; if (core_grid.has_value()) { user_core_coord = CoreCoord(core_grid->x, core_grid->y); @@ -173,7 +182,8 @@ Tensor LinearOperation::invoke( transpose_a, transpose_b, output_tile}, - /*queue_id=*/0); + /*queue_id=*/0, + optional_output_tensor); } } // namespace matmul diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul.hpp b/ttnn/cpp/ttnn/operations/matmul/matmul.hpp index 1db1436c7b3..b35c28219e3 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul.hpp @@ -33,7 +33,8 @@ ttnn::Tensor bound_matmul( const ttnn::Tensor& input_tensor_b, const std::optional& bias, const struct Matmul& parameters, - const uint8_t& queue_id); + const uint8_t& queue_id, + std::optional& optional_output_tensor); struct MatmulOperation { static Tensor invoke( @@ -47,7 +48,8 @@ struct MatmulOperation { const std::optional& activation = std::nullopt, const std::optional compute_kernel_config = std::nullopt, const std::optional core_grid = std::nullopt, - const std::optional& output_tile = std::nullopt); + const std::optional& output_tile = std::nullopt, + std::optional optional_output_tensor = std::nullopt); }; struct LinearOperation { @@ -63,7 +65,8 @@ struct LinearOperation { const std::optional& activation = std::nullopt, const std::optional compute_kernel_config = std::nullopt, const std::optional core_grid = std::nullopt, - const std::optional& output_tile = std::nullopt); + const std::optional& output_tile = std::nullopt, + std::optional optional_output_tensor = std::nullopt); }; } // namespace matmul diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp index de6d7348eb2..2bc4499ce17 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp @@ -276,6 +276,11 @@ void py_module(py::module& module) { compute_kernel_config (ttnn.DeviceComputeKernelConfig): the compute kernel configuration for the matmul operation. Defaults to `None`. core_grid (ttnn.CoreGrid): the grid on which to distribute the sharded tensor on (writes to the cores L1s). Defaults to `None`. output_tile (List of [int], optional): Specifies the output tile configuration. Defaults to `None`. + optional_output_tensor (ttnn.Tensor) : User provided on-device output tensor where the result of matmul is to be written. + If optional output tensor is specified, then dtype and memory config need to be checked as follows: + if they are default then they should be set based on optional output tensor + if the are not default then they should be compared and if there is a difference an error is reported + Returns: ttnn.Tensor: the output tensor. @@ -330,7 +335,8 @@ void py_module(py::module& module) { const std::optional& activation, const std::optional compute_kernel_config, const std::optional core_grid, - const std::optional& output_tile) -> ttnn::Tensor { + const std::optional& output_tile, + std::optional& optional_output_tensor) -> ttnn::Tensor { return self( input_tensor_a, input_tensor_b, @@ -342,7 +348,8 @@ void py_module(py::module& module) { activation, compute_kernel_config, core_grid, - output_tile); + output_tile, + optional_output_tensor); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), @@ -356,6 +363,7 @@ void py_module(py::module& module) { py::arg("compute_kernel_config") = std::nullopt, py::arg("core_grid") = std::nullopt, py::arg("output_tile") = std::nullopt, + py::arg("optional_output_tensor") = std::nullopt, }); bind_registered_operation( @@ -381,6 +389,7 @@ void py_module(py::module& module) { compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional): the compute kernel configuration for the matmul operation. Defaults to `None`. core_grid (ttnn.CoreGrid, optional): the grid on which to distribute the sharded tensor on (writes to the cores L1s). Defaults to `None`. output_tile (List of [int], optional): Specifies the output tile configuration. Defaults to `None`. + optional_output_tensor (ttnn.Tensor) : User provided on-device output tensor where the result of linear is to be written. Returns: ttnn.Tensor: the output tensor. @@ -407,7 +416,8 @@ void py_module(py::module& module) { const std::optional& activation, const std::optional compute_kernel_config, const std::optional core_grid, - const std::optional& output_tile) -> ttnn::Tensor { + const std::optional& output_tile, + std::optional& optional_output_tensor) -> ttnn::Tensor { return self( input_tensor_a, input_tensor_b, @@ -420,7 +430,8 @@ void py_module(py::module& module) { activation, compute_kernel_config, core_grid, - output_tile); + output_tile, + optional_output_tensor); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), @@ -435,6 +446,7 @@ void py_module(py::module& module) { py::arg("compute_kernel_config") = std::nullopt, py::arg("core_grid") = std::nullopt, py::arg("output_tile") = std::nullopt, + py::arg("optional_output_tensor") = std::nullopt, }); }