diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 873296beece..df20d1c9a78 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -150,6 +150,7 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/unary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding/embedding.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding_backward/embedding_backward.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding_backward/embedding_backward_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding_backward/device/embedding_backward_device_operation.cpp @@ -371,7 +372,6 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/device/uniform_program_factory.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam_pybind.cpp diff --git a/ttnn/cpp/ttnn/operations/embedding/embedding.cpp b/ttnn/cpp/ttnn/operations/embedding/embedding.cpp new file mode 100644 index 00000000000..bc8fc0a3a71 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/embedding/embedding.cpp @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/operations/embedding/embedding.hpp" +#include "ttnn/operations/core/core.hpp" +#include "ttnn/common/constants.hpp" +#include "ttnn/operations/embedding/device/embedding_device_operation.hpp" +#include "ttnn/run_operation.hpp" +#include "ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp" + +namespace ttnn::operations::embedding{ + +ttnn::Tensor EmbeddingOperation::invoke( + uint8_t queue_id, + const Tensor& input_tensor_arg, + const Tensor& weight_arg, + const std::optional& pad_token, + const Layout& layout, + EmbeddingsType embeddings_type, + const std::optional dtype, + const std::optional& memory_config, + std::optional optional_output_tensor) { + if (pad_token.has_value()) { + embeddings_type = EmbeddingsType::PADDED; + } + Tensor mutable_input_tensor = input_tensor_arg; + Tensor mutable_weight = weight_arg; + if (mutable_input_tensor.get_layout() == ttnn::TILE_LAYOUT) { + mutable_input_tensor = ttnn::to_layout(mutable_input_tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); + } + if (mutable_weight.get_layout() == ttnn::TILE_LAYOUT) { + mutable_weight = ttnn::to_layout(mutable_weight, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); + } + auto hidden_embedding_dim = mutable_weight.get_shape()[-1]; + auto padded_hidden_embedding_dim = mutable_weight.get_shape().with_tile_padding()[-1]; + auto weight = ttnn::unsqueeze_to_4D(mutable_weight); + + auto batch_size = mutable_input_tensor.get_shape()[0]; + auto sentence_size = mutable_input_tensor.get_shape()[-1]; + auto input_tensor = + ttnn::reshape(mutable_input_tensor, ttnn::Shape{std::array{batch_size, 1, 1, sentence_size}}); + + bool fused_tilized = layout == ttnn::TILE_LAYOUT; + + // If layout is row major, OR if the input tensor is not a multiple of TILE_HEIGHT, then we cannot use tilized + if (fused_tilized) { + if (input_tensor.get_legacy_shape()[-1] % TILE_HEIGHT != 0 + || weight.get_legacy_shape()[-1] % TILE_WIDTH != 0) { + fused_tilized = false; + } + } + + auto embeddings = operation::run( + Embeddings{ + .output_mem_config = memory_config.value_or(input_tensor.memory_config()), + .tilized = fused_tilized, + .embeddings_type = embeddings_type, + .pad_token = pad_token, + .output_dtype = dtype.value_or(weight.get_dtype())}, + {input_tensor, weight}) + .at(0); + embeddings = ttnn::reshape( + embeddings, ttnn::Shape{std::array{batch_size, sentence_size, hidden_embedding_dim}}); + embeddings = ttnn::to_layout(embeddings, layout, std::nullopt, std::nullopt, (Device*)nullptr); + return embeddings; +} +ttnn::Tensor EmbeddingOperation::invoke( + const Tensor& input_tensor_arg, + const Tensor& weight_arg, + const std::optional& pad_token, + const Layout& layout, + EmbeddingsType embeddings_type, + const std::optional dtype, + const std::optional& memory_config, + std::optional optional_output_tensor + ) { + return invoke(DefaultQueueId, input_tensor_arg, weight_arg, pad_token, layout, embeddings_type, dtype, memory_config, optional_output_tensor); +} + +} // namespace ttnn::operations::embedding diff --git a/ttnn/cpp/ttnn/operations/embedding/embedding.hpp b/ttnn/cpp/ttnn/operations/embedding/embedding.hpp index 796df708eaa..248838025e3 100644 --- a/ttnn/cpp/ttnn/operations/embedding/embedding.hpp +++ b/ttnn/cpp/ttnn/operations/embedding/embedding.hpp @@ -4,11 +4,8 @@ #pragma once -#include "ttnn/common/constants.hpp" #include "ttnn/operations/embedding/device/embedding_device_operation.hpp" -#include "ttnn/run_operation.hpp" #include "ttnn/decorators.hpp" -#include "ttnn/operations/core/core.hpp" namespace ttnn { @@ -17,7 +14,7 @@ namespace operations { namespace embedding { struct EmbeddingOperation { - static inline Tensor invoke( + static ttnn::Tensor invoke( uint8_t queue_id, const Tensor& input_tensor_arg, const Tensor& weight_arg, @@ -26,47 +23,8 @@ struct EmbeddingOperation { EmbeddingsType embeddings_type = EmbeddingsType::GENERIC, const std::optional dtype = std::nullopt, const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt) { - if (pad_token.has_value()) { - embeddings_type = EmbeddingsType::PADDED; - } - Tensor mutable_input_tensor = input_tensor_arg; - Tensor mutable_weight = weight_arg; - if (mutable_input_tensor.get_layout() == ttnn::TILE_LAYOUT) { - mutable_input_tensor = ttnn::to_layout(mutable_input_tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); - } - if (mutable_weight.get_layout() == ttnn::TILE_LAYOUT) { - mutable_weight = ttnn::to_layout(mutable_weight, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); - } - auto hidden_embedding_dim = mutable_weight.get_shape()[-1]; - auto padded_hidden_embedding_dim = mutable_weight.get_shape().with_tile_padding()[-1]; - auto weight = ttnn::unsqueeze_to_4D(mutable_weight); - - auto batch_size = mutable_input_tensor.get_shape()[0]; - auto sentence_size = mutable_input_tensor.get_shape()[-1]; - auto input_tensor = - ttnn::reshape(mutable_input_tensor, ttnn::Shape{std::array{batch_size, 1, 1, sentence_size}}); - - bool fuzed_tilized = layout == ttnn::TILE_LAYOUT; - - // If layout is row major, OR if the input tensor is not a multiple of TILE_HEIGHT, then we cannot use tilized - if(!fuzed_tilized || input_tensor.get_legacy_shape()[-1] % TILE_HEIGHT) fuzed_tilized = false; - if(!fuzed_tilized || weight.get_legacy_shape()[-1] % TILE_WIDTH) fuzed_tilized = false; - - auto embeddings = operation::run( - Embeddings{ - .output_mem_config = memory_config.value_or(input_tensor.memory_config()), - .tilized = fuzed_tilized, - .embeddings_type = embeddings_type, - .pad_token = pad_token, - .output_dtype = dtype.value_or(weight.get_dtype())}, - {input_tensor, weight}) - .at(0); - embeddings = ttnn::reshape( - embeddings, ttnn::Shape{std::array{batch_size, sentence_size, hidden_embedding_dim}}); - return embeddings; - } - static inline auto invoke( + std::optional optional_output_tensor = std::nullopt); + static ttnn::Tensor invoke( const Tensor& input_tensor_arg, const Tensor& weight_arg, const std::optional& pad_token = std::nullopt, @@ -74,10 +32,7 @@ struct EmbeddingOperation { EmbeddingsType embeddings_type = EmbeddingsType::GENERIC, const std::optional dtype = std::nullopt, const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt - ) { - return invoke(DefaultQueueId, input_tensor_arg, weight_arg, pad_token, layout, embeddings_type, dtype, memory_config, optional_output_tensor); - } + std::optional optional_output_tensor = std::nullopt); }; } // namespace embedding