Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CANN] Support cpu offload optimizer for Ascend NPU #4568

Merged
merged 4 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions csrc/adagrad/cpu_adagrad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ void Adagrad_Optimizer::Step_1(float* _params,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
Expand All @@ -62,7 +64,7 @@ void Adagrad_Optimizer::Step_1(float* _params,
grad += _eps;
grad = momentum / grad;
param = grad * step_size + param;
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
#endif
if (half_precision)
Expand All @@ -79,6 +81,17 @@ void Adagrad_Optimizer::Step_1(float* _params,
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
}
Expand Down Expand Up @@ -180,7 +193,7 @@ int ds_adagrad_step(int optimizer_id,
opt->update_state(lr, epsilon, weight_decay);
opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel());

#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
#endif
return 0;
Expand All @@ -196,7 +209,7 @@ int ds_adagrad_step_plus_copy(int optimizer_id,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
{
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
Expand Down
27 changes: 20 additions & 7 deletions csrc/adam/cpu_adam_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ void Adam_Optimizer::Step_1(float* _params,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
Expand All @@ -81,7 +83,7 @@ void Adam_Optimizer::Step_1(float* _params,
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
#endif
if (half_precision)
Expand All @@ -96,6 +98,17 @@ void Adam_Optimizer::Step_1(float* _params,
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);

_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
Expand Down Expand Up @@ -239,7 +252,7 @@ int ds_adam_step(int optimizer_id,
nullptr,
(params.options().dtype() == at::kHalf));

#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
#endif
return 0;
Expand All @@ -257,18 +270,18 @@ int ds_adam_step_plus_copy(int optimizer_id,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
torch::Tensor& device_params)
{
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto device_params_c = device_params.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
auto grads_c = grads.contiguous();

float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
ds_half_precision_t* device_params_ptr = (ds_half_precision_t*)device_params_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();

Expand All @@ -281,7 +294,7 @@ int ds_adam_step_plus_copy(int optimizer_id,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
gpu_params_ptr,
device_params_ptr,
(params.options().dtype() == at::kHalf));

opt->SynchronizeStreams();
Expand Down
37 changes: 36 additions & 1 deletion csrc/includes/cpu_adagrad.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
#include "cuda.h"
#include "custom_cuda_layers.h"
typedef __half ds_half_precision_t;
#elif defined(__ENABLE_CANN__)
#include "acl/acl.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
typedef c10::Half ds_half_precision_t;
#else
typedef unsigned short ds_half_precision_t;
#endif
Expand All @@ -41,6 +45,11 @@ class Adagrad_Optimizer {

_streams[0] = TrainingContext::Instance().GetCurrentStream();
_streams[1] = TrainingContext::Instance().GetNewStream();
_buf_index = false;
#elif defined(__ENABLE_CANN__)
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));

_buf_index = false;
#endif
}
Expand All @@ -49,6 +58,9 @@ class Adagrad_Optimizer {
#if defined(__ENABLE_CUDA__)
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
#elif defined(__ENABLE_CANN__)
aclrtFreeHost(_doubled_buffer[0]);
aclrtFreeHost(_doubled_buffer[1]);
#endif
}
#if defined(__AVX512__) or defined(__AVX256__)
Expand All @@ -69,6 +81,11 @@ class Adagrad_Optimizer {
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
}
#elif defined(__ENABLE_CANN__)
inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
}
#endif
inline void IncrementStep(size_t step)
{
Expand All @@ -95,6 +112,11 @@ class Adagrad_Optimizer {
bool _buf_index;
float* _doubled_buffer[2];
cudaStream_t _streams[2];
#elif defined(__ENABLE_CANN__)
float* _doubled_buffer[2];
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
c10_npu::getNPUStreamFromPool()};
bool _buf_index;
#endif
};

Expand Down Expand Up @@ -125,6 +147,8 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
Expand All @@ -149,7 +173,7 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
simd_fma<span>(param_4, grad_4, step_size_4, param_4);

simd_store<span>(_params + i, param_4, half_precision);
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) {
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
}
Expand All @@ -167,6 +191,17 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,

_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
#endif
}
*rounded_size = new_rounded_size;
Expand Down
38 changes: 37 additions & 1 deletion csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#include "cuda.h"
#include "custom_cuda_layers.h"
typedef __half ds_half_precision_t;
#elif defined(__ENABLE_CANN__)
#include "acl/acl.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
typedef c10::Half ds_half_precision_t;
#else
#include <cmath>
typedef unsigned short ds_half_precision_t;
Expand Down Expand Up @@ -57,6 +61,11 @@ class Adam_Optimizer {

_streams[0] = TrainingContext::Instance().GetCurrentStream();
_streams[1] = TrainingContext::Instance().GetNewStream();
_buf_index = false;
#elif defined(__ENABLE_CANN__)
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));

_buf_index = false;
#endif
}
Expand All @@ -65,6 +74,9 @@ class Adam_Optimizer {
#if defined(__ENABLE_CUDA__)
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
#elif defined(__ENABLE_CANN__)
aclrtFreeHost(_doubled_buffer[0]);
aclrtFreeHost(_doubled_buffer[1]);
#endif
}

Expand All @@ -87,6 +99,11 @@ class Adam_Optimizer {
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
}
#elif defined(__ENABLE_CANN__)
inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
}
#endif
inline void IncrementStep(size_t step, float beta1, float beta2)
{
Expand Down Expand Up @@ -142,6 +159,11 @@ class Adam_Optimizer {
float* _doubled_buffer[2];
cudaStream_t _streams[2];
bool _buf_index;
#elif defined(__ENABLE_CANN__)
float* _doubled_buffer[2];
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
c10_npu::getNPUStreamFromPool()};
bool _buf_index;
#endif
};

Expand Down Expand Up @@ -192,6 +214,9 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream((_streams[_buf_index].stream());
}
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
Expand Down Expand Up @@ -227,7 +252,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
simd_fma<span>(param_4, grad_4, step_size_4, param_4);

simd_store<span>(_params + (i >> rshft), param_4, half_precision);
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) {
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
}
Expand All @@ -246,6 +271,17 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,

_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
#endif
}
*rounded_size = new_rounded_size;
Expand Down
Loading