Skip to content

Commit

Permalink
#14324 Add Uint32 Support for Untilize and Tilize
Browse files Browse the repository at this point in the history
* #0: add uint32 support to untilize and tilize
  • Loading branch information
ntarafdar authored and o2buzzle committed Nov 4, 2024
1 parent 9829450 commit b82aee9
Show file tree
Hide file tree
Showing 21 changed files with 168 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -787,7 +787,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -980,7 +980,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -1018,7 +1018,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -1186,7 +1186,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -783,7 +783,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -962,7 +962,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -1000,7 +1000,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -813,7 +813,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -988,7 +988,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -1026,7 +1026,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -873,7 +873,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -1062,7 +1062,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down Expand Up @@ -1102,7 +1102,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
x = ttnn.tilize_with_val_padding(
x,
padded_shape,
0,
0.0,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=self.model_config["ACTIVATIONS_DTYPE"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def gen_tilize_with_val_padding_args(
]
output_tensor_shape[-2] = nearest_32(output_tensor_shape[-2])
output_tensor_shape[-1] = nearest_32(output_tensor_shape[-1])
pad_value = random.uniform(-100, 100)
pad_value = float(random.uniform(-100, 100))
# Cast to bfloat16 then back to float for exact match
pad_value = torch.Tensor([pad_value]).to(torch.bfloat16).to(torch.float).item()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"input_mem_config": [ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM)],
"output_mem_config": ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
"output_tensor_shape": [1, 1, 128, 7328],
"pad_value": 10,
"pad_value": 10.0,
},
)
]
Expand Down
23 changes: 23 additions & 0 deletions tests/ttnn/unit_tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,26 @@ def test_reshape_host(input_shape, output_shape, device):
output = ttnn.to_torch(ttnn_output)

assert_with_pcc(torch_result, output, 0.9999)


# required for Embedding
@pytest.mark.parametrize(
"input_shape, output_shape",
[
((1, 12), (12, 1)),
((1, 32), (32, 1)),
((64, 32), (1, 1, 64, 32)),
],
)
def test_reshape_int(input_shape, output_shape, device):
torch_input_tensor = torch.randint(0, 100, input_shape)
torch_result = torch_input_tensor.reshape(output_shape)

input_tensor = ttnn.from_torch(
torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
ttnn_output = ttnn.reshape(input_tensor, output_shape)

output = ttnn.to_torch(ttnn_output)

assert_with_pcc(torch_result, output, 0.9999)
15 changes: 14 additions & 1 deletion ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,21 @@ Tensor to_layout_impl(
tensor = ttnn::pad(0, tensor, padding, 0, true, std::nullopt);
return ttnn::tilize(tensor, output_memory_config, dtype, use_multicore_tilize);
} else {
PadValue pad_value_variant;
if (tensor.get_dtype() == ttnn::DataType::BFLOAT16 or tensor.get_dtype() == ttnn::DataType::FLOAT32) {
pad_value_variant = 0.0f;
}
else {
pad_value_variant = (uint32_t) 0;
}

tensor = ttnn::tilize_with_val_padding(
tensor, padded_output_shape, 0, output_memory_config, dtype, use_multicore_tilize);
tensor, padded_output_shape,
pad_value_variant,
output_memory_config,
dtype,
use_multicore_tilize
);
}

return ttnn::reshape(tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape, padded_output_shape}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

void kernel_main() {

constexpr uint32_t bytes_per_tile_row = get_compile_time_arg_val(3);

// Constexpr
constexpr uint32_t cb_id_in0 = 0;
constexpr uint32_t tile_height = 32;
Expand All @@ -33,7 +35,7 @@ void kernel_main() {
// that the stick size dictates tiles c, but stick size
// doesn't necessarily need to be divisible by tiles c...
// this is only the case really for tilize
const uint32_t num_tiles_block_c = block_row_size / 64; // Assuming 2 bytes per datum, there are 64 bytes per tile row
const uint32_t num_tiles_block_c = block_row_size / bytes_per_tile_row; // Assuming 2 bytes per datum, there are 64 bytes per tile row

constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1;
#define stick_size_is_pow2 get_compile_time_arg_val(1) == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ void kernel_main() {
constexpr uint32_t cb_id_in0 = 0;
constexpr uint32_t tile_height = 32;

constexpr uint32_t tile_row_shift_bits = get_compile_time_arg_val(3);

const uint32_t src_addr = get_arg_val<uint32_t>(0);
const uint32_t unpadded_X_size = get_arg_val<uint32_t>(1);
const uint32_t padded_X_size = get_arg_val<uint32_t>(2);
Expand All @@ -25,7 +27,7 @@ void kernel_main() {
const uint32_t n_block_reps = get_arg_val<uint32_t>(5);

const uint32_t num_tiles_per_row =
padded_X_size >> 6; // means / 64, assuming bfloat16, there are 64 bytes per tile row
padded_X_size >> tile_row_shift_bits; // means / 64, assuming bfloat16, there are 64 bytes per tile row

constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1;
#define stick_size_is_pow2 get_compile_time_arg_val(1) == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ void TilizeWithValPadding::validate(const std::vector<Tensor>& input_tensors) co
TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands need to be on device!");
TT_FATAL(input_tensor_a.buffer() != nullptr, "Operands need to be allocated in buffers on device!");
TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "Can only tilize row major data");
TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16, "Can only tilize bfloat16 tensors");
TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16 or input_tensor_a.get_dtype() == DataType::UINT32, "Can only tilize bfloat16 or uint32 tensors");
TT_FATAL(input_shape.rank() >= 2, "Input tensor must be of rank >2, but its shape is {}", input_shape);


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/operation.hpp"
#include "ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding_common.hpp"

namespace ttnn::operations::data_movement {

struct TilizeWithValPadding {
const tt::tt_metal::LegacyShape output_tensor_shape;
const float pad_value;
const PadValue pad_value;
const MemoryConfig output_mem_config;
const DataType output_dtype;
const bool use_multicore;
Expand Down
Loading

0 comments on commit b82aee9

Please sign in to comment.