Skip to content

Commit

Permalink
#13317: revise moreh_adam (#13318)
Browse files Browse the repository at this point in the history
  • Loading branch information
o2buzzle authored Oct 9, 2024
1 parent 399c744 commit e9cf074
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 6 deletions.
23 changes: 22 additions & 1 deletion tests/ttnn/unit_tests/operations/test_moreh_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import ttnn
import pytest
from models.utility_functions import is_wormhole_b0, comp_allclose_and_pcc, comp_pcc, is_wormhole_b0
from models.utility_functions import (
comp_allclose_and_pcc,
)
from loguru import logger
from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import (
get_compute_kernel_options,
Expand Down Expand Up @@ -141,3 +143,22 @@ def forward(self, x):
logger.debug(f"Out passing (max_exp_avg_sq)={passing}")
logger.debug(f"Output pcc={out}")
assert passing


@pytest.mark.parametrize(
"params",
(
# shape, lr, betas, eps, weight_decay, amsgrad, fp32_dest_acc_en
([32, 32], 0.0, (0.9, 0.999), 1e-06, 0.0, True, True),
([2, 2, 2, 2, 2, 2, 64, 64], 0.0, (0.9, 0.999), 1e-06, 0.0, False, False),
),
)
def test_moreh_adam_enable_cache(params, device, use_program_cache):
for i in range(4):
shape, lr, betas, eps, weight_decay, amsgrad, fp32_dest_acc_en = params
if i % 2 == 1:
amsgrad = not amsgrad

test_moreh_adam(shape, lr, betas, eps, weight_decay, amsgrad, fp32_dest_acc_en, device)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ std::tuple<MorehAdamOperation::operation_attributes_t, MorehAdamOperation::tenso
step.value_or(0),
amsgrad.value_or(false),
memory_config.value_or(param_in.memory_config()),
compute_kernel_config},
init_device_compute_kernel_config(param_in.device()->arch(), compute_kernel_config, MathFidelity::HiFi4),
},
tensor_args_t{
param_in,
grad,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct MorehAdamOperation {
bool amsgrad = false;

const MemoryConfig output_mem_config;
const std::optional<DeviceComputeKernelConfig> compute_kernel_config;
const DeviceComputeKernelConfig compute_kernel_config;
};

struct tensor_args_t {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program
auto step = operation_attributes.step;
auto amsgrad = operation_attributes.amsgrad;

auto compute_kernel_config =
init_device_compute_kernel_config(param_in.device()->arch(), operation_attributes.compute_kernel_config);
auto compute_kernel_config = operation_attributes.compute_kernel_config;

uint32_t num_tiles = param_in.volume() / tt::constants::TILE_HW;

Expand Down
41 changes: 41 additions & 0 deletions ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "moreh_adam.hpp"

#include "ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.hpp"
#include "ttnn/run_operation.hpp"

namespace ttnn::operations::moreh::moreh_adam {
std::vector<std::optional<Tensor>> MorehAdam::invoke(
Expand Down Expand Up @@ -46,4 +47,44 @@ std::vector<std::optional<Tensor>> MorehAdam::invoke(
memory_config,
compute_kernel_config);
}

std::vector<Tensor> MorehAdam::create_async_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs) {
const auto& param_in = input_tensors.at(0);
const auto& grad = input_tensors.at(1);
const auto& exp_avg_in = input_tensors.at(2);
const auto& exp_avg_sq_in = input_tensors.at(3);

const auto& max_exp_avg_sq_in = optional_inputs.at(0);

return {
Tensor(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})),
Tensor(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})),
Tensor(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})),
Tensor(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})),
};
}

std::vector<bool> MorehAdam::create_async_return_flag(
const Tensor& param_in,
const Tensor& grad,
const Tensor& exp_avg_in,
const Tensor& exp_avg_sq_in,
const std::optional<float> lr,
const std::optional<float> beta1,
const std::optional<float> beta2,
const std::optional<float> eps,
const std::optional<float> weight_decay,
const std::optional<uint32_t> step,
const std::optional<bool> amsgrad,
const std::optional<const Tensor> max_exp_avg_sq_in,
const std::optional<const Tensor> param_out,
const std::optional<const Tensor> exp_avg_out,
const std::optional<const Tensor> exp_avg_sq_out,
const std::optional<const Tensor> max_exp_avg_sq_out,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config) {
// First three are always true, last one depends on amsgrad
return {true, true, true, amsgrad.value_or(false)};
}
} // namespace ttnn::operations::moreh::moreh_adam
25 changes: 24 additions & 1 deletion ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,33 @@ struct MorehAdam {
const std::optional<const Tensor> max_exp_avg_sq_out,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config);

static std::vector<Tensor> create_async_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs);

static std::vector<bool> create_async_return_flag(
const Tensor& param_in,
const Tensor& grad,
const Tensor& exp_avg_in,
const Tensor& exp_avg_sq_in,
const std::optional<float> lr,
const std::optional<float> beta1,
const std::optional<float> beta2,
const std::optional<float> eps,
const std::optional<float> weight_decay,
const std::optional<uint32_t> step,
const std::optional<bool> amsgrad,
const std::optional<const Tensor> max_exp_avg_sq_in,
const std::optional<const Tensor> param_out,
const std::optional<const Tensor> exp_avg_out,
const std::optional<const Tensor> exp_avg_sq_out,
const std::optional<const Tensor> max_exp_avg_sq_out,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config);
};
} // namespace ttnn::operations::moreh::moreh_adam

namespace ttnn {
constexpr auto moreh_adam =
ttnn::register_operation<"ttnn::moreh_adam", ttnn::operations::moreh::moreh_adam::MorehAdam>();
ttnn::register_operation_with_auto_launch_op<"ttnn::moreh_adam", ttnn::operations::moreh::moreh_adam::MorehAdam>();
}

0 comments on commit e9cf074

Please sign in to comment.