Skip to content

Commit

Permalink
Fix prompt truncation bug and handle deepspeed preparation
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeepchittilla authored and s00652993 committed Nov 20, 2023
1 parent dfa814d commit 506fbbd
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 36 deletions.
68 changes: 48 additions & 20 deletions examples/dpo_ultrafeedback.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import itertools
import json
import sys
import json
from functools import partial

from datasets import load_dataset
from transformers import AutoTokenizer

import trlx
from trlx.data.default_configs import (
Expand All @@ -15,54 +16,81 @@
TRLConfig,
)

model_path = "HuggingFaceH4/mistral-7b-sft-beta"
wandb_project = "trlx"

default_config = TRLConfig(
train=TrainConfig(
seq_length=1024,
epochs=100,
total_steps=1000,
epochs=2,
total_steps=1000000,
batch_size=1,
checkpoint_interval=10000,
eval_interval=100,
checkpoint_interval=100000,
eval_interval=1000,
seed=42,
project_name=wandb_project,
pipeline="PromptPipeline",
trainer="AccelerateDPOTrainer",
checkpoint_dir="checkpoints/dpo_ultrafeedback",
),
model=ModelConfig(model_path="HuggingFaceH4/mistral-7b-sft-beta", num_layers_unfrozen=-1),
tokenizer=TokenizerConfig(tokenizer_path="HuggingFaceH4/mistral-7b-sft-beta", truncation_side="right"),
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=2e-5, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)),
model=ModelConfig(model_path=model_path, num_layers_unfrozen=-1),
tokenizer=TokenizerConfig(tokenizer_path=model_path, truncation_side="right"),
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-5, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps
method=DPOConfig(
name="DPOConfig",
gen_kwargs=dict(max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, do_sample=True),
gen_kwargs=dict(max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.95, do_sample=True),
beta=0.1,
label_pad_token_id=-100,
padding_value=0,
),
)


def preprocess(sample):
def preprocess(sample, tokenizer, test=False):
"""
Return list of lists with Context/Prompt at index 0, Chosen at index 1 and rejected at index 2
Formats the input to the same training style used for mistral-7b-v0.1
When fine-tuning, modify your pre-processing to match the prompt template used during pretraining.
"""
assert len(sample["chosen"]) == len(sample["rejected"]) == 2

sample["dpo"] = [sample["prompt"], sample["chosen"][1]["content"], sample["rejected"][1]["content"]]
return sample
assistant_prompt = "<|assistant|>"

prompt, chosen = tokenizer.apply_chat_template(sample["chosen"], tokenize=False).split(assistant_prompt)
rejected = tokenizer.apply_chat_template(sample["rejected"], tokenize=False).split(assistant_prompt)[-1]

return {
"prompt": prompt if not test else prompt + assistant_prompt,
"chosen": assistant_prompt + chosen,
"rejected": assistant_prompt + rejected,
}


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized").map(preprocess)
tokenizer = AutoTokenizer.from_pretrained(model_path)
dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized")

dataset["dpo_train"] = dataset["train_prefs"].map(
partial(preprocess, tokenizer=tokenizer, test=False),
remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"],
)
dataset["dpo_test"] = dataset["test_prefs"].map(
partial(preprocess, tokenizer=tokenizer, test=True),
remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"],
)

print(
f"Length of training dataset : {len(dataset['dpo_train'])} \
Length of test dataset : {len(dataset['dpo_test'])}"
)

trlx.train(
config=config,
samples=dataset["train_prefs"]["dpo"],
eval_prompts=dataset["test_prefs"]["prompt"][:128],
# metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)},
stop_sequences=["User:", "user:", "Assistant:", "assistant:"]
+ ["{e}x {i}put:".format(e=e, i=i) for e, i in itertools.product(["e", "E"], ["in", "In", "out", "Out"])],
samples=dataset["dpo_train"],
eval_prompts=dataset["dpo_test"]["prompt"][:8], # running eval on subset only
stop_sequences=["<|user|>", "<|User|>"],
)


Expand Down
19 changes: 11 additions & 8 deletions trlx/pipeline/offline_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,16 +300,19 @@ def __init__(

@staticmethod
def tokenize_preferences(
sample: Iterable[str], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_length=2048
sample: Iterable[str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
max_length=2048,
max_prompt_length=256,
) -> DPOElement:
if isinstance(sample, Iterable):
if len(sample) != 3:
raise ValueError(
f"Expected iterable of length 3 (prompt, chosen response, rejected response). Got {len(sample)}"
)
prompt_tokens = tokenizer(sample[0], add_special_tokens=False)
chosen_tokens = tokenizer(sample[1], add_special_tokens=False)
rejected_tokens = tokenizer(sample[2], add_special_tokens=False)
prompt_tokens = tokenizer(sample["prompt"], add_special_tokens=False)
chosen_tokens = tokenizer(sample["chosen"], add_special_tokens=False)
rejected_tokens = tokenizer(sample["rejected"], add_special_tokens=False)
else:
raise ValueError(f"{sample} is not an iterable")

Expand All @@ -324,14 +327,14 @@ def tokenize_preferences(
# if combined sequence is too long, truncate the prompt only
if len(prompt_tokens["input_ids"]) + longer_response_length > max_length:
if tokenizer.truncation_side == "right":
prompt_tokens = {k: v[:max_length] for k, v in prompt_tokens.items()}
prompt_tokens = {k: v[:max_prompt_length] for k, v in prompt_tokens.items()}
elif tokenizer.truncation_side == "left":
prompt_tokens = {k: v[-max_length:] for k, v in prompt_tokens.items()}
prompt_tokens = {k: v[-max_prompt_length:] for k, v in prompt_tokens.items()}

# if that's still too long, truncate the response
if len(prompt_tokens["input_ids"]) + longer_response_length > max_length:
chosen_tokens = {k: v[: max_length - max_length] for k, v in chosen_tokens.items()}
rejected_tokens = {k: v[: max_length - max_length] for k, v in rejected_tokens.items()}
chosen_tokens = {k: v[: max_length - max_prompt_length] for k, v in chosen_tokens.items()}
rejected_tokens = {k: v[: max_length - max_prompt_length] for k, v in rejected_tokens.items()}

return DPOElement(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens)

Expand Down
17 changes: 12 additions & 5 deletions trlx/trainer/accelerate_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
if is_deepspeed_available():
import deepspeed

import trlx.utils.logging as logging
from trlx.data.configs import TRLConfig
from trlx.data.method_configs import MethodConfig, register_method
from trlx.pipeline.offline_pipeline import DPOStore
Expand All @@ -18,6 +19,9 @@
from trlx.utils.modeling import pad_to_length


logger = logging.get_logger(__name__)


@dataclass
@register_method
class DPOConfig(MethodConfig):
Expand Down Expand Up @@ -47,9 +51,10 @@ def __init__(self, config: TRLConfig, **kwargs):

# TODO: Avoid setting up a reference model when hydra heads are used
self.ref_model = self.get_arch(self.config)
if self.accelerator.state.deepspeed_plugin.zero_stage == 3:
self.ref_model = self._prepare_deepspeed_zero3(self.ref_model)
else:
try:
if self.accelerator.state.deepspeed_plugin.zero_stage == 3:
self.ref_model = self._prepare_deepspeed_zero3(self.ref_model)
except:
self.ref_model.to(self.accelerator.device)
self.ref_model.eval()

Expand Down Expand Up @@ -311,6 +316,8 @@ def prepare_learning(self):
self.total_steps = self.config.train.epochs * len(self.train_dataloader)
self.total_steps = min(self.total_steps, self.config.train.total_steps)

def make_experience(self, samples: Iterable[Iterable], seq_length: int):
preferences = [DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length) for sample in samples]
def make_experience(self, samples: Iterable[Iterable], seq_length: int, max_prompt_length: int):
preferences = [
DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length, max_prompt_length) for sample in samples
]
self.store = DPOStore(preferences, self.tokenizer, self.label_pad_token_id, self.padding_value)
2 changes: 1 addition & 1 deletion trlx/trainer/accelerate_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def prepare_learning(self):
self.total_steps = self.config.train.epochs * len(self.train_dataloader)
self.total_steps = min(self.total_steps, self.config.train.total_steps)

def make_experience(self, samples, seq_length):
def make_experience(self, samples, seq_length, **kwargs):
if isinstance(samples[0], str):
self.store = PromptPipeline(samples, seq_length, self.tokenizer)
else:
Expand Down
2 changes: 1 addition & 1 deletion trlx/trainer/nemo_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def eval_collate(elems):
torch.set_float32_matmul_precision("medium")
self.trainer.fit(self.model)

def make_experience(self, samples, seq_length):
def make_experience(self, samples, seq_length, **kwargs):
if isinstance(samples[0], str):
self.store = PromptPipeline(samples, seq_length, self.tokenizer)
else:
Expand Down
3 changes: 2 additions & 1 deletion trlx/trlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def train( # noqa: C901
if rewards is not None:
trainer.make_experience(samples, rewards, config.train.seq_length)
else:
trainer.make_experience(samples, config.train.seq_length)
# this should be abstracted for all trainers with **kwargs
trainer.make_experience(samples, config.train.seq_length, max_prompt_length)
else:
raise ValueError("Either `samples` or `reward_fn` should be given for training")

Expand Down

0 comments on commit 506fbbd

Please sign in to comment.