diff --git a/csrc/adagrad/cpu_adagrad.cpp b/csrc/adagrad/cpu_adagrad.cpp index 8eebe00349bef..2aa2106107e43 100644 --- a/csrc/adagrad/cpu_adagrad.cpp +++ b/csrc/adagrad/cpu_adagrad.cpp @@ -47,6 +47,8 @@ void Adagrad_Optimizer::Step_1(float* _params, size_t offset = copy_size + t; #if defined(__ENABLE_CUDA__) if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#elif defined(__ENABLE_CANN__) + if ((t / TILE) >= 2) { c10_npu::acl::AclrtSynchronizeStreamWithTimeout(_streams[_buf_index]); } #endif #pragma omp parallel for for (size_t k = t; k < offset; k++) { @@ -62,7 +64,7 @@ void Adagrad_Optimizer::Step_1(float* _params, grad += _eps; grad = momentum / grad; param = grad * step_size + param; -#if defined(__ENABLE_CUDA__) +#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) if (dev_params) _doubled_buffer[_buf_index][k - t] = param; #endif if (half_precision) @@ -79,6 +81,14 @@ void Adagrad_Optimizer::Step_1(float* _params, _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); _buf_index = !_buf_index; } +#elif defined(__ENABLE_CANN__) + if (dev_params) { + size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); + aclrtMemcpy(dev_params + t, memcpy_size, _doubled_buffer[_buf_index], memcpy_size, + aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); + + _buf_index = !_buf_index; + } #endif } } @@ -180,7 +190,7 @@ int ds_adagrad_step(int optimizer_id, opt->update_state(lr, epsilon, weight_decay); opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel()); -#if defined(__ENABLE_CUDA__) +#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) opt->SynchronizeStreams(); #endif return 0; @@ -196,7 +206,7 @@ int ds_adagrad_step_plus_copy(int optimizer_id, torch::Tensor& exp_avg_sq, torch::Tensor& gpu_params) { -#if defined(__ENABLE_CUDA__) +#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) auto params_c = params.contiguous(); auto gpu_params_c = gpu_params.contiguous(); auto exp_avg_sq_c = exp_avg_sq.contiguous(); diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp index 742cb42927773..de4838adbb5b6 100644 --- a/csrc/adam/cpu_adam_impl.cpp +++ b/csrc/adam/cpu_adam_impl.cpp @@ -61,6 +61,8 @@ void Adam_Optimizer::Step_1(float* _params, size_t offset = copy_size + t; #if defined(__ENABLE_CUDA__) if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#elif defined(__ENABLE_CANN__) + if ((t / TILE) >= 2) { c10_npu::acl::AclrtSynchronizeStreamWithTimeout(_streams[_buf_index]); } #endif #pragma omp parallel for for (size_t k = t; k < offset; k++) { @@ -81,7 +83,7 @@ void Adam_Optimizer::Step_1(float* _params, grad = momentum / grad; if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } param = grad * step_size + param; -#if defined(__ENABLE_CUDA__) +#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) if (dev_params) _doubled_buffer[_buf_index][k - t] = param; #endif if (half_precision) @@ -96,6 +98,14 @@ void Adam_Optimizer::Step_1(float* _params, launch_param_update( _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); + _buf_index = !_buf_index; + } +#elif defined(__ENABLE_CANN__) + if (dev_params) { + size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); + aclrtMemcpy(dev_params + t, memcpy_size, _doubled_buffer[_buf_index], memcpy_size, + aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); + _buf_index = !_buf_index; } #endif @@ -239,7 +249,7 @@ int ds_adam_step(int optimizer_id, nullptr, (params.options().dtype() == at::kHalf)); -#if defined(__ENABLE_CUDA__) +#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) opt->SynchronizeStreams(); #endif return 0; @@ -257,18 +267,18 @@ int ds_adam_step_plus_copy(int optimizer_id, torch::Tensor& grads, torch::Tensor& exp_avg, torch::Tensor& exp_avg_sq, - torch::Tensor& gpu_params) + torch::Tensor& device_params) { -#if defined(__ENABLE_CUDA__) +#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) auto params_c = params.contiguous(); - auto gpu_params_c = gpu_params.contiguous(); + auto device_params_c = device_params.contiguous(); auto exp_avg_c = exp_avg.contiguous(); auto exp_avg_sq_c = exp_avg_sq.contiguous(); auto grads_c = grads.contiguous(); float* params_ptr = (float*)params_c.data_ptr(); float* grads_ptr = (float*)grads_c.data_ptr(); - ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr(); + ds_half_precision_t* device_params_ptr = (ds_half_precision_t*)device_params_c.data_ptr(); float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); @@ -281,7 +291,7 @@ int ds_adam_step_plus_copy(int optimizer_id, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), - gpu_params_ptr, + device_params_ptr, (params.options().dtype() == at::kHalf)); opt->SynchronizeStreams(); diff --git a/csrc/includes/cpu_adagrad.h b/csrc/includes/cpu_adagrad.h index ba40fcf7b62a7..f1df62c695475 100644 --- a/csrc/includes/cpu_adagrad.h +++ b/csrc/includes/cpu_adagrad.h @@ -18,6 +18,10 @@ #include "cuda.h" #include "custom_cuda_layers.h" typedef __half ds_half_precision_t; +#elif defined(__ENABLE_CANN__) +#include "acl/acl.h" +#include "torch_npu/csrc/core/npu/NPUStream.h" +typedef c10::Half ds_half_precision_t; #else typedef unsigned short ds_half_precision_t; #endif @@ -41,6 +45,11 @@ class Adagrad_Optimizer { _streams[0] = TrainingContext::Instance().GetCurrentStream(); _streams[1] = TrainingContext::Instance().GetNewStream(); + _buf_index = false; +#elif defined(__ENABLE_CANN__) + aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); + aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); + _buf_index = false; #endif } @@ -49,6 +58,9 @@ class Adagrad_Optimizer { #if defined(__ENABLE_CUDA__) cudaFreeHost(_doubled_buffer[0]); cudaFreeHost(_doubled_buffer[1]); +#elif defined(__ENABLE_CANN__) + aclrtFreeHost(_doubled_buffer[0]); + aclrtFreeHost(_doubled_buffer[1]); #endif } #if defined(__AVX512__) or defined(__AVX256__) @@ -69,6 +81,11 @@ class Adagrad_Optimizer { { for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); } +#elif defined(__ENABLE_CANN__) + inline void SynchronizeStreams() + { + for (int i = 0; i < 2; i++) c10_npu::acl::AclrtSynchronizeStreamWithTimeout(_streams[i]); + } #endif inline void IncrementStep(size_t step) { @@ -95,6 +112,10 @@ class Adagrad_Optimizer { bool _buf_index; float* _doubled_buffer[2]; cudaStream_t _streams[2]; +#elif defined(__ENABLE_CANN__) + float* _doubled_buffer[2]; + c10_npu::NPUStream _streams[2] = { c10_npu::getCurrentNPUStream(), c10_npu::getNPUStreamFromPool() }; + bool _buf_index; #endif }; @@ -125,6 +146,8 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, size_t offset = copy_size + t; #if defined(__ENABLE_CUDA__) if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#elif defined(__ENABLE_CANN__) + if ((t / TILE) >= 2) { c10_npu::acl::AclrtSynchronizeStreamWithTimeout(_streams[_buf_index]); } #endif #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { @@ -149,7 +172,7 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, simd_fma(param_4, grad_4, step_size_4, param_4); simd_store(_params + i, param_4, half_precision); -#if defined(__ENABLE_CUDA__) +#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) if (dev_params) { simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); } @@ -167,6 +190,15 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, _buf_index = !_buf_index; } +#elif defined(__ENABLE_CANN__) + if (dev_params) { + size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); + if (half_precision) + memoryCopySize /= 2; + aclrtMemcpy(dev_params + t, memcpy_size, _doubled_buffer[_buf_index], memcpy_size, + aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); + + _buf_index = !_buf_index; #endif } *rounded_size = new_rounded_size; diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index c4f7edcd74102..07e7d84f5dead 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -19,6 +19,10 @@ #include "cuda.h" #include "custom_cuda_layers.h" typedef __half ds_half_precision_t; +#elif defined(__ENABLE_CANN__) +#include "acl/acl.h" +#include "torch_npu/csrc/core/npu/NPUStream.h" +typedef c10::Half ds_half_precision_t; #else #include typedef unsigned short ds_half_precision_t; @@ -57,6 +61,11 @@ class Adam_Optimizer { _streams[0] = TrainingContext::Instance().GetCurrentStream(); _streams[1] = TrainingContext::Instance().GetNewStream(); + _buf_index = false; +#elif defined(__ENABLE_CANN__) + aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); + aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); + _buf_index = false; #endif } @@ -65,6 +74,9 @@ class Adam_Optimizer { #if defined(__ENABLE_CUDA__) cudaFreeHost(_doubled_buffer[0]); cudaFreeHost(_doubled_buffer[1]); +#elif defined(__ENABLE_CANN__) + aclrtFreeHost(_doubled_buffer[0]); + aclrtFreeHost(_doubled_buffer[1]); #endif } @@ -87,6 +99,11 @@ class Adam_Optimizer { { for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); } +#elif defined(__ENABLE_CANN__) + inline void SynchronizeStreams() + { + for (int i = 0; i < 2; i++) c10_npu::acl::AclrtSynchronizeStreamWithTimeout(_streams[i]); + } #endif inline void IncrementStep(size_t step, float beta1, float beta2) { @@ -142,6 +159,10 @@ class Adam_Optimizer { float* _doubled_buffer[2]; cudaStream_t _streams[2]; bool _buf_index; +#elif defined(__ENABLE_CANN__) + float* _doubled_buffer[2]; + c10_npu::NPUStream _streams[2] = { c10_npu::getCurrentNPUStream(), c10_npu::getNPUStreamFromPool() }; + bool _buf_index; #endif }; @@ -192,6 +213,8 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, size_t offset = copy_size + t; #if defined(__ENABLE_CUDA__) if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#elif defined(__ENABLE_CANN__) + if ((t / TILE) >= 2) { c10_npu::acl::AclrtSynchronizeStreamWithTimeout(_streams[_buf_index]); } #endif #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { @@ -227,7 +250,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, simd_fma(param_4, grad_4, step_size_4, param_4); simd_store(_params + (i >> rshft), param_4, half_precision); -#if defined(__ENABLE_CUDA__) +#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) if (dev_params) { simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); } @@ -246,6 +269,15 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, _buf_index = !_buf_index; } +#elif defined(__ENABLE_CANN__) + if (dev_params) { + size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); + if (half_precision) + memoryCopySize /= 2; + aclrtMemcpy(dev_params + t, memcpy_size, _doubled_buffer[_buf_index], memcpy_size, + aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); + + _buf_index = !_buf_index; #endif } *rounded_size = new_rounded_size; diff --git a/csrc/includes/cpu_lion.h b/csrc/includes/cpu_lion.h index 76034ceb34593..b1aa55f4fc413 100644 --- a/csrc/includes/cpu_lion.h +++ b/csrc/includes/cpu_lion.h @@ -19,6 +19,10 @@ #include "cuda.h" #include "custom_cuda_layers.h" typedef __half ds_half_precision_t; +#elif defined(__ENABLE_CANN__) +#include "acl/acl.h" +#include "torch_npu/csrc/core/npu/NPUStream.h" +typedef c10::Half ds_half_precision_t; #else #include typedef unsigned short ds_half_precision_t; @@ -46,6 +50,11 @@ class Lion_Optimizer { _streams[0] = TrainingContext::Instance().GetCurrentStream(); _streams[1] = TrainingContext::Instance().GetNewStream(); + _buf_index = false; +#elif defined(__ENABLE_CANN__) + aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); + aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); + _buf_index = false; #endif } @@ -54,6 +63,9 @@ class Lion_Optimizer { #if defined(__ENABLE_CUDA__) cudaFreeHost(_doubled_buffer[0]); cudaFreeHost(_doubled_buffer[1]); +#elif defined(__ENABLE_CANN__) + aclrtFreeHost(_doubled_buffer[0]); + aclrtFreeHost(_doubled_buffer[1]); #endif } @@ -75,6 +87,11 @@ class Lion_Optimizer { { for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); } +#elif defined(__ENABLE_CANN__) + inline void SynchronizeStreams() + { + for (int i = 0; i < 2; i++) c10_npu::acl::AclrtSynchronizeStreamWithTimeout(_streams[i]); + } #endif inline void IncrementStep(size_t step, float beta1, float beta2) { @@ -102,6 +119,10 @@ class Lion_Optimizer { float* _doubled_buffer[2]; cudaStream_t _streams[2]; bool _buf_index; +#elif defined(__ENABLE_CANN__) + float* _doubled_buffer[2]; + c10_npu::NPUStream _streams[2] = { c10_npu::getCurrentNPUStream(), c10_npu::getNPUStreamFromPool() }; + bool _buf_index; #endif }; @@ -149,6 +170,8 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size, size_t offset = copy_size + t; #if defined(__ENABLE_CUDA__) if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#elif defined(__ENABLE_CANN__) + if ((t / TILE) >= 2) { c10_npu::acl::AclrtSynchronizeStreamWithTimeout(_streams[_buf_index]); } #endif #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { @@ -178,7 +201,7 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size, simd_fma(momentum_4, grad_4, betta2_minus1_4, momentum_4); simd_store(_params + (i >> rshft), param_4, half_precision); -#if defined(__ENABLE_CUDA__) +#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) if (dev_params) { simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); } @@ -196,6 +219,15 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size, _buf_index = !_buf_index; } +#elif defined(__ENABLE_CANN__) + if (dev_params) { + size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); + if (half_precision) + memoryCopySize /= 2; + aclrtMemcpy(dev_params + t, memcpy_size, _doubled_buffer[_buf_index], memcpy_size, + aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); + + _buf_index = !_buf_index; #endif } *rounded_size = new_rounded_size; diff --git a/csrc/lion/cpu_lion_impl.cpp b/csrc/lion/cpu_lion_impl.cpp index 5c24e23b4b205..fcc2c77d31ba0 100644 --- a/csrc/lion/cpu_lion_impl.cpp +++ b/csrc/lion/cpu_lion_impl.cpp @@ -54,6 +54,8 @@ void Lion_Optimizer::Step_1(float* _params, size_t offset = copy_size + t; #if defined(__ENABLE_CUDA__) if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#elif defined(__ENABLE_CANN__) + if ((t / TILE) >= 2) { c10_npu::acl::AclrtSynchronizeStreamWithTimeout(_streams[_buf_index]); } #endif #pragma omp parallel for for (size_t k = t; k < offset; k++) { @@ -72,7 +74,7 @@ void Lion_Optimizer::Step_1(float* _params, } momentum = momentum * _betta2; momentum = grad * betta2_minus1 + momentum; -#if defined(__ENABLE_CUDA__) +#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) if (dev_params) _doubled_buffer[_buf_index][k - t] = param; #endif if (half_precision) @@ -86,6 +88,14 @@ void Lion_Optimizer::Step_1(float* _params, launch_param_update( _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); + _buf_index = !_buf_index; + } +#elif defined(__ENABLE_CANN__) + if (dev_params) { + size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); + aclrtMemcpy(dev_params + t, memcpy_size, _doubled_buffer[_buf_index], memcpy_size, + aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); + _buf_index = !_buf_index; } #endif @@ -201,7 +211,7 @@ int ds_lion_step(int optimizer_id, nullptr, (params.options().dtype() == at::kHalf)); -#if defined(__ENABLE_CUDA__) +#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) opt->SynchronizeStreams(); #endif return 0; @@ -218,7 +228,7 @@ int ds_lion_step_plus_copy(int optimizer_id, torch::Tensor& exp_avg, torch::Tensor& gpu_params) { -#if defined(__ENABLE_CUDA__) +#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__) auto params_c = params.contiguous(); auto gpu_params_c = gpu_params.contiguous(); auto exp_avg_c = exp_avg.contiguous(); diff --git a/deepspeed/env_report.py b/deepspeed/env_report.py index 2c3a9e701d4d5..ea0204aaa7ae3 100644 --- a/deepspeed/env_report.py +++ b/deepspeed/env_report.py @@ -79,6 +79,30 @@ def nvcc_version(): release = output_split[release_idx + 1].replace(',', '').split(".") return ".".join(release) +def installed_cann_path(): + if "ASCEND_HOME_PATH" in os.environ or os.path.exists(os.environ["ASCEND_HOME_PATH"]): + return os.environ["ASCEND_HOME_PATH"] + return None + +def cann_version(): + import re + ascend_path = installed_cann_path() + if ascend_path is None: + return f"CANN_HOME does not exist, unable to compile NPU op(s)" + cann_version = "" + for dirpath, _, filenames in os.walk(os.path.realpath(ascend_path)): + if cann_version: + break + install_files = [file for file in filenames if re.match(r"ascend_.*_install\.info", file)] + if install_files: + filepath = os.path.join(dirpath, install_files[0]) + with open(filepath, "r") as f: + for line in f: + if line.find("version") != -1: + cann_version = line.strip().split("=")[-1] + break + return cann_version + def get_shm_size(): try: @@ -122,6 +146,12 @@ def debug_report(): ("deepspeed wheel compiled w.", f"torch {torch_info['version']}, " + (f"hip {torch_info['hip_version']}" if hip_version else f"cuda {torch_info['cuda_version']}")) ]) + elif get_accelerator().device_name() == 'npu': + import torch_npu + report.extend([("deepspeed wheel compiled w.", f"torch {torch_info['version']}"), + ("torch_npu install path", torch_npu.__path__), ("torch_npu version", torch_npu.__version__), + ("cann version", cann_version()) + ]) else: report.extend([("deepspeed wheel compiled w.", f"torch {torch_info['version']} ")]) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 8c025a1a2b9f1..2d9e8ca62db61 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -587,8 +587,11 @@ def _configure_moe_settings(self): assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE" def _update_model_bit16_weights(self, group_index): + # Work around due to bug in torch_npu, @see https://gitee.com/ascend/pytorch/pulls/6484 + # Remove me after torch_npu fixed. + unflatten_sizes = [tensor.to(get_accelerator().current_device_name()) for tensor in self.round_robin_bit16_groups[group_index]] updated_params = self.unflatten(self.bit16_groups_flat[group_index], - self.round_robin_bit16_groups[group_index]) + unflatten_sizes) for p, q in zip(self.round_robin_bit16_groups[group_index], updated_params): p.data = q.data @@ -1876,7 +1879,8 @@ def has_overflow_partitioned_grads_serial(self): def has_overflow(self, partition_gradients=True): if partition_gradients: overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial() - overflow_gpu = get_accelerator().ByteTensor([overflow]) + # Work around due to bug in HCCL, revert me after fixed. + overflow_gpu = get_accelerator().IntTensor([overflow]) '''This will capture overflow across all data parallel and expert parallel process Since expert parallel process are a subset of data parallel process''' dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) @@ -1888,7 +1892,7 @@ def has_overflow(self, partition_gradients=True): params.append(param) overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) - overflow_gpu = get_accelerator().ByteTensor([overflow]) + overflow_gpu = get_accelerator().IntTensor([overflow]) # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs diff --git a/op_builder/npu/__init__.py b/op_builder/npu/__init__.py index 0d9e76b3903f1..6ad9124d36f0c 100644 --- a/op_builder/npu/__init__.py +++ b/op_builder/npu/__init__.py @@ -4,6 +4,8 @@ # DeepSpeed Team '''Copyright The Microsoft DeepSpeed Team''' -# NPU related operators will be added in the future. from .fused_adam import FusedAdamBuilder from .no_impl import NotImplementedBuilder +from .cpu_adam import CPUAdamBuilder +from .cpu_adagrad import CPUAdagradBuilder +from .cpu_lion import CPULionBuilder diff --git a/op_builder/npu/builder.py b/op_builder/npu/builder.py index 7773388737a22..76b6901ba5b3d 100644 --- a/op_builder/npu/builder.py +++ b/op_builder/npu/builder.py @@ -3,6 +3,9 @@ # DeepSpeed Team +import re +import os +import torch_npu try: # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed # if successful this also means we're doing a local install and not JIT compile path @@ -13,22 +16,64 @@ class NPUOpBuilder(OpBuilder): + _ascend_path = None + _torch_npu_path = None + _cann_version = None - def builder(self): - from torch.utils.cpp_extension import CppExtension as ExtensionBuilder + def __init__(self, name): + super().__init__(name) + self._ascend_path = self.installed_cann_path() + self._torch_npu_path = os.path.join(os.path.dirname(os.path.abspath(torch_npu.__file__))) + try: + self._cann_version = self.installed_cann_version(self.name) + except BaseException: + print(f"{self.name} cann is missing, npu ops cannot be compiled!") - compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} + def cann_defs(self): + if self._cann_version: + return '-D__ENABLE_CANN__' + return '-D__DISABLE_CANN__' - cpp_ext = ExtensionBuilder(name=self.absolute_name(), - sources=self.strip_empty_entries(self.sources()), - include_dirs=self.strip_empty_entries(self.include_paths()), - libraries=self.strip_empty_entries(self.libraries_args()), - extra_compile_args=compile_args) + def installed_cann_path(self): + if "ASCEND_HOME_PATH" in os.environ or os.path.exists(os.environ["ASCEND_HOME_PATH"]): + return os.environ["ASCEND_HOME_PATH"] + return None - return cpp_ext + def installed_cann_version(self, name=""): + ascend_path = self.installed_cann_path() + assert ascend_path is not None, "CANN_HOME does not exist, unable to compile NPU op(s)" + cann_version = "" + for dirpath, _, filenames in os.walk(os.path.realpath(ascend_path)): + if cann_version: + break + install_files = [file for file in filenames if re.match(r"ascend_.*_install\.info", file)] + if install_files: + filepath = os.path.join(dirpath, install_files[0]) + with open(filepath, "r") as f: + for line in f: + if line.find("version") != -1: + cann_version = line.strip().split("=")[-1] + break + return cann_version + + def include_paths(self): + paths = super().include_paths() + paths += [os.path.join(self._ascend_path, 'include'), + os.path.join(self._torch_npu_path, 'include')] + return paths def cxx_args(self): - return [] + args = super().cxx_args() + args += ['-O3', '-std=c++17', '-g', '-Wno-reorder', '-fopenmp'] + args += ['-fstack-protector-all', '-Wl,-z,relro,-z,now,-z,noexecstack', + '-Wl,--disable-new-dtags,--rpath'] + args += [self.cann_defs(), self.cpu_arch(),self.simd_width(), + '-L' + os.path.join(self._ascend_path, 'lib64'), + '-L' + os.path.join(self._torch_npu_path, 'lib')] + return args - def libraries_args(self): - return [] + def extra_ldflags(self): + flags = super().extra_ldflags() + flags += ['-L' + os.path.join(self._ascend_path, 'lib64'), '-lascendcl', + '-L' + os.path.join(self._torch_npu_path, 'lib'), '-ltorch_npu'] + return flags diff --git a/op_builder/npu/cpu_adagrad.py b/op_builder/npu/cpu_adagrad.py new file mode 100644 index 0000000000000..161bc82efe1ca --- /dev/null +++ b/op_builder/npu/cpu_adagrad.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import NPUOpBuilder + + +class CPUAdagradBuilder(NPUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAGRAD" + NAME = "cpu_adagrad" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adagrad.{self.NAME}_op' + + def sources(self): + return ['csrc/adagrad/cpu_adagrad.cpp'] + + def include_paths(self): + args = super().include_paths() + args += ['csrc/includes'] + return args diff --git a/op_builder/npu/cpu_adam.py b/op_builder/npu/cpu_adam.py new file mode 100644 index 0000000000000..a4e9569c0f336 --- /dev/null +++ b/op_builder/npu/cpu_adam.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import NPUOpBuilder + + +class CPUAdamBuilder(NPUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAM" + NAME = "cpu_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + + def include_paths(self): + args = super().include_paths() + args += ['csrc/includes'] + return args diff --git a/op_builder/npu/cpu_lion.py b/op_builder/npu/cpu_lion.py new file mode 100644 index 0000000000000..6917e0fd03d08 --- /dev/null +++ b/op_builder/npu/cpu_lion.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import NPUOpBuilder + + +class CPULionBuilder(NPUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_LION" + NAME = "cpu_lion" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.lion.{self.NAME}_op' + + def sources(self): + return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp'] + + def include_paths(self): + args = super().include_paths() + args += ['csrc/includes'] + return args diff --git a/op_builder/npu/no_impl.py b/op_builder/npu/no_impl.py index f17973fda401e..5b1771fabc22f 100644 --- a/op_builder/npu/no_impl.py +++ b/op_builder/npu/no_impl.py @@ -22,3 +22,12 @@ def load(self, verbose=True): def sources(self): return [] + + def cxx_args(self): + return [] + + def extra_ldflags(self): + return [] + + def include_paths(self): + return [] diff --git a/tests/unit/common.py b/tests/unit/common.py index 3fb335318fde7..020ce04b22801 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -81,6 +81,9 @@ def set_accelerator_visible(): match = re.search('Device Type.*GPU', line) if match: num_accelerators += 1 + elif get_accelerator().device_name() == 'npu': + npu_smi = subprocess.check_output(['npu-smi', 'info', '-l']) + num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip()) else: assert get_accelerator().device_name() == 'cpu' cpu_sockets = int( @@ -204,13 +207,13 @@ def _dist_run(self, local_rank, num_procs, master_port): if get_accelerator().is_available(): set_accelerator_visible() + if get_accelerator().is_available(): + get_accelerator().set_device(local_rank) + if self.init_distributed: deepspeed.init_distributed(dist_backend=self.backend) dist.barrier() - if get_accelerator().is_available(): - get_accelerator().set_device(local_rank) - try: self.run(**self._fixture_kwargs) except BaseException as e: diff --git a/tests/unit/util.py b/tests/unit/util.py index 536e8b79e1d1a..e037c37bf8ccf 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -14,6 +14,8 @@ def skip_on_arch(min_arch=7): if deepspeed.accelerator.get_accelerator().device_name() == 'cuda': if torch.cuda.get_device_capability()[0] < min_arch: #ignore-cuda pytest.skip(f"needs higher compute capability than {min_arch}") + elif deepspeed.accelerator.get_accelerator().device_name() == 'npu': + return else: assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu' return @@ -26,6 +28,8 @@ def skip_on_cuda(valid_cuda): CUDA_VERSION = (CUDA_MAJOR * 10) + CUDA_MINOR if valid_cuda.count(CUDA_VERSION) == 0: pytest.skip(f"requires cuda versions {valid_cuda}") + elif deepspeed.accelerator.get_accelerator().device_name() == 'npu': + return else: assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu' return