From 37f4fb70468fc1c891d3b52542ef93fd8b55224a Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Wed, 27 Nov 2024 19:53:14 +0000 Subject: [PATCH 1/5] #0: Port all misc ops to tensorspec --- .../unit_tests/gtests/test_ccl_on_galaxy.cpp | 152 ++++++++++-------- ttnn/cpp/ttnn/operation.hpp | 69 ++++---- .../full/device/full_device_operation.cpp | 18 +-- .../full/device/full_device_operation.hpp | 4 +- .../device/full_like_device_operation.cpp | 17 +- .../device/full_like_device_operation.hpp | 4 +- .../device/index_fill_device_operation.cpp | 12 +- .../device/index_fill_device_operation.hpp | 4 +- .../kv_cache/device/update_cache_op.cpp | 3 +- .../kv_cache/device/update_cache_op.hpp | 2 +- .../groupnorm/device/groupnorm_op.cpp | 24 +-- .../groupnorm/device/groupnorm_op.hpp | 2 +- .../layernorm/device/layernorm_op.cpp | 47 +++--- .../layernorm/device/layernorm_op.hpp | 2 +- .../device/layernorm_post_all_gather_op.cpp | 14 +- .../device/layernorm_post_all_gather_op.hpp | 3 +- .../device/layernorm_pre_all_gather_op.cpp | 11 +- .../device/layernorm_pre_all_gather_op.hpp | 3 +- .../softmax/device/softmax_op.cpp | 15 +- .../softmax/device/softmax_op.hpp | 2 +- .../pool/downsample/device/downsample_op.cpp | 28 ++-- .../pool/downsample/device/downsample_op.hpp | 3 +- .../maxpool/device/max_pool2d_device_op.cpp | 34 +--- .../maxpool/device/max_pool2d_device_op.hpp | 4 +- .../pool/upsample/device/upsample_op.cpp | 83 +++++----- .../pool/upsample/device/upsample_op.hpp | 3 +- .../reduction/argmax/device/argmax_op.cpp | 19 ++- .../reduction/argmax/device/argmax_op.hpp | 3 +- .../reduction/moe/device/moe_op.cpp | 19 ++- .../reduction/moe/device/moe_op.hpp | 3 +- .../reduction/prod/device/prod_nc_op.cpp | 2 +- .../reduction/prod/device/prod_nc_op.hpp | 2 +- .../reduction/prod/device/prod_op_all.cpp | 11 +- .../reduction/prod/device/prod_op_all.hpp | 3 +- .../reduction/topk/device/topk_op.cpp | 27 +++- .../reduction/topk/device/topk_op.hpp | 3 +- .../transformer/sdpa/device/sdpa_op.cpp | 11 +- .../transformer/sdpa/device/sdpa_op.hpp | 4 +- .../sdpa_decode/device/sdpa_decode_op.cpp | 12 +- .../sdpa_decode/device/sdpa_decode_op.hpp | 4 +- ttnn/cpp/ttnn/run_operation.cpp | 4 +- 41 files changed, 342 insertions(+), 348 deletions(-) diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index 67c0e772c2a..a25cd5df7b0 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -18,18 +18,20 @@ using namespace tt_metal; // We use this to dispatch a single device operation asynchronously // Needed to reproduce the deadlock scenario with a very specific pattern of commands -// This can go away once device_operation::run will be made async and ccl op is moved to the new tmp-based DeviceOperation +// This can go away once device_operation::run will be made async and ccl op is moved to the new tmp-based +// DeviceOperation namespace async_detail { -template +template std::vector run_operation( uint8_t cq_id, OpConfig devop, const operation::Tensors& input_tensors, const operation::OptionalConstTensors& optional_input_tensors = {}, const operation::OptionalTensors& optional_output_tensors = {}) { - static_assert(operation::detail::is_device_operation(), "ttnn::run_operation can only dispatch Device Operations!"); + static_assert( + operation::detail::is_device_operation(), "ttnn::run_operation can only dispatch Device Operations!"); // Create output tensor vector by examining the number of output shapes created by the device operation - auto output_shapes = operation::DeviceOperation(devop).compute_output_shapes(input_tensors); + auto output_shapes = operation::DeviceOperation(devop).compute_output_shapes(input_tensors, {}); size_t output_shapes_size = 0; if (std::holds_alternative>(output_shapes)) { output_shapes_size = std::get>(output_shapes).size(); @@ -44,71 +46,69 @@ std::vector run_operation( // Send the operation to the async engine, which will populate the output tensors. for (auto worker : outputs.at(0).workers) { tt::tt_metal::operation::launch_op( - [devop, worker, cq_id] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { - return operation::run(std::move(devop), input_tensors, optional_input_tensors, optional_output_tensors, cq_id); - }, input_tensors, outputs, optional_input_tensors, optional_output_tensors); + [devop, worker, cq_id]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { + return operation::run( + std::move(devop), input_tensors, optional_input_tensors, optional_output_tensors, cq_id); + }, + input_tensors, + outputs, + optional_input_tensors, + optional_output_tensors); } return outputs; } -} // namespace async_detail +} // namespace async_detail -bool is_tg_system() -{ +bool is_tg_system() { const bool is_galaxy_system = tt::Cluster::instance().is_galaxy_cluster(); const size_t num_mmio_devices = tt::Cluster::instance().number_of_pci_devices(); const size_t num_devices = tt::Cluster::instance().number_of_user_devices(); return is_galaxy_system && (num_mmio_devices == 4) && (num_devices == 32); } -bool is_tgg_system() -{ +bool is_tgg_system() { const bool is_galaxy_system = tt::Cluster::instance().is_galaxy_cluster(); const size_t num_mmio_devices = tt::Cluster::instance().number_of_pci_devices(); const size_t num_devices = tt::Cluster::instance().number_of_user_devices(); return is_galaxy_system && (num_mmio_devices == 8) && (num_devices == 64); } -ttnn::MeshShape get_mesh_shape() -{ +ttnn::MeshShape get_mesh_shape() { ttnn::MeshShape shape; - if (is_tg_system()) - { + if (is_tg_system()) { shape = {8, 4}; - } - else { + } else { TT_FATAL(is_tgg_system(), "Unsupported Galaxy system"); shape = {8, 8}; } return shape; } -void validate_num_tunnels_and_tunnel_depth() -{ +void validate_num_tunnels_and_tunnel_depth() { const uint32_t num_devices_in_tunnel = tt::Cluster::instance().get_mmio_device_max_tunnel_depth(0); const uint32_t num_mmio_devices = tt::Cluster::instance().number_of_pci_devices(); const uint32_t cluster_tunnel_count = tt::Cluster::instance().get_mmio_device_tunnel_count(0); - TT_FATAL(num_devices_in_tunnel == 4, "Expected Galaxy to have tunnel depth of 4, detected tunnel depth of {}", num_devices_in_tunnel); + TT_FATAL( + num_devices_in_tunnel == 4, + "Expected Galaxy to have tunnel depth of 4, detected tunnel depth of {}", + num_devices_in_tunnel); const uint32_t num_tunnels = num_mmio_devices * cluster_tunnel_count; - if (is_tg_system()) - { + if (is_tg_system()) { TT_FATAL(num_tunnels == 8, "Expected 8 tunnels in a TG system, detected {} tunnels", num_tunnels); - } - else if (is_tgg_system()) - { + } else if (is_tgg_system()) { TT_FATAL(num_tunnels == 16, "Expected 16 tunnels in a TGG system, detected {} tunnels", num_tunnels); } } -std::shared_ptr create_container_for_readback_data(const uint32_t buf_size_datums) -{ - if (is_tg_system()) - { - return std::shared_ptr(new bfloat16[buf_size_datums * 4]); - } - else - { +std::shared_ptr create_container_for_readback_data(const uint32_t buf_size_datums) { + if (is_tg_system()) { + return std::shared_ptr(new bfloat16[buf_size_datums * 4]); + } else { TT_FATAL(is_tgg_system(), "Unsupported Galaxy system"); - return std::shared_ptr(new bfloat16[buf_size_datums * 8]); + return std::shared_ptr(new bfloat16[buf_size_datums * 8]); } } @@ -119,18 +119,17 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { validate_num_tunnels_and_tunnel_depth(); ttnn::MeshShape mesh_shape = get_mesh_shape(); - std::shared_ptr mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); + std::shared_ptr mesh = + ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); // Setup input data and output data containers MemoryConfig mem_cfg = MemoryConfig{ - .memory_layout = TensorMemoryLayout::INTERLEAVED, - .buffer_type = BufferType::DRAM, - .shard_spec = std::nullopt}; + .memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::DRAM, .shard_spec = std::nullopt}; ttnn::SimpleShape shape{1, 1, 32, 16384}; const uint32_t buf_size_datums = 32 * 16384; const uint32_t datum_size_bytes = 2; - auto host_data = std::shared_ptr(new bfloat16[buf_size_datums]); - std::shared_ptr readback_data = create_container_for_readback_data(buf_size_datums); + auto host_data = std::shared_ptr(new bfloat16[buf_size_datums]); + std::shared_ptr readback_data = create_container_for_readback_data(buf_size_datums); const uint32_t outer_loops = 200; // Input to CCL is a tensor of 1s. The output should contain 4x the amount of data, but all values should be 1. @@ -167,12 +166,22 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { uint32_t receiver_device_id = device_ids[(dev_idx) + 1 % num_devices_in_row]; uint32_t sender_device_id = device_ids[(dev_idx + num_devices_in_row - 1) % num_devices_in_row]; auto all_gather_op = ttnn::AllGather{ - 3, 2, num_devices_in_row, dev_idx, std::nullopt, std::nullopt, receiver_device_id, sender_device_id, input_tensor.memory_config(), ttnn::ccl::Topology::Linear}; + 3, + 2, + num_devices_in_row, + dev_idx, + std::nullopt, + std::nullopt, + receiver_device_id, + sender_device_id, + input_tensor.memory_config(), + ttnn::ccl::Topology::Linear}; // Send CCL to this device. All CCLs will complete simultaneously. output_tensors.push_back(async_detail::run_operation(0, all_gather_op, {input_tensor}).at(0)); - // Expose deadlock: After the CCL is sent to the first device in the tunnel, send enough data to it to backpressure prefetch_h. This will block the - // demux, which will prevent the CCL from being sent to additional chips. If the CCL has been tagged as having multi-device dependencies, deadlock should - // get bypassed. + // Expose deadlock: After the CCL is sent to the first device in the tunnel, send enough data to it to + // backpressure prefetch_h. This will block the demux, which will prevent the CCL from being sent to + // additional chips. If the CCL has been tagged as having multi-device dependencies, deadlock should get + // bypassed. if (!dev_idx) { ttnn::write_buffer(0, input_tensor, {host_data}); } @@ -180,7 +189,9 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { } // Readback data and verify correctness. for (auto& tensor : output_tensors) { - ASSERT_EQ(tensor.get_shape(), ttnn::Shape(LegacyShape({1, 1, 32, static_cast(16384 * device_ids.size())}))); + ASSERT_EQ( + tensor.get_shape(), + ttnn::Shape(LegacyShape({1, 1, 32, static_cast(16384 * device_ids.size())}))); ttnn::read_buffer(0, tensor, {readback_data}); for (int j = 0; j < device_ids.size() * 32 * 16384; j++) { ASSERT_EQ(readback_data[j].to_float(), 1); @@ -198,17 +209,20 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { validate_num_tunnels_and_tunnel_depth(); ttnn::MeshShape mesh_shape = get_mesh_shape(); - std::shared_ptr mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); - // Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the - // first tunnel (forward path). + std::shared_ptr mesh = + ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); + // Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks + // when we send CCLs to the first tunnel (forward path). auto view = ttnn::MeshDeviceView(*mesh); - std::vector ring_devices = view.get_devices_on_row(0); // Tunnel 0 - std::vector ring_devices_1 = view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks + std::vector ring_devices = view.get_devices_on_row(0); // Tunnel 0 + std::vector ring_devices_1 = + view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks ring_devices_1 = std::vector(ring_devices_1.begin() + 1, ring_devices_1.end()); - std::vector ring_devices_2 = view.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering + std::vector ring_devices_2 = + view.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering std::reverse(ring_devices_2.begin(), ring_devices_2.end()); ring_devices_2 = std::vector(ring_devices_2.begin() + 1, ring_devices_2.end()); - std::vector ring_devices_3 = view.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks + std::vector ring_devices_3 = view.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks std::reverse(ring_devices_3.begin(), ring_devices_3.end()); ring_devices_3 = std::vector(ring_devices_3.begin() + 1, ring_devices_3.end() - 1); @@ -218,15 +232,13 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { // Setup input data and output data containers MemoryConfig mem_cfg = MemoryConfig{ - .memory_layout = TensorMemoryLayout::INTERLEAVED, - .buffer_type = BufferType::DRAM, - .shard_spec = std::nullopt}; + .memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::DRAM, .shard_spec = std::nullopt}; ttnn::SimpleShape shape{1, 2, 256, static_cast(256 * ring_devices.size())}; const uint32_t buf_size_datums = 2 * 256 * 256 * ring_devices.size(); const uint32_t datum_size_bytes = 2; // Output of reduce scatter is input_numel / num_devices_used_in_scatter_op - auto host_data = std::shared_ptr(new bfloat16[buf_size_datums]); - auto readback_data = std::shared_ptr(new bfloat16[buf_size_datums / ring_devices.size()]); + auto host_data = std::shared_ptr(new bfloat16[buf_size_datums]); + auto readback_data = std::shared_ptr(new bfloat16[buf_size_datums / ring_devices.size()]); uint32_t scatter_dim = 3; uint32_t outer_loops = 500; @@ -265,16 +277,24 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { uint32_t receiver_device_id = device_ids[(dev_idx + 1) % ring_devices.size()]; uint32_t sender_device_id = device_ids[(dev_idx + ring_devices.size() - 1) % ring_devices.size()]; auto all_gather_op = ttnn::ReduceScatter{ - ttnn::operations::binary::BinaryOpType::ADD, scatter_dim, 1, static_cast(ring_devices.size()), dev_idx, receiver_device_id, sender_device_id, input_tensor.memory_config(), ttnn::ccl::Topology::Ring}; + ttnn::operations::binary::BinaryOpType::ADD, + scatter_dim, + 1, + static_cast(ring_devices.size()), + dev_idx, + receiver_device_id, + sender_device_id, + input_tensor.memory_config(), + ttnn::ccl::Topology::Ring}; // Send CCL to this device. All CCLs will complete simultaneously. output_tensors.push_back(async_detail::run_operation(0, all_gather_op, {input_tensor}).at(0)); - // Expose deadlock: After the CCL is sent to a device in the first tunnel, send enough data to it to backpressure prefetch_h. This will block the - // demux, which will prevent the CCL from being sent to additional chips on the tunnel. If the CCL has been tagged as having multi-device dependencies, deadlock should - // get bypassed. - // if (dev_idx < 3) { - for (int j = 0; j < 16; j++) { - ttnn::write_buffer(0, input_tensor, {host_data}); - } + // Expose deadlock: After the CCL is sent to a device in the first tunnel, send enough data to it to + // backpressure prefetch_h. This will block the demux, which will prevent the CCL from being sent to + // additional chips on the tunnel. If the CCL has been tagged as having multi-device dependencies, deadlock + // should get bypassed. if (dev_idx < 3) { + for (int j = 0; j < 16; j++) { + ttnn::write_buffer(0, input_tensor, {host_data}); + } // } dev_idx++; } diff --git a/ttnn/cpp/ttnn/operation.hpp b/ttnn/cpp/ttnn/operation.hpp index f0a6b030c2e..710f805e604 100644 --- a/ttnn/cpp/ttnn/operation.hpp +++ b/ttnn/cpp/ttnn/operation.hpp @@ -298,9 +298,15 @@ constexpr bool implements_compute_output_shapes() { template using has_compute_output_specs_t = decltype(std::declval().compute_output_specs(std::declval()...)); +template +constexpr bool implements_compute_output_specs_with_optional_output_tensors() { + return std::experimental::is_detected_v; +} + template constexpr bool implements_compute_output_specs() { - return std::experimental::is_detected_v; + return std::experimental::is_detected_v || + implements_compute_output_specs_with_optional_output_tensors(); } template @@ -396,8 +402,9 @@ constexpr bool implements_get_parallelization_strategy() { template auto default_create_output_tensors( - const ConcreteOperation& operation, const Tensors& input_tensors, const OptionalTensors& optional_output_tensors) - -> ProgramOutputTensors { + const ConcreteOperation& operation, + const Tensors& input_tensors, + const OptionalTensors& optional_output_tensors) -> ProgramOutputTensors { using OutputTensors = ProgramOutputTensors; OutputTensors output_tensors; @@ -440,8 +447,9 @@ struct DeviceOperation final { } // TODO: Rename into compute_output_specs in later PR - inline const ComputedShapes compute_output_shapes(const Tensors& input_tensors) const { - return this->compute_output_shapes_impl_(this->type_erased_storage, input_tensors); + inline const ComputedShapes compute_output_shapes( + const Tensors& input_tensors, const OptionalTensors& output_tensors) const { + return this->compute_output_shapes_impl_(this->type_erased_storage, input_tensors, output_tensors); } inline const OutputTensors create_output_tensors( @@ -580,23 +588,28 @@ struct DeviceOperation final { "Operation must implement either validate or validate_with_output_tensors"); } }}, - compute_output_shapes_impl_{[](const storage_t& storage, const Tensors& input_tensors) -> const ComputedShapes { - const auto& operation = *reinterpret_cast*>(&storage); - if constexpr ( - detail::implements_compute_output_shapes() and detail::implements_compute_output_specs()) { - static_assert( - tt::stl::concepts::always_false_v, - "Operation cannot implement both compute_output_shapes and compute_output_specs"); - } else if constexpr (detail::implements_compute_output_shapes()) { - return operation.compute_output_shapes(input_tensors); - } else if constexpr (detail::implements_compute_output_specs()) { - return operation.compute_output_specs(input_tensors); - } else { - static_assert( - tt::stl::concepts::always_false_v, - "Operation must implement either compute_output_shapes or compute_output_specs"); - } - }}, + compute_output_shapes_impl_{ + [](const storage_t& storage, + const Tensors& input_tensors, + const OptionalTensors& output_tensors) -> const ComputedShapes { + const auto& operation = *reinterpret_cast*>(&storage); + if constexpr ( + detail::implements_compute_output_shapes() and detail::implements_compute_output_specs()) { + static_assert( + tt::stl::concepts::always_false_v, + "Operation cannot implement both compute_output_shapes and compute_output_specs"); + } else if constexpr (detail::implements_compute_output_shapes()) { + return operation.compute_output_shapes(input_tensors); + } else if constexpr (detail::implements_compute_output_specs_with_optional_output_tensors()) { + return operation.compute_output_specs(input_tensors, output_tensors); + } else if constexpr (detail::implements_compute_output_specs()) { + return operation.compute_output_specs(input_tensors); + } else { + static_assert( + tt::stl::concepts::always_false_v, + "Operation must implement either compute_output_shapes or compute_output_specs"); + } + }}, create_output_tensors_impl_{ [](const storage_t& storage, const Tensors& input_tensors, @@ -604,13 +617,15 @@ struct DeviceOperation final { const auto& operation = *reinterpret_cast*>(&storage); if constexpr (detail::implements_create_output_tensors_with_optional_output_tensors()) { static_assert( - detail::implements_compute_output_shapes(), - "Operation must implement compute_output_shapes if it implements create_output_tensors"); + detail::implements_compute_output_shapes() || detail::implements_compute_output_specs(), + "Operation must implement compute_output_shapes or compute_output_specs if it implements " + "create_output_tensors"); return operation.create_output_tensors(input_tensors, output_tensors); } else if constexpr (detail::implements_create_output_tensors()) { static_assert( - detail::implements_compute_output_shapes(), - "Operation must implement compute_output_shapes if it implements create_output_tensors"); + detail::implements_compute_output_shapes() || detail::implements_compute_output_specs(), + "Operation must implement compute_output_shapes or compute_output_specs if it implements " + "create_output_tensors"); return operation.create_output_tensors(input_tensors); } else if constexpr (detail::implements_compute_output_specs()) { return detail::default_create_output_tensors(operation, input_tensors, output_tensors); @@ -810,7 +825,7 @@ struct DeviceOperation final { const Tensors&, const std::vector>&, const OptionalTensors&); - const ComputedShapes (*compute_output_shapes_impl_)(const storage_t& value, const Tensors&); + const ComputedShapes (*compute_output_shapes_impl_)(const storage_t& value, const Tensors&, const OptionalTensors&); const OutputTensors (*create_output_tensors_impl_)(const storage_t& value, const Tensors&, const OptionalTensors&); CacheableProgram (*create_program_impl_)( diff --git a/ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp b/ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp index 7aaffe18c6b..9a8a94d46c7 100644 --- a/ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp @@ -50,20 +50,18 @@ void FullOperation::validate_on_program_cache_hit( validate_inputs(operation_attributes, tensor_args); }; -FullOperation::shape_return_value_t FullOperation::compute_output_shapes( - const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return SimpleShape(operation_attributes.shape); +FullOperation::spec_return_value_t FullOperation::compute_output_specs( + const operation_attributes_t& operation_attributes, const tensor_args_t&) { + return TensorSpec( + SimpleShape(operation_attributes.shape), + TensorLayout( + operation_attributes.dtype, PageConfig(operation_attributes.layout), operation_attributes.memory_config)); }; FullOperation::tensor_return_value_t FullOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - auto output_shape = compute_output_shapes(operation_attributes, tensor_args); - return create_device_tensor( - output_shape, - operation_attributes.dtype, - operation_attributes.layout, - tensor_args.any.device(), - operation_attributes.memory_config); + auto output_spec = compute_output_specs(operation_attributes, tensor_args); + return create_device_tensor(output_spec, tensor_args.any.device()); } std::tuple FullOperation::invoke( diff --git a/ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp b/ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp index a7efe354c34..7cf76d2c7ca 100644 --- a/ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp @@ -24,7 +24,7 @@ struct FullOperation { const Tensor& any; }; - using shape_return_value_t = SimpleShape; + using spec_return_value_t = TensorSpec; using tensor_return_value_t = Tensor; struct ProgramFactory { @@ -54,7 +54,7 @@ struct FullOperation { static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static std::tuple invoke( diff --git a/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp index 531509e382d..be79166991a 100644 --- a/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp @@ -42,21 +42,18 @@ void FullLikeOperation::validate_on_program_cache_hit( validate(operation_attributes, tensor_args); } -FullLikeOperation::shape_return_value_t FullLikeOperation::compute_output_shapes( +FullLikeOperation::spec_return_value_t FullLikeOperation::compute_output_specs( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return tensor_args.input.get_logical_shape(); + return TensorSpec( + tensor_args.input.get_logical_shape(), + TensorLayout( + operation_attributes.dtype, PageConfig(operation_attributes.layout), operation_attributes.memory_config)); } FullLikeOperation::tensor_return_value_t FullLikeOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto output_shape = compute_output_shapes(operation_attributes, tensor_args); - const auto& input = tensor_args.input; - return create_device_tensor( - output_shape, - operation_attributes.dtype, - operation_attributes.layout, - input.device(), - operation_attributes.memory_config); + const auto output_spec = compute_output_specs(operation_attributes, tensor_args); + return create_device_tensor(output_spec, tensor_args.input.device()); } std::tuple FullLikeOperation::invoke( diff --git a/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp index 7b8341ef7db..905ba6dbe93 100644 --- a/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp @@ -26,7 +26,7 @@ struct FullLikeOperation { const Tensor& input; }; - using shape_return_value_t = SimpleShape; + using spec_return_value_t = TensorSpec; using tensor_return_value_t = Tensor; struct ProgramFactory { @@ -54,7 +54,7 @@ struct FullLikeOperation { static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); static void validate(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static std::tuple invoke( diff --git a/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.cpp b/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.cpp index 05555467a83..79840d66a99 100644 --- a/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.cpp @@ -36,16 +36,16 @@ void IndexFillOperation::validate_on_program_cache_hit( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { validate(operation_attributes, tensor_args); } -IndexFillOperation::shape_return_value_t IndexFillOperation::compute_output_shapes( +IndexFillOperation::spec_return_value_t IndexFillOperation::compute_output_specs( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return tensor_args.input.get_logical_shape(); + return TensorSpec( + tensor_args.input.get_logical_shape(), + tensor_args.input.get_tensor_spec().tensor_layout().with_memory_config(operation_attributes.memory_config)); } IndexFillOperation::tensor_return_value_t IndexFillOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto output_shape = compute_output_shapes(operation_attributes, tensor_args); - const auto& input = tensor_args.input; - return create_device_tensor( - output_shape, input.dtype(), input.layout(), input.device(), operation_attributes.memory_config); + const auto output_spec = compute_output_specs(operation_attributes, tensor_args); + return create_device_tensor(output_spec, tensor_args.input.device()); } std::tuple IndexFillOperation::invoke( const Tensor& input, diff --git a/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.hpp b/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.hpp index 757c88962b2..75dd75d3c86 100644 --- a/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.hpp @@ -21,7 +21,7 @@ struct IndexFillOperation { const Tensor& input; const Tensor& index; }; - using shape_return_value_t = SimpleShape; + using spec_return_value_t = TensorSpec; using tensor_return_value_t = Tensor; struct MultiCore { struct shared_variables_t { @@ -46,7 +46,7 @@ struct IndexFillOperation { static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); static void validate(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static std::tuple invoke( const Tensor& input, diff --git a/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.cpp b/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.cpp index d75671f4af9..c2795db919c 100644 --- a/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.cpp +++ b/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.cpp @@ -95,8 +95,7 @@ void UpdateCache::validate(const std::vector& input_tensors) const { } } -std::vector UpdateCache::compute_output_shapes( - const std::vector& input_tensors) const { +std::vector UpdateCache::compute_output_specs(const std::vector&) const { // Do nothing because it's an in-place operation return {}; } diff --git a/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.hpp b/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.hpp index c742ac748c4..431a64ed338 100644 --- a/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.hpp +++ b/ttnn/cpp/ttnn/operations/kv_cache/device/update_cache_op.hpp @@ -34,7 +34,7 @@ struct UpdateCache { UpdateCacheOpParallelizationStrategy get_parallelization_strategy(const std::vector& input_tensors) const; void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp b/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp index 018017e490d..f962c112c6e 100644 --- a/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp @@ -79,23 +79,23 @@ void GroupNorm::validate( TT_FATAL(input_mask.value().get_legacy_shape()[3] % TILE_WIDTH == 0, "Error"); } } -std::vector GroupNorm::compute_output_shapes(const std::vector& input_tensors) const { - return {input_tensors.at(0).get_logical_shape()}; +std::vector GroupNorm::compute_output_specs(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + if (this->program_config.inplace) { + return {input_tensor.get_tensor_spec()}; + } + auto mem_config = this->output_mem_config; + mem_config.shard_spec = input_tensor.shard_spec(); + return {TensorSpec( + input_tensor.get_logical_shape(), + TensorLayout(program_config.out_data_format, PageConfig(program_config.output_layout), mem_config))}; } std::vector GroupNorm::create_output_tensors(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); if (this->program_config.inplace) { - return {input_tensors.at(0)}; - } else { - auto mem_config = this->output_mem_config; - mem_config.shard_spec = input_tensor.shard_spec(); - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - program_config.out_data_format, - this->program_config.output_layout, - input_tensor.device(), - mem_config)}; + return {input_tensor}; } + return {create_device_tensor(this->compute_output_specs(input_tensors).at(0), input_tensor.device())}; } operation::ProgramWithCallbacks GroupNorm::create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.hpp b/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.hpp index 8503453446c..5fd2878e7e1 100644 --- a/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.hpp @@ -58,7 +58,7 @@ struct GroupNorm { void validate( const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp index 9d91075c6a1..e95bee08c1d 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp @@ -203,22 +203,19 @@ void LayerNorm::validate( }, this->program_config); } -std::vector LayerNorm::compute_output_shapes(const std::vector& input_tensors) const { - auto output_shape = input_tensors.at(0).get_logical_shape(); +std::vector LayerNorm::compute_output_specs(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + auto output_shape = input_tensor.get_logical_shape(); if (this->distributed_norm_stage == DistributedLayerNormStage::PRE_ALL_GATHER) { uint32_t num_tiles_w = this->norm_type == LayerNormType::LAYERNORM ? 2 : 1; output_shape[3] = num_tiles_w * TILE_WIDTH; } - return {output_shape}; -} -std::vector LayerNorm::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); + return std::visit( - [&](const auto& program_config) -> std::vector { + [&](const auto& program_config) -> std::vector { using ProgramConfigType = std::decay_t; if constexpr (std::is_same_v) { if (this->distributed_norm_stage == DistributedLayerNormStage::PRE_ALL_GATHER) { - auto output_shape = this->compute_output_shapes(input_tensors).at(0); auto shard_spec = input_tensor.shard_spec().value(); shard_spec.shape[1] = output_shape[3]; @@ -227,29 +224,29 @@ std::vector LayerNorm::create_output_tensors(const std::vector& shard_spec.grid = core_range_set; auto mem_config = this->output_mem_config; mem_config.shard_spec = shard_spec; - return {create_device_tensor( - output_shape, DataType::BFLOAT16, Layout::TILE, input_tensor.device(), mem_config)}; - } else { - if (program_config.inplace) { - return {input_tensor}; - } else { - auto mem_config = this->output_mem_config; - mem_config.shard_spec = input_tensor.shard_spec().value(); - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - input_tensors.at(0).get_dtype(), - Layout::TILE, - input_tensor.device(), - mem_config)}; - } + return {TensorSpec( + output_shape, TensorLayout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_config))}; + } + + if (program_config.inplace) { + return {input_tensor.get_tensor_spec()}; } + + auto mem_config = this->output_mem_config; + mem_config.shard_spec = input_tensor.shard_spec().value(); + return {TensorSpec( + output_shape, TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), mem_config))}; } else { - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config); + return {TensorSpec( + output_shape, TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; } }, this->program_config); } +std::vector LayerNorm::create_output_tensors(const std::vector& input_tensors) const { + auto output_spec = compute_output_specs(input_tensors)[0]; + return {create_device_tensor(output_spec, input_tensors.at(0).device())}; +} operation::ProgramWithCallbacks LayerNorm::create_program( const std::vector& input_tensors, const std::vector>& optional_input_tensors, diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.hpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.hpp index b821e87d244..7e654e32168 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.hpp @@ -52,7 +52,7 @@ struct LayerNorm { void validate( const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.cpp index d7d1dda809d..dec7a4005e4 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.cpp @@ -96,15 +96,11 @@ void LayerNormPostAllGather::validate( } } -std::vector LayerNormPostAllGather::compute_output_shapes( - const std::vector& input_tensors) const { - return {input_tensors.at(0).get_logical_shape()}; -} - -std::vector LayerNormPostAllGather::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->memory_config); +std::vector LayerNormPostAllGather::compute_output_specs(const std::vector& input_tensors) const { + auto& input_tensor = input_tensors.at(0); + return {TensorSpec( + input_tensor.get_logical_shape(), + TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), memory_config))}; } operation::ProgramWithCallbacks LayerNormPostAllGather::create_program( diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.hpp b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.hpp index c1410ab96ba..a4edff38e7e 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.hpp @@ -35,8 +35,7 @@ struct LayerNormPostAllGather { void validate( const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, const std::vector>& optional_input_tensors, diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.cpp index 2b3f2a5c4c2..dceb72e35e2 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.cpp @@ -36,8 +36,7 @@ void LayerNormPreAllGather::validate(const std::vector& input_tensors) c TT_FATAL(tensor.buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); } -std::vector LayerNormPreAllGather::compute_output_shapes( - const std::vector& input_tensors) const { +std::vector LayerNormPreAllGather::compute_output_specs(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); auto output_shape = input_tensors.at(0).get_logical_shape(); @@ -47,13 +46,7 @@ std::vector LayerNormPreAllGather::compute_output_shapes( } output_shape[3] = num_tiles_w * TILE_WIDTH; - return {output_shape}; -} - -std::vector LayerNormPreAllGather::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - return operation::generic_create_output_tensors( - *this, input_tensors, this->dtype, Layout::TILE, input_tensor.memory_config()); + return {TensorSpec(output_shape, TensorLayout(dtype, PageConfig(Layout::TILE), input_tensor.memory_config()))}; } operation::ProgramWithCallbacks LayerNormPreAllGather::create_program( diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.hpp b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.hpp index 303f285603c..299f0328952 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.hpp @@ -28,8 +28,7 @@ struct LayerNormPreAllGather { const DeviceComputeKernelConfig compute_kernel_config; void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; }; diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp index 0ca94a69259..deadbf2423e 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp @@ -110,17 +110,22 @@ void Softmax::validate( } } -std::vector Softmax::compute_output_shapes(const std::vector& input_tensors) const { - return {input_tensors.at(0).get_legacy_shape()}; +std::vector Softmax::compute_output_specs(const std::vector& input_tensors) const { + auto& input_tensor = input_tensors.at(0); + if (this->inplace) { + return {input_tensor.get_tensor_spec()}; + } + return {TensorSpec( + input_tensor.get_logical_shape(), + TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; } std::vector Softmax::create_output_tensors(const std::vector& input_tensors) const { if (this->inplace) { return {input_tensors.at(0)}; - } else { - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensors.at(0).get_dtype(), Layout::TILE, this->output_mem_config); } + + return {create_device_tensor(compute_output_specs(input_tensors)[0], input_tensors.at(0).device())}; } operation::ProgramWithCallbacks Softmax::create_program( diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp index 121ed08720f..8f7f1fde1ae 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp @@ -30,7 +30,7 @@ struct Softmax { void validate( const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.cpp b/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.cpp index 6e94c79887f..1d9abd72cb4 100644 --- a/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.cpp @@ -30,27 +30,17 @@ void Downsample::validate(const std::vector& input_tensors) const { input_tensor_a.memory_config().memory_layout); } -std::vector Downsample::compute_output_shapes( - const std::vector& input_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - TT_ASSERT(input_tensor_a.get_legacy_shape()[0] == 1 && input_tensor_a.get_legacy_shape()[1] == 1); - uint32_t input_height = input_tensor_a.get_legacy_shape()[2]; +std::vector Downsample::compute_output_specs(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + TT_ASSERT(input_tensor.get_legacy_shape()[0] == 1 && input_tensor.get_legacy_shape()[1] == 1); + uint32_t input_height = input_tensor.get_legacy_shape()[2]; auto [img_batch_size, img_height, img_width, img_stride_h, img_stride_w] = this->downsample_params; TT_ASSERT(input_height >= img_batch_size * img_height * img_width); - uint32_t output_height_unpadded = img_batch_size * ceil((double)img_height / (double)img_stride_h) * - ceil((double)img_width / (double)img_stride_w); - uint32_t output_height = tt::round_up(output_height_unpadded, TILE_HEIGHT); - uint32_t output_width = input_tensor_a.get_legacy_shape()[3]; - auto output_padding = - Padding({{0, 0}, {0, 0}, {0, (output_height - output_height_unpadded)}, {0, 0}}, Padding::PadValue::Any); - auto output_tensor_shape = tt::tt_metal::LegacyShape({1, 1, output_height, output_width}, output_padding); - log_debug(tt::LogOp, "Downsample output shape: {}", output_tensor_shape); - return {output_tensor_shape}; -} + uint32_t output_height = img_batch_size * ceil((double)img_height / (double)img_stride_h) * + ceil((double)img_width / (double)img_stride_w); + uint32_t output_width = input_tensor.get_legacy_shape()[3]; + auto output_shape = SimpleShape({1, 1, output_height, output_width}); -std::vector Downsample::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - auto output_shape = this->compute_output_shapes(input_tensors).at(0); auto [num_cores_height_sliced, num_cores_width_sliced] = detail::get_num_cores_height_width_sliced( input_tensor.shard_spec().value().grid, input_tensor.memory_config().memory_layout, @@ -64,7 +54,7 @@ std::vector Downsample::create_output_tensors(const std::vector& input_tensor.shard_spec().value().grid, std::array{{output_shard_height, output_shard_width}}, input_tensor.shard_spec().value().orientation}; - return {create_device_tensor(output_shape, this->dtype, Layout::TILE, input_tensor.device(), mem_config)}; + return {TensorSpec(output_shape, TensorLayout(dtype, PageConfig(Layout::TILE), mem_config))}; } operation::ProgramWithCallbacks Downsample::create_program( diff --git a/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.hpp b/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.hpp index f8418c8ad89..eae1360e2e4 100644 --- a/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.hpp +++ b/ttnn/cpp/ttnn/operations/pool/downsample/device/downsample_op.hpp @@ -19,8 +19,7 @@ struct Downsample { std::array downsample_params; DataType dtype; void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; }; diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp index 253218a226e..e083618f536 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp @@ -49,7 +49,7 @@ void MaxPool2D::validate_on_program_cache_hit(const operation_attributes_t& op_a return validate_maxpool(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.memory_config_); } -MaxPool2D::shape_return_value_t MaxPool2D::compute_output_shapes( +MaxPool2D::spec_return_value_t MaxPool2D::compute_output_specs( const operation_attributes_t& op_attr, const tensor_args_t& tensors) { auto& input = tensors.input_tensor_; auto& sliding_window_config = op_attr.sliding_window_config_; @@ -65,34 +65,12 @@ MaxPool2D::shape_return_value_t MaxPool2D::compute_output_shapes( uint32_t out_h = sliding_window_config.get_output_shape()[1]; uint32_t out_w = sliding_window_config.get_output_shape()[2]; - bool is_out_tiled = output_dtype == DataType::BFLOAT8_B; - - // need to pad the last dim to TILE_WIDTH uint32_t out_c = input_shape[3]; - uint32_t out_c_padded = ceil_multiple_of(out_c, (out_c <= 16) ? 16 : tt::constants::TILE_WIDTH); - uint32_t out_pagesize = out_c_padded * datum_size(datatype_to_dataformat_converter(input.get_dtype())); uint32_t out_nhw = sliding_window_config.batch_size * out_h * out_w; - uint32_t out_nhw_padded = - tt::round_up(out_nhw, (is_out_tiled ? tt::constants::TILE_HEIGHT : 1) * sliding_window_config.num_cores_nhw); - // {1, 1, N * H * W, C} - const ttnn::SmallVector out_dims({1, 1, out_nhw_padded, out_c_padded}); - const auto padding = Padding( - {{0, 0}, {0, 0}, {0, out_nhw_padded - out_nhw}, {0, out_c_padded - out_c}}, - Padding::PadValue::NegativeInfinity); - auto out_shape = Shape(tt::tt_metal::LegacyShape(out_dims, padding)); - return out_shape; -} - -MaxPool2D::tensor_return_value_t MaxPool2D::create_output_tensors( - const operation_attributes_t& op_attr, const tensor_args_t& tensors) { - auto& input = tensors.input_tensor_; - auto& sliding_window_config = op_attr.sliding_window_config_; - auto& out_mem_config = op_attr.memory_config_; - auto& output_dtype = op_attr.output_dtype_; + auto output_shape = ttnn::SimpleShape({1, 1, out_nhw, out_c}); - Shape output_shape = compute_output_shapes(op_attr, tensors); auto mem_config = out_mem_config; if (mem_config.shard_spec.has_value()) { mem_config.shard_spec->shape[1] = input.shard_spec()->shape[1]; @@ -107,8 +85,12 @@ MaxPool2D::tensor_return_value_t MaxPool2D::create_output_tensors( mem_config.shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR, false}; } - // return create_device_tensor(output_shape, input.get_dtype(), input.get_layout(), input.device(), mem_config); - return create_device_tensor(output_shape, output_dtype, input.get_layout(), input.device(), mem_config); + return TensorSpec(output_shape, TensorLayout(output_dtype, input.tensor_spec().page_config(), mem_config)); +} + +MaxPool2D::tensor_return_value_t MaxPool2D::create_output_tensors( + const operation_attributes_t& op_attr, const tensor_args_t& tensors) { + return create_device_tensor(compute_output_specs(op_attr, tensors), tensors.input_tensor_.device()); } tt::stl::hash::hash_t MaxPool2D::compute_program_hash( diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.hpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.hpp index f6ddce38a20..3b67e9b19ae 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.hpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.hpp @@ -32,7 +32,7 @@ struct MaxPool2D { const Tensor& input_tensor_; }; - using shape_return_value_t = ttnn::Shape; + using spec_return_value_t = TensorSpec; using tensor_return_value_t = Tensor; struct MultiCore { @@ -63,7 +63,7 @@ struct MaxPool2D { static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); static Tensor create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static tt::stl::hash::hash_t compute_program_hash(const operation_attributes_t&, const tensor_args_t&); static operation::OpPerformanceModel create_op_performance_model( diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp index 3a313e803cc..7aa66dc0fd7 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp @@ -45,7 +45,7 @@ void UpSample::validate(const std::vector& input_tensors) const { } } -std::vector UpSample::compute_output_shapes(const std::vector& input_tensors) const { +std::vector UpSample::compute_output_specs(const std::vector& input_tensors) const { // NOTE1: data is packed in { N, H , W, C } // NOTE2: Mapping it into in 2D format should be {N*H*W, C} // NOTE3: Assuming output data type is same as input @@ -56,55 +56,50 @@ std::vector UpSample::compute_output_shapes(const std uint32_t out_h = input_shape[1] * scale_factor_h_; uint32_t out_w = input_shape[2] * scale_factor_w_; uint32_t out_c = input_shape[3]; - const ttnn::SmallVector out_dims({out_n, out_h, out_w, out_c}); // in the NHWC format - return {tt::tt_metal::LegacyShape{out_dims}}; -} + auto output_shape = ttnn::SimpleShape({out_n, out_h, out_w, out_c}); -std::vector UpSample::create_output_tensors(const std::vector& inputs) const { - const auto& input = inputs.at(0); if (output_mem_config_.is_sharded()) { - if (input.memory_config().is_sharded()) { - auto mem_config = output_mem_config_; - auto input_shard_spec = input.memory_config().shard_spec.value(); - auto output_shape = compute_output_shapes(inputs).at(0); - if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - auto ncores = input_shard_spec.num_cores(); - std::array output_shard_shape = { - div_up(output_shape[0] * output_shape[1] * output_shape[2], ncores), output_shape[-1]}; - auto output_shard_spec = input_shard_spec; - output_shard_spec.shape = output_shard_shape; - mem_config.shard_spec = output_shard_spec; - log_debug(LogOp, "output_shard_shape: {}", output_shard_shape); - log_debug(LogOp, "output_shard_spec: {}", output_shard_spec); - return {create_device_tensor( - output_shape, input.get_dtype(), input.get_layout(), input.device(), mem_config)}; - } else if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - auto shard_grid = input_shard_spec.grid.ranges(); - TT_FATAL(shard_grid.size() == 1, "Block sharded input should have only one CoreRange"); - auto core_range = *shard_grid.begin(); - uint32_t ncores_w = core_range.end_coord.x + 1; - uint32_t ncores_h = core_range.end_coord.y + 1; - // std::array output_shard_shape = {output_shape[0] * output_shape[1] * output_shape[2] / - // ncores_h, output_shape[-1] / ncores_w}; auto output_shard_spec = input_shard_spec; - // output_shard_spec.shape = output_shard_shape; - // mem_config.shard_spec = output_shard_spec; - auto output_shard_spec = mem_config.shard_spec.value(); - auto output_shard_shape = output_shard_spec.shape; - log_debug(LogOp, "ncores_w, ncores_h: {} {}", ncores_w, ncores_h); - log_debug(LogOp, "output_shard_shape: {}", output_shard_shape); - return {create_device_tensor( - output_shape, input.get_dtype(), input.get_layout(), input.device(), mem_config)}; - } else { - TT_THROW("input memory config is not HEIGHT or BLOCK sharded"); - } - } else { + if (!input.memory_config().is_sharded()) { TT_THROW("Output memory config is sharded but input memory config is not sharded"); } - } else { - return operation::generic_create_output_tensors( - *this, inputs, input.get_dtype(), input.get_layout(), output_mem_config_); + auto mem_config = output_mem_config_; + auto input_shard_spec = input.memory_config().shard_spec.value(); + if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + auto ncores = input_shard_spec.num_cores(); + std::array output_shard_shape = { + div_up(output_shape[0] * output_shape[1] * output_shape[2], ncores), output_shape[-1]}; + auto output_shard_spec = input_shard_spec; + output_shard_spec.shape = output_shard_shape; + mem_config.shard_spec = output_shard_spec; + log_debug(LogOp, "output_shard_shape: {}", output_shard_shape); + log_debug(LogOp, "output_shard_spec: {}", output_shard_spec); + return { + TensorSpec(output_shape, TensorLayout(input.get_dtype(), PageConfig(input.get_layout()), mem_config))}; + } + if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { + auto shard_grid = input_shard_spec.grid.ranges(); + TT_FATAL(shard_grid.size() == 1, "Block sharded input should have only one CoreRange"); + auto core_range = *shard_grid.begin(); + uint32_t ncores_w = core_range.end_coord.x + 1; + uint32_t ncores_h = core_range.end_coord.y + 1; + // std::array output_shard_shape = {output_shape[0] * output_shape[1] * output_shape[2] / + // ncores_h, output_shape[-1] / ncores_w}; auto output_shard_spec = input_shard_spec; + // output_shard_spec.shape = output_shard_shape; + // mem_config.shard_spec = output_shard_spec; + auto output_shard_spec = mem_config.shard_spec.value(); + auto output_shard_shape = output_shard_spec.shape; + log_debug(LogOp, "ncores_w, ncores_h: {} {}", ncores_w, ncores_h); + log_debug(LogOp, "output_shard_shape: {}", output_shard_shape); + return { + TensorSpec(output_shape, TensorLayout(input.get_dtype(), PageConfig(input.get_layout()), mem_config))}; + } + + TT_THROW("input memory config is not HEIGHT or BLOCK sharded"); } + + return { + TensorSpec(output_shape, TensorLayout(input.get_dtype(), PageConfig(input.get_layout()), output_mem_config_))}; } operation::ProgramWithCallbacks UpSample::create_program( diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.hpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.hpp index f07c9f8b472..be183519d98 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.hpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.hpp @@ -20,8 +20,7 @@ struct UpSample { const DeviceComputeKernelConfig compute_kernel_config_; void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; UpSampleParallelizationStrategy get_parallelization_strategy(const std::vector& input_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp b/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp index ea1094216b5..53bd54fdb44 100644 --- a/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp @@ -48,13 +48,20 @@ void ArgMax::validate_with_output_tensors( TT_FATAL(input_shape[1] == 1, "dim 1 must be 1"); } -std::vector ArgMax::compute_output_shapes(const std::vector& input_tensors) const { +std::vector ArgMax::compute_output_specs( + const std::vector& input_tensors, const std::vector>& output_tensors) const { + if (output_tensors.at(0).has_value()) { + return {output_tensors.at(0)->get_tensor_spec()}; + } + + const auto& input_tensor = input_tensors[0]; + ttnn::SimpleShape output_shape({1, 1, 1, 1}); if (this->dim.has_value()) { auto input_shape = input_tensors[0].get_logical_shape(); - return {ttnn::SimpleShape{input_shape[0], input_shape[1], 1, input_shape[2]}}; - } else { - return {ttnn::SimpleShape{1, 1, 1, 1}}; + output_shape = ttnn::SimpleShape{input_shape[0], input_shape[1], 1, input_shape[2]}; } + return { + TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(input_tensor.get_layout()), output_mem_config))}; } std::vector ArgMax::create_output_tensors( @@ -63,9 +70,7 @@ std::vector ArgMax::create_output_tensors( return {output_tensors.at(0).value()}; } - const auto& input_tensor = input_tensors[0]; - return operation::generic_create_output_tensors( - *this, input_tensors, this->output_dtype, input_tensor.get_layout(), this->output_mem_config); + return {create_device_tensor(compute_output_specs(input_tensors, output_tensors)[0], input_tensors[0].device())}; } operation::ProgramWithCallbacks ArgMax::create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.hpp b/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.hpp index dce8dcb8181..d0b6fa0b858 100644 --- a/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.hpp @@ -20,7 +20,8 @@ struct ArgMax { void validate_with_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs( + const std::vector& input_tensors, const std::vector>& output_tensors) const; std::vector create_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.cpp b/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.cpp index 1bfb6453dc5..3cdb60eed74 100644 --- a/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.cpp @@ -43,11 +43,19 @@ void MoeDeviceOperation::validate_with_output_tensors( TT_FATAL(expert_shape[-2] == 32, "Expert shape inner dim must be equal to 32, got {}", expert_shape[-2]); } -std::vector MoeDeviceOperation::compute_output_shapes( - const std::vector& input_tensors) const { - auto output_shape = input_tensors.at(0).get_logical_shape(); +std::vector MoeDeviceOperation::compute_output_specs( + const std::vector& input_tensors, const std::vector>& output_tensors) const { + if (output_tensors.size() == 1) { + if (output_tensors.at(0).has_value()) { + return {output_tensors[0]->get_tensor_spec()}; + } + } + + auto& input_tensor = input_tensors.at(0); + auto output_shape = input_tensor.get_logical_shape(); output_shape[-1] = 1; - return {output_shape}; + return { + TensorSpec(output_shape, TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; } std::vector MoeDeviceOperation::create_output_tensors( @@ -57,8 +65,7 @@ std::vector MoeDeviceOperation::create_output_tensors( return {output_tensors[0].value()}; } } - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensors.at(0).get_dtype(), Layout::TILE, this->output_mem_config); + return {create_device_tensor(compute_output_specs(input_tensors, output_tensors)[0], input_tensors.at(0).device())}; } operation::ProgramWithCallbacks MoeDeviceOperation::create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.hpp b/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.hpp index 8a95de49817..f380dbb0a87 100644 --- a/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/moe/device/moe_op.hpp @@ -17,7 +17,8 @@ struct MoeDeviceOperation { void validate_with_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs( + const std::vector& input_tensors, const std::vector>& output_tensors) const; std::vector create_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp index 3fb869370a4..1bd009cccc1 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp @@ -44,7 +44,7 @@ std::vector Prod::create_output_tensors(const std::vector& input return {}; } -std::vector Prod::compute_output_shapes(const std::vector& inputs) const { +std::vector Prod::compute_output_specs(const std::vector&) const { // Inplace return {}; } diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp index 3406689741b..e061c1ddeb1 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp @@ -23,7 +23,7 @@ using namespace tt_metal; struct Prod { int64_t dim; void validate(const std::vector& inputs) const; - std::vector compute_output_shapes(const std::vector& inputs) const; + std::vector compute_output_specs(const std::vector& inputs) const; std::vector create_output_tensors(const std::vector& inputs) const; operation::ProgramWithCallbacks create_program( const std::vector& inputs, std::vector& outputs) const; diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.cpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.cpp index b4e4c7e0a42..42da8fbc1fd 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.cpp @@ -26,14 +26,11 @@ void Prod_op::validate(const std::vector& input_tensors) const { TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16, "Error"); } -std::vector Prod_op::compute_output_shapes(const std::vector& input_tensors) const { - return {input_tensors.at(0).get_logical_shape()}; -} - -std::vector Prod_op::create_output_tensors(const std::vector& input_tensors) const { +std::vector Prod_op::compute_output_specs(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config); + return {TensorSpec( + input_tensor.get_logical_shape(), + TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; } operation::ProgramWithCallbacks Prod_op::create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.hpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.hpp index 066ffdb9a6c..aea785c5f25 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_op_all.hpp @@ -22,8 +22,7 @@ struct Prod_op { const MemoryConfig output_mem_config; const DataType output_dtype; // TODO: Uplift output_dtype as an option for general dot/bmm void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; }; diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp index adbe2f6101f..8d96b25ffec 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp @@ -85,10 +85,22 @@ void TopK::validate_with_output_tensors( } } -std::vector TopK::compute_output_shapes(const std::vector& input_tensors) const { +std::vector TopK::compute_output_specs( + const std::vector& input_tensors, const std::vector>& output_tensors) const { + if (output_tensors.size() == 2) { + if (output_tensors.at(0).has_value() && output_tensors.at(1).has_value()) { + return {output_tensors[0]->get_tensor_spec(), output_tensors[1]->get_tensor_spec()}; + } + } + const auto& input_tensor = input_tensors.at(0); auto output_shape = input_tensors.at(0).get_logical_shape(); output_shape[-1] = this->k; - return {output_shape, output_shape}; + + auto values_spec = + TensorSpec(output_shape, TensorLayout(input_tensor.get_dtype(), PageConfig(Layout::TILE), output_mem_config)); + auto index_spec = + TensorSpec(output_shape, TensorLayout(DataType::UINT16, PageConfig(Layout::TILE), output_mem_config)); + return {values_spec, index_spec}; } std::vector TopK::create_output_tensors( @@ -98,13 +110,12 @@ std::vector TopK::create_output_tensors( return {output_tensors[0].value(), output_tensors[1].value()}; } } + auto output_specs = compute_output_specs(input_tensors, output_tensors); const auto& input_tensor = input_tensors.at(0); - const auto shapes = compute_output_shapes(input_tensors); - auto values_tensor = create_device_tensor( - shapes[0], input_tensor.get_dtype(), Layout::TILE, input_tensor.device(), this->output_mem_config); - auto index_tensor = - create_device_tensor(shapes[1], DataType::UINT16, Layout::TILE, input_tensor.device(), this->output_mem_config); - return {values_tensor, index_tensor}; + return { + create_device_tensor(output_specs[0], input_tensor.device()), + create_device_tensor(output_specs[1], input_tensor.device()), + }; } operation::ProgramWithCallbacks TopK::create_program( diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.hpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.hpp index 473da98ee4d..dac6dbc6766 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.hpp @@ -20,7 +20,8 @@ struct TopK { void validate_with_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs( + const std::vector& input_tensors, const std::vector>& output_tensors) const; std::vector create_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp index 5a4048afc5a..0708eb6645d 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp @@ -116,14 +116,11 @@ void ScaledDotProductAttention::validate( } } -std::vector ScaledDotProductAttention::compute_output_shapes( +std::vector ScaledDotProductAttention::compute_output_specs( const std::vector& input_tensors) const { - return {input_tensors.at(0).get_legacy_shape()}; -} - -std::vector ScaledDotProductAttention::create_output_tensors(const std::vector& input_tensors) const { - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensors.at(0).get_dtype(), Layout::TILE, this->output_mem_config); + auto& input = input_tensors.at(0); + return {TensorSpec( + input.get_logical_shape(), TensorLayout(input.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; } operation::ProgramWithCallbacks ScaledDotProductAttention::create_program( diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp index 812b88ea222..dfe88389084 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp @@ -24,9 +24,7 @@ struct ScaledDotProductAttention { const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp index ae63c4e50f4..617a6b95b95 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp @@ -202,15 +202,11 @@ void ScaledDotProductAttentionDecode::validate( } } -std::vector ScaledDotProductAttentionDecode::compute_output_shapes( +std::vector ScaledDotProductAttentionDecode::compute_output_specs( const std::vector& input_tensors) const { - return {input_tensors.at(0).get_logical_shape()}; -} - -std::vector ScaledDotProductAttentionDecode::create_output_tensors( - const std::vector& input_tensors) const { - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensors.at(0).get_dtype(), Layout::TILE, this->output_mem_config); + auto& input = input_tensors.at(0); + return {TensorSpec( + input.get_logical_shape(), TensorLayout(input.get_dtype(), PageConfig(Layout::TILE), output_mem_config))}; } operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program( diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp index b892dd775f4..30c5c4ed999 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp @@ -28,9 +28,7 @@ struct ScaledDotProductAttentionDecode { const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/run_operation.cpp b/ttnn/cpp/ttnn/run_operation.cpp index dac090cdbab..4e188ff2a04 100644 --- a/ttnn/cpp/ttnn/run_operation.cpp +++ b/ttnn/cpp/ttnn/run_operation.cpp @@ -380,7 +380,7 @@ Tensors run_with_autoformat( } } - auto output_specs = operation.compute_output_shapes(input_tensors); + auto output_specs = operation.compute_output_shapes(input_tensors, optional_output_tensors); auto output_tensors = run( std::move(operation), formatted_input_tensors, @@ -453,7 +453,7 @@ Tensors run_with_autoformat( } } - auto output_specs = operation.compute_output_shapes(input_tensors); + auto output_specs = operation.compute_output_shapes(input_tensors, optional_output_tensors); auto output_tensors = run( std::move(operation), formatted_input_tensors, From cb53bc03ebf6baa4bcdef22176b963b9d6f60cae Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Tue, 3 Dec 2024 04:43:41 +0000 Subject: [PATCH 2/5] #0: Review fix --- ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp index 7aa66dc0fd7..3e0c54f5d33 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp @@ -83,10 +83,6 @@ std::vector UpSample::compute_output_specs(const std::vector auto core_range = *shard_grid.begin(); uint32_t ncores_w = core_range.end_coord.x + 1; uint32_t ncores_h = core_range.end_coord.y + 1; - // std::array output_shard_shape = {output_shape[0] * output_shape[1] * output_shape[2] / - // ncores_h, output_shape[-1] / ncores_w}; auto output_shard_spec = input_shard_spec; - // output_shard_spec.shape = output_shard_shape; - // mem_config.shard_spec = output_shard_spec; auto output_shard_spec = mem_config.shard_spec.value(); auto output_shard_shape = output_shard_spec.shape; log_debug(LogOp, "ncores_w, ncores_h: {} {}", ncores_w, ncores_h); From 2676d2b6ad596f412e997fe6b6fd18c311e7c9e3 Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Tue, 3 Dec 2024 13:39:02 +0000 Subject: [PATCH 3/5] #0: Fix layernorm --- .../layernorm/device/layernorm_op.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp index e95bee08c1d..6a023e99adc 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp @@ -244,8 +244,19 @@ std::vector LayerNorm::compute_output_specs(const std::vectorprogram_config); } std::vector LayerNorm::create_output_tensors(const std::vector& input_tensors) const { - auto output_spec = compute_output_specs(input_tensors)[0]; - return {create_device_tensor(output_spec, input_tensors.at(0).device())}; + return std::visit( + [&](const auto& program_config) -> std::vector { + using ProgramConfigType = std::decay_t; + if constexpr (std::is_same_v) { + if (this->distributed_norm_stage != DistributedLayerNormStage::PRE_ALL_GATHER && + program_config.inplace) { + return {input_tensors.at(0)}; + } + } + auto output_spec = compute_output_specs(input_tensors)[0]; + return {create_device_tensor(output_spec, input_tensors.at(0).device())}; + }, + this->program_config); } operation::ProgramWithCallbacks LayerNorm::create_program( const std::vector& input_tensors, From 9d259290b714d1be87bb840e8db85274a14d5341 Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Tue, 3 Dec 2024 14:00:28 +0000 Subject: [PATCH 4/5] #0: Fix pool --- .../pool/generic/device/pool_op.cpp | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp b/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp index 4eaedd3f725..1e60cd13c13 100644 --- a/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp @@ -66,11 +66,22 @@ Pool2D::spec_return_value_t Pool2D::compute_output_specs( uint32_t out_h = sliding_window_config.get_output_shape()[1]; uint32_t out_w = sliding_window_config.get_output_shape()[2]; + bool is_out_tiled = output_dtype == DataType::BFLOAT8_B; + + // need to pad the last dim to TILE_WIDTH uint32_t out_c = input_shape[3]; + uint32_t out_c_padded = tt::round_up(out_c, (out_c <= 16) ? 16 : tt::constants::TILE_WIDTH); uint32_t out_nhw = sliding_window_config.batch_size * out_h * out_w; + uint32_t out_nhw_padded = + tt::round_up(out_nhw, (is_out_tiled ? tt::constants::TILE_HEIGHT : 1) * sliding_window_config.num_cores_nhw); + // {1, 1, N * H * W, C} - auto output_shape = ttnn::SimpleShape({1, 1, out_nhw, out_c}); + const ttnn::SmallVector out_dims({1, 1, out_nhw_padded, out_c_padded}); + const auto padding = Padding( + {{0, 0}, {0, 0}, {0, out_nhw_padded - out_nhw}, {0, out_c_padded - out_c}}, + Padding::PadValue::NegativeInfinity); + auto output_shape = Shape(tt::tt_metal::LegacyShape(out_dims, padding)); auto mem_config = out_mem_config; if (mem_config.shard_spec.has_value()) { @@ -78,20 +89,21 @@ Pool2D::spec_return_value_t Pool2D::compute_output_specs( } else { uint32_t ncores = input.shard_spec().value().num_cores(); TT_FATAL(ncores == sliding_window_config.num_cores_nhw, "Number of cores should match"); - uint32_t nbatch = output_shape[0]; - uint32_t out_nhw_padded = output_shape[0] * output_shape[1] * output_shape[2]; uint32_t out_nhw_per_core = out_nhw_padded / ncores; CoreRangeSet shard_grid = sliding_window_config.core_range_set; std::array shard_shape = {out_nhw_per_core, input.get_legacy_shape()[-1]}; mem_config.shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR, false}; } - return TensorSpec(output_shape, TensorLayout(output_dtype, input.tensor_spec().page_config(), mem_config)); + return TensorSpec( + output_shape.logical_shape(), + TensorLayout::fromLegacyPaddedShape(output_dtype, PageConfig(input.get_layout()), mem_config, output_shape)); } Pool2D::tensor_return_value_t Pool2D::create_output_tensors( const operation_attributes_t& op_attr, const tensor_args_t& tensors) { - return create_device_tensor(compute_output_specs(op_attr, tensors), tensors.input_tensor_.device()); + auto output_spec = compute_output_specs(op_attr, tensors); + return create_device_tensor(output_spec, tensors.input_tensor_.device()); } tt::stl::hash::hash_t Pool2D::compute_program_hash( From 0b8a34e5cda15531b31cf74aeead8dfd0b06d562 Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Tue, 3 Dec 2024 19:25:24 +0000 Subject: [PATCH 5/5] #0: Pool fix --- ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp b/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp index 1e60cd13c13..98a972b0fc7 100644 --- a/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp @@ -89,7 +89,7 @@ Pool2D::spec_return_value_t Pool2D::compute_output_specs( } else { uint32_t ncores = input.shard_spec().value().num_cores(); TT_FATAL(ncores == sliding_window_config.num_cores_nhw, "Number of cores should match"); - uint32_t out_nhw_per_core = out_nhw_padded / ncores; + uint32_t out_nhw_per_core = output_shape[0] * output_shape[1] * output_shape[2] / ncores; CoreRangeSet shard_grid = sliding_window_config.core_range_set; std::array shard_shape = {out_nhw_per_core, input.get_legacy_shape()[-1]}; mem_config.shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR, false};