Skip to content

Commit

Permalink
#16415: fix moreh_adam
Browse files Browse the repository at this point in the history
  • Loading branch information
hschoi4448 committed Jan 14, 2025
1 parent 4cfb561 commit efd7356
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_moreh_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(self, x):
dev_param_out = create_tt_tensor(model.weight, device, dtype=dtype)
dev_exp_avg_out = create_tt_tensor(cpu_exp_avg, device, dtype=dtype)
dev_exp_avg_sq_out = create_tt_tensor(cpu_exp_avg_sq, device, dtype=dtype)
dev_max_exp_avg_sq_out = create_tt_tensor(cpu_max_exp_avg_sq, device, dtype=dtype)
dev_max_exp_avg_sq_out = create_tt_tensor(cpu_max_exp_avg_sq, device, dtype=dtype) if amsgrad else None

criterion = nn.L1Loss()
optimizer = optim.Adam({model.weight}, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ MorehAdamOperation::tensor_return_value_t MorehAdamOperation::create_output_tens
ret.push_back(tensor_args.output_tensors.at(idx).value());
} else if (output_specs[idx].has_value()) {
ret.push_back(create_device_tensor(*output_specs[idx], device));
} else {
ret.push_back(std::nullopt);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program
static_cast<uint32_t>(is_dram(param_out)),
static_cast<uint32_t>(is_dram(exp_avg_out)),
static_cast<uint32_t>(is_dram(exp_avg_sq_out)),
static_cast<uint32_t>(is_dram(max_exp_avg_sq_out.value()))};
static_cast<uint32_t>(max_exp_avg_sq_out.has_value() ? is_dram(max_exp_avg_sq_out.value()) : false)};

const auto reader_kernel_file =
"ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/kernels/"
Expand Down Expand Up @@ -272,7 +272,7 @@ void MorehAdamOperation::ProgramFactory::override_runtime_arguments(
auto param_out_buffer = tensor_return_value.at(0)->buffer();
auto exp_avg_out_buffer = tensor_return_value.at(1)->buffer();
auto exp_avg_sq_out_buffer = tensor_return_value.at(2)->buffer();
auto max_exp_avg_sq_out_buffer = tensor_return_value.at(3)->buffer();
auto max_exp_avg_sq_out_buffer = operation_attributes.amsgrad ? tensor_return_value.at(3)->buffer() : nullptr;

auto& core_group_1 = cached_program.shared_variables.core_group_1;
auto& core_group_2 = cached_program.shared_variables.core_group_2;
Expand Down

0 comments on commit efd7356

Please sign in to comment.