Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for DPO #556

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,5 @@ OUT/
examples/experiments/grounded_program_synthesis/dataset
ckpts/

ray_results/
ray_result/
examples/checkpoints/
99 changes: 99 additions & 0 deletions examples/dpo_ultrafeedback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import sys
import json
from functools import partial

from datasets import load_dataset
from transformers import AutoTokenizer

import trlx
from trlx.data.default_configs import (
DPOConfig,
ModelConfig,
OptimizerConfig,
SchedulerConfig,
TokenizerConfig,
TrainConfig,
TRLConfig,
)

model_path = "HuggingFaceH4/mistral-7b-sft-beta"
wandb_project = "trlx"

default_config = TRLConfig(
train=TrainConfig(
seq_length=1024,
epochs=1,
total_steps=70000,
batch_size=1,
checkpoint_interval=100000,
eval_interval=5000,
seed=42,
project_name=wandb_project,
pipeline="PromptPipeline",
trainer="AccelerateDPOTrainer",
checkpoint_dir="checkpoints/dpo_ultrafeedback",
),
model=ModelConfig(model_path=model_path, num_layers_unfrozen=-1),
tokenizer=TokenizerConfig(tokenizer_path=model_path, truncation_side="right"),
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=5e-7, betas=(0.9, 0.99), eps=1.0e-8, weight_decay=1.0e-5)),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps
method=DPOConfig(
name="DPOConfig",
gen_kwargs=dict(max_new_tokens=768, temperature=0.7, top_k=50, top_p=0.95, do_sample=True),
beta=0.1,
label_pad_token_id=-100,
padding_value=0,
),
)


def preprocess(sample, tokenizer, test=False):
"""
Formats the input to the same training style used for mistral-7b-v0.1
When fine-tuning, modify your pre-processing to match the prompt template used during pretraining.
"""
assert len(sample["chosen"]) == len(sample["rejected"]) == 2

assistant_prompt = "<|assistant|>"

prompt, chosen = tokenizer.apply_chat_template(sample["chosen"], tokenize=False).split(assistant_prompt)
rejected = tokenizer.apply_chat_template(sample["rejected"], tokenize=False).split(assistant_prompt)[-1]

return {
"prompt": prompt if not test else prompt + assistant_prompt,
"chosen": assistant_prompt + chosen,
"rejected": assistant_prompt + rejected,
}


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

tokenizer = AutoTokenizer.from_pretrained(model_path)
dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized")

dataset["dpo_train"] = dataset["train_prefs"].map(
partial(preprocess, tokenizer=tokenizer, test=False),
remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"],
)
dataset["dpo_test"] = dataset["test_prefs"].map(
partial(preprocess, tokenizer=tokenizer, test=True),
remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"],
)

print(
f"Length of training dataset : {len(dataset['dpo_train'])} \
Length of test dataset : {len(dataset['dpo_test'])}"
)

trlx.train(
config=config,
samples=dataset["dpo_train"],
eval_prompts=dataset["dpo_test"]["prompt"][:2], # running eval on subset only
stop_sequences=["<|user|>", "<|User|>"],
)


if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
64 changes: 64 additions & 0 deletions examples/hh/dpo_hh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import json
import sys

from datasets import load_dataset

import trlx
from trlx.data.default_configs import (
DPOConfig,
ModelConfig,
OptimizerConfig,
SchedulerConfig,
TokenizerConfig,
TrainConfig,
TRLConfig,
)

default_config = TRLConfig(
train=TrainConfig(
seq_length=1024,
epochs=100,
total_steps=1000,
batch_size=4,
checkpoint_interval=10000,
eval_interval=100,
pipeline="PromptPipeline",
trainer="AccelerateDPOTrainer",
checkpoint_dir="checkpoints/dpo_hh",
),
model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1),
tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"),
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps
method=DPOConfig(
name="DPOConfig",
gen_kwargs=dict(max_new_tokens=40, top_k=20, top_p=1.0, do_sample=True),
beta=0.1,
label_pad_token_id=-100,
padding_value=0,
),
)


def preprocess(sample):
sample["dpo"] = [sample["prompt"], sample["chosen"], sample["rejected"]]
return sample


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

dataset = load_dataset("Dahoas/full-hh-rlhf").map(preprocess)

trlx.train(
config=config,
samples=dataset["train"]["dpo"],
eval_prompts=dataset["test"]["prompt"][:280],
# metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)},
stop_sequences=["Human:", "human:", "Assistant:", "assistant:"],
)


if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
27 changes: 27 additions & 0 deletions trlx/data/default_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from trlx.models.modeling_ilql import ILQLConfig
from trlx.models.modeling_ppo import PPOConfig
from trlx.trainer.accelerate_dpo_trainer import DPOConfig
from trlx.trainer.accelerate_sft_trainer import SFTConfig

from .configs import (
Expand Down Expand Up @@ -146,3 +147,29 @@ def default_nemo_1_3b_config():

here = Path(__file__).parent
return OmegaConf.load(here.parent.parent / "configs" / "nemo_configs" / "megatron_1.3b.yaml")


def default_dpo_config():
return TRLConfig(
train=TrainConfig(
seq_length=1024,
epochs=100,
total_steps=1000,
batch_size=8,
checkpoint_interval=10000,
eval_interval=100,
pipeline="PromptPipeline",
trainer="DPOTrainer",
),
model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1),
tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"),
optimizer=OptimizerConfig(
name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
),
scheduler=SchedulerConfig(
name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4) # train.total_steps
),
method=DPOConfig(
name="DPOConfig", gen_kwargs=dict(max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True), beta=0.1
),
)
13 changes: 13 additions & 0 deletions trlx/data/dpo_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dataclasses import dataclass

from transformers import BatchEncoding


@dataclass
class DPOElement:
prompt_tokens: BatchEncoding
chosen_tokens: BatchEncoding
rejected_tokens: BatchEncoding


# TODO: Extend to include a concrete class for DPOPreferenceBatch
4 changes: 2 additions & 2 deletions trlx/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def __next__(self): # noqa: C901
minibatch = BatchEncoding(sliced_data)
elif is_dataclass(batch):
minibatch = batch.__class__(**sliced_data)
# else:
# minibatch = sliced_data
else:
minibatch = sliced_data

minibatches.append(minibatch)

Expand Down
124 changes: 124 additions & 0 deletions trlx/pipeline/offline_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PreTrainedTokenizerFast,
)

from trlx.data.dpo_types import DPOElement
from trlx.data.ilql_types import (
ILQLBatch,
ILQLElement,
Expand Down Expand Up @@ -277,3 +278,126 @@ def create_loader(self, batch_size: int):
collate_fn=ilql_seq2seq_collate_fn,
drop_last=torch.distributed.is_initialized(),
)


class DPOStore(BaseRolloutStore):
# Adapted from TRL
def __init__(
self,
preferences: List[DPOElement],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
label_pad_token_id: int,
padding_value: int,
):
super().__init__()
self.tokenizer = tokenizer
self.label_pad_token_id = label_pad_token_id
self.padding_value = padding_value

self.history = [
self._build_batch_from_preference_tokens(preference_element) for preference_element in preferences
]

@staticmethod
def tokenize_preferences(
sample: Iterable[str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
max_length=2048,
max_prompt_length=256,
) -> DPOElement:
if isinstance(sample, Iterable):
if len(sample) != 3:
raise ValueError(
f"Expected iterable of length 3 (prompt, chosen response, rejected response). Got {len(sample)}"
)
prompt_tokens = tokenizer(sample["prompt"], add_special_tokens=False)
chosen_tokens = tokenizer(sample["chosen"], add_special_tokens=False)
rejected_tokens = tokenizer(sample["rejected"], add_special_tokens=False)
else:
raise ValueError(f"{sample} is not an iterable")

chosen_tokens["input_ids"].append(tokenizer.eos_token_id)
chosen_tokens["attention_mask"].append(1)

rejected_tokens["input_ids"].append(tokenizer.eos_token_id)
rejected_tokens["attention_mask"].append(1)

longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

# if combined sequence is too long, truncate the prompt only
if len(prompt_tokens["input_ids"]) + longer_response_length > max_length:
if tokenizer.truncation_side == "right":
prompt_tokens = {k: v[:max_prompt_length] for k, v in prompt_tokens.items()}
elif tokenizer.truncation_side == "left":
prompt_tokens = {k: v[-max_prompt_length:] for k, v in prompt_tokens.items()}

# if that's still too long, truncate the response
if len(prompt_tokens["input_ids"]) + longer_response_length > max_length:
chosen_tokens = {k: v[: max_length - max_prompt_length] for k, v in chosen_tokens.items()}
rejected_tokens = {k: v[: max_length - max_prompt_length] for k, v in rejected_tokens.items()}

return DPOElement(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens)

def _build_batch_from_preference_tokens(self, preference_tokens: DPOElement) -> Dict:
# Create labels
chosen_sequence_tokens = {
k: preference_tokens.prompt_tokens[k] + preference_tokens.chosen_tokens[k]
for k in preference_tokens.chosen_tokens
}
rejected_sequence_tokens = {
k: preference_tokens.prompt_tokens[k] + preference_tokens.rejected_tokens[k]
for k in preference_tokens.rejected_tokens
}
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
chosen_sequence_tokens["labels"][: len(preference_tokens.prompt_tokens["input_ids"])] = [
self.label_pad_token_id
] * len(preference_tokens.prompt_tokens["input_ids"])
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
rejected_sequence_tokens["labels"][: len(preference_tokens.prompt_tokens["input_ids"])] = [
self.label_pad_token_id
] * len(preference_tokens.prompt_tokens["input_ids"])

batch = {}

for k, toks in {
"chosen": chosen_sequence_tokens,
"rejected": rejected_sequence_tokens,
"prompt": preference_tokens.prompt_tokens,
}.items():
for type_key, tokens in toks.items():
if type_key == "token_type_ids":
continue
batch[f"{k}_{type_key}"] = tokens

return batch

def create_loader(self, batch_size: int, shuffle=False) -> DataLoader:
def collate_fn(batch: Iterable[dict]):
# first, pad everything to the same length
padded_batch = {}
for k in batch[0].keys():
if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
# adapted from https://stackoverflow.com/questions/73256206
if "prompt" in k:
to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
else:
to_pad = [torch.LongTensor(ex[k]) for ex in batch]
if k.endswith("_input_ids"):
padding_value = self.tokenizer.pad_token_id
elif k.endswith("_labels"):
padding_value = self.label_pad_token_id
elif k.endswith("_attention_mask"):
padding_value = self.padding_value
else:
raise ValueError(f"Unexpected key in batch '{k}'")

padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
# for the prompt, flip back so padding is on left side
if "prompt" in k:
padded_batch[k] = padded_batch[k].flip(dims=[1])
else:
padded_batch[k] = [ex[k] for ex in batch]

return padded_batch

return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle, pin_memory=True)
Loading