Skip to content

Commit

Permalink
#17077: convert bfp8 to bf16 before performing fillpad, and convert b… (
Browse files Browse the repository at this point in the history
#18063)

…ack to bf8 after

### Ticket
[#17077 ](#17077)

### Problem description
Support BFP8 for fil_implicit_pad
Also going to address some comments from the original PR merge for
fill_pad

### What's changed
Just convert bfp8 to bfp16 and back.

### Checklist
- [ ] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/13466079605)
  • Loading branch information
yugi957 authored Feb 22, 2025
1 parent 5aab19f commit a409dad
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 20 deletions.
95 changes: 93 additions & 2 deletions tests/ttnn/unit_tests/operations/test_fill_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,100 @@ def create_nd_padded_tiled_tensor(shape, tile_size, fill_value, dtype):
ttnn_dtype_to_torch_dtype = {
ttnn.uint32: torch.int32,
ttnn.bfloat16: torch.float32,
ttnn.bfloat8_b: torch.bfloat16,
}


@pytest.mark.parametrize(
"shape",
[
(1, 16),
(16, 1),
(1, 17),
(17, 1),
(16, 16),
(17, 17),
(31, 31),
(33, 33),
(65, 65),
(97, 97),
(1, 2, 3, 2, 1, 2, 97, 97),
],
)
@pytest.mark.parametrize("fill_value", [1.5, float("inf"), float("-inf")])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16])
@pytest.mark.parametrize("input_mem_config", [ttnn.DRAM_MEMORY_CONFIG])
@pytest.mark.parametrize("output_mem_config", [ttnn.DRAM_MEMORY_CONFIG])
def test_fill_pad_bfloat16(
device,
shape,
fill_value,
dtype,
input_mem_config,
output_mem_config,
):
torch.manual_seed(1234)
torch_input_tensor, padded_torch_tensor = create_nd_padded_tiled_tensor(
shape, 32, fill_value, ttnn_dtype_to_torch_dtype[dtype]
)
input_tensor = ttnn.to_device(
ttnn.from_torch(torch_input_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT),
device,
memory_config=input_mem_config,
)

output_tensor = ttnn.fill_implicit_tile_padding(input_tensor, fill_value, memory_config=output_mem_config)
padded_torch_output_tensor = ttnn.from_device(output_tensor).to_torch_with_padded_shape()

assert_with_pcc(padded_torch_tensor, padded_torch_output_tensor)


@pytest.mark.parametrize(
"shape",
[
(1, 32),
(16, 32),
(1, 32),
(17, 32),
(16, 32),
(17, 32),
(31, 32),
(33, 32),
(65, 64),
(97, 96),
(1, 2, 3, 2, 1, 2, 97, 96),
],
)

# separate test for bfloat8_b where last dim is tile_width aligned (required for bf8b)
@pytest.mark.parametrize("fill_value", [1.5, float("inf"), float("-inf")])
@pytest.mark.parametrize("dtype", [ttnn.bfloat8_b])
@pytest.mark.parametrize("input_mem_config", [ttnn.DRAM_MEMORY_CONFIG])
@pytest.mark.parametrize("output_mem_config", [ttnn.DRAM_MEMORY_CONFIG])
def test_fill_pad_bfloat8_b(
device,
shape,
fill_value,
dtype,
input_mem_config,
output_mem_config,
):
torch.manual_seed(1234)
torch_input_tensor, padded_torch_tensor = create_nd_padded_tiled_tensor(
shape, 32, fill_value, ttnn_dtype_to_torch_dtype[dtype]
)
input_tensor = ttnn.to_device(
ttnn.from_torch(torch_input_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT),
device,
memory_config=input_mem_config,
)

output_tensor = ttnn.fill_implicit_tile_padding(input_tensor, fill_value, memory_config=output_mem_config)
padded_torch_output_tensor = ttnn.from_device(output_tensor).to_torch_with_padded_shape()

assert_with_pcc(padded_torch_tensor, padded_torch_output_tensor)


@pytest.mark.parametrize(
"shape",
[
Expand All @@ -71,10 +162,10 @@ def create_nd_padded_tiled_tensor(shape, tile_size, fill_value, dtype):
],
)
@pytest.mark.parametrize("fill_value", [1])
@pytest.mark.parametrize("dtype", [ttnn.uint32, ttnn.bfloat16])
@pytest.mark.parametrize("dtype", [ttnn.uint32])
@pytest.mark.parametrize("input_mem_config", [ttnn.DRAM_MEMORY_CONFIG])
@pytest.mark.parametrize("output_mem_config", [ttnn.DRAM_MEMORY_CONFIG])
def test_fill_pad(
def test_fill_pad_int(
device,
shape,
fill_value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,20 @@ operation::ProgramWithCallbacks fill_pad_multi_core(const Tensor& input_tensor,
(std::uint32_t)tiles_per_2d_tensor,
(std::uint32_t)tiles_per_tile_row,
(std::uint32_t)tt::constants::TILE_HEIGHT,
(std::uint32_t)tt::constants::FACE_HEIGHT,
(std::uint32_t)sharded};
(std::uint32_t)tt::constants::FACE_HEIGHT};

std::map<string, string> compute_defines;
if (sharded) {
shard_builder::extend_sharding_compile_time_args(input_tensor, writer_compile_time_args);
compute_defines["SHARDED"] = "1";
}

tt::tt_metal::KernelHandle writer_kernel_id = tt::tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/data_movement/fill_pad/device/kernels/dataflow/fill_pad_writer.cpp",
all_cores,
tt_metal::WriterDataMovementConfig(writer_compile_time_args)); // writer only for in-place operation
tt_metal::WriterDataMovementConfig(
writer_compile_time_args, compute_defines)); // writer only for in-place operation

auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, false);
std::vector<uint32_t> writer_runtime_args = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ void kernel_main() {
constexpr uint32_t tile_size = get_compile_time_arg_val(10);
constexpr uint32_t tile_hw = tile_size * tile_size;
constexpr uint32_t face_size = get_compile_time_arg_val(11);
#define SHARDED get_compile_time_arg_val(12) == 1
constexpr uint32_t face_hw = face_size * face_size;
constexpr uint32_t alignment_adjustor = 16;

Expand All @@ -31,15 +30,15 @@ void kernel_main() {
uint32_t starting_tile_offset = get_arg_val<uint32_t>(rt_arg_ind++);
uint32_t num_2d_tensors = get_arg_val<uint32_t>(rt_arg_ind++);

#if (SHARDED)
#ifdef SHARDED
typedef ShardedInfo<
get_compile_time_arg_val(13),
get_compile_time_arg_val(14),
get_compile_time_arg_val(15),
get_compile_time_arg_val(16),
get_compile_time_arg_val(17),
get_compile_time_arg_val(18),
get_compile_time_arg_val(19)>
get_compile_time_arg_val(12), // Memory layout
get_compile_time_arg_val(13), // The number of sharding cores
get_compile_time_arg_val(14), // The page size we offset each write to
get_compile_time_arg_val(15), // The number of pages in each sharding row not including padding pages
get_compile_time_arg_val(16), // This defines times when contiguous pages can't be calculated
get_compile_time_arg_val(17), // pages_per_shard_x
get_compile_time_arg_val(18)> // pages_per_shard_y
tensor_shard_info;

const auto [mapping_table, rt_increment] =
Expand Down
21 changes: 15 additions & 6 deletions ttnn/cpp/ttnn/operations/data_movement/fill_pad/fill_pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ttnn/common/queue_id.hpp"
#include "ttnn/operations/core/core.hpp"
#include <utility>
#include "cpp/ttnn/operations/copy.hpp"

using namespace tt::tt_metal;

Expand All @@ -27,27 +28,35 @@ ttnn::Tensor FillPadOperation::invoke(
if (padded_width == input_tensor.get_logical_shape()[-1] && padded_height == input_tensor.get_logical_shape()[-2]) {
return input_tensor;
}
auto mutable_input_tensor = input_tensor;
auto output_memory_config = memory_config.value_or(input_tensor.memory_config());
if (input_tensor.get_dtype() == DataType::BFLOAT8_B) {
mutable_input_tensor = ttnn::typecast(mutable_input_tensor, DataType::BFLOAT16);
}
// if input_tensor is rank > 3, then we need to reshape it to rank 3 such that the last 2 dims are the same
if (input_tensor.get_logical_shape().rank() > 3) {
ttnn::Shape original_shape = input_tensor.get_logical_shape();
if (mutable_input_tensor.get_logical_shape().rank() > 3) {
ttnn::Shape original_shape = mutable_input_tensor.get_logical_shape();

uint32_t third_dim = 1;
for (uint32_t i = 0; i < original_shape.rank() - 2; i++) {
third_dim *= original_shape[i];
}

ttnn::Shape new_shape = ttnn::Shape{std::array<uint32_t, 3>{third_dim, original_shape[-2], original_shape[-1]}};
auto reshaped_tensor = ttnn::reshape(input_tensor, new_shape);
auto reshaped_tensor = ttnn::reshape(mutable_input_tensor, new_shape);

reshaped_tensor = operation::run_without_autoformat(
FillPad{fill_value, output_memory_config}, {reshaped_tensor}, {}, {}, queue_id)
.at(0);
return ttnn::reshape(reshaped_tensor, original_shape);
}
return operation::run_without_autoformat(
FillPad{fill_value, output_memory_config}, {input_tensor}, {}, {}, queue_id)
.at(0);
auto output_tensor = operation::run_without_autoformat(
FillPad{fill_value, output_memory_config}, {mutable_input_tensor}, {}, {}, queue_id)
.at(0);
if (input_tensor.get_dtype() == DataType::BFLOAT8_B) {
return ttnn::typecast(output_tensor, DataType::BFLOAT8_B);
}
return output_tensor;
}

} // namespace ttnn::operations::data_movement

0 comments on commit a409dad

Please sign in to comment.