Skip to content

Commit

Permalink
[Feature] Add Ascend NPU accelerator support (#1096)
Browse files Browse the repository at this point in the history
* add npu support

* make precommit
  • Loading branch information
ji-huazhong authored Dec 15, 2023
1 parent 8140129 commit d708ec2
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments

from trl import SFTTrainer
from trl.import_utils import is_xpu_available
from trl.import_utils import is_npu_available, is_xpu_available
from trl.trainer import ConstantLengthDataset


Expand Down Expand Up @@ -175,6 +175,8 @@ def create_datasets(tokenizer, args):
del base_model
if is_xpu_available():
torch.xpu.empty_cache()
elif is_npu_available():
torch.npu.empty_cache()
else:
torch.cuda.empty_cache()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl.import_utils import is_xpu_available
from trl.import_utils import is_npu_available, is_xpu_available


toxicity = evaluate.load("ybelkada/toxicity", "DaNLP/da-electra-hatespeech-detection", module_type="measurement")
Expand Down Expand Up @@ -54,6 +54,8 @@
context_length = args.context_length
if is_xpu_available():
device = torch.xpu.current_device()
elif is_npu_available():
device = torch.npu.current_device()
else:
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -123,6 +125,8 @@
model = None
if is_xpu_available():
torch.xpu.empty_cache()
elif is_npu_available():
torch.npu.empty_cache()
else:
torch.cuda.empty_cache()

Expand Down
9 changes: 7 additions & 2 deletions examples/scripts/ddpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers import CLIPModel, CLIPProcessor

from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
from trl.import_utils import is_xpu_available
from trl.import_utils import is_npu_available, is_xpu_available


@dataclass
Expand Down Expand Up @@ -121,7 +121,12 @@ def aesthetic_scorer(hub_model_id, model_filename):
model_filename=model_filename,
dtype=torch.float32,
)
scorer = scorer.xpu() if is_xpu_available() else scorer.cuda()
if is_npu_available():
scorer = scorer.npu()
elif is_xpu_available():
scorer = scorer.xpu()
else:
scorer = scorer.cuda()

def _fn(images, prompts, metadata):
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
Expand Down
4 changes: 3 additions & 1 deletion examples/scripts/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler
from trl.import_utils import is_xpu_available
from trl.import_utils import is_npu_available, is_xpu_available


tqdm.pandas()
Expand Down Expand Up @@ -157,6 +157,8 @@ def collator(data):
if ppo_trainer.accelerator.num_processes == 1:
if is_xpu_available():
device = "xpu:0"
elif is_npu_available():
device = "npu:0"
else:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
Expand Down
8 changes: 7 additions & 1 deletion trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from .core import set_seed
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available
from .import_utils import (
is_diffusers_available,
is_npu_available,
is_peft_available,
is_wandb_available,
is_xpu_available,
)
from .models import (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
Expand Down
15 changes: 10 additions & 5 deletions trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.nn.utils.rnn import pad_sequence
from transformers import top_k_top_p_filtering

from .import_utils import is_xpu_available
from .import_utils import is_npu_available, is_xpu_available


try:
Expand Down Expand Up @@ -245,6 +245,8 @@ def set_seed(seed: int):
torch.manual_seed(seed)
if is_xpu_available():
torch.xpu.manual_seed_all(seed)
elif is_npu_available():
torch.npu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)

Expand All @@ -268,13 +270,16 @@ class PPODecorators(object):
@contextmanager
def empty_device_cache(cls):
yield
if is_xpu_available():
if cls.optimize_device_cache and torch.xpu.is_available():
if cls.optimize_device_cache:
if is_xpu_available():
gc.collect()
torch.xpu.empty_cache()
gc.collect()
else:
if cls.optimize_device_cache and torch.cuda.is_available():
elif is_npu_available():
gc.collect()
torch.npu.empty_cache()
gc.collect()
elif torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
gc.collect()
Expand Down
11 changes: 11 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,14 @@ def is_xpu_available() -> bool:
return hasattr(torch, "xpu") and torch.xpu.is_available()
except RuntimeError:
return False


def is_npu_available() -> bool:
"""Checks if `torch_npu` is installed and potentially if a NPU is in the environment"""
if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
return False

import torch
import torch_npu # noqa: F401

return hasattr(torch, "npu") and torch.npu.is_available()
4 changes: 3 additions & 1 deletion trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from safetensors.torch import load_file as safe_load_file
from transformers import PreTrainedModel

from ..import_utils import is_peft_available, is_transformers_greater_than, is_xpu_available
from ..import_utils import is_npu_available, is_peft_available, is_transformers_greater_than, is_xpu_available


if is_peft_available():
Expand Down Expand Up @@ -401,6 +401,8 @@ def _get_current_device(cls):
state = PartialState()
if is_xpu_available():
return f"xpu:{state.local_process_index}"
elif is_npu_available():
return f"npu:{state.local_process_index}"
else:
return state.local_process_index if torch.cuda.is_available() else "cpu"

Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
stack_dicts,
stats_to_np,
)
from ..import_utils import is_torch_greater_2_0, is_xpu_available
from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments

Expand Down Expand Up @@ -349,6 +349,8 @@ def __init__(
else:
if is_xpu_available():
self.current_device = torch.device("xpu:0")
elif is_npu_available():
self.current_device = torch.device("npu:0")
else:
self.current_device = torch.device("cuda:0")

Expand Down

0 comments on commit d708ec2

Please sign in to comment.