Skip to content

Commit

Permalink
#0: testing
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Dec 13, 2024
1 parent 137fb5a commit da022c9
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 139 deletions.
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ Result conv2d(
groups,
opt_conv_op_block_config.act_block_h_ntiles,
input_width,
true,
is_non_tile_mul_width);
}
// if 1x1 conv w/ stride 1, convert input tensor to tile layout if required
Expand Down
218 changes: 93 additions & 125 deletions ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,100 +59,6 @@ bool check_non_tile_mul_width(
return is_non_tile_mul_width;
}

template <typename T>
ttnn::Tensor conv_weight_layout_convert(
const ttnn::Tensor& weight_tensor,
uint32_t input_channels_alignment,
DataType weights_bias_dtype,
uint32_t weight_block_h_ntiles,
uint32_t weight_block_w_ntiles,
const ParallelConfig& parallel_config,
T * device,
uint32_t groups,
uint32_t act_block_h_ntiles,
uint32_t input_width,
bool is_non_tile_mul_width) {
ttnn::Tensor weight_tensor_; // tensor to return
auto original_weights_shape = weight_tensor.get_shape();
uint32_t original_weights_out_channels = original_weights_shape[0];
uint32_t original_weights_in_channels = original_weights_shape[1];
uint32_t original_weights_window_h = original_weights_shape[2];
uint32_t original_weights_window_w = original_weights_shape[3];

bool is_conv1d = original_weights_window_w == 1 && input_width == 1;
bool is_depthwise_conv = groups == original_weights_out_channels && original_weights_in_channels == 1;

weight_tensor_ = weight_tensor;

// Convert weight tensor to 0 padded shape if groups > 1
if (!is_conv1d and groups > 1) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype);
}
else if (is_conv1d and groups > 1) {
if (is_depthwise_conv) {
weight_tensor_ = convert_conv_weight_tensor_to_depthwise_layout(weight_tensor_, act_block_h_ntiles, weights_bias_dtype);
weight_block_h_ntiles = act_block_h_ntiles;
}
else{
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype);
}
}

auto weights_shape = weight_tensor_.get_shape();
uint32_t out_channels = weights_shape[0];
uint32_t in_channels = weights_shape[1];
uint32_t window_h = weights_shape[2];
uint32_t window_w = weights_shape[3];

uint32_t num_cores_channels = get_num_cores_channels_from_parallel_config(parallel_config);
uint32_t out_channels_padded = tt::round_up(out_channels, num_cores_channels * tt::constants::TILE_WIDTH);
uint32_t in_channels_padded = tt::round_up(in_channels, num_cores_channels * input_channels_alignment);
uint32_t out_channel_padding = out_channels_padded - out_channels;

tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array<uint32_t, 4>(
{out_channels_padded, in_channels_padded, window_h, window_w}));
if(is_non_tile_mul_width) {
weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array<uint32_t, 4>(
{round_up(out_channels, 32), round_up(in_channels, input_channels_alignment), window_h, window_w}));
out_channels_padded = tt::round_up(out_channels, 32);
}
if (weights_bias_dtype == DataType::BFLOAT8_B) {
TT_ASSERT(weight_tensor_.get_dtype() == DataType::FLOAT32);
} else {
// TODO: fix the need to check this. We should be able to accept any datatype and convert
TT_ASSERT(weight_tensor_.get_dtype() == weights_bias_dtype);
}
weight_tensor_ = ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0);

// for conv op, pad the weights to block shape
if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout(
weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype);
} else if(parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout_block_sharded(
weight_tensor_, num_cores_channels, weights_bias_dtype);
} else {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout(
weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype);
}

uint32_t weight_matrix_height = in_channels * window_h * window_w;
int32_t weight_matrix_height_padding = weight_tensor_.shape()[2] - weight_matrix_height;
TT_FATAL(weight_matrix_height_padding >= 0," Matrix Height Padding can't be negative");

auto target_shape = ttnn::Shape(std::array<uint32_t,4>{1, 1, weight_matrix_height, out_channels},
std::array<std::array<uint32_t, 2>, 4>{
std::array<uint32_t, 2>{0, 0},
std::array<uint32_t, 2>{0, 0},
std::array<uint32_t, 2>{0, weight_matrix_height_padding},
std::array<uint32_t, 2>{0, out_channel_padding}
});
weight_tensor_ = ttnn::reshape(weight_tensor_, target_shape);
return weight_tensor_;
}



template <typename T>
ttnn::Tensor conv_bias_layout_convert(
const ttnn::Tensor& bias_tensor,
Expand Down Expand Up @@ -289,15 +195,98 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
uint32_t groups,
uint32_t act_block_h_ntiles,
uint32_t input_width,
const bool parameters_on_device,
bool is_non_tile_mul_width) {

validate_weight_tensor(weight_tensor);
ttnn::Tensor weight_tensor_, bias_tensor_;
ttnn::Tensor weight_tensor_; // tensor to return
ttnn::Tensor bias_tensor_;

auto original_weights_shape = weight_tensor.get_shape();
uint32_t original_weights_out_channels = original_weights_shape[0];
uint32_t original_weights_in_channels = original_weights_shape[1];
uint32_t original_weights_window_h = original_weights_shape[2];
uint32_t original_weights_window_w = original_weights_shape[3];

bool is_conv1d = original_weights_window_w == 1 && input_width == 1;
bool is_depthwise_conv = groups == original_weights_out_channels && original_weights_in_channels == 1;

weight_tensor_ = weight_tensor;

auto weights_shape = weight_tensor.get_shape();
// Convert weight tensor to 0 padded shape if groups > 1
if (!is_conv1d and groups > 1) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype);
}
else if (is_conv1d and groups > 1) {
if (is_depthwise_conv) {
weight_tensor_ = convert_conv_weight_tensor_to_depthwise_layout(weight_tensor_, act_block_h_ntiles, weights_bias_dtype);
weight_block_h_ntiles = act_block_h_ntiles;
}
else{
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype);
}
}

auto weights_shape = weight_tensor_.get_shape();
uint32_t out_channels = weights_shape[0];
weight_tensor_ = conv_weight_layout_convert(weight_tensor, input_channels_alignment, weights_bias_dtype, weight_block_h_ntiles, weight_block_w_ntiles, parallel_config, device, groups, act_block_h_ntiles, input_width, is_non_tile_mul_width);
uint32_t in_channels = weights_shape[1];
uint32_t window_h = weights_shape[2];
uint32_t window_w = weights_shape[3];

uint32_t num_cores_channels = get_num_cores_channels_from_parallel_config(parallel_config);
uint32_t out_channels_padded = tt::round_up(out_channels, num_cores_channels * tt::constants::TILE_WIDTH);
uint32_t in_channels_padded = tt::round_up(in_channels, num_cores_channels * input_channels_alignment);
uint32_t out_channel_padding = out_channels_padded - out_channels;

tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array<uint32_t, 4>(
{out_channels_padded, in_channels_padded, window_h, window_w}));
if(is_non_tile_mul_width) {
weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array<uint32_t, 4>(
{round_up(out_channels, 32), round_up(in_channels, input_channels_alignment), window_h, window_w}));
out_channels_padded = tt::round_up(out_channels, 32);
}
if (weights_bias_dtype == DataType::BFLOAT8_B) {
TT_ASSERT(weight_tensor_.get_dtype() == DataType::FLOAT32);
if (bias_tensor.has_value()) {
TT_ASSERT(bias_tensor.value().get_dtype() == DataType::FLOAT32);
}
} else {
// TODO: fix the need to check this. We should be able to accept any datatype and convert
TT_ASSERT(weight_tensor_.get_dtype() == weights_bias_dtype);
if (bias_tensor.has_value()) {
TT_ASSERT(bias_tensor.value().get_dtype() == weights_bias_dtype);
}
}
weight_tensor_ = ttnn::pad(weight_tensor_, weights_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0);

// for conv op, pad the weights to block shape
if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout(
weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype);
} else if(parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout_block_sharded(
weight_tensor_, num_cores_channels, weights_bias_dtype);
} else {
weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout(
weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype);
}

uint32_t weight_matrix_height = in_channels * window_h * window_w;
int32_t weight_matrix_height_padding = weight_tensor_.shape()[2] - weight_matrix_height;
TT_FATAL(weight_matrix_height_padding >= 0," Matrix Height Padding can't be negative");

auto target_shape = ttnn::Shape(std::array<uint32_t,4>{1, 1, weight_matrix_height, out_channels},
std::array<std::array<uint32_t, 2>, 4>{
std::array<uint32_t, 2>{0, 0},
std::array<uint32_t, 2>{0, 0},
std::array<uint32_t, 2>{0, weight_matrix_height_padding},
std::array<uint32_t, 2>{0, out_channel_padding}
});
weight_tensor_ = ttnn::reshape(weight_tensor_, target_shape);

if(parameters_on_device)
weight_tensor_ = ttnn::operations::core::to_device(weight_tensor_, device, std::nullopt);

weight_tensor_ = ttnn::operations::core::to_device(weight_tensor_, device, std::nullopt);
if (bias_tensor.has_value()) {
bias_tensor_ = bias_tensor.value();
bool is_bias_tensor_is_on_device = ttnn::is_tensor_on_device_or_multidevice(bias_tensor_);
Expand Down Expand Up @@ -381,8 +370,10 @@ ttnn::Tensor prepare_conv_weights(
bool is_non_tile_mul_width = check_non_tile_mul_width(device, conv_config, in_channels);
std::optional<const ttnn::Tensor> bias_tensor = std::nullopt;
ttnn::Tensor weight_tensor_on_device = weight_tensor;
weight_tensor_on_device = conv_weight_layout_convert(
std::optional<ttnn::Tensor> bias_tensor_on_device = bias_tensor;
tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights_biases_and_move_to_device(
weight_tensor,
bias_tensor,
conv_config.input_channels_alignment,
conv_config.weights_dtype,
opt_conv_op_block_config.act_block_w_ntiles,
Expand All @@ -392,6 +383,7 @@ ttnn::Tensor prepare_conv_weights(
groups,
opt_conv_op_block_config.act_block_h_ntiles,
input_width,
false,
is_non_tile_mul_width);

return weight_tensor_on_device;
Expand Down Expand Up @@ -565,6 +557,7 @@ template std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weigh
uint32_t groups,
uint32_t act_block_h_ntiles,
uint32_t input_width,
const bool parameters_on_device,
bool is_non_tile_mul_width);

template std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases_and_move_to_device<MeshDevice>(
Expand All @@ -579,6 +572,7 @@ template std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weigh
uint32_t groups,
uint32_t act_block_h_ntiles,
uint32_t input_width,
const bool parameters_on_device,
bool is_non_tile_mul_width);

template ttnn::Tensor prepare_conv_bias<Device>(
Expand Down Expand Up @@ -617,32 +611,6 @@ template ttnn::Tensor prepare_conv_bias<MeshDevice>(
const std::optional<const Conv2dConfig>& conv_config_,
const std::optional<const DeviceComputeKernelConfig>& compute_config_);

template ttnn::Tensor conv_weight_layout_convert<MeshDevice>(
const ttnn::Tensor& weight_tensor,
uint32_t input_channels_alignment,
DataType weights_bias_dtype,
uint32_t weight_block_h_ntiles,
uint32_t weight_block_w_ntiles,
const ParallelConfig& parallel_config,
MeshDevice* device,
uint32_t groups,
uint32_t act_block_h_ntiles,
uint32_t input_width,
bool is_non_tile_mul_width);

template ttnn::Tensor conv_weight_layout_convert<Device>(
const ttnn::Tensor& weight_tensor,
uint32_t input_channels_alignment,
DataType weights_bias_dtype,
uint32_t weight_block_h_ntiles,
uint32_t weight_block_w_ntiles,
const ParallelConfig& parallel_config,
Device* device,
uint32_t groups,
uint32_t act_block_h_ntiles,
uint32_t input_width,
bool is_non_tile_mul_width);

template ttnn::Tensor conv_bias_layout_convert(
const ttnn::Tensor& bias_tensor,
DataType bias_dtype,
Expand Down
15 changes: 1 addition & 14 deletions ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,6 @@ namespace ttnn {
namespace operations::conv {
namespace conv2d {

template <typename T>
ttnn::Tensor conv_weight_layout_convert(
const ttnn::Tensor& weight_tensor,
uint32_t input_channels_alignment,
DataType weights_bias_dtype,
uint32_t weight_block_h_ntiles,
uint32_t weight_block_w_ntiles,
const sliding_window::ParallelConfig& parallel_config,
T * device,
uint32_t groups,
uint32_t act_block_h_ntiles,
uint32_t input_width,
bool is_non_tile_mul_width);

template <typename T>
ttnn::Tensor conv_bias_layout_convert(
const ttnn::Tensor& bias_tensor,
Expand Down Expand Up @@ -96,6 +82,7 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
uint32_t groups,
uint32_t act_block_h_ntiles,
uint32_t input_width,
const bool parameters_on_device=true,
bool is_non_tile_mul_width=false);

template <typename T>
Expand Down

0 comments on commit da022c9

Please sign in to comment.