Skip to content

Commit

Permalink
#9755: centralize data movement C++ implementation to the data_moveme…
Browse files Browse the repository at this point in the history
…nt/concat folder
  • Loading branch information
sjameelTT committed Jul 30, 2024
1 parent 7be6873 commit 67c15ae
Show file tree
Hide file tree
Showing 17 changed files with 186 additions and 64 deletions.
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,6 @@ the Type 1 contiguous representations.
Other Operations
================

.. autofunction:: tt_lib.tensor.concat

.. autofunction:: tt_lib.tensor.sum

.. autofunction:: tt_lib.tensor.lerp
Expand Down
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ set(TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/permute/permute.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp
Expand Down
2 changes: 0 additions & 2 deletions ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_1d_sum_reduce/multi_core_ssm_1d_sum_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/split/split_tiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/split/split_last_dim_two_chunks_tiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/concat/multi_core/concat_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/concat/concat_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/repeat/multi_core/repeat_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/repeat/repeat_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nlp_tms/nlp_tms.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/deprecated/tt_dnn/op_library/concat/concat_op.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp"

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#include "ttnn/deprecated/tt_dnn/op_library/reshape/reshape_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/reduce/reduce_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/concat/concat_op.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp"

#include "ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "ttnn/deprecated/tt_dnn/op_library/fold/fold_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/transpose/transpose_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/fill_rm/fill_rm_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/concat/concat_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/bcast/bcast_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/reduce/reduce_op.hpp"
Expand Down Expand Up @@ -67,21 +66,6 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "No"
)doc"
);
m_tensor.def("concat", &concat,
py::arg("input_tensors").noconvert(), py::arg("dim") = 0, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Concatenates shape of tensors ``arg0`` and ``arg1`` to new shape ``[W, Z, Y, X]`` along the specified dimension ``arg1``.
Input tensors must be on device, in ROW MAJOR or TILE layout, and have matching data type.
Output tensor will be on device, in same layout, and have same data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"input_tensors", "Input tensors to concat", "List of Tensors", "Tensors of shape [W, Z, Y, X], where Y or X must be a multiple of 32 if they are the concat dim", "Yes"
"dim", "dimension of concat", "int", "", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("repeat", &tt::tt_metal::repeat,
py::arg("input"), py::arg("size"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/data_movement.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#pragma once

#include "ttnn/tensor/types.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/concat/concat_op.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/repeat/repeat_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.hpp"
#include "ttnn/operations/upsample/upsample_op.hpp"
Expand Down
16 changes: 8 additions & 8 deletions ttnn/cpp/ttnn/operations/data_movement/concat/concat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "ttnn/tensor/types.hpp"
#include "ttnn/operations/core/core.hpp"

#include "ttnn/deprecated/tt_dnn/op_library/concat/concat_op.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp"

#include <ranges>

Expand All @@ -16,15 +16,15 @@ namespace ttnn {
namespace operations {
namespace data_movement {

struct Concat {
struct ConcatOperation {

// Wrapper for TTDNN
static inline ttnn::Tensor operator()(
uint8_t queue_id,
const std::vector<ttnn::Tensor>& input_tensors,
int dim,
const std::optional<MemoryConfig>& memory_config,
std::optional<ttnn::Tensor>& optional_output_tensor) {
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor=std::nullopt) {
TT_FATAL(input_tensors.size() > 0, "ttnn.concat: expected a non-empty list of Tensors!");
TT_FATAL(!optional_output_tensor.has_value(), "optional output tensor currently unsupported!");
const auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); // should match input tensor memory config when unpopulated but causes CI errors for now
Expand Down Expand Up @@ -84,7 +84,7 @@ struct Concat {
});
// Convert dim after unsqueeze
dim = dim + 4 - rank;
auto output_tensor = tt::tt_metal::concat(itensor, dim, mem_config);
auto output_tensor = concat_impl(itensor, dim, mem_config);
while (output_tensor.get_shape().rank() > rank) {
const auto shape = output_tensor.get_shape();
const auto full_shape = output_tensor.get_shape().with_tile_padding();
Expand All @@ -105,8 +105,8 @@ struct Concat {
static inline ttnn::Tensor operator()(
const std::vector<ttnn::Tensor>& input_tensors,
int dim,
const std::optional<MemoryConfig>& memory_config,
std::optional<ttnn::Tensor>& optional_output_tensor) {
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
constexpr uint8_t DefaultQueueId = 0;
return operator()(DefaultQueueId, input_tensors, dim, memory_config, optional_output_tensor);
}
Expand All @@ -116,6 +116,6 @@ struct Concat {
} // namespace operations

constexpr auto concat =
ttnn::register_operation_with_auto_launch_op<"ttnn::concat", ttnn::operations::data_movement::Concat>();
ttnn::register_operation_with_auto_launch_op<"ttnn::concat", ttnn::operations::data_movement::ConcatOperation>();

} // namespace ttnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.hpp"

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/auto_format.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/copy/copy_op.hpp"
#include "ttnn/run_operation.hpp"

using namespace tt::constants;

namespace ttnn::operations::data_movement {


ConcatOpParallelizationStrategy ConcatDeviceOperation::get_parallelization_strategy(const std::vector<Tensor> &input_tensors) const {
if (input_tensors[0].is_sharded()) {
return ConcatOpParallelizationStrategy::SHARDED_MULTI_CORE;
} else {
return ConcatOpParallelizationStrategy::MULTI_CORE;
}
}

void ConcatDeviceOperation::validate(const std::vector<Tensor> &input_tensors) const {
const auto &first_input = input_tensors[0];
tt::tt_metal::Shape shape_first = first_input.get_legacy_shape();
TT_FATAL(this->dim < shape_first.rank(), "ConcatDeviceOperation dim specified is larger than input tensor rank.");
shape_first[this->dim] = 0;
bool shard_first = input_tensors[0].is_sharded();

for (const Tensor &in_ref : input_tensors) {
TT_FATAL(in_ref.buffer(), "Operand to concat needs to be allocated in a buffer on device.");
TT_FATAL(in_ref.device(), "Operand to concat needs to be on device.");
TT_FATAL(in_ref.device() == first_input.device(), "Operands to concat need to be on the same device.");
TT_FATAL(in_ref.get_layout() == first_input.get_layout(), "All Tensors should have same layouts.");
TT_FATAL(in_ref.get_dtype() == first_input.get_dtype(), "All Tensors should have same dtypes.");
tt::tt_metal::Shape curr_shape = in_ref.get_legacy_shape();
TT_FATAL(curr_shape.rank() == shape_first.rank(), "Input tensor ranks must be equal");
curr_shape[this->dim] = 0;
TT_FATAL(curr_shape == shape_first, "concat tensors differ in shape across non-concat dimensions.");
if (in_ref.get_layout() == Layout::ROW_MAJOR && this->dim == shape_first.rank() - 1) {
TT_FATAL(
(in_ref.get_legacy_shape()[this->dim] * in_ref.element_size()) % in_ref.buffer()->alignment() == 0,
"Current concat implementation requires aligned last dim when concatting on last dim");
}
TT_FATAL(in_ref.is_sharded() == shard_first, "All tensors must be sharded or all must be interleaved");
if (shard_first) {
TT_FATAL((in_ref.get_layout() == Layout::ROW_MAJOR), "Only row major supported for sharded concat.");
}
}
if (shard_first) {
TT_FATAL(this->dim == shape_first.rank() - 1, "Only width concat on sharded tensors");
TT_FATAL(this->output_mem_config.is_sharded(), "Output must be sharded if input is sharded");
}
}

std::vector<tt::tt_metal::Shape> ConcatDeviceOperation::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
tt::tt_metal::Shape shape_out = input_tensors[0].get_legacy_shape();
shape_out[this->dim] = 0;
for (const Tensor &in_ref : input_tensors) {
tt::tt_metal::Shape curr_shape = in_ref.get_legacy_shape();
shape_out[this->dim] += curr_shape[this->dim];
}
return {shape_out};
}

std::vector<Tensor> ConcatDeviceOperation::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
const Tensor &ref_in_tensor = input_tensors.at(0);

if (this->output_mem_config.is_sharded()) {
return {create_device_tensor(
this->compute_output_shapes(input_tensors).at(0),
ref_in_tensor.get_dtype(),
ref_in_tensor.get_layout(),
ref_in_tensor.device(),
this->output_mem_config)};
} else {
return operation::generic_create_output_tensors(
*this, input_tensors, ref_in_tensor.get_dtype(), ref_in_tensor.get_layout(), this->output_mem_config);
}
}

operation::ProgramWithCallbacks ConcatDeviceOperation::create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const {
switch (this->get_parallelization_strategy(input_tensors)) {
case ConcatOpParallelizationStrategy::SHARDED_MULTI_CORE:
return detail::sharded_concat_multi_core(input_tensors, this->dim, output_tensors[0]);
case ConcatOpParallelizationStrategy::MULTI_CORE:
default:
return detail::concat_multi_core(input_tensors, this->dim, output_tensors[0]);
};
}

Tensor concat_impl(std::vector<Tensor> &input_tensors, const std::int64_t dim, const MemoryConfig &output_mem_config) {
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensors[0]}))};
operation::launch_op(
[dim, output_mem_config](
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &optional_input_tensors,
const std::vector<std::optional<Tensor>> &optional_output_tensors) -> std::vector<Tensor> {
TT_FATAL(input_tensors.size() > 0, "need 1 or more tensors");
if (input_tensors.size() == 1) {
return {AutoFormat::move_tensor_to_mem_config(input_tensors[0], output_mem_config)};
}
uint32_t ref_rank = input_tensors[0].get_legacy_shape().rank();
uint32_t normalized_dim = input_tensors[0].get_legacy_shape().get_normalized_index(dim);

if (input_tensors[0].is_sharded()) {
return operation::run(ConcatDeviceOperation{normalized_dim, output_mem_config}, {input_tensors});
} else {
if (input_tensors[0].get_layout() == Layout::ROW_MAJOR && normalized_dim == ref_rank - 1) {
for (const auto &input_tensor : input_tensors) {
TT_FATAL(
(input_tensor.get_legacy_shape()[dim] * input_tensor.element_size()) % input_tensor.buffer()->alignment() ==
0,
"Current concat implementation requires aligned last dim when concatting on last dim");
}
}
Layout target_layout = Layout::TILE;
for (const auto &input_tensor : input_tensors) {
if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
const auto &input_shape = input_tensor.get_legacy_shape();
if (input_shape.rank() < 2 || input_shape[-2] % TILE_HEIGHT != 0 ||
input_shape[-1] % TILE_WIDTH != 0) {
target_layout = Layout::ROW_MAJOR;
break;
}
}
}
std::vector<FormatParams> input_format_params;
input_format_params.reserve(input_tensors.size());
for (const auto &input_tensor : input_tensors) {
if (target_layout == Layout::ROW_MAJOR) {
input_format_params.push_back(FormatParams{
.pad_shape = input_tensor.get_legacy_shape(),
.pad_value = 0.0,
.target_layout = target_layout});
} else {
tt::tt_metal::Shape pad_shape = AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape());
input_format_params.push_back(
FormatParams{.pad_shape = pad_shape, .pad_value = 0.0, .target_layout = target_layout});
}
}

return operation::run_with_autoformat(
ConcatDeviceOperation{normalized_dim, output_mem_config}, {input_tensors}, {input_format_params}, {target_layout});
}
},
input_tensors,
output_tensors);
return output_tensors.at(0);
}

} // namespace ttnn::operations::data_movement
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/run_operation.hpp"

namespace tt {

namespace tt_metal {
namespace ttnn::operations::data_movement {

enum class ConcatOpParallelizationStrategy { MULTI_CORE, SHARDED_MULTI_CORE };

struct Concat {
struct ConcatDeviceOperation {
uint32_t dim;
const MemoryConfig output_mem_config;
void validate(const std::vector<Tensor> &input_tensors) const;
Expand All @@ -24,18 +22,11 @@ struct Concat {
ConcatOpParallelizationStrategy get_parallelization_strategy(const std::vector<Tensor> &input_tensors) const;
};

operation::ProgramWithCallbacks sharded_concat_multi_core(
const std::vector<Tensor> &input_tensors, uint32_t dim, Tensor &output);
operation::ProgramWithCallbacks concat_multi_core(
const std::vector<Tensor> &input_tensors, const uint32_t dim, const Tensor &output);

// Ref: https://pytorch.org/docs/stable/generated/torch.cat.html#torch.cat
// Notes: Non-empty tensors provided must have the same shape, except in the cat dimension.
Tensor concat(
Tensor concat_impl(
std::vector<Tensor> &input_tensors,
const std::int64_t dim = 0,
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

} // namespace tt_metal

} // namespace tt
} // namespace ttnn::operations::data_movement
Loading

0 comments on commit 67c15ae

Please sign in to comment.