diff --git a/examples/dpo_ultrafeedback.py b/examples/dpo_ultrafeedback.py index c97a4ce5e..f02a608c5 100644 --- a/examples/dpo_ultrafeedback.py +++ b/examples/dpo_ultrafeedback.py @@ -1,8 +1,9 @@ -import itertools -import json import sys +import json +from functools import partial from datasets import load_dataset +from transformers import AutoTokenizer import trlx from trlx.data.default_configs import ( @@ -15,25 +16,30 @@ TRLConfig, ) +model_path = "HuggingFaceH4/mistral-7b-sft-beta" +wandb_project = "trlx" + default_config = TRLConfig( train=TrainConfig( seq_length=1024, - epochs=100, - total_steps=1000, + epochs=2, + total_steps=1000000, batch_size=1, - checkpoint_interval=10000, - eval_interval=100, + checkpoint_interval=100000, + eval_interval=1000, + seed=42, + project_name=wandb_project, pipeline="PromptPipeline", trainer="AccelerateDPOTrainer", checkpoint_dir="checkpoints/dpo_ultrafeedback", ), - model=ModelConfig(model_path="HuggingFaceH4/mistral-7b-sft-beta", num_layers_unfrozen=-1), - tokenizer=TokenizerConfig(tokenizer_path="HuggingFaceH4/mistral-7b-sft-beta", truncation_side="right"), - optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=2e-5, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)), + model=ModelConfig(model_path=model_path, num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path=model_path, truncation_side="right"), + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-5, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)), scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps method=DPOConfig( name="DPOConfig", - gen_kwargs=dict(max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, do_sample=True), + gen_kwargs=dict(max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.95, do_sample=True), beta=0.1, label_pad_token_id=-100, padding_value=0, @@ -41,28 +47,50 @@ ) -def preprocess(sample): +def preprocess(sample, tokenizer, test=False): """ - Return list of lists with Context/Prompt at index 0, Chosen at index 1 and rejected at index 2 + Formats the input to the same training style used for mistral-7b-v0.1 + When fine-tuning, modify your pre-processing to match the prompt template used during pretraining. """ assert len(sample["chosen"]) == len(sample["rejected"]) == 2 - sample["dpo"] = [sample["prompt"], sample["chosen"][1]["content"], sample["rejected"][1]["content"]] - return sample + assistant_prompt = "<|assistant|>" + + prompt, chosen = tokenizer.apply_chat_template(sample["chosen"], tokenize=False).split(assistant_prompt) + rejected = tokenizer.apply_chat_template(sample["rejected"], tokenize=False).split(assistant_prompt)[-1] + + return { + "prompt": prompt if not test else prompt + assistant_prompt, + "chosen": assistant_prompt + chosen, + "rejected": assistant_prompt + rejected, + } def main(hparams={}): config = TRLConfig.update(default_config, hparams) - dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized").map(preprocess) + tokenizer = AutoTokenizer.from_pretrained(model_path) + dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized") + + dataset["dpo_train"] = dataset["train_prefs"].map( + partial(preprocess, tokenizer=tokenizer, test=False), + remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"], + ) + dataset["dpo_test"] = dataset["test_prefs"].map( + partial(preprocess, tokenizer=tokenizer, test=True), + remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"], + ) + + print( + f"Length of training dataset : {len(dataset['dpo_train'])} \ + Length of test dataset : {len(dataset['dpo_test'])}" + ) trlx.train( config=config, - samples=dataset["train_prefs"]["dpo"], - eval_prompts=dataset["test_prefs"]["prompt"][:128], - # metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, - stop_sequences=["User:", "user:", "Assistant:", "assistant:"] - + ["{e}x {i}put:".format(e=e, i=i) for e, i in itertools.product(["e", "E"], ["in", "In", "out", "Out"])], + samples=dataset["dpo_train"], + eval_prompts=dataset["dpo_test"]["prompt"][:8], # running eval on subset only + stop_sequences=["<|user|>", "<|User|>"], ) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index a824ffb98..6bab1db73 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -300,16 +300,19 @@ def __init__( @staticmethod def tokenize_preferences( - sample: Iterable[str], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_length=2048 + sample: Iterable[str], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + max_length=2048, + max_prompt_length=256, ) -> DPOElement: if isinstance(sample, Iterable): if len(sample) != 3: raise ValueError( f"Expected iterable of length 3 (prompt, chosen response, rejected response). Got {len(sample)}" ) - prompt_tokens = tokenizer(sample[0], add_special_tokens=False) - chosen_tokens = tokenizer(sample[1], add_special_tokens=False) - rejected_tokens = tokenizer(sample[2], add_special_tokens=False) + prompt_tokens = tokenizer(sample["prompt"], add_special_tokens=False) + chosen_tokens = tokenizer(sample["chosen"], add_special_tokens=False) + rejected_tokens = tokenizer(sample["rejected"], add_special_tokens=False) else: raise ValueError(f"{sample} is not an iterable") @@ -324,14 +327,14 @@ def tokenize_preferences( # if combined sequence is too long, truncate the prompt only if len(prompt_tokens["input_ids"]) + longer_response_length > max_length: if tokenizer.truncation_side == "right": - prompt_tokens = {k: v[:max_length] for k, v in prompt_tokens.items()} + prompt_tokens = {k: v[:max_prompt_length] for k, v in prompt_tokens.items()} elif tokenizer.truncation_side == "left": - prompt_tokens = {k: v[-max_length:] for k, v in prompt_tokens.items()} + prompt_tokens = {k: v[-max_prompt_length:] for k, v in prompt_tokens.items()} # if that's still too long, truncate the response if len(prompt_tokens["input_ids"]) + longer_response_length > max_length: - chosen_tokens = {k: v[: max_length - max_length] for k, v in chosen_tokens.items()} - rejected_tokens = {k: v[: max_length - max_length] for k, v in rejected_tokens.items()} + chosen_tokens = {k: v[: max_length - max_prompt_length] for k, v in chosen_tokens.items()} + rejected_tokens = {k: v[: max_length - max_prompt_length] for k, v in rejected_tokens.items()} return DPOElement(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens) diff --git a/trlx/trainer/accelerate_dpo_trainer.py b/trlx/trainer/accelerate_dpo_trainer.py index c6f91dd76..e01b8e684 100644 --- a/trlx/trainer/accelerate_dpo_trainer.py +++ b/trlx/trainer/accelerate_dpo_trainer.py @@ -10,6 +10,7 @@ if is_deepspeed_available(): import deepspeed +import trlx.utils.logging as logging from trlx.data.configs import TRLConfig from trlx.data.method_configs import MethodConfig, register_method from trlx.pipeline.offline_pipeline import DPOStore @@ -18,6 +19,9 @@ from trlx.utils.modeling import pad_to_length +logger = logging.get_logger(__name__) + + @dataclass @register_method class DPOConfig(MethodConfig): @@ -47,9 +51,10 @@ def __init__(self, config: TRLConfig, **kwargs): # TODO: Avoid setting up a reference model when hydra heads are used self.ref_model = self.get_arch(self.config) - if self.accelerator.state.deepspeed_plugin.zero_stage == 3: - self.ref_model = self._prepare_deepspeed_zero3(self.ref_model) - else: + try: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3: + self.ref_model = self._prepare_deepspeed_zero3(self.ref_model) + except: self.ref_model.to(self.accelerator.device) self.ref_model.eval() @@ -311,6 +316,8 @@ def prepare_learning(self): self.total_steps = self.config.train.epochs * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) - def make_experience(self, samples: Iterable[Iterable], seq_length: int): - preferences = [DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length) for sample in samples] + def make_experience(self, samples: Iterable[Iterable], seq_length: int, max_prompt_length: int): + preferences = [ + DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length, max_prompt_length) for sample in samples + ] self.store = DPOStore(preferences, self.tokenizer, self.label_pad_token_id, self.padding_value) diff --git a/trlx/trainer/accelerate_sft_trainer.py b/trlx/trainer/accelerate_sft_trainer.py index d5cbe3ea5..b76f50f14 100644 --- a/trlx/trainer/accelerate_sft_trainer.py +++ b/trlx/trainer/accelerate_sft_trainer.py @@ -87,7 +87,7 @@ def prepare_learning(self): self.total_steps = self.config.train.epochs * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) - def make_experience(self, samples, seq_length): + def make_experience(self, samples, seq_length, **kwargs): if isinstance(samples[0], str): self.store = PromptPipeline(samples, seq_length, self.tokenizer) else: diff --git a/trlx/trainer/nemo_sft_trainer.py b/trlx/trainer/nemo_sft_trainer.py index 7f25254f1..d7410b35c 100644 --- a/trlx/trainer/nemo_sft_trainer.py +++ b/trlx/trainer/nemo_sft_trainer.py @@ -133,7 +133,7 @@ def eval_collate(elems): torch.set_float32_matmul_precision("medium") self.trainer.fit(self.model) - def make_experience(self, samples, seq_length): + def make_experience(self, samples, seq_length, **kwargs): if isinstance(samples[0], str): self.store = PromptPipeline(samples, seq_length, self.tokenizer) else: diff --git a/trlx/trlx.py b/trlx/trlx.py index 7f2aef9c0..ba75b5f54 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -114,7 +114,8 @@ def train( # noqa: C901 if rewards is not None: trainer.make_experience(samples, rewards, config.train.seq_length) else: - trainer.make_experience(samples, config.train.seq_length) + # this should be abstracted for all trainers with **kwargs + trainer.make_experience(samples, config.train.seq_length, max_prompt_length) else: raise ValueError("Either `samples` or `reward_fn` should be given for training")