From 37c0bb07c10cb1907b6533ea31b2d05e93b308cb Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Wed, 4 Dec 2024 09:00:27 -0800 Subject: [PATCH] #0: Port all Misc ops to use TensorSpec (#15509) ### Ticket ### Problem description We need to migrate all ops to use `compute_output_specs` with TensorSpec, instead of older `compute_output_shapes` ### What's changed Migrated all misc ops to TensorSpec Minor infra upgrades to support the migration. ### Checklist - [x] [Post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/12057375995) - [x] [(Single-card) Demo tests](https://github.com/tenstorrent/tt-metal/actions/runs/12133160610) - [x] [(Single-card) Model perf tests](https://github.com/tenstorrent/tt-metal/actions/runs/12133156506) - [x] [(Single-card) Device perf regressions](https://github.com/tenstorrent/tt-metal/actions/runs/12146782009) - [x] New/Existing tests provide coverage for changes --- .../unit_tests/gtests/test_ccl_on_galaxy.cpp | 2 +- ttnn/cpp/ttnn/operation.hpp | 64 +++++++++------ .../full/device/full_device_operation.cpp | 18 ++--- .../full/device/full_device_operation.hpp | 4 +- .../device/full_like_device_operation.cpp | 17 ++-- .../device/full_like_device_operation.hpp | 4 +- .../device/index_fill_device_operation.cpp | 12 +-- .../device/index_fill_device_operation.hpp | 4 +- .../kv_cache/device/update_cache_op.cpp | 3 +- .../kv_cache/device/update_cache_op.hpp | 2 +- .../groupnorm/device/groupnorm_op.cpp | 24 +++--- .../groupnorm/device/groupnorm_op.hpp | 2 +- .../layernorm/device/layernorm_op.cpp | 58 ++++++++------ .../layernorm/device/layernorm_op.hpp | 2 +- .../device/layernorm_post_all_gather_op.cpp | 14 ++-- .../device/layernorm_post_all_gather_op.hpp | 3 +- .../device/layernorm_pre_all_gather_op.cpp | 11 +-- .../device/layernorm_pre_all_gather_op.hpp | 3 +- .../softmax/device/softmax_op.cpp | 15 ++-- .../softmax/device/softmax_op.hpp | 2 +- .../pool/downsample/device/downsample_op.cpp | 28 +++---- .../pool/downsample/device/downsample_op.hpp | 3 +- .../pool/generic/device/pool_op.cpp | 30 +++---- .../pool/generic/device/pool_op.hpp | 4 +- .../pool/upsample/device/upsample_op.cpp | 79 ++++++++----------- .../pool/upsample/device/upsample_op.hpp | 3 +- .../reduction/argmax/device/argmax_op.cpp | 19 +++-- .../reduction/argmax/device/argmax_op.hpp | 3 +- .../reduction/moe/device/moe_op.cpp | 19 +++-- .../reduction/moe/device/moe_op.hpp | 3 +- .../reduction/prod/device/prod_nc_op.cpp | 2 +- .../reduction/prod/device/prod_nc_op.hpp | 2 +- .../reduction/prod/device/prod_op_all.cpp | 11 +-- .../reduction/prod/device/prod_op_all.hpp | 3 +- .../reduction/topk/device/topk_op.cpp | 27 +++++-- .../reduction/topk/device/topk_op.hpp | 3 +- .../transformer/sdpa/device/sdpa_op.cpp | 11 +-- .../transformer/sdpa/device/sdpa_op.hpp | 4 +- .../sdpa_decode/device/sdpa_decode_op.cpp | 12 +-- .../sdpa_decode/device/sdpa_decode_op.hpp | 4 +- ttnn/cpp/ttnn/run_operation.cpp | 4 +- 41 files changed, 265 insertions(+), 273 deletions(-) diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index 87e0c2b5d57..6ab86fc8c8e 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -31,7 +31,7 @@ std::vector run_operation( static_assert( operation::detail::is_device_operation(), "ttnn::run_operation can only dispatch Device Operations!"); // Create output tensor vector by examining the number of output shapes created by the device operation - auto output_shapes = operation::DeviceOperation(devop).compute_output_shapes(input_tensors); + auto output_shapes = operation::DeviceOperation(devop).compute_output_shapes(input_tensors, {}); size_t output_shapes_size = 0; if (std::holds_alternative>(output_shapes)) { output_shapes_size = std::get>(output_shapes).size(); diff --git a/ttnn/cpp/ttnn/operation.hpp b/ttnn/cpp/ttnn/operation.hpp index 99c1c3a4fa8..710f805e604 100644 --- a/ttnn/cpp/ttnn/operation.hpp +++ b/ttnn/cpp/ttnn/operation.hpp @@ -298,9 +298,15 @@ constexpr bool implements_compute_output_shapes() { template using has_compute_output_specs_t = decltype(std::declval().compute_output_specs(std::declval()...)); +template +constexpr bool implements_compute_output_specs_with_optional_output_tensors() { + return std::experimental::is_detected_v; +} + template constexpr bool implements_compute_output_specs() { - return std::experimental::is_detected_v; + return std::experimental::is_detected_v || + implements_compute_output_specs_with_optional_output_tensors(); } template @@ -441,8 +447,9 @@ struct DeviceOperation final { } // TODO: Rename into compute_output_specs in later PR - inline const ComputedShapes compute_output_shapes(const Tensors& input_tensors) const { - return this->compute_output_shapes_impl_(this->type_erased_storage, input_tensors); + inline const ComputedShapes compute_output_shapes( + const Tensors& input_tensors, const OptionalTensors& output_tensors) const { + return this->compute_output_shapes_impl_(this->type_erased_storage, input_tensors, output_tensors); } inline const OutputTensors create_output_tensors( @@ -581,23 +588,28 @@ struct DeviceOperation final { "Operation must implement either validate or validate_with_output_tensors"); } }}, - compute_output_shapes_impl_{[](const storage_t& storage, const Tensors& input_tensors) -> const ComputedShapes { - const auto& operation = *reinterpret_cast*>(&storage); - if constexpr ( - detail::implements_compute_output_shapes() and detail::implements_compute_output_specs()) { - static_assert( - tt::stl::concepts::always_false_v, - "Operation cannot implement both compute_output_shapes and compute_output_specs"); - } else if constexpr (detail::implements_compute_output_shapes()) { - return operation.compute_output_shapes(input_tensors); - } else if constexpr (detail::implements_compute_output_specs()) { - return operation.compute_output_specs(input_tensors); - } else { - static_assert( - tt::stl::concepts::always_false_v, - "Operation must implement either compute_output_shapes or compute_output_specs"); - } - }}, + compute_output_shapes_impl_{ + [](const storage_t& storage, + const Tensors& input_tensors, + const OptionalTensors& output_tensors) -> const ComputedShapes { + const auto& operation = *reinterpret_cast*>(&storage); + if constexpr ( + detail::implements_compute_output_shapes() and detail::implements_compute_output_specs()) { + static_assert( + tt::stl::concepts::always_false_v, + "Operation cannot implement both compute_output_shapes and compute_output_specs"); + } else if constexpr (detail::implements_compute_output_shapes()) { + return operation.compute_output_shapes(input_tensors); + } else if constexpr (detail::implements_compute_output_specs_with_optional_output_tensors()) { + return operation.compute_output_specs(input_tensors, output_tensors); + } else if constexpr (detail::implements_compute_output_specs()) { + return operation.compute_output_specs(input_tensors); + } else { + static_assert( + tt::stl::concepts::always_false_v, + "Operation must implement either compute_output_shapes or compute_output_specs"); + } + }}, create_output_tensors_impl_{ [](const storage_t& storage, const Tensors& input_tensors, @@ -605,13 +617,15 @@ struct DeviceOperation final { const auto& operation = *reinterpret_cast*>(&storage); if constexpr (detail::implements_create_output_tensors_with_optional_output_tensors()) { static_assert( - detail::implements_compute_output_shapes(), - "Operation must implement compute_output_shapes if it implements create_output_tensors"); + detail::implements_compute_output_shapes() || detail::implements_compute_output_specs(), + "Operation must implement compute_output_shapes or compute_output_specs if it implements " + "create_output_tensors"); return operation.create_output_tensors(input_tensors, output_tensors); } else if constexpr (detail::implements_create_output_tensors()) { static_assert( - detail::implements_compute_output_shapes(), - "Operation must implement compute_output_shapes if it implements create_output_tensors"); + detail::implements_compute_output_shapes() || detail::implements_compute_output_specs(), + "Operation must implement compute_output_shapes or compute_output_specs if it implements " + "create_output_tensors"); return operation.create_output_tensors(input_tensors); } else if constexpr (detail::implements_compute_output_specs()) { return detail::default_create_output_tensors(operation, input_tensors, output_tensors); @@ -811,7 +825,7 @@ struct DeviceOperation final { const Tensors&, const std::vector>&, const OptionalTensors&); - const ComputedShapes (*compute_output_shapes_impl_)(const storage_t& value, const Tensors&); + const ComputedShapes (*compute_output_shapes_impl_)(const storage_t& value, const Tensors&, const OptionalTensors&); const OutputTensors (*create_output_tensors_impl_)(const storage_t& value, const Tensors&, const OptionalTensors&); CacheableProgram (*create_program_impl_)( diff --git a/ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp b/ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp index 7aaffe18c6b..9a8a94d46c7 100644 --- a/ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp @@ -50,20 +50,18 @@ void FullOperation::validate_on_program_cache_hit( validate_inputs(operation_attributes, tensor_args); }; -FullOperation::shape_return_value_t FullOperation::compute_output_shapes( - const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return SimpleShape(operation_attributes.shape); +FullOperation::spec_return_value_t FullOperation::compute_output_specs( + const operation_attributes_t& operation_attributes, const tensor_args_t&) { + return TensorSpec( + SimpleShape(operation_attributes.shape), + TensorLayout( + operation_attributes.dtype, PageConfig(operation_attributes.layout), operation_attributes.memory_config)); }; FullOperation::tensor_return_value_t FullOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - auto output_shape = compute_output_shapes(operation_attributes, tensor_args); - return create_device_tensor( - output_shape, - operation_attributes.dtype, - operation_attributes.layout, - tensor_args.any.device(), - operation_attributes.memory_config); + auto output_spec = compute_output_specs(operation_attributes, tensor_args); + return create_device_tensor(output_spec, tensor_args.any.device()); } std::tuple FullOperation::invoke( diff --git a/ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp b/ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp index a7efe354c34..7cf76d2c7ca 100644 --- a/ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp @@ -24,7 +24,7 @@ struct FullOperation { const Tensor& any; }; - using shape_return_value_t = SimpleShape; + using spec_return_value_t = TensorSpec; using tensor_return_value_t = Tensor; struct ProgramFactory { @@ -54,7 +54,7 @@ struct FullOperation { static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static std::tuple invoke( diff --git a/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp index 531509e382d..be79166991a 100644 --- a/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp @@ -42,21 +42,18 @@ void FullLikeOperation::validate_on_program_cache_hit( validate(operation_attributes, tensor_args); } -FullLikeOperation::shape_return_value_t FullLikeOperation::compute_output_shapes( +FullLikeOperation::spec_return_value_t FullLikeOperation::compute_output_specs( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return tensor_args.input.get_logical_shape(); + return TensorSpec( + tensor_args.input.get_logical_shape(), + TensorLayout( + operation_attributes.dtype, PageConfig(operation_attributes.layout), operation_attributes.memory_config)); } FullLikeOperation::tensor_return_value_t FullLikeOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto output_shape = compute_output_shapes(operation_attributes, tensor_args); - const auto& input = tensor_args.input; - return create_device_tensor( - output_shape, - operation_attributes.dtype, - operation_attributes.layout, - input.device(), - operation_attributes.memory_config); + const auto output_spec = compute_output_specs(operation_attributes, tensor_args); + return create_device_tensor(output_spec, tensor_args.input.device()); } std::tuple FullLikeOperation::invoke( diff --git a/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp index 7b8341ef7db..905ba6dbe93 100644 --- a/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp @@ -26,7 +26,7 @@ struct FullLikeOperation { const Tensor& input; }; - using shape_return_value_t = SimpleShape; + using spec_return_value_t = TensorSpec; using tensor_return_value_t = Tensor; struct ProgramFactory { @@ -54,7 +54,7 @@ struct FullLikeOperation { static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); static void validate(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static std::tuple invoke( diff --git a/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.cpp b/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.cpp index 05555467a83..79840d66a99 100644 --- a/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.cpp @@ -36,16 +36,16 @@ void IndexFillOperation::validate_on_program_cache_hit( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { validate(operation_attributes, tensor_args); } -IndexFillOperation::shape_return_value_t IndexFillOperation::compute_output_shapes( +IndexFillOperation::spec_return_value_t IndexFillOperation::compute_output_specs( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return tensor_args.input.get_logical_shape(); + return TensorSpec( + tensor_args.input.get_logical_shape(), + tensor_args.input.get_tensor_spec().tensor_layout().with_memory_config(operation_attributes.memory_config)); } IndexFillOperation::tensor_return_value_t IndexFillOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto output_shape = compute_output_shapes(operation_attributes, tensor_args); - const auto& input = tensor_args.input; - return create_device_tensor( - output_shape, input.dtype(), input.layout(), input.device(), operation_attributes.memory_config); + const auto output_spec = compute_output_specs(operation_attributes, tensor_args); + return create_device_tensor(output_spec, tensor_args.input.device()); } std::tuple IndexFillOperation::invoke( const Tensor& input, diff --git a/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.hpp b/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.hpp index 757c88962b2..75dd75d3c86 100644 --- a/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.hpp @@ -21,7 +21,7 @@ struct IndexFillOperation { const Tensor& input; const Tensor& index; }; - using shape_return_value_t = SimpleShape; + using spec_return_value_t = TensorSpec; using tensor_return_value_t = Tensor; struct MultiCore { struct shared_variables_t { @@ -46,7 +46,7 @@ struct IndexFillOperation { static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); static void validate(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static std::tuple invoke( const Tensor& input, diff --git a/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.cpp b/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.cpp index d75671f4af9..c2795db919c 100644 --- a/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.cpp +++ b/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.cpp @@ -95,8 +95,7 @@ void UpdateCache::validate(const std::vector& input_tensors) const { } } -std::vector UpdateCache::compute_output_shapes( - const std::vector& input_tensors) const { +std::vector UpdateCache::compute_output_specs(const std::vector&) const { // Do nothing because it's an in-place operation return {}; } diff --git a/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.hpp b/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.hpp index c742ac748c4..431a64ed338 100644 --- a/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.hpp +++ b/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.hpp @@ -34,7 +34,7 @@ struct UpdateCache { UpdateCacheOpParallelizationStrategy get_parallelization_strategy(const std::vector& input_tensors) const; void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp b/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp index 018017e490d..f962c112c6e 100644 --- a/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp @@ -79,23 +79,23 @@ void GroupNorm::validate( TT_FATAL(input_mask.value().get_legacy_shape()[3] % TILE_WIDTH == 0, "Error"); } } -std::vector GroupNorm::compute_output_shapes(const std::vector& input_tensors) const { - return {input_tensors.at(0).get_logical_shape()}; +std::vector GroupNorm::compute_output_specs(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + if (this->program_config.inplace) { + return {input_tensor.get_tensor_spec()}; + } + auto mem_config = this->output_mem_config; + mem_config.shard_spec = input_tensor.shard_spec(); + return {TensorSpec( + input_tensor.get_logical_shape(), + TensorLayout(program_config.out_data_format, PageConfig(program_config.output_layout), mem_config))}; } std::vector GroupNorm::create_output_tensors(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); if (this->program_config.inplace) { - return {input_tensors.at(0)}; - } else { - auto mem_config = this->output_mem_config; - mem_config.shard_spec = input_tensor.shard_spec(); - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - program_config.out_data_format, - this->program_config.output_layout, - input_tensor.device(), - mem_config)}; + return {input_tensor}; } + return {create_device_tensor(this->compute_output_specs(input_tensors).at(0), input_tensor.device())}; } operation::ProgramWithCallbacks GroupNorm::create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.hpp b/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.hpp index 8503453446c..5fd2878e7e1 100644 --- a/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.hpp @@ -58,7 +58,7 @@ struct GroupNorm { void validate( const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp index 9d91075c6a1..6a023e99adc 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp @@ -203,22 +203,19 @@ void LayerNorm::validate( }, this->program_config); } -std::vector LayerNorm::compute_output_shapes(const std::vector& input_tensors) const { - auto output_shape = input_tensors.at(0).get_logical_shape(); +std::vector LayerNorm::compute_output_specs(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + auto output_shape = input_tensor.get_logical_shape(); if (this->distributed_norm_stage == DistributedLayerNormStage::PRE_ALL_GATHER) { uint32_t num_tiles_w = this->norm_type == LayerNormType::LAYERNORM ? 2 : 1; output_shape[3] = num_tiles_w * TILE_WIDTH; } - return {output_shape}; -} -std::vector LayerNorm::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); + return std::visit( - [&](const auto& program_config) -> std::vector { + [&](const auto& program_config) -> std::vector { using ProgramConfigType = std::decay_t; if constexpr (std::is_same_v) { if (this->distributed_norm_stage == DistributedLayerNormStage::PRE_ALL_GATHER) { - auto output_shape = this->compute_output_shapes(input_tensors).at(0); auto shard_spec = input_tensor.shard_spec().value(); shard_spec.shape[1] = output_shape[3]; @@ -227,26 +224,37 @@ std::vector LayerNorm::create_output_tensors(const std::vector& shard_spec.grid = core_range_set; auto mem_config = this->output_mem_config; mem_config.shard_spec = shard_spec; - return {create_device_tensor( - output_shape, DataType::BFLOAT16, Layout::TILE, input_tensor.device(), mem_config)}; - } else { - if (program_config.inplace) { - return {input_tensor}; - } else { - auto mem_config = this->output_mem_config; - mem_config.shard_spec = input_tensor.shard_spec().value(); - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - input_tensors.at(0).get_dtype(), - Layout::TILE, - input_tensor.device(), - mem_config)}; - } + return {TensorSpec( + output_shape, TensorLayout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_config))}; + } + + if (program_config.inplace) { + return {input_tensor.get_tensor_spec()}; } + + auto mem_config = this->output_mem_config; + mem_config.shard_spec = input_tensor.shard_spec().value(); + return {TensorSpec( + output_shape, TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), mem_config))}; } else { - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config); + return {TensorSpec( + output_shape, TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; + } + }, + this->program_config); +} +std::vector LayerNorm::create_output_tensors(const std::vector& input_tensors) const { + return std::visit( + [&](const auto& program_config) -> std::vector { + using ProgramConfigType = std::decay_t; + if constexpr (std::is_same_v) { + if (this->distributed_norm_stage != DistributedLayerNormStage::PRE_ALL_GATHER && + program_config.inplace) { + return {input_tensors.at(0)}; + } } + auto output_spec = compute_output_specs(input_tensors)[0]; + return {create_device_tensor(output_spec, input_tensors.at(0).device())}; }, this->program_config); } diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.hpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.hpp index b821e87d244..7e654e32168 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.hpp @@ -52,7 +52,7 @@ struct LayerNorm { void validate( const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.cpp index d7d1dda809d..dec7a4005e4 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.cpp @@ -96,15 +96,11 @@ void LayerNormPostAllGather::validate( } } -std::vector LayerNormPostAllGather::compute_output_shapes( - const std::vector& input_tensors) const { - return {input_tensors.at(0).get_logical_shape()}; -} - -std::vector LayerNormPostAllGather::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->memory_config); +std::vector LayerNormPostAllGather::compute_output_specs(const std::vector& input_tensors) const { + auto& input_tensor = input_tensors.at(0); + return {TensorSpec( + input_tensor.get_logical_shape(), + TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), memory_config))}; } operation::ProgramWithCallbacks LayerNormPostAllGather::create_program( diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.hpp b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.hpp index c1410ab96ba..a4edff38e7e 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.hpp @@ -35,8 +35,7 @@ struct LayerNormPostAllGather { void validate( const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, const std::vector>& optional_input_tensors, diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.cpp index 2b3f2a5c4c2..dceb72e35e2 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.cpp @@ -36,8 +36,7 @@ void LayerNormPreAllGather::validate(const std::vector& input_tensors) c TT_FATAL(tensor.buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); } -std::vector LayerNormPreAllGather::compute_output_shapes( - const std::vector& input_tensors) const { +std::vector LayerNormPreAllGather::compute_output_specs(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); auto output_shape = input_tensors.at(0).get_logical_shape(); @@ -47,13 +46,7 @@ std::vector LayerNormPreAllGather::compute_output_shapes( } output_shape[3] = num_tiles_w * TILE_WIDTH; - return {output_shape}; -} - -std::vector LayerNormPreAllGather::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - return operation::generic_create_output_tensors( - *this, input_tensors, this->dtype, Layout::TILE, input_tensor.memory_config()); + return {TensorSpec(output_shape, TensorLayout(dtype, PageConfig(Layout::TILE), input_tensor.memory_config()))}; } operation::ProgramWithCallbacks LayerNormPreAllGather::create_program( diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.hpp b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.hpp index 303f285603c..299f0328952 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.hpp @@ -28,8 +28,7 @@ struct LayerNormPreAllGather { const DeviceComputeKernelConfig compute_kernel_config; void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; }; diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp index 0ca94a69259..deadbf2423e 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp @@ -110,17 +110,22 @@ void Softmax::validate( } } -std::vector Softmax::compute_output_shapes(const std::vector& input_tensors) const { - return {input_tensors.at(0).get_legacy_shape()}; +std::vector Softmax::compute_output_specs(const std::vector& input_tensors) const { + auto& input_tensor = input_tensors.at(0); + if (this->inplace) { + return {input_tensor.get_tensor_spec()}; + } + return {TensorSpec( + input_tensor.get_logical_shape(), + TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; } std::vector Softmax::create_output_tensors(const std::vector& input_tensors) const { if (this->inplace) { return {input_tensors.at(0)}; - } else { - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensors.at(0).get_dtype(), Layout::TILE, this->output_mem_config); } + + return {create_device_tensor(compute_output_specs(input_tensors)[0], input_tensors.at(0).device())}; } operation::ProgramWithCallbacks Softmax::create_program( diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp index 121ed08720f..8f7f1fde1ae 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp @@ -30,7 +30,7 @@ struct Softmax { void validate( const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.cpp b/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.cpp index 6e94c79887f..1d9abd72cb4 100644 --- a/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.cpp @@ -30,27 +30,17 @@ void Downsample::validate(const std::vector& input_tensors) const { input_tensor_a.memory_config().memory_layout); } -std::vector Downsample::compute_output_shapes( - const std::vector& input_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - TT_ASSERT(input_tensor_a.get_legacy_shape()[0] == 1 && input_tensor_a.get_legacy_shape()[1] == 1); - uint32_t input_height = input_tensor_a.get_legacy_shape()[2]; +std::vector Downsample::compute_output_specs(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + TT_ASSERT(input_tensor.get_legacy_shape()[0] == 1 && input_tensor.get_legacy_shape()[1] == 1); + uint32_t input_height = input_tensor.get_legacy_shape()[2]; auto [img_batch_size, img_height, img_width, img_stride_h, img_stride_w] = this->downsample_params; TT_ASSERT(input_height >= img_batch_size * img_height * img_width); - uint32_t output_height_unpadded = img_batch_size * ceil((double)img_height / (double)img_stride_h) * - ceil((double)img_width / (double)img_stride_w); - uint32_t output_height = tt::round_up(output_height_unpadded, TILE_HEIGHT); - uint32_t output_width = input_tensor_a.get_legacy_shape()[3]; - auto output_padding = - Padding({{0, 0}, {0, 0}, {0, (output_height - output_height_unpadded)}, {0, 0}}, Padding::PadValue::Any); - auto output_tensor_shape = tt::tt_metal::LegacyShape({1, 1, output_height, output_width}, output_padding); - log_debug(tt::LogOp, "Downsample output shape: {}", output_tensor_shape); - return {output_tensor_shape}; -} + uint32_t output_height = img_batch_size * ceil((double)img_height / (double)img_stride_h) * + ceil((double)img_width / (double)img_stride_w); + uint32_t output_width = input_tensor.get_legacy_shape()[3]; + auto output_shape = SimpleShape({1, 1, output_height, output_width}); -std::vector Downsample::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - auto output_shape = this->compute_output_shapes(input_tensors).at(0); auto [num_cores_height_sliced, num_cores_width_sliced] = detail::get_num_cores_height_width_sliced( input_tensor.shard_spec().value().grid, input_tensor.memory_config().memory_layout, @@ -64,7 +54,7 @@ std::vector Downsample::create_output_tensors(const std::vector& input_tensor.shard_spec().value().grid, std::array{{output_shard_height, output_shard_width}}, input_tensor.shard_spec().value().orientation}; - return {create_device_tensor(output_shape, this->dtype, Layout::TILE, input_tensor.device(), mem_config)}; + return {TensorSpec(output_shape, TensorLayout(dtype, PageConfig(Layout::TILE), mem_config))}; } operation::ProgramWithCallbacks Downsample::create_program( diff --git a/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.hpp b/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.hpp index f8418c8ad89..eae1360e2e4 100644 --- a/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.hpp +++ b/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.hpp @@ -19,8 +19,7 @@ struct Downsample { std::array downsample_params; DataType dtype; void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; }; diff --git a/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp b/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp index aa7734dcc9f..98a972b0fc7 100644 --- a/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp @@ -50,7 +50,7 @@ void Pool2D::validate_on_program_cache_hit(const operation_attributes_t& op_attr return validate_pool2d(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.memory_config_); } -Pool2D::shape_return_value_t Pool2D::compute_output_shapes( +Pool2D::spec_return_value_t Pool2D::compute_output_specs( const operation_attributes_t& op_attr, const tensor_args_t& tensors) { auto& input = tensors.input_tensor_; auto& sliding_window_config = op_attr.sliding_window_config_; @@ -71,7 +71,6 @@ Pool2D::shape_return_value_t Pool2D::compute_output_shapes( // need to pad the last dim to TILE_WIDTH uint32_t out_c = input_shape[3]; uint32_t out_c_padded = tt::round_up(out_c, (out_c <= 16) ? 16 : tt::constants::TILE_WIDTH); - uint32_t out_pagesize = out_c_padded * datum_size(datatype_to_dataformat_converter(input.get_dtype())); uint32_t out_nhw = sliding_window_config.batch_size * out_h * out_w; uint32_t out_nhw_padded = @@ -82,34 +81,29 @@ Pool2D::shape_return_value_t Pool2D::compute_output_shapes( const auto padding = Padding( {{0, 0}, {0, 0}, {0, out_nhw_padded - out_nhw}, {0, out_c_padded - out_c}}, Padding::PadValue::NegativeInfinity); - auto out_shape = Shape(tt::tt_metal::LegacyShape(out_dims, padding)); - return out_shape; -} - -Pool2D::tensor_return_value_t Pool2D::create_output_tensors( - const operation_attributes_t& op_attr, const tensor_args_t& tensors) { - auto& input = tensors.input_tensor_; - auto& sliding_window_config = op_attr.sliding_window_config_; - auto& out_mem_config = op_attr.memory_config_; - auto& output_dtype = op_attr.output_dtype_; + auto output_shape = Shape(tt::tt_metal::LegacyShape(out_dims, padding)); - Shape output_shape = compute_output_shapes(op_attr, tensors); auto mem_config = out_mem_config; if (mem_config.shard_spec.has_value()) { mem_config.shard_spec->shape[1] = input.shard_spec()->shape[1]; } else { uint32_t ncores = input.shard_spec().value().num_cores(); TT_FATAL(ncores == sliding_window_config.num_cores_nhw, "Number of cores should match"); - uint32_t nbatch = output_shape[0]; - uint32_t out_nhw_padded = output_shape[0] * output_shape[1] * output_shape[2]; - uint32_t out_nhw_per_core = out_nhw_padded / ncores; + uint32_t out_nhw_per_core = output_shape[0] * output_shape[1] * output_shape[2] / ncores; CoreRangeSet shard_grid = sliding_window_config.core_range_set; std::array shard_shape = {out_nhw_per_core, input.get_legacy_shape()[-1]}; mem_config.shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR, false}; } - // return create_device_tensor(output_shape, input.get_dtype(), input.get_layout(), input.device(), mem_config); - return create_device_tensor(output_shape, output_dtype, input.get_layout(), input.device(), mem_config); + return TensorSpec( + output_shape.logical_shape(), + TensorLayout::fromLegacyPaddedShape(output_dtype, PageConfig(input.get_layout()), mem_config, output_shape)); +} + +Pool2D::tensor_return_value_t Pool2D::create_output_tensors( + const operation_attributes_t& op_attr, const tensor_args_t& tensors) { + auto output_spec = compute_output_specs(op_attr, tensors); + return create_device_tensor(output_spec, tensors.input_tensor_.device()); } tt::stl::hash::hash_t Pool2D::compute_program_hash( diff --git a/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.hpp b/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.hpp index 8391ae7924d..0ab122cea54 100644 --- a/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.hpp +++ b/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.hpp @@ -35,7 +35,7 @@ struct Pool2D { const Tensor& input_tensor_; }; - using shape_return_value_t = ttnn::Shape; + using spec_return_value_t = TensorSpec; using tensor_return_value_t = Tensor; struct MultiCore { @@ -66,7 +66,7 @@ struct Pool2D { static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); static Tensor create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static tt::stl::hash::hash_t compute_program_hash(const operation_attributes_t&, const tensor_args_t&); static operation::OpPerformanceModel create_op_performance_model( diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp index 3a313e803cc..3e0c54f5d33 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp @@ -45,7 +45,7 @@ void UpSample::validate(const std::vector& input_tensors) const { } } -std::vector UpSample::compute_output_shapes(const std::vector& input_tensors) const { +std::vector UpSample::compute_output_specs(const std::vector& input_tensors) const { // NOTE1: data is packed in { N, H , W, C } // NOTE2: Mapping it into in 2D format should be {N*H*W, C} // NOTE3: Assuming output data type is same as input @@ -56,55 +56,46 @@ std::vector UpSample::compute_output_shapes(const std uint32_t out_h = input_shape[1] * scale_factor_h_; uint32_t out_w = input_shape[2] * scale_factor_w_; uint32_t out_c = input_shape[3]; - const ttnn::SmallVector out_dims({out_n, out_h, out_w, out_c}); // in the NHWC format - return {tt::tt_metal::LegacyShape{out_dims}}; -} + auto output_shape = ttnn::SimpleShape({out_n, out_h, out_w, out_c}); -std::vector UpSample::create_output_tensors(const std::vector& inputs) const { - const auto& input = inputs.at(0); if (output_mem_config_.is_sharded()) { - if (input.memory_config().is_sharded()) { - auto mem_config = output_mem_config_; - auto input_shard_spec = input.memory_config().shard_spec.value(); - auto output_shape = compute_output_shapes(inputs).at(0); - if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - auto ncores = input_shard_spec.num_cores(); - std::array output_shard_shape = { - div_up(output_shape[0] * output_shape[1] * output_shape[2], ncores), output_shape[-1]}; - auto output_shard_spec = input_shard_spec; - output_shard_spec.shape = output_shard_shape; - mem_config.shard_spec = output_shard_spec; - log_debug(LogOp, "output_shard_shape: {}", output_shard_shape); - log_debug(LogOp, "output_shard_spec: {}", output_shard_spec); - return {create_device_tensor( - output_shape, input.get_dtype(), input.get_layout(), input.device(), mem_config)}; - } else if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - auto shard_grid = input_shard_spec.grid.ranges(); - TT_FATAL(shard_grid.size() == 1, "Block sharded input should have only one CoreRange"); - auto core_range = *shard_grid.begin(); - uint32_t ncores_w = core_range.end_coord.x + 1; - uint32_t ncores_h = core_range.end_coord.y + 1; - // std::array output_shard_shape = {output_shape[0] * output_shape[1] * output_shape[2] / - // ncores_h, output_shape[-1] / ncores_w}; auto output_shard_spec = input_shard_spec; - // output_shard_spec.shape = output_shard_shape; - // mem_config.shard_spec = output_shard_spec; - auto output_shard_spec = mem_config.shard_spec.value(); - auto output_shard_shape = output_shard_spec.shape; - log_debug(LogOp, "ncores_w, ncores_h: {} {}", ncores_w, ncores_h); - log_debug(LogOp, "output_shard_shape: {}", output_shard_shape); - return {create_device_tensor( - output_shape, input.get_dtype(), input.get_layout(), input.device(), mem_config)}; - } else { - TT_THROW("input memory config is not HEIGHT or BLOCK sharded"); - } - } else { + if (!input.memory_config().is_sharded()) { TT_THROW("Output memory config is sharded but input memory config is not sharded"); } - } else { - return operation::generic_create_output_tensors( - *this, inputs, input.get_dtype(), input.get_layout(), output_mem_config_); + auto mem_config = output_mem_config_; + auto input_shard_spec = input.memory_config().shard_spec.value(); + if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + auto ncores = input_shard_spec.num_cores(); + std::array output_shard_shape = { + div_up(output_shape[0] * output_shape[1] * output_shape[2], ncores), output_shape[-1]}; + auto output_shard_spec = input_shard_spec; + output_shard_spec.shape = output_shard_shape; + mem_config.shard_spec = output_shard_spec; + log_debug(LogOp, "output_shard_shape: {}", output_shard_shape); + log_debug(LogOp, "output_shard_spec: {}", output_shard_spec); + return { + TensorSpec(output_shape, TensorLayout(input.get_dtype(), PageConfig(input.get_layout()), mem_config))}; + } + if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { + auto shard_grid = input_shard_spec.grid.ranges(); + TT_FATAL(shard_grid.size() == 1, "Block sharded input should have only one CoreRange"); + auto core_range = *shard_grid.begin(); + uint32_t ncores_w = core_range.end_coord.x + 1; + uint32_t ncores_h = core_range.end_coord.y + 1; + auto output_shard_spec = mem_config.shard_spec.value(); + auto output_shard_shape = output_shard_spec.shape; + log_debug(LogOp, "ncores_w, ncores_h: {} {}", ncores_w, ncores_h); + log_debug(LogOp, "output_shard_shape: {}", output_shard_shape); + return { + TensorSpec(output_shape, TensorLayout(input.get_dtype(), PageConfig(input.get_layout()), mem_config))}; + } + + TT_THROW("input memory config is not HEIGHT or BLOCK sharded"); } + + return { + TensorSpec(output_shape, TensorLayout(input.get_dtype(), PageConfig(input.get_layout()), output_mem_config_))}; } operation::ProgramWithCallbacks UpSample::create_program( diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.hpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.hpp index f07c9f8b472..be183519d98 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.hpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.hpp @@ -20,8 +20,7 @@ struct UpSample { const DeviceComputeKernelConfig compute_kernel_config_; void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; UpSampleParallelizationStrategy get_parallelization_strategy(const std::vector& input_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp b/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp index ea1094216b5..53bd54fdb44 100644 --- a/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp @@ -48,13 +48,20 @@ void ArgMax::validate_with_output_tensors( TT_FATAL(input_shape[1] == 1, "dim 1 must be 1"); } -std::vector ArgMax::compute_output_shapes(const std::vector& input_tensors) const { +std::vector ArgMax::compute_output_specs( + const std::vector& input_tensors, const std::vector>& output_tensors) const { + if (output_tensors.at(0).has_value()) { + return {output_tensors.at(0)->get_tensor_spec()}; + } + + const auto& input_tensor = input_tensors[0]; + ttnn::SimpleShape output_shape({1, 1, 1, 1}); if (this->dim.has_value()) { auto input_shape = input_tensors[0].get_logical_shape(); - return {ttnn::SimpleShape{input_shape[0], input_shape[1], 1, input_shape[2]}}; - } else { - return {ttnn::SimpleShape{1, 1, 1, 1}}; + output_shape = ttnn::SimpleShape{input_shape[0], input_shape[1], 1, input_shape[2]}; } + return { + TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(input_tensor.get_layout()), output_mem_config))}; } std::vector ArgMax::create_output_tensors( @@ -63,9 +70,7 @@ std::vector ArgMax::create_output_tensors( return {output_tensors.at(0).value()}; } - const auto& input_tensor = input_tensors[0]; - return operation::generic_create_output_tensors( - *this, input_tensors, this->output_dtype, input_tensor.get_layout(), this->output_mem_config); + return {create_device_tensor(compute_output_specs(input_tensors, output_tensors)[0], input_tensors[0].device())}; } operation::ProgramWithCallbacks ArgMax::create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.hpp b/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.hpp index dce8dcb8181..d0b6fa0b858 100644 --- a/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.hpp @@ -20,7 +20,8 @@ struct ArgMax { void validate_with_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs( + const std::vector& input_tensors, const std::vector>& output_tensors) const; std::vector create_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.cpp b/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.cpp index 1bfb6453dc5..3cdb60eed74 100644 --- a/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.cpp @@ -43,11 +43,19 @@ void MoeDeviceOperation::validate_with_output_tensors( TT_FATAL(expert_shape[-2] == 32, "Expert shape inner dim must be equal to 32, got {}", expert_shape[-2]); } -std::vector MoeDeviceOperation::compute_output_shapes( - const std::vector& input_tensors) const { - auto output_shape = input_tensors.at(0).get_logical_shape(); +std::vector MoeDeviceOperation::compute_output_specs( + const std::vector& input_tensors, const std::vector>& output_tensors) const { + if (output_tensors.size() == 1) { + if (output_tensors.at(0).has_value()) { + return {output_tensors[0]->get_tensor_spec()}; + } + } + + auto& input_tensor = input_tensors.at(0); + auto output_shape = input_tensor.get_logical_shape(); output_shape[-1] = 1; - return {output_shape}; + return { + TensorSpec(output_shape, TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; } std::vector MoeDeviceOperation::create_output_tensors( @@ -57,8 +65,7 @@ std::vector MoeDeviceOperation::create_output_tensors( return {output_tensors[0].value()}; } } - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensors.at(0).get_dtype(), Layout::TILE, this->output_mem_config); + return {create_device_tensor(compute_output_specs(input_tensors, output_tensors)[0], input_tensors.at(0).device())}; } operation::ProgramWithCallbacks MoeDeviceOperation::create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.hpp b/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.hpp index 8a95de49817..f380dbb0a87 100644 --- a/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.hpp @@ -17,7 +17,8 @@ struct MoeDeviceOperation { void validate_with_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs( + const std::vector& input_tensors, const std::vector>& output_tensors) const; std::vector create_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp index 3fb869370a4..1bd009cccc1 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp @@ -44,7 +44,7 @@ std::vector Prod::create_output_tensors(const std::vector& input return {}; } -std::vector Prod::compute_output_shapes(const std::vector& inputs) const { +std::vector Prod::compute_output_specs(const std::vector&) const { // Inplace return {}; } diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp index 3406689741b..e061c1ddeb1 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp @@ -23,7 +23,7 @@ using namespace tt_metal; struct Prod { int64_t dim; void validate(const std::vector& inputs) const; - std::vector compute_output_shapes(const std::vector& inputs) const; + std::vector compute_output_specs(const std::vector& inputs) const; std::vector create_output_tensors(const std::vector& inputs) const; operation::ProgramWithCallbacks create_program( const std::vector& inputs, std::vector& outputs) const; diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.cpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.cpp index b4e4c7e0a42..42da8fbc1fd 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.cpp @@ -26,14 +26,11 @@ void Prod_op::validate(const std::vector& input_tensors) const { TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16, "Error"); } -std::vector Prod_op::compute_output_shapes(const std::vector& input_tensors) const { - return {input_tensors.at(0).get_logical_shape()}; -} - -std::vector Prod_op::create_output_tensors(const std::vector& input_tensors) const { +std::vector Prod_op::compute_output_specs(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config); + return {TensorSpec( + input_tensor.get_logical_shape(), + TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; } operation::ProgramWithCallbacks Prod_op::create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.hpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.hpp index 066ffdb9a6c..aea785c5f25 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.hpp @@ -22,8 +22,7 @@ struct Prod_op { const MemoryConfig output_mem_config; const DataType output_dtype; // TODO: Uplift output_dtype as an option for general dot/bmm void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; }; diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp index adbe2f6101f..8d96b25ffec 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp @@ -85,10 +85,22 @@ void TopK::validate_with_output_tensors( } } -std::vector TopK::compute_output_shapes(const std::vector& input_tensors) const { +std::vector TopK::compute_output_specs( + const std::vector& input_tensors, const std::vector>& output_tensors) const { + if (output_tensors.size() == 2) { + if (output_tensors.at(0).has_value() && output_tensors.at(1).has_value()) { + return {output_tensors[0]->get_tensor_spec(), output_tensors[1]->get_tensor_spec()}; + } + } + const auto& input_tensor = input_tensors.at(0); auto output_shape = input_tensors.at(0).get_logical_shape(); output_shape[-1] = this->k; - return {output_shape, output_shape}; + + auto values_spec = + TensorSpec(output_shape, TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), output_mem_config)); + auto index_spec = + TensorSpec(output_shape, TensorLayout(DataType::UINT16, PageConfig(Layout::TILE), output_mem_config)); + return {values_spec, index_spec}; } std::vector TopK::create_output_tensors( @@ -98,13 +110,12 @@ std::vector TopK::create_output_tensors( return {output_tensors[0].value(), output_tensors[1].value()}; } } + auto output_specs = compute_output_specs(input_tensors, output_tensors); const auto& input_tensor = input_tensors.at(0); - const auto shapes = compute_output_shapes(input_tensors); - auto values_tensor = create_device_tensor( - shapes[0], input_tensor.get_dtype(), Layout::TILE, input_tensor.device(), this->output_mem_config); - auto index_tensor = - create_device_tensor(shapes[1], DataType::UINT16, Layout::TILE, input_tensor.device(), this->output_mem_config); - return {values_tensor, index_tensor}; + return { + create_device_tensor(output_specs[0], input_tensor.device()), + create_device_tensor(output_specs[1], input_tensor.device()), + }; } operation::ProgramWithCallbacks TopK::create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.hpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.hpp index 473da98ee4d..dac6dbc6766 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.hpp @@ -20,7 +20,8 @@ struct TopK { void validate_with_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs( + const std::vector& input_tensors, const std::vector>& output_tensors) const; std::vector create_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp index b08783fa5b5..5b3981bedfe 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp @@ -101,14 +101,11 @@ void ScaledDotProductAttention::validate( } } -std::vector ScaledDotProductAttention::compute_output_shapes( +std::vector ScaledDotProductAttention::compute_output_specs( const std::vector& input_tensors) const { - return {input_tensors.at(0).get_legacy_shape()}; -} - -std::vector ScaledDotProductAttention::create_output_tensors(const std::vector& input_tensors) const { - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensors.at(0).get_dtype(), Layout::TILE, this->output_mem_config); + auto& input = input_tensors.at(0); + return {TensorSpec( + input.get_logical_shape(), TensorLayout(input.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; } operation::ProgramWithCallbacks ScaledDotProductAttention::create_program( diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp index 812b88ea222..dfe88389084 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp @@ -24,9 +24,7 @@ struct ScaledDotProductAttention { const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp index ae63c4e50f4..617a6b95b95 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp @@ -202,15 +202,11 @@ void ScaledDotProductAttentionDecode::validate( } } -std::vector ScaledDotProductAttentionDecode::compute_output_shapes( +std::vector ScaledDotProductAttentionDecode::compute_output_specs( const std::vector& input_tensors) const { - return {input_tensors.at(0).get_logical_shape()}; -} - -std::vector ScaledDotProductAttentionDecode::create_output_tensors( - const std::vector& input_tensors) const { - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensors.at(0).get_dtype(), Layout::TILE, this->output_mem_config); + auto& input = input_tensors.at(0); + return {TensorSpec( + input.get_logical_shape(), TensorLayout(input.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; } operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program( diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp index b892dd775f4..30c5c4ed999 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp @@ -28,9 +28,7 @@ struct ScaledDotProductAttentionDecode { const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/run_operation.cpp b/ttnn/cpp/ttnn/run_operation.cpp index dac090cdbab..4e188ff2a04 100644 --- a/ttnn/cpp/ttnn/run_operation.cpp +++ b/ttnn/cpp/ttnn/run_operation.cpp @@ -380,7 +380,7 @@ Tensors run_with_autoformat( } } - auto output_specs = operation.compute_output_shapes(input_tensors); + auto output_specs = operation.compute_output_shapes(input_tensors, optional_output_tensors); auto output_tensors = run( std::move(operation), formatted_input_tensors, @@ -453,7 +453,7 @@ Tensors run_with_autoformat( } } - auto output_specs = operation.compute_output_shapes(input_tensors); + auto output_specs = operation.compute_output_shapes(input_tensors, optional_output_tensors); auto output_tensors = run( std::move(operation), formatted_input_tensors,