From a4d01f9da55bc4ea372d57cbe39ffe1eb5123313 Mon Sep 17 00:00:00 2001 From: Jay Kruer Date: Sat, 14 Dec 2024 00:00:07 +0000 Subject: [PATCH] #0: Address more Stas comments --- .../data_movement/common/common.cpp | 129 +++++++++++------- .../data_movement/common/common.hpp | 13 +- .../reader_pad_dims_rm_sharded_stickwise.cpp | 3 + .../ttnn/operations/data_movement/pad/pad.cpp | 23 +--- ttnn/cpp/ttnn/tensor/shape/shape_base.hpp | 1 + 5 files changed, 98 insertions(+), 71 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp b/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp index 666fbce0a037..43882245d0b8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp @@ -56,81 +56,108 @@ ttnn::Tensor pad_to_tile_vol( } uint32_t wrap_index(int index, int size) { return index < 0 ? size + index : index; } +std::array compute_block_sharded_shard_shape(const std::array& squeezed_tensor_hw, + const tt::tt_metal::Layout& layout, + const tt::tt_metal::CoreCoord& grid_size, + const tt::tt_metal::ShardOrientation& orientation, + const uint32_t total_num_cores) { + TT_FATAL(grid_size.y * grid_size.x == total_num_cores, "compute_block_sharded_shard_shape received a core grid shape that does not match the total number of cores"); + auto adjusted_grid_size = grid_size; + if (orientation == tt::tt_metal::ShardOrientation::COL_MAJOR) { + // for col major, we partition the width of the tensor along the height of the core grid + std::swap(adjusted_grid_size.x, adjusted_grid_size.y); + } + + auto [tensor_height, tensor_width] = squeezed_tensor_hw; + auto tensor_height_padded_to_tile = layout == tt::tt_metal::Layout::TILE + ? tt::round_up(tensor_height, adjusted_grid_size.y * tt::constants::TILE_HEIGHT) + : tensor_height; + std::array shard_shape = {tt::div_up(tensor_height_padded_to_tile, adjusted_grid_size.y), + tt::div_up(tensor_width, adjusted_grid_size.x)}; + + return shard_shape; +} + +std::array compute_width_sharded_shard_shape(const std::array& squeezed_tensor_hw, + const uint32_t total_num_cores) { + return {squeezed_tensor_hw[0], tt::div_up(squeezed_tensor_hw[1], total_num_cores)}; +} + +std::array compute_height_sharded_shard_shape(const std::array& squeezed_tensor_hw, + const tt::tt_metal::Layout& layout, + const uint32_t total_num_cores) { + auto [tensor_height, tensor_width] = squeezed_tensor_hw; + auto squeezed_height_padded_to_tile = layout == tt::tt_metal::Layout::TILE + ? tt::round_up(tensor_height, total_num_cores) + : tensor_height; + return {tt::div_up(squeezed_height_padded_to_tile, total_num_cores), tensor_width}; +} + ttnn::MemoryConfig create_sharded_memory_config( - const ttnn::Shape& shape, + const ttnn::SimpleShape& logical_shape, const tt::tt_metal::CoreRangeSet& core_grid, const ShardStrategy& strategy, const tt::tt_metal::ShardOrientation& orientation, - bool halo, - bool use_height_and_width_as_shard_shape, - const tt::tt_metal::Layout& layout) { - auto is_tile_layout = layout == tt::tt_metal::Layout::TILE; - - auto rank = shape.rank(); + std::optional> shard_shape, + const tt::tt_metal::Layout& layout, + bool halo) { + auto rank = logical_shape.rank(); TT_FATAL(rank >= 2, "rank of tensor to shard must be at least 2."); - auto tensor_memory_layout = ttnn::TensorMemoryLayout::BLOCK_SHARDED; - if (strategy == ShardStrategy::WIDTH) { + ttnn::TensorMemoryLayout tensor_memory_layout; + if (strategy == ShardStrategy::BLOCK) { + tensor_memory_layout = ttnn::TensorMemoryLayout::BLOCK_SHARDED; + } else if (strategy == ShardStrategy::WIDTH) { tensor_memory_layout = ttnn::TensorMemoryLayout::WIDTH_SHARDED; } else if (strategy == ShardStrategy::HEIGHT) { tensor_memory_layout = ttnn::TensorMemoryLayout::HEIGHT_SHARDED; } - auto shard_orientation = orientation; - auto shard_grid = core_grid; + auto height = logical_shape[-2]; + auto width = logical_shape[-1]; + std::array computed_shard_shape; - auto height = shape[-2]; - auto width = shape[-1]; - std::array shard_shape; - - if (use_height_and_width_as_shard_shape) { - if (shard_orientation == tt::tt_metal::ShardOrientation::ROW_MAJOR) { - shard_shape = {height, width}; - } else if (shard_orientation == tt::tt_metal::ShardOrientation::COL_MAJOR) { - shard_shape = {width, height}; - } else { - TT_THROW("Invalid shard orientation"); - } + if (shard_shape.has_value()) { + computed_shard_shape = shard_shape.value(); } else { uint32_t batch_size = 1; for (int i = 0; i < rank - 2; i++) { - batch_size *= shape[i]; + batch_size *= logical_shape[i]; } auto tensor_height = batch_size * height; auto tensor_width = width; - auto total_num_cores = shard_grid.num_cores(); - auto grid_size = shard_grid.bounding_box().grid_size(); - - if (tensor_memory_layout == ttnn::TensorMemoryLayout::BLOCK_SHARDED) { - TT_ASSERT(grid_size.y * grid_size.x == total_num_cores, "Invalid CoreRangeSet for block sharding strategy"); - - if (shard_orientation == tt::tt_metal::ShardOrientation::ROW_MAJOR) { - auto tensor_height_padded = - is_tile_layout ? tt::round_up(tensor_height, grid_size.y * 32) : tensor_height; - shard_shape = {tt::div_up(tensor_height_padded, grid_size.y), tt::div_up(tensor_width, grid_size.x)}; - } else if (shard_orientation == tt::tt_metal::ShardOrientation::COL_MAJOR) { - auto tensor_height_padded = - is_tile_layout ? tt::round_up(tensor_height, grid_size.x * 32) : tensor_height; - shard_shape = {tt::div_up(tensor_height_padded, grid_size.x), tt::div_up(tensor_width, grid_size.y)}; - } else { - TT_THROW("Invalid shard orientation"); - } - } else if (tensor_memory_layout == ttnn::TensorMemoryLayout::HEIGHT_SHARDED) { - auto tensor_height_padded = is_tile_layout ? tt::round_up(tensor_height, total_num_cores) : tensor_height; - shard_shape = {tt::div_up(tensor_height_padded, total_num_cores), tensor_width}; - } else if (tensor_memory_layout == ttnn::TensorMemoryLayout::WIDTH_SHARDED) { - shard_shape = {tensor_height, tt::div_up(tensor_width, total_num_cores)}; - } else { - TT_THROW("Invalid sharding scheme"); + std::array squeezed_tensor_hw{tensor_height, tensor_width}; + auto total_num_cores = core_grid.num_cores(); + CoreCoord grid_size = core_grid.bounding_box().grid_size(); + + switch (strategy) { + case ShardStrategy::BLOCK: + computed_shard_shape = compute_block_sharded_shard_shape(squeezed_tensor_hw, layout, grid_size, orientation, total_num_cores); + break; + case ShardStrategy::WIDTH: + computed_shard_shape = compute_width_sharded_shard_shape(squeezed_tensor_hw, total_num_cores); + break; + case ShardStrategy::HEIGHT: + computed_shard_shape = compute_height_sharded_shard_shape(squeezed_tensor_hw, layout, total_num_cores); + break; + default: + TT_ASSERT(false, "Invalid shard strategy"); } } - if (is_tile_layout && shard_shape[0] % 32 != 0 && shard_shape[1] % 32 != 0) { - TT_THROW("For sharding tiled tensors, the shard shape must fit neatly into tiles."); + if (layout == tt::tt_metal::Layout::TILE) { + auto [shard_height, shard_width] = computed_shard_shape; + auto tile_divides_shard_height = shard_height % tt::constants::TILE_HEIGHT == 0; + auto tile_divides_shard_width = shard_width % tt::constants::TILE_WIDTH == 0; + TT_FATAL(tile_divides_shard_width && tile_divides_shard_height, + "For sharding tiled tensors, the shard shape must fit neatly into tiles but " + "create_sharded_memory_config got shard width {} and shard height {} while " + "on this architecture we have tile width {} and tile height {}", + computed_shard_shape[0], computed_shard_shape[1], tt::constants::TILE_WIDTH, tt::constants::TILE_HEIGHT); } - auto shard_spec = tt::tt_metal::ShardSpec(shard_grid, shard_shape, shard_orientation, halo); + auto shard_spec = tt::tt_metal::ShardSpec(core_grid, computed_shard_shape, orientation, halo); return ttnn::MemoryConfig(tensor_memory_layout, ttnn::BufferType::L1, shard_spec); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp index 2630a6311303..789388284485 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp @@ -158,14 +158,19 @@ ttnn::Tensor pad_to_tile_vol( enum class ShardStrategy { BLOCK, HEIGHT, WIDTH }; +// Helper function for creating a sharded memory configuration for a tensor +// based on its logical shape, a shard strategy and orientation, and a core +// grid. Optionally, you may pass a preferred shard shape to use. If not +// provided, the shard shape will be inferred from the tensor shape and the +// shard strategy. ttnn::MemoryConfig create_sharded_memory_config( - const ttnn::Shape& shape, + const ttnn::SimpleShape& logical_shape, const tt::tt_metal::CoreRangeSet& core_grid, const ShardStrategy& strategy, const tt::tt_metal::ShardOrientation& orientation, - bool halo = false, - bool use_height_and_width_as_shard_shape = false, - const tt::tt_metal::Layout& layout = tt::tt_metal::Layout::ROW_MAJOR); + std::optional> shard_shape = std::nullopt, + const tt::tt_metal::Layout& layout = tt::tt_metal::Layout::ROW_MAJOR, + bool halo = false); std::pair> tensor_coord_to_height_sharded_coord( const std::span& tensor_shape, diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_sharded_stickwise.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_sharded_stickwise.cpp index 378c5a7f70b5..831707d0bf00 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_sharded_stickwise.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_sharded_stickwise.cpp @@ -40,6 +40,9 @@ void kernel_main() { // optimization (upcoming as of 12/12/2024) this might be worth // investigating. + // paulk says that an optimized loop will still be faster. + // TODO(jkruer): get paul's help optimizing this. + // read the input stick into the padded output stick starting after the // front padding for (uint32_t i = 0; i < unpadded_stick_bytes; i++) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp index efe219953c0d..28ec0e1b2197 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp @@ -53,7 +53,7 @@ static ttnn::Tensor pad_impl( auto input_tensor_shape = input_tensor.get_shape(); const auto rank = input_tensor_shape.rank(); - TT_FATAL(rank == 4, "ttnn.pad: input tensor rank is not 4"); + TT_ASSERT(rank == 4, "ttnn.pad: input tensor passed to pad_impl must have rank == 4."); using ShardStrategy = ttnn::operations::data_movement::ShardStrategy; using ShardOrientation = tt::tt_metal::ShardOrientation; @@ -73,24 +73,18 @@ static ttnn::Tensor pad_impl( auto width_distinct = [](const auto& shape, const auto& other_shape) { return shape[3] != other_shape[3]; }; uint32_t input_w = input_logical_shape[3]; - uint32_t output_w = output_padded_shape[3]; if (width_distinct(input_logical_shape, output_padded_shape)) { - ttnn::SmallVector output_shape_width_padded{ - input_logical_shape.begin(), input_logical_shape.end() - 1}; - output_shape_width_padded.push_back(output_w); - + std::array output_shape_width_padded{ + input_logical_shape[0], input_logical_shape[1], input_logical_shape[2], output_w}; auto width_pad_memory_config = create_sharded_memory_config( - ttnn::Shape{output_shape_width_padded}, + ttnn::SimpleShape{output_shape_width_padded}, input_tensor.shard_spec()->grid, // reuse input cores for now: FIXME: can we do better? // it's complicated because we need the input shards to be local // to the core holding the output shard currently. ShardStrategy::HEIGHT, // stay height sharded - ShardOrientation::ROW_MAJOR, - false, - false, - Layout::ROW_MAJOR); + ShardOrientation::ROW_MAJOR); output_memory_config = width_pad_memory_config; if (height_distinct(input_logical_shape, output_padded_shape)) { @@ -119,13 +113,10 @@ static ttnn::Tensor pad_impl( "infinite recursion"); auto height_pad_memory_config = create_sharded_memory_config( - ttnn::Shape{output_padded_shape}, + ttnn::SimpleShape{output_padded_shape}, input_tensor.shard_spec()->grid, ShardStrategy::HEIGHT, - ShardOrientation::ROW_MAJOR, - false, - false, - Layout::ROW_MAJOR); + ShardOrientation::ROW_MAJOR); // then pad height auto output_tensor_height_padded = pad_impl( diff --git a/ttnn/cpp/ttnn/tensor/shape/shape_base.hpp b/ttnn/cpp/ttnn/tensor/shape/shape_base.hpp index 113d5c3ec1ba..3863210a4f2d 100644 --- a/ttnn/cpp/ttnn/tensor/shape/shape_base.hpp +++ b/ttnn/cpp/ttnn/tensor/shape/shape_base.hpp @@ -24,6 +24,7 @@ class ShapeBase { explicit ShapeBase(const std::array& arr) : value_(arr.begin(), arr.end()) { init(); } + explicit ShapeBase(std::span span) : value_(span.begin(), span.end()) { init(); } template bool operator==(const std::array& other) const {