Skip to content

Commit

Permalink
#14316: refactoring moreh_helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
hschoi4448 committed Oct 26, 2024
1 parent 734c8c1 commit dd80a01
Show file tree
Hide file tree
Showing 80 changed files with 456 additions and 470 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ void MorehClipGradNormStep1::validate(
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &optional_input_tensors) const {
for (const auto &input : input_tensors) {
check_tensor(input, "moreh_clip_grad_norm_step1", "input");
ttnn::operations::check_tensor(input, "moreh_clip_grad_norm_step1", "input");
}

const auto &tmp_pow_sum = optional_input_tensors.at(0).value();
check_tensor(tmp_pow_sum, "moreh_clip_grad_norm_step1", "tmp_pow_sum");
ttnn::operations::check_tensor(tmp_pow_sum, "moreh_clip_grad_norm_step1", "tmp_pow_sum");
};

std::vector<ttnn::SimpleShape> MorehClipGradNormStep1::compute_output_shapes(const std::vector<Tensor> &) const { return {}; }
Expand Down Expand Up @@ -99,10 +99,10 @@ void moreh_clip_grad_norm_step1(const std::vector<Tensor> &inputs, float norm_ty

void MorehClipGradNormStep2::validate(const std::vector<Tensor> &input_tensors) const {
const auto &tmp_pow_sum = input_tensors.at(0);
check_tensor(tmp_pow_sum, "moreh_clip_grad_norm_step2", "tmp_pow_sum");
ttnn::operations::check_tensor(tmp_pow_sum, "moreh_clip_grad_norm_step2", "tmp_pow_sum");

const auto &total_norm = input_tensors.at(1);
check_tensor(total_norm, "moreh_clip_grad_norm_step2", "total_norm");
ttnn::operations::check_tensor(total_norm, "moreh_clip_grad_norm_step2", "total_norm");
}

std::vector<ttnn::SimpleShape> MorehClipGradNormStep2::compute_output_shapes(const std::vector<Tensor> &) const { return {}; }
Expand Down Expand Up @@ -139,11 +139,11 @@ void MorehClipGradNormStep3::validate(
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &optional_input_tensors) const {
for (const auto &input : input_tensors) {
check_tensor(input, "moreh_clip_grad_norm_step3", "input");
ttnn::operations::check_tensor(input, "moreh_clip_grad_norm_step3", "input");
}

const auto &clip_coef_clamped = optional_input_tensors.at(0).value();
check_tensor(clip_coef_clamped, "moreh_clip_grad_norm_step3", "clip_coef_clamped");
ttnn::operations::check_tensor(clip_coef_clamped, "moreh_clip_grad_norm_step3", "clip_coef_clamped");
}

std::vector<ttnn::SimpleShape> MorehClipGradNormStep3::compute_output_shapes(const std::vector<Tensor> &) const { return {}; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(

const auto cb_data_format = tt_metal::datatype_to_dataformat_converter(tmp_pow_sum.get_dtype());

CreateCircularBuffer(
ttnn::operations::CreateCircularBuffer(
program,
core_group_1,
cb_data_format,
Expand Down Expand Up @@ -112,8 +112,8 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(
"ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/kernels/"
"writer_moreh_clip_grad_norm_step1.cpp";

const auto reader_kernels_id = CreateReadKernel(program, reader_kernel_file, core_group_1);
const auto writer_kernels_id = CreateWriteKernel(program, writer_kernel_file, core_group_1);
const auto reader_kernels_id = ttnn::operations::CreateReadKernel(program, reader_kernel_file, core_group_1);
const auto writer_kernels_id = ttnn::operations::CreateWriteKernel(program, writer_kernel_file, core_group_1);

////////////////////////////////////////////////////////////////////////////
// ComputeKernel SetUp
Expand All @@ -127,7 +127,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(
"moreh_clip_grad_norm_step1_kernel.cpp";

const auto compute_kernels_id =
CreateComputeKernel(program, compute_kernel_file, {core_group_1, num_inputs_per_core_group_1}, compute_defines);
ttnn::operations::CreateComputeKernel(program, compute_kernel_file, {core_group_1, num_inputs_per_core_group_1}, compute_defines);

////////////////////////////////////////////////////////////////////////////
// RuntimeArgs SetUp
Expand All @@ -146,7 +146,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(
// reader
const std::vector<uint32_t> reader_runtime_args{
input_addr,
static_cast<uint32_t>(is_dram(input)),
static_cast<uint32_t>(ttnn::operations::is_dram(input)),
num_tiles,
*reinterpret_cast<uint32_t*>(&decimal),
origin_h,
Expand All @@ -155,7 +155,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl(

// writer
const std::vector<uint32_t> writer_runtime_args{
output_addr, static_cast<uint32_t>(is_dram(tmp_pow_sum)), tile_offset};
output_addr, static_cast<uint32_t>(ttnn::operations::is_dram(tmp_pow_sum)), tile_offset};
SetRuntimeArgs(program, writer_kernels_id, core, writer_runtime_args);

// compute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl(

const auto cb_data_format = tt_metal::datatype_to_dataformat_converter(total_norm.get_dtype());

CreateCircularBuffer(
ttnn::operations::CreateCircularBuffer(
program,
single_core,
cb_data_format,
Expand All @@ -82,8 +82,8 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl(
"ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/kernels/"
"writer_moreh_clip_grad_norm_step2.cpp";

const auto reader_kernels_id = CreateReadKernel(program, reader_kernel_file, single_core);
const auto writer_kernels_id = CreateWriteKernel(program, writer_kernel_file, single_core);
const auto reader_kernels_id = ttnn::operations::CreateReadKernel(program, reader_kernel_file, single_core);
const auto writer_kernels_id = ttnn::operations::CreateWriteKernel(program, writer_kernel_file, single_core);

////////////////////////////////////////////////////////////////////////////
// ComputeKernel SetUp
Expand All @@ -92,7 +92,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl(
"ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/kernels/"
"moreh_clip_grad_norm_step2_kernel.cpp";

const auto compute_kernels_id = CreateComputeKernel(program, compute_kernel_file, {single_core, num_tiles});
const auto compute_kernels_id = ttnn::operations::CreateComputeKernel(program, compute_kernel_file, {single_core, num_tiles});

////////////////////////////////////////////////////////////////////////////
// RuntimeArgs SetUp
Expand All @@ -102,11 +102,11 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl(

// reader
const std::vector<uint32_t> reader_runtime_args{
input_addr, static_cast<uint32_t>(is_dram(tmp_pow_sum)), num_tiles, *reinterpret_cast<uint32_t*>(&decimal)};
input_addr, static_cast<uint32_t>(ttnn::operations::is_dram(tmp_pow_sum)), num_tiles, *reinterpret_cast<uint32_t*>(&decimal)};
SetRuntimeArgs(program, reader_kernels_id, single_core, reader_runtime_args);

// writer
const std::vector<uint32_t> writer_runtime_args{output_addr, static_cast<uint32_t>(is_dram(total_norm))};
const std::vector<uint32_t> writer_runtime_args{output_addr, static_cast<uint32_t>(ttnn::operations::is_dram(total_norm))};
SetRuntimeArgs(program, writer_kernels_id, single_core, writer_runtime_args);

// compute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl(

const auto cb_data_format = tt_metal::datatype_to_dataformat_converter(inputs.at(0).get_dtype());

CreateCircularBuffer(
ttnn::operations::CreateCircularBuffer(
program,
core_group_1,
cb_data_format,
Expand All @@ -82,8 +82,8 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl(
"ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/kernels/"
"writer_moreh_clip_grad_norm_step3.cpp";

const auto reader_kernels_id = CreateReadKernel(program, reader_kernel_file, core_group_1);
const auto writer_kernels_id = CreateWriteKernel(program, writer_kernel_file, core_group_1);
const auto reader_kernels_id = ttnn::operations::CreateReadKernel(program, reader_kernel_file, core_group_1);
const auto writer_kernels_id = ttnn::operations::CreateWriteKernel(program, writer_kernel_file, core_group_1);

////////////////////////////////////////////////////////////////////////////
// ComputeKernel SetUp
Expand All @@ -93,7 +93,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl(
"moreh_clip_grad_norm_step3_kernel.cpp";

const auto compute_kernels_id =
CreateComputeKernel(program, compute_kernel_file, {core_group_1, num_inputs_per_core_group_1});
ttnn::operations::CreateComputeKernel(program, compute_kernel_file, {core_group_1, num_inputs_per_core_group_1});

////////////////////////////////////////////////////////////////////////////
// RuntimeArgs SetUp
Expand All @@ -109,14 +109,14 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl(
// reader
const std::vector<uint32_t> reader_runtime_args{
input_addr,
static_cast<uint32_t>(is_dram(input)),
static_cast<uint32_t>(ttnn::operations::is_dram(input)),
clip_coef_clamped_addr,
static_cast<uint32_t>(is_dram(clip_coef_clamped)),
static_cast<uint32_t>(ttnn::operations::is_dram(clip_coef_clamped)),
num_tiles};
SetRuntimeArgs(program, reader_kernels_id, core, reader_runtime_args);

// writer
const std::vector<uint32_t> writer_runtime_args{input_addr, static_cast<uint32_t>(is_dram(input)), num_tiles};
const std::vector<uint32_t> writer_runtime_args{input_addr, static_cast<uint32_t>(ttnn::operations::is_dram(input)), num_tiles};
SetRuntimeArgs(program, writer_kernels_id, core, writer_runtime_args);

// compute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ Tensor _fast_reduce_nc(

void FastReduceNCDeviceOperation::validate_with_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const {
using namespace ttnn::operations;

const auto& input = input_tensors.at(0);
auto& output = output_tensors.at(0);

// validate tensor
tt::operations::primary::check_tensor(input, "FastReduceNC", "input", {DataType::BFLOAT16, DataType::BFLOAT8_B});
tt::operations::primary::check_tensor(output, "FastReduceNC", "output", {DataType::BFLOAT16, DataType::BFLOAT8_B});
check_tensor(input, "FastReduceNC", "input", {DataType::BFLOAT16, DataType::BFLOAT8_B});
check_tensor(output, "FastReduceNC", "output", {DataType::BFLOAT16, DataType::BFLOAT8_B});

// validate input dim
const auto input_rank = input.get_logical_shape().rank();
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/operations/full/device/full_program_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ FullOperation::ProgramFactory::cached_program_t FullOperation::ProgramFactory::c

// Create circular buffer
auto cb_index = tt::CB::c_intermed0;
tt::operations::primary::CreateCircularBuffer(
CreateCircularBuffer(
program,
all_cores,
data_format,
Expand All @@ -57,7 +57,7 @@ FullOperation::ProgramFactory::cached_program_t FullOperation::ProgramFactory::c
default: break;
}

auto writer_id = tt::operations::primary::CreateWriteKernel(
auto writer_id = CreateWriteKernel(
program,
"ttnn/cpp/ttnn/operations/full/device/kernels/writer_full.cpp",
all_cores,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,31 @@ void MorehAdamOperation::validate_inputs(
auto& exp_avg_in = tensor_args.exp_avg_in;
auto& exp_avg_sq_in = tensor_args.exp_avg_sq_in;

tt::operations::primary::check_tensor(params_in, "moreh_adam", "params_in");
tt::operations::primary::check_tensor(grad, "moreh_adam", "grad");
tt::operations::primary::check_tensor(exp_avg_in, "moreh_adam", "exp_avg_in");
tt::operations::primary::check_tensor(exp_avg_sq_in, "moreh_adam", "exp_avg_sq_in");
check_tensor(params_in, "moreh_adam", "params_in");
check_tensor(grad, "moreh_adam", "grad");
check_tensor(exp_avg_in, "moreh_adam", "exp_avg_in");
check_tensor(exp_avg_sq_in, "moreh_adam", "exp_avg_sq_in");

if (tensor_args.max_exp_avg_sq_in) {
tt::operations::primary::check_tensor(*tensor_args.max_exp_avg_sq_in, "moreh_adam", "max_exp_avg_sq_in");
check_tensor(*tensor_args.max_exp_avg_sq_in, "moreh_adam", "max_exp_avg_sq_in");
}

const auto& params_out = tensor_args.output_tensors.at(0);

if (params_out.has_value()) {
tt::operations::primary::check_tensor(params_out.value(), "moreh_adam", "params_out");
check_tensor(params_out.value(), "moreh_adam", "params_out");
}

if (tensor_args.output_tensors.at(1).has_value()) {
tt::operations::primary::check_tensor(tensor_args.output_tensors.at(1).value(), "moreh_adam", "exp_avg_out");
check_tensor(tensor_args.output_tensors.at(1).value(), "moreh_adam", "exp_avg_out");
}

if (tensor_args.output_tensors.at(2).has_value()) {
tt::operations::primary::check_tensor(tensor_args.output_tensors.at(2).value(), "moreh_adam", "exp_avg_sq_out");
check_tensor(tensor_args.output_tensors.at(2).value(), "moreh_adam", "exp_avg_sq_out");
}

if (tensor_args.output_tensors.at(3).has_value()) {
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.output_tensors.at(3).value(), "moreh_adam", "max_exp_avg_sq_out");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program
////////////////////////////////////////////////////////////////////////////
auto data_format = tt::tt_metal::datatype_to_dataformat_converter(param_in.get_dtype());
auto intermed_cb_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format;
tt::operations::primary::CreateCircularBuffer(
CreateCircularBuffer(
program,
all_cores,
data_format,
Expand Down Expand Up @@ -94,17 +94,17 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program
////////////////////////////////////////////////////////////////////////////

const std::vector<uint32_t> reader_compile_time_args{
static_cast<uint32_t>(tt::operations::primary::is_dram(param_in)),
static_cast<uint32_t>(tt::operations::primary::is_dram(grad)),
static_cast<uint32_t>(tt::operations::primary::is_dram(exp_avg_in)),
static_cast<uint32_t>(tt::operations::primary::is_dram(exp_avg_sq_in)),
static_cast<uint32_t>(tt::operations::primary::is_dram(max_exp_avg_sq_in))};
static_cast<uint32_t>(is_dram(param_in)),
static_cast<uint32_t>(is_dram(grad)),
static_cast<uint32_t>(is_dram(exp_avg_in)),
static_cast<uint32_t>(is_dram(exp_avg_sq_in)),
static_cast<uint32_t>(is_dram(max_exp_avg_sq_in))};

const std::vector<uint32_t> writer_compile_time_args{
static_cast<uint32_t>(tt::operations::primary::is_dram(param_out)),
static_cast<uint32_t>(tt::operations::primary::is_dram(exp_avg_out)),
static_cast<uint32_t>(tt::operations::primary::is_dram(exp_avg_sq_out)),
static_cast<uint32_t>(tt::operations::primary::is_dram(max_exp_avg_sq_out.value()))};
static_cast<uint32_t>(is_dram(param_out)),
static_cast<uint32_t>(is_dram(exp_avg_out)),
static_cast<uint32_t>(is_dram(exp_avg_sq_out)),
static_cast<uint32_t>(is_dram(max_exp_avg_sq_out.value()))};

const auto reader_kernel_file =
"ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/kernels/"
Expand All @@ -120,9 +120,9 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program
if (fp32_dest_acc_en) {
data_movement_defines["FP32_DEST_ACC_EN"] = "1";
}
const auto reader_kernel_id = tt::operations::primary::CreateReadKernel(
const auto reader_kernel_id = CreateReadKernel(
program, reader_kernel_file, all_cores, reader_compile_time_args, data_movement_defines);
const auto writer_kernel_id = tt::operations::primary::CreateWriteKernel(
const auto writer_kernel_id = CreateWriteKernel(
program, writer_kernel_file, all_cores, writer_compile_time_args, data_movement_defines);

////////////////////////////////////////////////////////////////////////////
Expand All @@ -143,7 +143,7 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program
"ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/kernels/"
"moreh_adam.cpp";

auto compute_kernel_1_id = tt ::operations::primary::CreateComputeKernel(
auto compute_kernel_1_id = CreateComputeKernel(
program,
compute_kernel_file,
{core_group_1, num_tiles_per_core_group_1, compute_args_group_1},
Expand All @@ -155,7 +155,7 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program
if (!core_group_2.ranges().empty()) {
const std::vector<uint32_t> compute_args_group_2{num_tiles_per_core_group_2};

compute_kernel_2_id = tt::operations::primary::CreateComputeKernel(
compute_kernel_2_id = CreateComputeKernel(
program,
compute_kernel_file,
{core_group_2, num_tiles_per_core_group_2, compute_args_group_2},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,27 @@ MorehAdamWDeviceOperation::program_factory_t MorehAdamWDeviceOperation::select_p

void MorehAdamWDeviceOperation::validate_inputs(
const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {
tt::operations::primary::check_tensor(tensor_args.param_in, "moreh_adamw", "param_in");
tt::operations::primary::check_tensor(tensor_args.grad, "moreh_adamw", "grad");
tt::operations::primary::check_tensor(tensor_args.exp_avg_in, "moreh_adamw", "exp_avg_in");
tt::operations::primary::check_tensor(tensor_args.exp_avg_sq_in, "moreh_adamw", "exp_avg_sq_in");
check_tensor(tensor_args.param_in, "moreh_adamw", "param_in");
check_tensor(tensor_args.grad, "moreh_adamw", "grad");
check_tensor(tensor_args.exp_avg_in, "moreh_adamw", "exp_avg_in");
check_tensor(tensor_args.exp_avg_sq_in, "moreh_adamw", "exp_avg_sq_in");

if (tensor_args.max_exp_avg_sq_in.has_value()) {
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.max_exp_avg_sq_in.value(), "moreh_adamw", "max_exp_avg_sq_in");
}

if (tensor_args.param_out.has_value()) {
tt::operations::primary::check_tensor(tensor_args.param_out.value(), "moreh_adamw", "param_out");
check_tensor(tensor_args.param_out.value(), "moreh_adamw", "param_out");
}
if (tensor_args.exp_avg_out.has_value()) {
tt::operations::primary::check_tensor(tensor_args.exp_avg_out.value(), "moreh_adamw", "exp_avg_out");
check_tensor(tensor_args.exp_avg_out.value(), "moreh_adamw", "exp_avg_out");
}
if (tensor_args.exp_avg_sq_out.has_value()) {
tt::operations::primary::check_tensor(tensor_args.exp_avg_sq_out.value(), "moreh_adamw", "exp_avg_sq_out");
check_tensor(tensor_args.exp_avg_sq_out.value(), "moreh_adamw", "exp_avg_sq_out");
}
if (tensor_args.max_exp_avg_sq_out.has_value()) {
tt::operations::primary::check_tensor(
check_tensor(
tensor_args.max_exp_avg_sq_out.value(), "moreh_adamw", "max_exp_avg_sq_out");
}
}
Expand Down
Loading

0 comments on commit dd80a01

Please sign in to comment.