Skip to content

Commit

Permalink
Clean up and add support for Column major shard orientation.
Browse files Browse the repository at this point in the history
Signed-off-by: Nilaykumar Patel <[email protected]>
  • Loading branch information
nkpatel-tt committed Dec 13, 2024
1 parent cbcf042 commit 67eba82
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 52 deletions.
10 changes: 6 additions & 4 deletions tests/ttnn/unit_tests/operations/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
import torch.nn as nn
import ttnn
from models.utility_functions import skip_for_grayskull, skip_for_blackhole
from models.utility_functions import skip_for_grayskull, skip_for_blackhole, is_grayskull
from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout


Expand Down Expand Up @@ -119,11 +119,14 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w):
[1, 64, 132, 19],
],
)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize("scale_h", [2, 3])
@pytest.mark.parametrize("scale_w", [2, 3])
@pytest.mark.parametrize("shard_strategy", [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.BLOCK])
def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy):
@pytest.mark.parametrize("shard_orientation", [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR])
def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy, shard_orientation):
if is_grayskull() and (scale_h > 2 or scale_w > 2):
pytest.skip("Skipping test because it won't fit in L1!")

## input shape is N C H W
batch_size, num_channels, height, width = input_shape
torch.manual_seed(0)
Expand Down Expand Up @@ -191,7 +194,6 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate
# )

shard_grid = get_shard_grid_from_num_cores(device, ncores)
shard_orientation = ttnn.ShardOrientation.ROW_MAJOR

if shard_strategy == ttnn.ShardStrategy.BLOCK:
tensor_memory_layout = ttnn.types.TensorMemoryLayout.BLOCK_SHARDED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <math.h>
#include <cstdint>
#include <vector>

#include "buffers/buffer.hpp"
#include "buffers/buffer_constants.hpp"
#include "common/assert.hpp"
#include "common/core_coord.hpp"
Expand All @@ -21,63 +20,98 @@ using namespace tt::tt_metal;
namespace ttnn::operations::upsample {
using namespace tt;

static Tensor create_config_tensor_block_sharded(
Device *device,
uint32_t input_nsticks_per_core,
static Tensor create_config_tensor(
Device* device,
ShardSpec shard_spec,
const uint32_t batch_size,
const uint32_t in_h,
const uint32_t in_w,
const uint32_t scale_factor_h,
const uint32_t scale_factor_w,
uint32_t ncores_x,
bool is_height_sharded) {
std::vector<uint16_t> config_vector;
const uint32_t ncores_x,
const bool is_height_sharded,
const bool is_col_major) {
uint16_t in_core = 0, curr_stick = 0;
uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core;
const uint32_t input_nsticks_per_core = shard_spec.shape[0];

std::vector<std::vector<int>> core_range;
auto ranges = shard_spec.grid.ranges();
// in case of height sharding and shards arranged in column major order, get cores where shard are placed.
if (is_col_major && is_height_sharded) {
for (auto i = 0; i < ranges.size(); i++) {
auto range = ranges[i];
for (auto x = range.start_coord.x; x <= range.end_coord.x; x++) {
for (auto y = range.start_coord.y; y <= range.end_coord.y; y++) {
core_range.push_back({x, y});
}
}
}
}

std::vector<uint16_t> logical_core_to_stick_map;
size_t logical_core_to_stick_map_entry_size = 3;
size_t row_size = logical_core_to_stick_map_entry_size * in_w;
// Create map of core and respective offsets in input
for (uint32_t b = 0; b < batch_size; ++b) {
for (uint32_t h = 0; h < in_h; ++h) {
for (uint32_t w = 0; w < in_w; ++w, ++curr_stick) {
if (curr_stick == input_nsticks_per_core) curr_stick = 0, ++in_core;
config_vector.push_back(in_core);
config_vector.push_back(curr_stick);
if (curr_stick == input_nsticks_per_core) {
curr_stick = 0, ++in_core;
}
if (is_height_sharded && is_col_major) {
logical_core_to_stick_map.push_back(core_range[in_core][0]);
logical_core_to_stick_map.push_back(core_range[in_core][1]);
} else {
logical_core_to_stick_map.push_back(in_core);
logical_core_to_stick_map.push_back(0);
}
logical_core_to_stick_map.push_back(curr_stick);
}
for (uint32_t j = 1; j < scale_factor_h; ++j) {
logical_core_to_stick_map.insert(
logical_core_to_stick_map.end(),
logical_core_to_stick_map.end() - row_size,
logical_core_to_stick_map.end());
}
size_t row_size = 2 * in_w, initial_size = config_vector.size();
for (uint32_t j = 1; j < scale_factor_h; ++j)
config_vector.insert(config_vector.end(), config_vector.end() - row_size, config_vector.end());
}
}

std::vector<uint16_t> temp_config_vector;
std::vector<uint16_t> config_vector;

// Based on core calculate physical dimentions of cores
// Based on core calculate physical location of cores
CoreCoord core_coords;
if (is_height_sharded) {
for (size_t j = 0; j < config_vector.size(); j += 2) {
core_coords = device->worker_core_from_logical_core(CoreCoord(config_vector[j] % ncores_x, config_vector[j] / ncores_x));
temp_config_vector.push_back(core_coords.x);
temp_config_vector.push_back(core_coords.y);
temp_config_vector.push_back(config_vector[j + 1]);
temp_config_vector.push_back(0);
for (size_t j = 0; j < logical_core_to_stick_map.size(); j += logical_core_to_stick_map_entry_size) {
CoreCoord core_coords;
if (is_col_major) {
core_coords = device->worker_core_from_logical_core(
CoreCoord(logical_core_to_stick_map[j], logical_core_to_stick_map[j + 1]));
} else {
core_coords = device->worker_core_from_logical_core(
CoreCoord(logical_core_to_stick_map[j] % ncores_x, logical_core_to_stick_map[j] / ncores_x));
}
config_vector.push_back(core_coords.x);
config_vector.push_back(core_coords.y);
config_vector.push_back(logical_core_to_stick_map[j + 2]);
config_vector.push_back(0);
}
} else {
for (uint32_t i = 0; i < ncores_x; i++) {
for (size_t j = 0; j < config_vector.size(); j += 2) {
core_coords = device->worker_core_from_logical_core(CoreCoord(i, config_vector[j]));
temp_config_vector.push_back(core_coords.x);
temp_config_vector.push_back(core_coords.y);
temp_config_vector.push_back(config_vector[j + 1]);
temp_config_vector.push_back(0);
for (size_t i = 0; i < ncores_x; i++) {
for (size_t j = 0; j < logical_core_to_stick_map.size(); j += logical_core_to_stick_map_entry_size) {
core_coords = device->worker_core_from_logical_core(CoreCoord(i, logical_core_to_stick_map[j]));
config_vector.push_back(core_coords.x);
config_vector.push_back(core_coords.y);
config_vector.push_back(logical_core_to_stick_map[j + 2]);
config_vector.push_back(0);
}
}
}
Shape config_shape({temp_config_vector.size() / elems_per_core, elems_per_core});
auto config_buffer = owned_buffer::create<uint16_t>(std::move(temp_config_vector));
uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core;
Shape config_shape({config_vector.size() / elems_per_core, elems_per_core});
auto config_buffer = owned_buffer::create<uint16_t>(std::move(config_vector));
return Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR);
}


operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w) {
Program program = CreateProgram();
Device *device = input.device();
Expand Down Expand Up @@ -126,7 +160,6 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor&

// TODO: Support non-multiple case
TT_FATAL(in_nsticks_per_core == input_nsticks_per_core, "Input sticks per shard {} should be same as input sticks per core {}", in_nsticks_per_core, input_nsticks_per_core);
TT_FATAL(shard_spec.orientation == ShardOrientation::ROW_MAJOR, "Input tensor is expected to have ROW_MAJOR shard orientation, got {}", shard_spec.orientation);

// CBs

Expand Down Expand Up @@ -164,29 +197,33 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor&

// create config tensor
Tensor config_tensor;
if((input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) || (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) {
config_tensor = create_config_tensor_block_sharded(
device,
shard_spec.shape[0],
input.legacy_shape()[0],
input.legacy_shape()[1],
in_w,
scale_factor_h,
scale_factor_w,
ncores_x,
input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED);
if ((input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) ||
(input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) {
config_tensor = create_config_tensor(
device,
shard_spec,
input.legacy_shape()[0],
input.legacy_shape()[1],
in_w,
scale_factor_h,
scale_factor_w,
ncores_x,
input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED,
shard_spec.orientation == ShardOrientation::COL_MAJOR);
} else {
TT_THROW("Unsupported sharding layout");
}
auto shard_shape = std::array<uint32_t, 2>({1, (uint32_t)config_tensor.get_shape()[-1]});
auto config_tensor_shard_orientation = input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED ? (shard_spec.orientation == ShardOrientation::COL_MAJOR ? ShardOrientation::ROW_MAJOR : ShardOrientation::COL_MAJOR) : ShardOrientation::ROW_MAJOR;
auto config_tensor_shard_orientation = input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED
? ShardOrientation::COL_MAJOR
: shard_spec.orientation;
ShardSpec config_shard_spec(input.shard_spec().value().grid, shard_shape, config_tensor_shard_orientation, false);
MemoryConfig memory_config{input.memory_config().memory_layout, BufferType::L1_SMALL, config_shard_spec};
MemoryConfig memory_config{input.memory_config().memory_layout, BufferType::L1, config_shard_spec};
auto config_tensor_device = config_tensor.to(device, memory_config);
tt::tt_metal::detail::AddConfigBuffer(program, config_tensor_device.device_buffer());

tt::DataFormat config_df = tt::DataFormat::RawUInt16;
Buffer *config_buffer = config_tensor_device.buffer();
Buffer* config_buffer = config_tensor_device.buffer();
auto config_buffer_page_size = config_buffer->page_size();
uint32_t config_cb_id = tt::CB::c_in2;
auto config_cb_config = CircularBufferConfig(config_buffer_page_size, {{config_cb_id, config_df}})
Expand Down

0 comments on commit 67eba82

Please sign in to comment.