Skip to content

Commit

Permalink
#13593: separated embedding into hpp and cpp. With uint32 untilize, t…
Browse files Browse the repository at this point in the history
…his reaches 51% on sweeps and finishes most of non-tensor layout errors.
  • Loading branch information
yugi957 committed Nov 5, 2024
1 parent e8dd5f7 commit 5b3497d
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 50 deletions.
2 changes: 1 addition & 1 deletion ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding/embedding.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding_backward/embedding_backward.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding_backward/embedding_backward_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding_backward/device/embedding_backward_device_operation.cpp
Expand Down Expand Up @@ -371,7 +372,6 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/device/uniform_program_factory.cpp

${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam_pybind.cpp
Expand Down
81 changes: 81 additions & 0 deletions ttnn/cpp/ttnn/operations/embedding/embedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/embedding/embedding.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/common/constants.hpp"
#include "ttnn/operations/embedding/device/embedding_device_operation.hpp"
#include "ttnn/run_operation.hpp"
#include "ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp"

namespace ttnn::operations::embedding{

ttnn::Tensor EmbeddingOperation::invoke(
uint8_t queue_id,
const Tensor& input_tensor_arg,
const Tensor& weight_arg,
const std::optional<int>& pad_token,
const Layout& layout,
EmbeddingsType embeddings_type,
const std::optional<const DataType> dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor) {
if (pad_token.has_value()) {
embeddings_type = EmbeddingsType::PADDED;
}
Tensor mutable_input_tensor = input_tensor_arg;
Tensor mutable_weight = weight_arg;
if (mutable_input_tensor.get_layout() == ttnn::TILE_LAYOUT) {
mutable_input_tensor = ttnn::to_layout(mutable_input_tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr);
}
if (mutable_weight.get_layout() == ttnn::TILE_LAYOUT) {
mutable_weight = ttnn::to_layout(mutable_weight, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr);
}
auto hidden_embedding_dim = mutable_weight.get_shape()[-1];
auto padded_hidden_embedding_dim = mutable_weight.get_shape().with_tile_padding()[-1];
auto weight = ttnn::unsqueeze_to_4D(mutable_weight);

auto batch_size = mutable_input_tensor.get_shape()[0];
auto sentence_size = mutable_input_tensor.get_shape()[-1];
auto input_tensor =
ttnn::reshape(mutable_input_tensor, ttnn::Shape{std::array<uint32_t, 4>{batch_size, 1, 1, sentence_size}});

bool fused_tilized = layout == ttnn::TILE_LAYOUT;

// If layout is row major, OR if the input tensor is not a multiple of TILE_HEIGHT, then we cannot use tilized
if (fused_tilized) {
if (input_tensor.get_legacy_shape()[-1] % TILE_HEIGHT != 0
|| weight.get_legacy_shape()[-1] % TILE_WIDTH != 0) {
fused_tilized = false;
}
}

auto embeddings = operation::run(
Embeddings{
.output_mem_config = memory_config.value_or(input_tensor.memory_config()),
.tilized = fused_tilized,
.embeddings_type = embeddings_type,
.pad_token = pad_token,
.output_dtype = dtype.value_or(weight.get_dtype())},
{input_tensor, weight})
.at(0);
embeddings = ttnn::reshape(
embeddings, ttnn::Shape{std::array<uint32_t, 3>{batch_size, sentence_size, hidden_embedding_dim}});
embeddings = ttnn::to_layout(embeddings, layout, std::nullopt, std::nullopt, (Device*)nullptr);
return embeddings;
}
ttnn::Tensor EmbeddingOperation::invoke(
const Tensor& input_tensor_arg,
const Tensor& weight_arg,
const std::optional<int>& pad_token,
const Layout& layout,
EmbeddingsType embeddings_type,
const std::optional<const DataType> dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor
) {
return invoke(DefaultQueueId, input_tensor_arg, weight_arg, pad_token, layout, embeddings_type, dtype, memory_config, optional_output_tensor);
}

} // namespace ttnn::operations::embedding
53 changes: 4 additions & 49 deletions ttnn/cpp/ttnn/operations/embedding/embedding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@

#pragma once

#include "ttnn/common/constants.hpp"
#include "ttnn/operations/embedding/device/embedding_device_operation.hpp"
#include "ttnn/run_operation.hpp"
#include "ttnn/decorators.hpp"
#include "ttnn/operations/core/core.hpp"

namespace ttnn {

Expand All @@ -17,7 +14,7 @@ namespace operations {
namespace embedding {

struct EmbeddingOperation {
static inline Tensor invoke(
static ttnn::Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor_arg,
const Tensor& weight_arg,
Expand All @@ -26,58 +23,16 @@ struct EmbeddingOperation {
EmbeddingsType embeddings_type = EmbeddingsType::GENERIC,
const std::optional<const DataType> dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt) {
if (pad_token.has_value()) {
embeddings_type = EmbeddingsType::PADDED;
}
Tensor mutable_input_tensor = input_tensor_arg;
Tensor mutable_weight = weight_arg;
if (mutable_input_tensor.get_layout() == ttnn::TILE_LAYOUT) {
mutable_input_tensor = ttnn::to_layout(mutable_input_tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr);
}
if (mutable_weight.get_layout() == ttnn::TILE_LAYOUT) {
mutable_weight = ttnn::to_layout(mutable_weight, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr);
}
auto hidden_embedding_dim = mutable_weight.get_shape()[-1];
auto padded_hidden_embedding_dim = mutable_weight.get_shape().with_tile_padding()[-1];
auto weight = ttnn::unsqueeze_to_4D(mutable_weight);

auto batch_size = mutable_input_tensor.get_shape()[0];
auto sentence_size = mutable_input_tensor.get_shape()[-1];
auto input_tensor =
ttnn::reshape(mutable_input_tensor, ttnn::Shape{std::array<uint32_t, 4>{batch_size, 1, 1, sentence_size}});

bool fuzed_tilized = layout == ttnn::TILE_LAYOUT;

// If layout is row major, OR if the input tensor is not a multiple of TILE_HEIGHT, then we cannot use tilized
if(!fuzed_tilized || input_tensor.get_legacy_shape()[-1] % TILE_HEIGHT) fuzed_tilized = false;
if(!fuzed_tilized || weight.get_legacy_shape()[-1] % TILE_WIDTH) fuzed_tilized = false;

auto embeddings = operation::run(
Embeddings{
.output_mem_config = memory_config.value_or(input_tensor.memory_config()),
.tilized = fuzed_tilized,
.embeddings_type = embeddings_type,
.pad_token = pad_token,
.output_dtype = dtype.value_or(weight.get_dtype())},
{input_tensor, weight})
.at(0);
embeddings = ttnn::reshape(
embeddings, ttnn::Shape{std::array<uint32_t, 3>{batch_size, sentence_size, hidden_embedding_dim}});
return embeddings;
}
static inline auto invoke(
std::optional<Tensor> optional_output_tensor = std::nullopt);
static ttnn::Tensor invoke(
const Tensor& input_tensor_arg,
const Tensor& weight_arg,
const std::optional<int>& pad_token = std::nullopt,
const Layout& layout = ttnn::ROW_MAJOR_LAYOUT,
EmbeddingsType embeddings_type = EmbeddingsType::GENERIC,
const std::optional<const DataType> dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt
) {
return invoke(DefaultQueueId, input_tensor_arg, weight_arg, pad_token, layout, embeddings_type, dtype, memory_config, optional_output_tensor);
}
std::optional<Tensor> optional_output_tensor = std::nullopt);
};

} // namespace embedding
Expand Down

0 comments on commit 5b3497d

Please sign in to comment.