Skip to content

Commit

Permalink
#15496 Change Tensor serialization to serialize TensorSpec with flatb…
Browse files Browse the repository at this point in the history
…uffer (#17748)

### Ticket
#15496 
#16067

### Problem description
Currently TensorSpec isn't being serialized properly, which causes
issues in some cases.
In particular, it causes bugs in `as_tensor` with transposed tiles.

### What's changed
Introduce flatbuffer schema for TensorSpec serialization.
Added conversion code to/from TensorSpec to flatbuffer struct.
Heavily modified serialization code to preserve compatibility with the
old format, but serialize TensorSpec with flatbuffer in newer versions.
Changed fstream io into fread/fwrite to improve performance.

### Checklist
- [x] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/13213319656)
CI passes
- [x] [Model
regression](https://github.com/tenstorrent/tt-metal/actions/runs/13209781898)
- [x] [Device performance
regression](https://github.com/tenstorrent/tt-metal/actions/runs/13209784841)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sminakov-tt authored Feb 8, 2025
1 parent 99a6252 commit 15ffcc8
Show file tree
Hide file tree
Showing 12 changed files with 764 additions and 235 deletions.
4 changes: 4 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ set(TTNN_PUBLIC_INCLUDE_DIRS
${CMAKE_CURRENT_SOURCE_DIR} # ${PROJECT_SOURCE_DIR}/ttnn
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/deprecated # symlink to tt_eager; should become native folder once merge complete
${CMAKE_CURRENT_SOURCE_DIR}/cpp
${CMAKE_CURRENT_BINARY_DIR}/flatbuffers
)
set(TTNN_PUBLIC_LINK_LIBRARIES
metal_common_libs
Expand All @@ -689,6 +690,7 @@ set(TTNN_PUBLIC_LINK_LIBRARIES
xtensor
xtensor-blas
xtl
FlatBuffers::FlatBuffers
)
set(TTNN_PUBLIC_LINK_DIRS "")

Expand Down Expand Up @@ -803,6 +805,8 @@ endforeach(
${TTNN_SUBLIBRARIES}
)

GENERATE_FBS_HEADER(${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/tensor/flatbuffer/tensor_types.fbs)
list(APPEND TENSOR_SRCS ${FBS_GENERATED_HEADER_FILE})
add_ttnn_sublibrary(ttnn_tensor ${TENSOR_SRCS})
add_ttnn_sublibrary(ttnn_ccl ${CCL_TTNN_SRCS})
add_ttnn_sublibrary(ttnn_ccl_exp ${CCL_EXPERIMENTAL_TTNN_SRCS})
Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ set(TENSOR_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/layout/page_config.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/tensor_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/xtensor/partition.cpp
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/tensor_types_to_flatbuffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/tensor_types_from_flatbuffer.cpp
CACHE INTERNAL
"Tensor sources to reuse in ttnn build"
)
103 changes: 103 additions & 0 deletions ttnn/cpp/ttnn/tensor/flatbuffer/tensor_types.fbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
namespace ttnn.flatbuffer;

table CoreCoord {
x: int;
y: int;
}

table CoreRange {
start: CoreCoord;
end: CoreCoord;
}

table CoreRangeSet {
ranges: [CoreRange];
}

table Tile {
tile_shape_h: uint32;
tile_shape_w: uint32;
transpose_tile: bool;
}

enum TensorMemoryLayout: ushort {
Interleaved = 0,
SingleBank = 1,
HeightSharded = 2,
WidthSharded = 3,
BlockSharded = 4,
}

enum BufferType: ushort {
DRAM = 0,
L1 = 1,
SystemMemory = 2,
L1Small = 3,
Trace = 4,
}

enum ShardOrientation : ubyte {
RowMajor = 0,
ColMajor = 1,
}

enum ShardMode : ubyte {
Physical,
Logical,
}

table ShardShape {
height: uint32;
width: uint32;
}

table ShardSpec {
grid: CoreRangeSet;
shape_h: uint32;
shape_w: uint32;
orientation: ShardOrientation;
shard_mode: ShardMode;
physical_shard_shape: ShardShape;
}

enum DataType : ubyte {
BFloat16 = 0,
Float32 = 1,
UInt32 = 2,
BFloat8B = 3,
BFloat4B = 4,
UInt8 = 5,
UInt16 = 6,
Int32 = 7,
Invalid = 8
}

table RowMajorPageConfig {}
table TilePageConfig {
tile: Tile;
}

union PageConfig {
row_major: RowMajorPageConfig,
tile: TilePageConfig,
}

table MemoryConfig {
memory_layout: TensorMemoryLayout;
buffer_type: BufferType;
shard_spec: ShardSpec;
}

table TensorLayout {
data_type: DataType;
page_config: PageConfig;
memory_config: MemoryConfig;
alignment: [uint32];
}

table TensorSpec {
shape: [uint32];
tensor_layout: TensorLayout;
}

root_type TensorSpec;
132 changes: 132 additions & 0 deletions ttnn/cpp/ttnn/tensor/flatbuffer/tensor_types_from_flatbuffer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "tensor_types_from_flatbuffer.hpp"

namespace ttnn {

BufferType from_flatbuffer(flatbuffer::BufferType type) {
switch (type) {
case flatbuffer::BufferType::DRAM: return BufferType::DRAM;
case flatbuffer::BufferType::L1: return BufferType::L1;
case flatbuffer::BufferType::SystemMemory: return BufferType::SYSTEM_MEMORY;
case flatbuffer::BufferType::L1Small: return BufferType::L1_SMALL;
case flatbuffer::BufferType::Trace: return BufferType::TRACE;
}
TT_THROW("Unsupported BufferType from flatbuffer.");
}

TensorMemoryLayout from_flatbuffer(flatbuffer::TensorMemoryLayout layout) {
switch (layout) {
case flatbuffer::TensorMemoryLayout::Interleaved: return TensorMemoryLayout::INTERLEAVED;
case flatbuffer::TensorMemoryLayout::SingleBank: return TensorMemoryLayout::SINGLE_BANK;
case flatbuffer::TensorMemoryLayout::HeightSharded: return TensorMemoryLayout::HEIGHT_SHARDED;
case flatbuffer::TensorMemoryLayout::WidthSharded: return TensorMemoryLayout::WIDTH_SHARDED;
case flatbuffer::TensorMemoryLayout::BlockSharded: return TensorMemoryLayout::BLOCK_SHARDED;
}
TT_THROW("Unsupported TensorMemoryLayout from flatbuffer.");
}

DataType from_flatbuffer(flatbuffer::DataType type) {
switch (type) {
case flatbuffer::DataType::BFloat16: return DataType::BFLOAT16;
case flatbuffer::DataType::Float32: return DataType::FLOAT32;
case flatbuffer::DataType::UInt32: return DataType::UINT32;
case flatbuffer::DataType::BFloat8B: return DataType::BFLOAT8_B;
case flatbuffer::DataType::BFloat4B: return DataType::BFLOAT4_B;
case flatbuffer::DataType::UInt8: return DataType::UINT8;
case flatbuffer::DataType::UInt16: return DataType::UINT16;
case flatbuffer::DataType::Int32: return DataType::INT32;
case flatbuffer::DataType::Invalid: return DataType::INVALID;
}
TT_THROW("Unsupported DataType from flatbuffer.");
}

MemoryConfig from_flatbuffer(const flatbuffer::MemoryConfig* config) {
std::optional<ShardSpec> shard_spec;
if (config->shard_spec()) {
shard_spec = from_flatbuffer(config->shard_spec());
}
return MemoryConfig{
from_flatbuffer(config->memory_layout()),
from_flatbuffer(config->buffer_type()),
shard_spec,
};
}

ShardOrientation from_flatbuffer(flatbuffer::ShardOrientation orientation) {
switch (orientation) {
case flatbuffer::ShardOrientation::RowMajor: return ShardOrientation::ROW_MAJOR;
case flatbuffer::ShardOrientation::ColMajor: return ShardOrientation::COL_MAJOR;
}
TT_THROW("Unsupported ShardOrientation from flatbuffer.");
}

ShardMode from_flatbuffer(flatbuffer::ShardMode mode) {
switch (mode) {
case flatbuffer::ShardMode::Physical: return ShardMode::PHYSICAL;
case flatbuffer::ShardMode::Logical: return ShardMode::LOGICAL;
}
TT_THROW("Unsupported ShardMode from flatbuffer.");
}

ShardSpec from_flatbuffer(const flatbuffer::ShardSpec* spec) {
CoreRangeSet grid = from_flatbuffer(spec->grid());
std::array<uint32_t, 2> shape = {spec->shape_h(), spec->shape_w()};
ShardOrientation orientation = from_flatbuffer(spec->orientation());
ShardMode mode = from_flatbuffer(spec->shard_mode());
if (const auto* fb_shard_shape = spec->physical_shard_shape()) {
std::array<uint32_t, 2> physical_shard_shape = {fb_shard_shape->height(), fb_shard_shape->width()};
return ShardSpec(grid, shape, physical_shard_shape, orientation);
}
return ShardSpec(grid, shape, orientation, mode);
}

CoreCoord from_flatbuffer(const flatbuffer::CoreCoord* core_coord) {
return CoreCoord{core_coord->x(), core_coord->y()};
}

CoreRange from_flatbuffer(const flatbuffer::CoreRange* core_range) {
return CoreRange{
{core_range->start()->x(), core_range->start()->y()}, {core_range->end()->x(), core_range->end()->y()}};
}

CoreRangeSet from_flatbuffer(const flatbuffer::CoreRangeSet* core_range_set) {
std::vector<CoreRange> ranges;
for (const auto* range : *core_range_set->ranges()) {
ranges.emplace_back(
CoreCoord{range->start()->x(), range->start()->y()}, CoreCoord{range->end()->x(), range->end()->y()});
}
return CoreRangeSet{ranges};
}

TensorLayout from_flatbuffer(const flatbuffer::TensorLayout* layout) {
PageConfig page_config = [&] {
switch (layout->page_config_type()) {
case flatbuffer::PageConfig::row_major: return PageConfig(Layout::ROW_MAJOR);
case flatbuffer::PageConfig::tile: {
const auto* tile_page_config = layout->page_config_as_tile();
const auto* flat_tile = tile_page_config->tile();
Tile tile(
std::array{flat_tile->tile_shape_h(), flat_tile->tile_shape_w()}, flat_tile->transpose_tile());
return PageConfig(Layout::TILE, tile);
}
default: TT_THROW("Unsupported PageConfig type from flatbuffer.");
}
}();

return TensorLayout::restore_from_serialized(
from_flatbuffer(layout->data_type()),
page_config,
from_flatbuffer(layout->memory_config()),
Alignment(SmallVector<uint32_t>(layout->alignment()->cbegin(), layout->alignment()->cend())));
}

TensorSpec from_flatbuffer(const flatbuffer::TensorSpec* spec) {
return TensorSpec(
Shape(SmallVector<uint32_t>(spec->shape()->cbegin(), spec->shape()->cend())),
from_flatbuffer(spec->tensor_layout()));
}

} // namespace ttnn
26 changes: 26 additions & 0 deletions ttnn/cpp/ttnn/tensor/flatbuffer/tensor_types_from_flatbuffer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "tensor_types_generated.h"
#include "ttnn/tensor/types.hpp"
#include "ttnn/tensor/tensor_spec.hpp"

namespace ttnn {

BufferType from_flatbuffer(flatbuffer::BufferType type);
TensorMemoryLayout from_flatbuffer(flatbuffer::TensorMemoryLayout layout);
DataType from_flatbuffer(flatbuffer::DataType type);
ShardOrientation from_flatbuffer(flatbuffer::ShardOrientation orientation);
ShardMode from_flatbuffer(flatbuffer::ShardMode mode);
CoreCoord from_flatbuffer(const flatbuffer::CoreCoord* fb_coord);
CoreRange from_flatbuffer(const flatbuffer::CoreRange* fb_coord);
CoreRangeSet from_flatbuffer(const flatbuffer::CoreRangeSet* fb_coord);
ShardSpec from_flatbuffer(const flatbuffer::ShardSpec* spec);
MemoryConfig from_flatbuffer(const flatbuffer::MemoryConfig* config);
TensorLayout from_flatbuffer(const flatbuffer::TensorLayout* layout);
TensorSpec from_flatbuffer(const flatbuffer::TensorSpec* spec);

} // namespace ttnn
Loading

0 comments on commit 15ffcc8

Please sign in to comment.