Skip to content

Commit

Permalink
#0: Use small vector instead of std::vector for shapes to optimize al…
Browse files Browse the repository at this point in the history
…locations (#14281)

* #0: Use small vector instead of std::vector for shapes

* #0: Try to fix python

* #0: Fixup

* #0: Remove VectorBase, cleanup

* #0: Rebase fix

* #0: Rebase fixes

* #0: Fix UB - access out of bounds in SqueezeOperation

* #0: Review fixes

* #0: Use tt::stl::Span

* #0: CMake cleanup

* #0: Replace std::span with tt::stl::Span in test_tiles.hpp
  • Loading branch information
sminakov-tt authored Oct 28, 2024
1 parent 01174b0 commit d0e0787
Show file tree
Hide file tree
Showing 113 changed files with 561 additions and 466 deletions.
14 changes: 7 additions & 7 deletions best_practices.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ void write_buffer(queue_id cq_id, Tensor& dst, std::vector<std::shared_ptr<void>
void write_buffer(queue_id cq_id, Tensor& dst, const std::vector<std::shared_ptr<void>>& src, const std::optional<std::size_t>& transfer_size = std::nullopt); // Right!
```

## 2. Use `std::span` for Input Parameters
## 2. Use `tt::stl::Span` for Input Parameters

### Practice
Consider using `std::span` as input instead of `std::vector`. This allows `std::array` to be used as an argument as well.
Consider using `tt::stl::Span` as input instead of `std::vector`. This allows `std::array` to be used as an argument as well.

### Explanation
`std::span` is a lightweight view over a contiguous sequence of objects, such as arrays and vectors. It provides a safe and flexible way to handle array-like data structures without copying them.
`tt::stl::Spann` is a lightweight view over a contiguous sequence of objects, such as arrays and vectors. It provides a safe and flexible way to handle array-like data structures without copying them.

### Motivation
- **Flexibility**: Enables functions to accept both `std::vector` and `std::array`.
Expand All @@ -33,7 +33,7 @@ Consider using `std::span` as input instead of `std::vector`. This allows `std::
### Example
```
template <typename T>
void print_elements(std::span<T> data) {
void print_elements(tt::stl::Span<const T> data) {
for (const auto& element : data) {
std::cout << element << " ";
}
Expand Down Expand Up @@ -217,7 +217,7 @@ Use the Copy-and-Swap idiom to avoid duplicating code between different construc
### Explanation
The Copy-and-Swap idiom is a robust and elegant method to implement copy assignment operators. It leverages the copy constructor and the swap method to provide strong exception safety and reduce code duplication.

### Example
### Example
https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom


Expand Down Expand Up @@ -279,7 +279,7 @@ Prefer:
enum class ThreadingOption { SingleCore, MultiCore };
tensor = tt::tt_metal::tilize_with_val_padding(tensor, output_shape, 0, output_memory_config, dtype, ThreadingOption::MultiCore);
```
Also consider giving enums power-of-2 values to pass them all as a single argument, e.g.
Also consider giving enums power-of-2 values to pass them all as a single argument, e.g.
```cpp
Options::FOO | Options::BAR
```
Expand Down Expand Up @@ -343,7 +343,7 @@ void doSomething(...) {
Prefer:
```cpp
void doSomething(...) {
if (!contractCheck)
if (!contractCheck)
return;
// Do a lot of things
Expand Down
1 change: 1 addition & 0 deletions cmake/dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ include(${PROJECT_SOURCE_DIR}/cmake/fetch_boost.cmake)

fetch_boost_library(core)
fetch_boost_library(smart_ptr)
fetch_boost_library(container)

add_library(span INTERFACE)
target_link_libraries(span INTERFACE Boost::core)
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/gtests/test_async_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) {
Tensor np_tensor = ttnn::numpy::full<float>(input_shape.value, static_cast<float>(1), DataType::BFLOAT16)
.to(Layout::TILE)
.to(device);
std::vector<int64_t> reduce_dims = {3};
ttnn::SmallVector<int64_t> reduce_dims = {3};
Tensor np_out = ttnn::moreh_sum(np_tensor, reduce_dims, false, std::nullopt, std::nullopt, std::nullopt);
Tensor np_out_host = np_out.cpu();
const bfloat16* golden_output = std::get<owned_buffer::Buffer<bfloat16>>(std::get<OwnedStorage>(np_out_host.get_storage()).buffer).begin();
Expand Down
19 changes: 10 additions & 9 deletions tt_metal/common/test_tiles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cstdint>
#include <vector>
#include <optional>
#include "tt_metal/tt_stl/span.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/common/assert.hpp"
#include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp"
Expand All @@ -25,8 +26,8 @@ enum TensorLayout {
template <class T, template <typename...> typename BufferType>
std::vector<T> convert_to_tile_layout(
const BufferType<T>& data,
const std::optional<std::vector<uint32_t>>& tile_shape = std::nullopt,
const std::optional<const std::vector<uint32_t>>& face_shape = std::nullopt) {
std::optional<tt::stl::Span<const uint32_t>> tile_shape = std::nullopt,
std::optional<tt::stl::Span<const uint32_t>> face_shape = std::nullopt) {
ZoneScoped;
std::vector<T> result;
result.reserve(data.size());
Expand Down Expand Up @@ -79,8 +80,8 @@ std::vector<T> convert_to_tile_layout(
template <class T, template <typename...> typename BufferTyp>
std::vector<T> convert_to_flat_layout(
const BufferTyp<T>& data,
const std::optional<std::vector<uint32_t>>& tile_shape = std::nullopt,
const std::optional<const std::vector<uint32_t>>& face_shape = std::nullopt) {
std::optional<tt::stl::Span<const uint32_t>> tile_shape = std::nullopt,
std::optional<tt::stl::Span<const uint32_t>> face_shape = std::nullopt) {
ZoneScoped;
std::vector<T> result;
result.reserve(data.size());
Expand Down Expand Up @@ -115,7 +116,7 @@ std::vector<T> convert_to_flat_layout(

// Converts a 32-swizzled tilized row-major tensor to a linear 32-zero-padded row-major tensor
template <typename T, template <typename...> typename BufferType>
inline std::vector<T> untilize_nchw(const BufferType<T>& in, const std::vector<std::uint32_t>& shape, const std::optional<std::vector<uint32_t>>& tile_shape = std::nullopt) {
inline std::vector<T> untilize_nchw(const BufferType<T>& in, tt::stl::Span<const uint32_t> shape, std::optional<tt::stl::Span<const uint32_t>> tile_shape = std::nullopt) {
ZoneScoped;
auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT;
auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH;
Expand Down Expand Up @@ -159,7 +160,7 @@ inline std::uint32_t round_up_to_tile(int val, int tile_val) { return (val + til

// Converts a linear non-zero-padded row-major tensor to zero-padded-32 32-swizzled tilized row-major tensor
template <typename T, template <typename...> typename BufferType>
inline std::vector<T> tilize_nchw(const BufferType<T>& in_rowmajor, const std::vector<std::uint32_t>& shape, const std::optional<std::vector<uint32_t>>& tile_shape = std::nullopt) {
inline std::vector<T> tilize_nchw(const BufferType<T>& in_rowmajor, tt::stl::Span<const uint32_t> shape, std::optional<tt::stl::Span<const uint32_t>> tile_shape = std::nullopt) {
ZoneScoped;
int H = shape[shape.size() - 2], W = shape[shape.size() - 1];
auto batch_size = 1;
Expand Down Expand Up @@ -221,11 +222,11 @@ struct TensAddr {
template <typename T, template <typename...> typename BufferType>
inline std::vector<T> convert_layout(
const BufferType<T>& inp,
const std::vector<uint32_t>& shape,
tt::stl::Span<const uint32_t> shape,
TensorLayout inL,
TensorLayout outL,
const std::optional<std::vector<uint32_t>>& tile_shape = std::nullopt,
const std::optional<const std::vector<uint32_t>>& face_shape = std::nullopt) {
std::optional<tt::stl::Span<const uint32_t>> tile_shape = std::nullopt,
std::optional<const tt::stl::Span<const uint32_t>> face_shape = std::nullopt) {
ZoneScoped;
switch (inL) {
case TILED_SWIZZLED:
Expand Down
7 changes: 7 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ set(TTNN_PUBLIC_LINK_LIBRARIES
metal_header_directories
metal_common_libs
tt_metal
Boost::container
)
set(TTNN_PUBLIC_LINK_DIRS "")

Expand Down Expand Up @@ -626,6 +627,12 @@ target_compile_options(
-fno-var-tracking
)

if(WITH_PYTHON_BINDINGS)
target_compile_definitions(ttnn PUBLIC TTNN_WITH_PYTHON_BINDINGS=1)
else()
target_compile_definitions(ttnn PUBLIC TTNN_WITH_PYTHON_BINDINGS=0)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
target_compile_definitions(ttnn PUBLIC DISABLE_NAMESPACE_STATIC_ASSERT)
endif()
Expand Down
12 changes: 6 additions & 6 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void log_external_operation(
#endif

template <typename T>
Tensor create_owned_tensor(T* data_ptr, size_t num_elements, std::vector<uint32_t>& shape, DataType data_type, Layout layout, const std::optional<Tile>& optional_tile = std::nullopt)
Tensor create_owned_tensor(T* data_ptr, size_t num_elements, tt::stl::Span<const uint32_t> shape, DataType data_type, Layout layout, const std::optional<Tile>& optional_tile = std::nullopt)
{
auto data = std::vector(data_ptr, data_ptr + num_elements);
auto buffer = owned_buffer::create(std::move(data));
Expand All @@ -80,7 +80,7 @@ Tensor convert_torch_tensor_to_tt_tensor(
}

auto torch_dtype = torch_tensor.attr("dtype");
auto shape = py::cast<std::vector<uint32_t>>(torch_tensor.attr("shape"));
auto shape = py::cast<ttnn::SmallVector<uint32_t>>(torch_tensor.attr("shape"));

auto contiguous_torch_tensor = torch_tensor.attr("contiguous")();

Expand Down Expand Up @@ -251,7 +251,7 @@ Tensor convert_numpy_tensor_to_tt_tensor(
}

auto np_dtype = np_tensor.attr("dtype");
auto shape = py::cast<std::vector<uint32_t>>(np_tensor.attr("shape"));
auto shape = py::cast<ttnn::SmallVector<uint32_t>>(np_tensor.attr("shape"));

auto contiguous_np_tensor = np.attr("ascontiguousarray")(np_tensor);

Expand Down Expand Up @@ -1325,7 +1325,7 @@ void pytensor_module(py::module &m_tensor) {
)doc")
.def(
"unpad_from_tile",
[](const Tensor &self, const std::vector<uint32_t> &output_tensor_shape) {
[](const Tensor &self, const ttnn::SmallVector<uint32_t> &output_tensor_shape) {
return self.unpad_from_tile(ttnn::SimpleShape(output_tensor_shape));
},
R"doc(
Expand Down Expand Up @@ -1593,7 +1593,7 @@ void pytensor_module(py::module &m_tensor) {
)doc")
.def(
"reshape",
[](Tensor &self, int N, int C, int H, int W) { return self.reshape(infer_dims_for_reshape(self, {N, C, H, W})); },
[](Tensor &self, int N, int C, int H, int W) { return self.reshape(infer_dims_for_reshape(self, ttnn::SmallVector<int>{N, C, H, W})); },
R"doc(
Reshapes TT tensor
Expand All @@ -1613,7 +1613,7 @@ void pytensor_module(py::module &m_tensor) {
)doc")
.def(
"reshape",
[](Tensor &self, const std::vector<int32_t> &shape) -> Tensor { return self.reshape(infer_dims_for_reshape(self, shape)); },
[](Tensor &self, const ttnn::SmallVector<int32_t> &shape) -> Tensor { return self.reshape(infer_dims_for_reshape(self, shape)); },
R"doc(
Reshapes TT tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(
const auto [origin_h, origin_w] = origin_hw_vec.at(i);

// reader
const std::vector<uint32_t> reader_runtime_args{
const std::array reader_runtime_args{
input_addr,
static_cast<uint32_t>(is_dram(input)),
num_tiles,
Expand All @@ -154,12 +154,12 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(
SetRuntimeArgs(program, reader_kernels_id, core, reader_runtime_args);

// writer
const std::vector<uint32_t> writer_runtime_args{
const std::array writer_runtime_args{
output_addr, static_cast<uint32_t>(is_dram(tmp_pow_sum)), tile_offset};
SetRuntimeArgs(program, writer_kernels_id, core, writer_runtime_args);

// compute
const std::vector<uint32_t> compute_runtime_args{
const std::array compute_runtime_args{
num_tiles,
p,
static_cast<uint32_t>(p_is_negative),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl(
const auto output_addr = total_norm.buffer()->address();

// reader
const std::vector<uint32_t> reader_runtime_args{
const std::array reader_runtime_args{
input_addr, static_cast<uint32_t>(is_dram(tmp_pow_sum)), num_tiles, *reinterpret_cast<uint32_t*>(&decimal)};
SetRuntimeArgs(program, reader_kernels_id, single_core, reader_runtime_args);

// writer
const std::vector<uint32_t> writer_runtime_args{output_addr, static_cast<uint32_t>(is_dram(total_norm))};
const std::array writer_runtime_args{output_addr, static_cast<uint32_t>(is_dram(total_norm))};
SetRuntimeArgs(program, writer_kernels_id, single_core, writer_runtime_args);

// compute
const std::vector<uint32_t> compute_runtime_args{num_tiles, p, static_cast<uint32_t>(p_is_negative)};
const std::array compute_runtime_args{num_tiles, p, static_cast<uint32_t>(p_is_negative)};
SetRuntimeArgs(program, compute_kernels_id, single_core, compute_runtime_args);

////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl(
const auto num_tiles = input.volume() / tt::constants::TILE_HW;

// reader
const std::vector<uint32_t> reader_runtime_args{
const std::array reader_runtime_args{
input_addr,
static_cast<uint32_t>(is_dram(input)),
clip_coef_clamped_addr,
Expand All @@ -116,11 +116,11 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl(
SetRuntimeArgs(program, reader_kernels_id, core, reader_runtime_args);

// writer
const std::vector<uint32_t> writer_runtime_args{input_addr, static_cast<uint32_t>(is_dram(input)), num_tiles};
const std::array writer_runtime_args{input_addr, static_cast<uint32_t>(is_dram(input)), num_tiles};
SetRuntimeArgs(program, writer_kernels_id, core, writer_runtime_args);

// compute
const std::vector<uint32_t> compute_runtime_args{num_tiles};
const std::array compute_runtime_args{num_tiles};
SetRuntimeArgs(program, compute_kernels_id, core, compute_runtime_args);
}

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/graph/graph_trace_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ttnn::Shape parse_shape(std::string_view shape_string) {
std::string_view shape_values = shape_string.substr(start, end - start);

// Vector to hold the parsed shape values
std::vector<uint32_t> shape;
SmallVector<uint32_t> shape;
const char* str = shape_values.data();
const char* end_str = str + shape_values.size();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(
block_config.act_block_h_ntiles % block_config.out_subblock_h_ntiles == 0,
"Out_block_h must be divisible by out_subblock_h!");
}
ttnn::Shape ashape_with_channels_padded(std::vector<uint32_t>({ashape[0], ashape[1], ashape[2], input_channels_padded}));
ttnn::Shape ashape_with_channels_padded(ttnn::SmallVector<uint32_t>({ashape[0], ashape[1], ashape[2], input_channels_padded}));
uint32_t conv_act_size_h = ashape_with_channels_padded[1];
uint32_t conv_act_size_w = ashape_with_channels_padded[2];
uint32_t conv_act_size_c = ashape_with_channels_padded[3];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl(
"Out_block_h must be divisible by out_subblock_h!");
}

ttnn::Shape ashape_with_channels_padded(std::vector<uint32_t>{ashape[0], ashape[1], ashape[2], input_channels_padded});
ttnn::Shape ashape_with_channels_padded({ashape[0], ashape[1], ashape[2], input_channels_padded});

uint32_t conv_act_size_h = ashape_with_channels_padded[1];
uint32_t conv_act_size_w = ashape_with_channels_padded[2];
Expand Down
16 changes: 8 additions & 8 deletions ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ Tensor to_layout_impl(

auto tensor = tensor_arg;

std::vector<uint32_t> output_shape;
SmallVector<uint32_t> output_shape;
if (layout == ttnn::TILE_LAYOUT and intended_shape.rank() < 2) {
output_shape.push_back(1);
tensor = ttnn::reshape(
tensor,
ttnn::Shape(
std::vector<std::uint32_t>{1, intended_shape[0]},
std::vector<std::uint32_t>{1, tensor_arg.get_shape().with_tile_padding()[0]}));
SmallVector<uint32_t>{1, intended_shape[0]},
SmallVector<uint32_t>{1, tensor_arg.get_shape().with_tile_padding()[0]}));
}
for (auto index = 0; index < intended_shape.rank(); ++index) {
output_shape.push_back(intended_shape[index]);
Expand Down Expand Up @@ -144,7 +144,7 @@ Tensor to_layout_impl(
output_memory_config =
tt::tt_metal::MemoryConfig{memory_config.memory_layout, memory_config.buffer_type};
}
std::vector<uint32_t> output_tensor_end;
SmallVector<uint32_t> output_tensor_end;
for (auto index = 0; index < tensor.get_shape().rank(); ++index) {
output_tensor_end.push_back(tensor.get_shape()[index] - 1);
}
Expand All @@ -154,7 +154,7 @@ Tensor to_layout_impl(
return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape});

} else if (layout == ttnn::TILE_LAYOUT) {
std::vector<uint32_t> padded_output_shape;
SmallVector<uint32_t> padded_output_shape;

for (int index = 0; index < tensor.get_shape().rank(); ++index) {
if (index >= tensor.get_shape().rank() - 2) {
Expand All @@ -166,7 +166,7 @@ Tensor to_layout_impl(
if (tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) {
// ttnn::tilize_with_val_padding doesn't support height sharded tensors
// workaround by applying padding and then tilizing
std::vector<std::pair<uint32_t, uint32_t>> padding = {
SmallVector<std::pair<uint32_t, uint32_t>> padding = {
{0, 0},
{0, 0},
{0, padded_output_shape[2] - output_shape[2]},
Expand All @@ -192,8 +192,8 @@ Tensor to_layout_impl(
tensor = tensor.unpad_from_tile(tensor.get_logical_shape());
return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape});
} else if (layout == ttnn::TILE_LAYOUT) {
std::vector<uint32_t> padded_output_shape;
std::vector<uint32_t> padded_input_start;
SmallVector<uint32_t> padded_output_shape;
SmallVector<uint32_t> padded_input_start;
for (int index = 0; index < tensor.get_shape().rank(); ++index) {
if (index >= tensor.get_shape().rank() - 2) {
padded_output_shape.push_back(ttnn::pad_to_multiple_of_tile_size(tensor.get_shape()[index]));
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ namespace data_movement {
while (output_tensor.get_shape().rank() > rank) {
const auto shape = output_tensor.get_shape();
const auto full_shape = output_tensor.get_shape().with_tile_padding();
std::vector<uint32_t> shape_vec{};
std::vector<uint32_t> full_shape_vec{};
SmallVector<uint32_t> shape_vec{};
SmallVector<uint32_t> full_shape_vec{};
// int i = 0;
// while(i < 3 and shape[i] == 1) i++;
for (int i = 1; i < shape.rank(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void bind_fold_operation(py::module& module) {
)doc",
ttnn::pybind_overload_t{
[](const decltype(ttnn::fold)& op, const ttnn::Tensor& input, uint32_t stride_h, uint32_t stride_w,
bool use_transpose_as_fold, std::optional<std::vector<uint32_t>> output_shape, uint32_t pad_c, uint32_t pad_h, uint32_t pad_w, std::optional<CoreCoord> grid_size, std::optional<MemoryConfig> override_memory_config,
bool use_transpose_as_fold, std::optional<SmallVector<uint32_t>> output_shape, uint32_t pad_c, uint32_t pad_h, uint32_t pad_w, std::optional<CoreCoord> grid_size, std::optional<MemoryConfig> override_memory_config,
const uint8_t& queue_id)
-> ttnn::Tensor {
return op(queue_id, input, stride_h, stride_w, use_transpose_as_fold, output_shape, pad_c, pad_h, pad_w, grid_size, override_memory_config);
Expand Down
Loading

0 comments on commit d0e0787

Please sign in to comment.