Skip to content

Commit

Permalink
#13541: Allow block sharding for mm convs
Browse files Browse the repository at this point in the history
Allow auto shard to select block sharding convs that map to ttnn::matmul.
Workaround the limitation of ttnn::tilize which can't tilize block sharded tensors,
but running ttnn::tilize while tensor is still in dram interleaved layout.
  • Loading branch information
Pavle Josipovic authored and pavlejosipovic committed Oct 21, 2024
1 parent 03b9615 commit 3cffcc5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
# Contains following params
# [batch_size, output_channels, input_channels, input_height, input_width, kernel_height, kernel_width, stride_x, stride_y, pad_x, pad_y, groups, bias, dilation]
[1, 32, 1, 28, 28, 3, 3, 1, 1, 0, 0, 1, True, 1],
[1, 32, 1, 28, 28, 3, 3, 1, 1, 0, 0, 1, True, 1],
[1, 100, 100, 14, 14, 3, 3, 1, 1, 1, 1, 100, False, 1],
[1, 1008, 1008, 14, 14, 3, 3, 2, 2, 1, 1, 21, False, 1],
[1, 1008, 1008, 7, 7, 3, 3, 1, 1, 1, 1, 21, False, 1],
[1, 1024, 1024, 10, 10, 3, 3, 1, 1, 1, 1, 1024, False, 1],
Expand Down
48 changes: 36 additions & 12 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
#include "conv2d.hpp"
#include <sys/types.h>
#include <cstdint>
#include <optional>

#include "common/constants.hpp"
#include "impl/buffers/buffer_constants.hpp"
#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/operations/pool/downsample/device/downsample_op.hpp"
#include "tt_metal/detail/reports/memory_reporter.hpp"
#include "tt_metal/common/work_split.hpp"
Expand Down Expand Up @@ -325,6 +327,16 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config(
.out_subblock_w_ntiles = out_subblock_w_ntiles};
}

static bool use_matmul_for_1x1_conv(
const std::array<uint32_t, 2>& kernel_size,
const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& padding,
const std::array<uint32_t, 2>& dilation,
uint32_t groups) {
return kernel_size[0] == 1 && kernel_size[1] == 1 && stride[0] == stride[1] && stride[0] == 1 && padding[0] == 0 &&
padding[1] == 0 && dilation[0] == 1 && dilation[1] == 1 && groups == 1;
}

// Implements a heuristic for selecting shard layout based on how many tenix cores are available
// for each shard.
template <typename T>
Expand Down Expand Up @@ -360,18 +372,17 @@ static TensorMemoryLayout select_shard_spec(
const bool is_block_sharding_valid =
(kernel_size[0] == 3 && kernel_size[1] == 3 && (stride[0] == 1 || stride[0] == 2)) ||
(kernel_size[0] == 1 && kernel_size[1] == 1 && stride[0] == 2);
const bool use_matmul_for_1x1_conv = kernel_size[0] == 1 && kernel_size[1] == 1 && stride[0] == stride[1] &&
stride[0] == 1 && padding[0] == 0 && padding[1] == 0 && dilation[0] == 1 &&
dilation[1] == 1 && groups == 1;
const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups);

// 1d convs support only height sharding
const bool is_conv1d = weights_width == 1 && input_width == 1;

const uint32_t cc_height = get_core_count_for_sharding(TensorMemoryLayout::HEIGHT_SHARDED);
// matmul doesn't support width sharding
const uint32_t cc_width =
!use_matmul_for_1x1_conv && !is_conv1d ? get_core_count_for_sharding(TensorMemoryLayout::WIDTH_SHARDED) : 0;
!mm_conv && !is_conv1d ? get_core_count_for_sharding(TensorMemoryLayout::WIDTH_SHARDED) : 0;
const uint32_t cc_block =
is_block_sharding_valid && !is_conv1d ? get_core_count_for_sharding(TensorMemoryLayout::BLOCK_SHARDED) : 0;
(is_block_sharding_valid || mm_conv) && !is_conv1d ? get_core_count_for_sharding(TensorMemoryLayout::BLOCK_SHARDED) : 0;

uint32_t max_cc = cc_block;
TensorMemoryLayout shard_layout = TensorMemoryLayout::BLOCK_SHARDED;
Expand Down Expand Up @@ -617,7 +628,14 @@ std::tuple<ttnn::Tensor, ParallelConfig, bool> shard_or_reshard_tensor_if_requir
}
}

const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups);
if (input_tensor_on_device) {
if (mm_conv && input_tensor.layout() == Layout::ROW_MAJOR &&
parallel_config.shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED) {
// Workaround #13979 ttnn::tilize doesn't support BLOCK_SHARDED layout
input_tensor =
ttnn::to_layout(input_tensor, Layout::TILE, std::nullopt, std::nullopt, input_tensor.device());
}
auto resharded_input_tensor = ttnn::to_memory_config(
input_tensor, input_tensor_sharded_memory_config, std::nullopt);
if (conv_config.deallocate_activation) {
Expand All @@ -626,8 +644,16 @@ std::tuple<ttnn::Tensor, ParallelConfig, bool> shard_or_reshard_tensor_if_requir
}
input_tensor = resharded_input_tensor;
} else {
input_tensor = ttnn::operations::core::to_device(
input_tensor, device, input_tensor_sharded_memory_config);
if (mm_conv && input_tensor.layout() == Layout::ROW_MAJOR &&
parallel_config.shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED) {
// Workaround #13979 ttnn::tilize doesn't support BLOCK_SHARDED layout
input_tensor = ttnn::to_device(input_tensor, device, std::nullopt);
input_tensor =
ttnn::to_layout(input_tensor, Layout::TILE, std::nullopt, std::nullopt, input_tensor.device());
input_tensor = ttnn::to_memory_config(input_tensor, input_tensor_sharded_memory_config, std::nullopt);
} else {
input_tensor = ttnn::to_device(input_tensor, device, input_tensor_sharded_memory_config);
}
}
}
return {input_tensor, parallel_config, needs_shard_or_reshard};
Expand Down Expand Up @@ -887,11 +913,9 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
input_width);
}
// if 1x1 conv w/ stride 1, convert input tensor to tile layout if required
bool use_matmul_for_1x1_conv = kernel_size[0] == 1 && kernel_size[1] == 1 && stride[0] == stride[1] && stride[0] == 1 &&
padding[0] == 0 && padding[1] == 0 && dilation[0] == 1 && dilation[1] == 1 &&
groups == 1;
const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups);
Tensor input_tensor_post_tm_out;
if (use_matmul_for_1x1_conv) {
if (mm_conv) {
input_tensor_post_tm_out = ttnn::to_layout(
input_tensor_post_tm, Layout::TILE, conv_config.dtype, input_tensor_post_tm.memory_config(), device);
if (conv_config.deallocate_activation) {
Expand All @@ -911,7 +935,7 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
conv_config.fp32_dest_acc_enabled,
conv_config.packer_l1_accum_enabled);

if (!use_matmul_for_1x1_conv) {
if (!mm_conv) {
// call halo op
SlidingWindowConfig sliding_window_config = SlidingWindowConfig{
.batch_size = batch_size,
Expand Down

0 comments on commit 3cffcc5

Please sign in to comment.