Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention error #973

Closed
SpursLipu opened this issue Nov 9, 2023 · 6 comments
Closed

Flash attention error #973

SpursLipu opened this issue Nov 9, 2023 · 6 comments

Comments

@SpursLipu
Copy link

I want to use dpo ft qwen-chat-14b, but I meet the error. The input(q, k, y) type of flash-attention in qwen has to be set as float16 or bfloat16, but in dpo_trainer the type is float32. If I turn off the flash-attention this error will not occur. But training become very slow. How to solve this problem?

Traceback (most recent call last):
File "/mnt/afs/smartbrain/FastChat/fastchat/rlhf/dpo_qwen.py", line 215, in
dpo_trainer.train()
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1555, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1837, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 2682, in training_step
loss = self.compute_loss(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 594, in compute_loss
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 545, in get_batch_metrics
) = self.concatenated_forward(model, batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 511, in concatenated_forward
all_logits = model(
^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 632, in forward
return model_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 620, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/peft/peft_model.py", line 918, in forward
return self.base_model(
^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 94, in forward
return self.model.forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 1108, in forward
transformer_outputs = self.transformer(
^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 938, in forward
outputs = block(
^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 639, in forward
attn_outputs = self.attn(
^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 546, in forward
context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 174, in forward
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
AssertionError
0%| | 0/127611 [00:01<?, ?it/s]

@lvwerra
Copy link
Member

lvwerra commented Nov 9, 2023

Tagging @kashif and @younesbelkada.

@kashif
Copy link
Collaborator

kashif commented Nov 9, 2023

i have to check the dpo_qwen.py script but seems like some type mis-match between fp16/bf16 ... perhaps you can check if the model is cast to float16/bf16 properly?

@younesbelkada
Copy link
Contributor

Hey @SpursLipu
That model uses a custom FA-2 implementation: https://huggingface.co/Qwen/Qwen-14B-Chat/blob/main/modeling_qwen.py#L83 I suggest to open an issue on the Hub repo directly

@SpursLipu
Copy link
Author

i have to check the dpo_qwen.py script but seems like some type mis-match between fp16/bf16 ... perhaps you can check if the model is cast to float16/bf16 properly?

my dpo_qwen.py please check

import os
from dataclasses import dataclass, field
from typing import Dict, Optional

import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from fastchat.conversation import get_conv_template
from trl import DPOTrainer

@dataclass
class ScriptArguments:
    beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})

    model_name_or_path: Optional[str] = field(
        default="gpt2",
        metadata={"help": "the model name"}
    )
    dataset: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset path"})
    trust_remote_code: Optional[bool] = field(default=True, metadata={"help": "trust_remote_code"})
    learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"})
    per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"})
    gradient_accumulation_steps: Optional[int] = field(
        default=1, metadata={"help": "the number of gradient accumulation steps"}
    )
    label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"})
    lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
    lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
    lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})


    max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"})
    max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"})

    report_to: Optional[str] = field(
        default=None,
        metadata={
            "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
            '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
            'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
        },
    )
    ignore_bias_buffers: Optional[bool] = field(
        default=False,
        metadata={
            "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
            "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
        },
    )
    gradient_checkpointing: Optional[bool] = field(
        default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
    )
    gradient_checkpointing_kwargs: Optional[dict] = field(
        default=None,
        metadata={
            "help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
        },
    )
def preprocess(dataset: str, split: str, silent: bool = False, cache_dir: str = None) -> Dataset:
    """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      \n\nHuman: <prompt>\n\nAssistant:
    Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
    """
    dataset = load_dataset(dataset, split=split, cache_dir=cache_dir)
    conv = get_conv_template("qwen-7b-chat")
    def split_prompt_and_responses(sample) -> Dict[str, str]:
        chosen = sample["chosen"].split("\n\nAssistant: ")[-1]
        rejected = sample["rejected"].split("\n\nAssistant: ")[-1]

        prompt = sample["chosen"][len("\n\nHuman: "): sample["chosen"].rfind("\n\nAssistant: ")]
        prompt = prompt.replace("\n\nAssistant: ", conv.sep + conv.roles[1] + '\n')
        prompt = prompt.replace("\n\nHuman: ", conv.sep + conv.roles[0] + '\n')
        prompt = conv.roles[0] + '\n' + prompt + conv.sep + conv.roles[1] + '\n'
        return {
            "prompt": prompt,
            "chosen": chosen,
            "rejected": rejected,
        }

    return dataset.map(split_prompt_and_responses)
if __name__ == "__main__":
    global local_rank
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]
    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        device_map={"": Accelerator().local_process_index},
        trust_remote_code=script_args.trust_remote_code,
        load_in_4bit=True)

    if script_args.ignore_bias_buffers:
        # torch distributed hack
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

    model_ref = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        device_map={"": Accelerator().local_process_index},
        trust_remote_code=script_args.trust_remote_code,
        load_in_4bit=True)

    tokenizer = AutoTokenizer.from_pretrained(
        script_args.model_name_or_path,
        trust_remote_code=script_args.trust_remote_code,
        pad_token='<|endoftext|>',
        eos_token='<|im_end|>',
        bos_token='<|im_start|>')

    train_dataset = preprocess(script_args.dataset, "train")

    eval_dataset = preprocess(script_args.dataset, "test")

    training_args = TrainingArguments(
        per_device_train_batch_size=script_args.per_device_train_batch_size,
        remove_unused_columns=False,
        gradient_accumulation_steps=script_args.gradient_accumulation_steps,
        learning_rate=script_args.learning_rate,
        evaluation_strategy="steps",
        logging_first_step=True,
        logging_steps=10,  # match results in blog post
        eval_steps=500,
        output_dir="./test",
        optim="adamw_torch",
        warmup_steps=150,
        report_to=script_args.report_to,
        bf16=True,
        gradient_checkpointing=script_args.gradient_checkpointing,
    )
    local_rank = training_args.local_rank
    peft_config = LoraConfig(
        r=script_args.lora_r,
        lora_alpha=script_args.lora_alpha,
        lora_dropout=script_args.lora_dropout,
        target_modules=[
            "q_proj",
            "v_proj",
            "k_proj",
            "out_proj",
            "fc_in",
            "fc_out",
            "wte",
        ],
        bias="none",
        task_type="CAUSAL_LM",
    )
    dpo_trainer = DPOTrainer(
        model,
        model_ref,
        args=training_args,
        beta=script_args.beta,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        peft_config=peft_config,
        max_length=script_args.max_length,
        max_prompt_length=script_args.max_prompt_length,
        generate_during_eval=False,
    )
    dpo_trainer.train()

@younesbelkada
Copy link
Contributor

younesbelkada commented Nov 10, 2023

Can you pass BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) in from_pretrained:

    quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        device_map={"": Accelerator().local_process_index},
        trust_remote_code=script_args.trust_remote_code,
        quantization_config=quantization_config)

But I am really not sure this will solve your bug, I just suspect that there might be some weird interaction between the compute dtype and FA-2 on their repository

Copy link

github-actions bot commented Dec 9, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants