Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO on multi-GPU but get Error: Expected all tensors to be on the same device #809

Closed
Ricardokevins opened this issue Sep 22, 2023 · 22 comments

Comments

@Ricardokevins
Copy link

Ricardokevins commented Sep 22, 2023

I am training alpaca-7B on 4 * A100 80G
I am using the provided Deepspeed-zero2 yaml file as the configuration file in the repository and running it. Even when I set the device-map of the model to None and batch size to 2, there is still a memory overflow issue with minibatch size set to 1.

Therefore, I tried setting the device-map to 'auto' so that accelerate can shard the model across different GPUs. However, when the code reaches ppo-trainer.generate, the aforementioned error occurs. Could you please advise on how to resolve this?

the command i use:

accelerate launch --config_file=deepspeed_zero2.yaml --num_processes 4 train.py --reward_fuction $RF --model_name $model_name

the model loading code:

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    script_args.model_name,
    device_map='auto',
    #device_map=None,
    torch_dtype=torch.bfloat16,
)

the PPo config

config = PPOConfig(
    model_name=script_args.model_name,
    log_with=script_args.log_with,
    learning_rate=5e-6,
    batch_size=2,
    mini_batch_size=1,
    gradient_accumulation_steps=2,
    ppo_epochs=4,
    early_stopping=True,
    optimize_cuda_cache=True,
    seed=script_args.seed,
    project_kwargs = project_kwargs,
    remove_unused_columns=False,
)
@Ricardokevins
Copy link
Author

@younesbelkada
Hi. I apologize for the interruption. I'm currently uncertain whether ppo_trainer.generate supports model parameter sharding across different GPUs. If it indeed doesn't, then flash_attn alone cannot solve all the problems, especially when the model size grows and requires more memory.

@younesbelkada
Copy link
Contributor

Hi @Ricardokevins
i see, thanks for the description, I think the fix is to make sure your lm_head is on the same device as the input. Can you print model.pretrained_model.hf_device_map ?

@Ricardokevins
Copy link
Author

Ricardokevins commented Sep 25, 2023

Hi @Ricardokevins i see, thanks for the description, I think the fix is to make sure your lm_head is on the same device as the input. Can you print model.pretrained_model.hf_device_map ?

Hi,thank you for your reply! @younesbelkada
I follow your instruction and try to print, but encounter an new issue: AttributeError: 'LlamaForCausalLM' object has no attribute 'hf_device_map'

is it an issue about transformers version? i use transformers==4.33.1

@younesbelkada
Copy link
Contributor

Hmm this is strange, loading a model with device_map="auto" should add the hf_device_map attribute to the model. Have you loaded your model with device_map="auto" ?

@Ricardokevins
Copy link
Author

Hmm this is strange, loading a model with device_map="auto" should add the hf_device_map attribute to the model. Have you loaded your model with device_map="auto" ?

Hi thank you for your help

I check the setting and set the device_map="auto" , the output is following:

{'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 1, 'model.layers.8': 1, 'model.layers.9': 1, 'model.layers.10': 1, 'model.layers.11': 1, 'model.layers.12': 1, 'model.layers.13': 1, 'model.layers.14': 1, 'model.layers.15': 1, 'model.layers.16': 2, 'model.layers.17': 2, 'model.layers.18': 2, 'model.layers.19': 2, 'model.layers.20': 2, 'model.layers.21': 2, 'model.layers.22': 2, 'model.layers.23': 2, 'model.layers.24': 2, 'model.layers.25': 3, 'model.layers.26': 3, 'model.layers.27': 3, 'model.layers.28': 3, 'model.layers.29': 3, 'model.layers.30': 3, 'model.layers.31': 3, 'model.norm': 3, 'lm_head': 3}

But still receive error:

  File "/mnt/data/sheshuaijie/anaconda3local/envs/shesj_ds/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 89, in forward
    return self.weight * hidden_states.to(input_dtype)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

I check the error is caused by the LlamaRMSNorm module
In LlamaDecoderLayer(nn.Module), the decoder layer will call : hidden_states = self.input_layernorm(hidden_states)
which case the error

@Ricardokevins
Copy link
Author

Ricardokevins commented Oct 1, 2023

I dive into the code, and find that the model's 7th layer is on gpu0 although in device_map it should be placed on gpu1. @younesbelkada

residual = hidden_states
print("hidden_states.device", hidden_states.device,self.input_layernorm.weight.device,self.self_attn.q_proj.weight.device)
hidden_states = self.input_layernorm(hidden_states)
entering decoder layer  0  of  32
hidden_states.device cuda:0 cuda:0 cuda:0
entering decoder layer  1  of  32
hidden_states.device cuda:0 cuda:0 cuda:0
entering decoder layer  2  of  32
hidden_states.device cuda:0 cuda:0 cuda:0
entering decoder layer  3  of  32
hidden_states.device cuda:0 cuda:0 cuda:0
entering decoder layer  4  of  32
hidden_states.device cuda:0 cuda:0 cuda:0
entering decoder layer  5  of  32
hidden_states.device cuda:0 cuda:0 cuda:0
entering decoder layer  6  of  32
hidden_states.device cuda:0 cuda:0 cuda:0
entering decoder layer  7  of  32
hidden_states.device cuda:1 cuda:0 cuda:0

@allanj
Copy link
Contributor

allanj commented Oct 31, 2023

I have some similar issues about "Expected all tensors to be on the same device" when I run the example in https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py

Later I found out using the example command is fine
accelerate launch --multi_gpu --num_machines 1 --num_processes 8 examples/stack_llama/scripts/rl_training.py --log_with=wandb --model_name=<LLAMA_SE_MODEL> --reward_model_name=<LLAMA_SE_RM_MODEL> --adafactor=False --tokenizer_name=<LLAMA_TOKENIZER> --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam

But I enable deepspeed zero-2, I have that error

@Ricardokevins
Copy link
Author

Anything Update here ?
The Problem still exsit

@allanj
Copy link
Contributor

allanj commented Nov 9, 2023

Remove Accelerator() in your code. I fixed the issue now

@Ricardokevins
Copy link
Author

Remove Accelerator() in your code. I fixed the issue now

Wow, amazing! I noticed that in the original code, the Accelerator was mainly used to provide the current_device. Have you replaced all the settings for current_device with 'auto'?

@Ricardokevins
Copy link
Author

Remove Accelerator() in your code. I fixed the issue now

Thank you for your previous suggestion in resolving the issue. I have implemented the changes based on your advice, but unfortunately, the problem still persists. If it's convenient for you, could you please share the code that you successfully ran, so that I can reference it for further troubleshooting?

Thank you for your help.

@imgremlin
Copy link

imgremlin commented Nov 14, 2023

@Ricardokevins have you solved the issue? I have the same problem

@allanj
Copy link
Contributor

allanj commented Nov 14, 2023

Hi @Ricardokevins @imgremlin , I think you guys can paste the reproducing code snippet for me to check?

@Ricardokevins
Copy link
Author

@Ricardokevins have you solved the issue? I have the same problem

No, I can't run the model correctly

@Ricardokevins
Copy link
Author

Ricardokevins commented Nov 14, 2023

Hi @Ricardokevins @imgremlin , I think you guys can paste the reproducing code snippet for me to check?

sure


from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
from dataclasses import dataclass, field
from typing import Optional
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import BitsAndBytesConfig, HfArgumentParser, LlamaTokenizer,LlamaForCausalLM
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSequenceClassification
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from trl.core import LengthSampler
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from trl import PreTrainedModelWrapper


VALUE_HEAD_FILE_NAME = "value_head.bin"



def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True):
    if isinstance(model,AutoModelForCausalLMWithValueHead):
        state_dict = model.pretrained_model.state_dict()
    else:
        state_dict = model.state_dict()
    print("Enter")
    for k, v in state_dict.items():
        print(k)
    filtered_state_dict = {}
    for k, v in model.named_parameters():
        if 'v_head' in k:
            continue
        k = k.replace("pretrained_model.",'')
        print(k)
        filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
    return filtered_state_dict

@dataclass
class ScriptArguments:
    """
    The name of the Casual LM model we wish to fine with PPO
    """
    local_rank: Optional[int] = field(default=-1, metadata={"help": "Local rank for distributed training (-1: not distributed)"})
    training_config: Optional[str] = field(default=None, metadata={"help": "Path to training config"})
    log_with: Optional[str] = field(default='tensorboard', metadata={"help": "use 'wandb' to log with wandb"})
    seed: Optional[int] = field(default=42, metadata={"help": "Random seed"})
    output_dir : Optional[str] = field(default="")


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
import json
training_details = {
    "model_name" : "alpaca-7b",
    "lr": 2e-5,
    "batch_size": 4,
    "mini_batch_size" : 1,
    "gradient_accumulation_steps" : 4,
    "ppo_epoch" : 4,
    "version_name" : "LoRA_PPO"
}
script_args.model_name = training_details['model_name']
script_args.reward_functon = training_details['reward_functon']

project_name = "{model}-{setting}".format(model=script_args.model_name.split("/")[-1],setting = script_args.training_config)
script_args.output_dir = script_args.output_dir + training_details['version_name'] + '/'
script_args.output_dir = script_args.output_dir + project_name
script_args.explore_data = "alpaca.json"

tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name,padding_side = "left")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

def tokenize(sample):
    sample['text'] = sample['instruction']
    sample["input_ids"] = tokenizer.encode(sample["text"])
    sample["query"] = tokenizer.decode(sample["input_ids"])
    if 'answer' in sample:
        sample["resonse_label"] = str(sample["answer"])
    else:
        sample["resonse_label"] = "NONE ANSWER"
    return sample

def create_and_prepare_dataset(config):
    ds = load_dataset("json", data_files=config.explore_data)['train'].shuffle(seed=42)
    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds



import os
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
    device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
    
model = AutoModelForCausalLMWithValueHead.from_pretrained(
    script_args.model_name,
    torch_dtype=torch.bfloat16,
)


tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name,padding_side = "left")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

dataset = create_and_prepare_dataset(script_args)


def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

project_kwargs={"logging_dir": script_args.output_dir}


config = PPOConfig(
    model_name=script_args.model_name,
    log_with=script_args.log_with,
    learning_rate=training_details['lr'],
    batch_size=training_details['batch_size'],
    mini_batch_size=training_details['mini_batch_size'],
    gradient_accumulation_steps=training_details['gradient_accumulation_steps'],
    ppo_epochs=training_details['ppo_epoch'],
    early_stopping=True,
    optimize_cuda_cache=True,
    seed=script_args.seed,
    project_kwargs = project_kwargs,
    remove_unused_columns=False,
)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(script_args.model_name)
model.gradient_checkpointing_enable()
ppo_trainer = PPOTrainer(
    config,
    model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    dataset=dataset,
    data_collator=collator,
)

device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a ` pipeline` bug


generation_config = GenerationConfig(
    top_p=1.0,
    top_k=0,
    max_new_tokens=1024,
    do_sample=True,
)




    
step = 0
accuracy = 0
total = 0



generated_sample = []
torch.backends.cuda.sdp_kernel(
    enable_flash=True, enable_math=False, enable_mem_efficient=False
)
for i in range(30):
    print("EPOCH ",i)
    for iteration, batch in tqdm(enumerate(ppo_trainer.dataloader)):    
        question_tensors = batch["input_ids"]
        ppo_trainer.accelerator.unwrap_model(model).gradient_checkpointing_disable()
        response_tensors = ppo_trainer.generate(
            question_tensors,
            return_prompt=False,
            generation_config = generation_config,
            #**generation_kwargs,
        )
        ppo_trainer.accelerator.unwrap_model(model).gradient_checkpointing_enable()
        batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)

        labels = batch["resonse_label"]
        texts = [q + r for q, r in zip(batch["query"], batch["response"])]
        rewards = [torch.tensor(1) for _,l in zip(texts,labels)]
       
        

        if script_args.local_rank == 0:
            print(batch["query"][0])
            print(batch["resonse_label"][0])
            print(batch["response"][0])
            print(rewards[0])
        stats = ppo_trainer.step(question_tensors, response_tensors, rewards)

        
        step += 1

        ppo_trainer.log_stats(stats, batch, rewards)   
        if step % 30 == 0:
            ppo_trainer.save_pretrained(script_args.output_dir + f"/ModelSaved/step_{step}")
         ```

@Ricardokevins
Copy link
Author

Ricardokevins commented Nov 14, 2023

Hi @Ricardokevins @imgremlin , I think you guys can paste the reproducing code snippet for me to check?

Here is my script

PORT=$(( $RANDOM % 1000 + 32768 ))
export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export NCCL_ASYNC_ERROR_HANDLING=1
export GRUB_CMDLINE_LINUX_DEFAULT="iommu=soft"
accelerate launch --config_file="deepspeed_zero2.yaml" --num_processes 4 ppo_full.py
compute_environment: LOCAL_MACHINE
deepspeed_config:
  gradient_accumulation_steps: 1
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: false
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

@allanj
Copy link
Contributor

allanj commented Nov 14, 2023

The code looks fine to me. Do you have a minimal runnable snippet?

@imgremlin
Copy link

This code works well with:

  • accelerate launch --config_file=scripts/deepspeed_configs/deepspeed_zero1.yaml --num_processes 1 scripts/deepspeed_repro.py
  • accelerate launch --config_file=scripts/deepspeed_configs/multi_gpu.yaml --num_processes 2 scripts/deepspeed_repro.py

But doesn't work with:

accelerate launch --config_file=scripts/deepspeed_configs/deepspeed_zero1.yaml --num_processes 2 scripts/deepspeed_repro.py

import torch
from accelerate import Accelerator
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from trl import AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer
from trl.import_utils import is_xpu_available

MODEL = 'facebook/bart-base'
BS = 32

device_map = {"": Accelerator().local_process_index}

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
    MODEL,
    device_map=device_map
)
tokenizer.pad_token = tokenizer.eos_token

class TextDataset(Dataset):

    def __init__(self):
        self.texts = ['Do you love Paris?'] * 1024

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]
    
dataset = TextDataset()

config = PPOConfig(ppo_epochs=1, batch_size=BS, mini_batch_size=BS//2)
ppo_trainer = PPOTrainer(
    config=config,
    model=model,
    tokenizer=tokenizer,
    dataset=dataset,
)

device = ppo_trainer.accelerator.device

if ppo_trainer.accelerator.num_processes == 1:
    if is_xpu_available():
        device = "xpu:0"
    else:
        device = 0 if torch.cuda.is_available() else "cpu"

for texts in ppo_trainer.dataloader:
    queries_arr = [tokenizer.encode(i, return_tensors='pt').squeeze().to(device) for i in texts]
    response_arr = [ppo_trainer.generate(prompt, return_prompt=False).squeeze() for prompt in queries_arr]
    rewards=[torch.tensor(1.5)] * len(queries_arr)
    stats = ppo_trainer.step(queries_arr, response_arr, rewards)
    print('Batch is done!')

@paraGONG
Copy link

Hello! I have the same problem. Have you solved this issue?

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@CHLEE-Leo
Copy link

Guys, I am suffering from the same issue abovementioned. Do anybody have solution to this?

1 similar comment
@NuoJohnChen
Copy link

Guys, I am suffering from the same issue abovementioned. Do anybody have solution to this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants