Skip to content

Commit

Permalink
#17768: Float32 support for Inference mode in Batch Norm (#17587)
Browse files Browse the repository at this point in the history
### Ticket
#17768

### Problem description
To Provide Fp32 support for Inference mode of BN

### What's changed
Support provided for fp32 data type for inference mode of BN

### Checklist
- [x] [All post-commit
tests](https://github.com/tenstorrent/tt-metal/actions/runs/13217558701)
- [x] [Blackhole post-commit
tests](https://github.com/tenstorrent/tt-metal/actions/runs/13157671059)
- [x] [(Single-card) Tests for new
models](https://github.com/tenstorrent/tt-metal/actions/runs/13217560775)
- Passed as in main
- [x] [(Single-card) Demo
tests](https://github.com/tenstorrent/tt-metal/actions/runs/13217560090)
- Passed as in main
- [x] [(Single-card) Device perf
regressions](https://github.com/tenstorrent/tt-metal/actions/runs/13217559606)
- [x] [(Single-card) Model perf
tests](https://github.com/tenstorrent/tt-metal/actions/runs/13217559245)
- Passed as in main
  • Loading branch information
VirdhatchaniKN authored Feb 9, 2025
1 parent a4b0687 commit e1a028f
Show file tree
Hide file tree
Showing 7 changed files with 476 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,23 @@ def data_gen_with_range_batch_norm(
device,
is_input=False,
required_grad=False,
testing_dtype="bfloat16",
):
assert high > low, "Incorrect range provided"
torch.manual_seed(213919)
channels = input_shapes[1]
size = input_shapes if is_input else channels
pt_tensor = torch.rand(size, requires_grad=required_grad).bfloat16() * (high - low) + low
torch_dtype = getattr(torch, testing_dtype)
ttnn_dtype = getattr(ttnn, testing_dtype)
pt_tensor = torch.rand(size, requires_grad=required_grad, dtype=torch_dtype) * (high - low) + low
reshaped_tensor = pt_tensor
if not is_input:
reshaped_tensor = pt_tensor.view(1, channels, 1, 1)
tt_tensor = ttnn.from_torch(
reshaped_tensor,
device=device,
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat16,
dtype=ttnn_dtype,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
return pt_tensor, tt_tensor
Expand Down
123 changes: 123 additions & 0 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,129 @@
compare_results_batch_norm,
)
from itertools import product
from models.utility_functions import skip_for_grayskull


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
@pytest.mark.parametrize("channel_size", [1, 2, 3, 4])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
def test_BN_fp32_full_value(device, channel_size, eps, weight, bias):
input_tensor_torch = torch.full(torch.Size([3, channel_size, 64, 120]), 1, dtype=torch.float32)
batch_mean_torch = torch.full(torch.Size([channel_size]), 0.00030171126, dtype=torch.float32)
batch_var_torch = torch.full(torch.Size([channel_size]), 0.1262342343, dtype=torch.float32)
weight_torch = torch.full(torch.Size([channel_size]), 0.246943565369, dtype=torch.float32) if weight else None
bias_torch = torch.full(torch.Size([channel_size]), 0.59, dtype=torch.float32) if bias else None

result_torch = torch.nn.functional.batch_norm(
input=input_tensor_torch,
running_mean=batch_mean_torch,
running_var=batch_var_torch,
weight=weight_torch,
bias=bias_torch,
eps=eps,
)

batch_mean_torch = batch_mean_torch.view(1, channel_size, 1, 1)
batch_var_torch = batch_var_torch.view(1, channel_size, 1, 1)
weight_torch = weight_torch.view(1, channel_size, 1, 1) if weight else None
bias_torch = bias_torch.view(1, channel_size, 1, 1) if bias else None

input_tensor_tt = ttnn.from_torch(input_tensor_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
batch_mean_tt = ttnn.from_torch(batch_mean_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
batch_var_tt = ttnn.from_torch(batch_var_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
weight_tt = (
ttnn.from_torch(weight_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) if weight else None
)
bias_tt = ttnn.from_torch(bias_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) if bias else None

result_tt = ttnn.batch_norm(
input_tensor_tt, running_mean=batch_mean_tt, running_var=batch_var_tt, eps=eps, weight=weight_tt, bias=bias_tt
)
tt_out = ttnn.to_torch(result_tt)

status_1 = torch.allclose(result_torch, tt_out, atol=1e-10, rtol=1e-5)
status_2 = compare_results_batch_norm([result_torch], [tt_out])
assert status_2 and status_1


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"input_shapes",
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])),
torch.Size([3, 1, 64, 120]),
torch.Size([3, 2, 64, 120]),
],
)
@pytest.mark.parametrize(
"check_mean, check_var",
[
(False, False), # xfail case
(True, False), # xfail case
(False, True), # xfail case
(True, True),
],
)
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
def test_batch_norm_fp32(
input_shapes, check_mean, check_var, weight, bias, eps, device, training=False, testing_dtype="float32"
):
in_data, input_tensor = data_gen_with_range_batch_norm(
input_shapes, 5, 10, device, is_input=True, testing_dtype=testing_dtype
)
mean_data, mean_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device, testing_dtype=testing_dtype)
if (check_mean)
else (None, None)
)
var_data, var_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 20, device, testing_dtype=testing_dtype)
if (check_var)
else (None, None)
)
weight_data, weight_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device, testing_dtype=testing_dtype)
if weight
else (None, None)
)
bias_data, bias_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device, testing_dtype=testing_dtype)
if bias
else (None, None)
)

if (not training) and ((not check_mean) or (not check_var)):
pytest.xfail("running_mean and running_var must be defined in evaluation mode")

tt_output_tensor_on_device = ttnn.batch_norm(
input_tensor,
running_mean=mean_tensor,
running_var=var_tensor,
training=training,
eps=eps,
weight=weight_tensor,
bias=bias_tensor,
)
tt_output = ttnn.to_torch(tt_output_tensor_on_device)
torch_result = torch.nn.functional.batch_norm(
input=in_data,
running_mean=mean_data,
running_var=var_data,
weight=weight_data,
bias=bias_data,
training=training,
eps=eps,
)
comp_pass = compare_results_batch_norm([tt_output], [torch_result]) and torch.allclose(
torch_result, tt_output, atol=1e-6, rtol=1e-3
)
assert comp_pass


@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,49 @@
#include "ttnn/tensor/tensor.hpp"

namespace ttnn::operations::normalization {

namespace {
inline void check_tensor_BN(const Tensor& tensor, std::string_view name, std::uint32_t input_c_dim) {
TT_FATAL(
tensor.get_layout() == Layout::TILE, "batch_norm only supports tiled layout. Got: {}", tensor.get_layout());
TT_FATAL(
tensor.get_dtype() == DataType::BFLOAT16 || tensor.get_dtype() == DataType::FLOAT32,
"batch_norm only supports bfloat16, float32. Got: {}",
tensor.get_dtype());
TT_FATAL(
tensor.storage_type() == StorageType::DEVICE,
"Operands to batch_norm need to be on device! Got: {}",
tensor.storage_type());
TT_FATAL(tensor.buffer() != nullptr, "Operands to batch_norm need to be allocated in buffers on device!");
TT_FATAL(tensor.get_logical_shape().rank() == 4, "batch_norm supports tensors of rank 4");
TT_FATAL(tensor.get_logical_shape()[1] == input_c_dim, "{}[1] must be the same as input's channel size.", name);
}
} // namespace

void BatchNormOperation::validate_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args;

check_tensor(input, "batch_norm", "input");
check_tensor(batch_mean, "batch_norm", "batch_mean");
check_tensor(batch_var, "batch_norm", "batch_var");
check_tensor(weight, "batch_norm", "weight");
check_tensor(bias, "batch_norm", "bias");
check_tensor(output, "batch_norm", "output");

// input (N, C, H, W)
auto C = input.get_logical_shape()[1];

check_tensor_BN(input, "input_shape", C);
check_tensor_BN(batch_mean, "batch_mean_shape", C);
check_tensor_BN(batch_var, "batch_mean_shape", C);

// output (N, C, H, W)
if (output.has_value()) {
auto check_C = output.value().get_logical_shape()[1];
TT_FATAL(C == check_C, "output_shape[1] must be the same as input's channel size.");
check_tensor_BN(output.value(), "output_shape", C);
}

// mean (1, C, 1, 1)
TT_FATAL(batch_mean.get_logical_shape()[1] == C, "batch_mean_shape[1] must be the same as input's channel size.");
// var (1, C, 1, 1)
TT_FATAL(batch_var.get_logical_shape()[1] == C, "batch_var_shape[1] must be the same as input's channel size.");

// weight (1, C, 1, 1)
if (weight.has_value()) {
TT_FATAL(
weight.value().get_logical_shape()[1] == C, "weight_shape[1] must be the same as input's channel size.");
TT_FATAL(
weight.value().get_logical_shape()[1] == C, "weight_shape[1] must be the same as input's channel size.");
check_tensor_BN(weight.value(), "weight_shape", C);
}

// bias (1, C, 1, 1)
if (bias.has_value()) {
TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size.");
TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size.");
check_tensor_BN(bias.value(), "bias_shape", C);
}
}

Expand Down Expand Up @@ -127,7 +134,7 @@ std::tuple<BatchNormOperation::operation_attributes_t, BatchNormOperation::tenso
std::optional<Tensor> bias,
std::optional<Tensor> output,
const std::optional<MemoryConfig>& memory_config) {
operation_attributes_t operation_attributes{eps, memory_config.value_or(input.memory_config())};
operation_attributes_t operation_attributes{eps, memory_config.value_or(input.memory_config()), input.get_dtype()};
tensor_args_t tensor_args{input, batch_mean, batch_var, std::move(weight), std::move(bias), std::move(output)};
return {operation_attributes, tensor_args};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ void set_or_update_runtime_arguments(
}

uint32_t cHtWt = cHt * cWt;
class bfloat16 bfloat_scalar_eps(eps);
uint32_t packed_scalar_eps = pack_two_bfloat16_into_uint32({bfloat_scalar_eps, bfloat_scalar_eps});
const auto scalar = eps;
const auto packed_scalar_eps = input_tensor.get_dtype() == DataType::FLOAT32
? std::bit_cast<uint32_t>(scalar)
: pack_two_bfloat16_into_uint32({scalar, scalar});

std::array reader_runtime_args = {
packed_scalar_eps,
input_tensor.buffer()->address(),
Expand Down Expand Up @@ -218,38 +221,83 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
const auto e_is_dram = weight_has_value and weight_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
const auto f_is_dram = bias_has_value and bias_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM;

std::map<std::string, std::string> dataflow_defines; // Currently support only for fp32, bf16
if (input_tensor.get_dtype() == DataType::FLOAT32) {
dataflow_defines["FILL_TILE_WITH_FIRST_ELEMENT"] = "fill_tile_with_first_element<float>";
dataflow_defines["FILL_WITH_VALUE_FLOAT"] = "fill_with_val<1024, float>";
} else {
dataflow_defines["FILL_TILE_WITH_FIRST_ELEMENT"] = "fill_tile_with_first_element_bfloat16";
dataflow_defines["FILL_WITH_VALUE"] = "fill_with_val_bfloat16";
}

// READER KERNEL
auto reader_defines = dataflow_defines;
auto reader_kernel_id = tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp",
all_device_cores,
tt_metal::ReaderDataMovementConfig({a_is_dram}));
tt_metal::ReaderDataMovementConfig({a_is_dram}, std::move(reader_defines)));

// WRITER KERNEL
auto writer_defines = dataflow_defines;
auto writer_kernel_id = tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp",
all_device_cores,
tt_metal::WriterDataMovementConfig({
b_is_dram,
c_is_dram,
d_is_dram,
e_is_dram,
f_is_dram,
static_cast<uint32_t>(weight_has_value),
static_cast<uint32_t>(bias_has_value),
}));
tt_metal::WriterDataMovementConfig(
{
b_is_dram,
c_is_dram,
d_is_dram,
e_is_dram,
f_is_dram,
static_cast<uint32_t>(weight_has_value),
static_cast<uint32_t>(bias_has_value),
},
std::move(writer_defines)));

// COMPUTE KERNEL
bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 ||
c_data_format == tt::DataFormat::Float32;

uint32_t src_input_cb_index = tt::CBIndex::c_0;
uint32_t src_batch_mean_cb_index = tt::CBIndex::c_1;
uint32_t src_batch_var_cb_index = tt::CBIndex::c_3;
uint32_t src_eps_cb_index = tt::CBIndex::c_4;
uint32_t src_temp_den_cb_index = tt::CBIndex::c_5;
uint32_t src_temp_num_cb_index = tt::CBIndex::c_6;
uint32_t src_weight_cb_index = tt::CBIndex::c_16;
uint32_t src_temp_1_cb_index = tt::CBIndex::c_17;
uint32_t src_bias_cb_index = tt::CBIndex::c_18;

std::vector<UnpackToDestMode> unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default);
if (fp32_dest_acc_en) {
for (const auto cb_index :
{src_input_cb_index,
src_batch_mean_cb_index,
src_batch_var_cb_index,
src_temp_num_cb_index,
src_temp_den_cb_index,
src_eps_cb_index,
src_weight_cb_index,
src_temp_1_cb_index,
src_bias_cb_index}) {
unpack_to_dest_mode[cb_index] = UnpackToDestMode::UnpackToDestFp32;
}
}

std::vector<uint32_t> compute_kernel_args = {
static_cast<uint32_t>(weight_has_value), static_cast<uint32_t>(bias_has_value)};
auto compute_kernel_id = tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp",
fmt::format(
"ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_{}.cpp",
fp32_dest_acc_en ? "sfpu_kernel" : "kernel"),
all_device_cores,
tt_metal::ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = compute_kernel_args});
tt_metal::ComputeConfig{
.fp32_dest_acc_en = fp32_dest_acc_en,
.unpack_to_dest_mode = std::move(unpack_to_dest_mode),
.compile_args = compute_kernel_args});

auto set_runtime_args = [](Program& program, KernelHandle kernel_id, CoreCoord core, auto&& args) {
tt_metal::SetRuntimeArgs(program, kernel_id, core, args);
Expand Down
Loading

0 comments on commit e1a028f

Please sign in to comment.