From 979c5c5e171fd1280d311e38ab36937622e64cc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 09:01:16 +0000 Subject: [PATCH 01/16] `get_batch_sample` -> `generate_from_model[_and_ref]` --- trl/trainer/bco_trainer.py | 4 ++-- trl/trainer/cpo_trainer.py | 4 ++-- trl/trainer/dpo_trainer.py | 4 ++-- trl/trainer/kto_trainer.py | 4 ++-- trl/trainer/orpo_trainer.py | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 91461a9b0d..480fa7a12c 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -1290,7 +1290,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: return None return SequentialSampler(self.train_dataset) - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -1407,7 +1407,7 @@ def evaluation_loop( "prompt_attention_mask": itemgetter(*target_indicies)(random_batch["prompt_attention_mask"]), "prompt": itemgetter(*target_indicies)(random_batch["prompt"]), } - policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch) + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) self.log( { diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 5847cb182b..5bf5238482 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -847,7 +847,7 @@ def compute_loss( return (loss, metrics) return loss - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -938,7 +938,7 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - policy_output_decoded = self.get_batch_samples(self.model, random_batch) + policy_output_decoded = self.generate_from_model(self.model, random_batch) self.log( { diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 8b9843cd91..380b1ce6e3 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1561,7 +1561,7 @@ def compute_loss( return (loss, metrics) return loss - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -1672,7 +1672,7 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch) + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch) self.log( { diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index ab9ba87e41..91aeb1a079 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1264,7 +1264,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: return None return SequentialSampler(self.train_dataset) - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -1383,7 +1383,7 @@ def evaluation_loop( "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies], "prompt": itemgetter(*target_indicies)(random_batch["prompt"]), } - policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch) + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) self.log( { diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 4edbf9b1a5..dc9e4080c0 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -866,7 +866,7 @@ def compute_loss( return (loss, metrics) return loss - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -957,7 +957,7 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - policy_output_decoded = self.get_batch_samples(self.model, random_batch) + policy_output_decoded = self.generate_from_model(self.model, random_batch) self.log( { From ada53cfa157bae50154d1c26d3d5f5a5e98803ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 09:21:49 +0000 Subject: [PATCH 02/16] add `num_items_in_batch=None` --- trl/trainer/bco_trainer.py | 1 + trl/trainer/cpo_trainer.py | 1 + trl/trainer/dpo_trainer.py | 1 + trl/trainer/gkd_trainer.py | 2 +- trl/trainer/kto_trainer.py | 1 + trl/trainer/orpo_trainer.py | 1 + trl/trainer/reward_trainer.py | 1 + 7 files changed, 7 insertions(+), 1 deletion(-) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 480fa7a12c..c6ce2d4902 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -1260,6 +1260,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 5bf5238482..c85ffb0a50 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -828,6 +828,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 380b1ce6e3..082a627ce0 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1547,6 +1547,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext() with compute_loss_context_manager: diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 1b7c77557d..efeb9e6d9a 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -215,7 +215,7 @@ def generalized_jsd_loss( else: return jsd - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # compute student output outputs_student = model( input_ids=inputs["input_ids"], diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 91aeb1a079..7f32424812 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1234,6 +1234,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index dc9e4080c0..54b97cd6d1 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -844,6 +844,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 787c6cbd54..0ebdee68b4 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -266,6 +266,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_reward_data_collator: warnings.warn( From 10bffa0f5acaa8cc85983599d11043660061ba27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 09:42:49 +0000 Subject: [PATCH 03/16] `num_items_in_batch` in `training_step` --- trl/trainer/gkd_trainer.py | 6 ++++-- trl/trainer/nash_md_trainer.py | 4 +++- trl/trainer/online_dpo_trainer.py | 4 +++- trl/trainer/xpo_trainer.py | 4 +++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index efeb9e6d9a..49e93e269b 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -273,7 +273,9 @@ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=No return generated_tokens, new_attention_mask, new_labels - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: """ Perform a training step for the Generalized Knowledge Distillation (GKD) model. @@ -298,7 +300,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, inputs["attention_mask"] = new_attention_mask inputs["labels"] = new_labels - loss = super().training_step(model, inputs) + loss = super().training_step(model, inputs, num_items_in_batch) return loss def _prepare_deepspeed(self, model: PreTrainedModelWrapper): diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index db0c3046b3..73aab7899a 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -328,7 +328,9 @@ def gather_mean(tensor): self.stats["beta"].append(self.beta) self.stats["mixture_coef"].append(self.mixture_coef) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: model.train() # Apply chat template and tokenize the input diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index ffc407b57d..c480c61fc5 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -366,7 +366,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None return self.accelerator.prepare(eval_dataloader) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: model.train() # Apply chat template and tokenize the input. diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 0255e6206f..a154875821 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -377,7 +377,9 @@ def gather_mean(tensor): self.stats["alpha"].append(self.alpha) self.stats["beta"].append(self.beta) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: model.train() # Apply chat template and tokenize the input From ca2d98f26d47be3161e48a8cc29fb38b9c6679c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 18 Oct 2024 12:11:57 +0200 Subject: [PATCH 04/16] Fix return type hint --- trl/trainer/cpo_trainer.py | 2 +- trl/trainer/orpo_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index c85ffb0a50..5e74fdaceb 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -848,7 +848,7 @@ def compute_loss( return (loss, metrics) return loss - def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> str: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 54b97cd6d1..123f935208 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -867,7 +867,7 @@ def compute_loss( return (loss, metrics) return loss - def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> str: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with From 0ecc5fb67c927e942ff09b5465bf7c377397bfb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 14:16:02 +0000 Subject: [PATCH 05/16] desc for unpair dataset util --- trl/data_utils.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/trl/data_utils.py b/trl/data_utils.py index 569d398b52..146466bd6b 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -199,7 +199,9 @@ def _unpair_row(examples: List[Dict[str, List[Dict[str, str]]]]) -> List[Dict[st return new_rows -def unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int] = None) -> DatasetType: +def unpair_preference_dataset( + dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None +) -> DatasetType: r""" Unpair a preference dataset. @@ -209,6 +211,8 @@ def unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int] = No `"prompt"`. num_proc (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. + desc (`str` or `None`, *optional*, defaults to `None`): + Meaningful description to be displayed alongside with the progress bar while mapping examples. Returns: `Dataset`: The unpaired preference dataset. @@ -233,10 +237,12 @@ def unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int] = No {'prompt': 'The sky is', 'completion': ' blue.', 'label': True} ``` """ - return dataset.map(_unpair_row, batched=True, remove_columns=["chosen", "rejected"], num_proc=num_proc) + return dataset.map(_unpair_row, batched=True, remove_columns=["chosen", "rejected"], num_proc=num_proc, desc=desc) -def maybe_unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int] = None) -> DatasetType: +def maybe_unpair_preference_dataset( + dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None +) -> DatasetType: r""" Unpair a preference dataset if it is paired. @@ -246,6 +252,8 @@ def maybe_unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int `"prompt"`. num_proc (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. + desc (`str` or `None`, *optional*, defaults to `None`): + Meaningful description to be displayed alongside with the progress bar while mapping examples. Returns: `Dataset` or `DatasetDict`: The unpaired preference dataset if it was paired, otherwise the original dataset. @@ -275,7 +283,7 @@ def maybe_unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int else: column_names = dataset.column_names if "chosen" in column_names and "rejected" in column_names: - return unpair_preference_dataset(dataset, num_proc=num_proc) + return unpair_preference_dataset(dataset, num_proc=num_proc, desc=desc) else: return dataset @@ -380,6 +388,8 @@ def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]: # "chosen": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}], # "rejected": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}]} # That's why we check if the prompt is also conversational before deciding not to extract it. + if "chosen" not in example or "rejected" not in example: # not a preference example + return example if "prompt" in example: # Both conversational or both non-conversational chosen_conv = is_conversational({"chosen": example["chosen"]}) From 2f60dd0c8c70f286d27baab8e0e6f4cfed97794a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 14:16:52 +0000 Subject: [PATCH 06/16] update example --- examples/scripts/kto.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index 84d56ac379..50dbcd5f36 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -55,7 +55,6 @@ --lora_alpha=16 """ -from accelerate import PartialState from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser @@ -65,7 +64,6 @@ ModelConfig, ScriptArguments, get_peft_config, - maybe_unpair_preference_dataset, setup_chat_format, ) @@ -95,24 +93,6 @@ # Load the dataset dataset = load_dataset(script_args.dataset_name) - # If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label) - dataset = maybe_unpair_preference_dataset(dataset, num_proc=training_args.dataset_num_proc) - - # Apply chat template - def format_dataset(example): - if isinstance(example["completion"], str): - example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False) - example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False) - else: - example["prompt"] = tokenizer.apply_chat_template(example["completion"][:-1], tokenize=False) - example["completion"] = tokenizer.apply_chat_template([example["completion"][-1]], tokenize=False) - return example - - # Compute that only on the main process for faster data processing. - # see: https://github.com/huggingface/trl/pull/1255 - with PartialState().local_main_process_first(): - dataset = dataset.map(format_dataset, num_proc=training_args.dataset_num_proc) - # Initialize the KTO trainer trainer = KTOTrainer( model, From c53d0fcc337b9c86cc01450279eb78b51e8630e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 14:17:07 +0000 Subject: [PATCH 07/16] process in KTO --- trl/trainer/kto_trainer.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 7f32424812..fa1541bb44 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -48,6 +48,7 @@ from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available +from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset from ..models import PreTrainedModelWrapper, create_reference_model from .kto_config import KTOConfig from .utils import ( @@ -566,11 +567,37 @@ def make_inputs_require_grad(module, input, output): " meaning the auxiliary loss will not be used." ) + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): - # Shuffle the datasets - train_dataset = train_dataset.shuffle(seed=args.data_seed) + # Extract the prompt if needed + train_dataset = train_dataset.map( + maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset" + ) + # Unpair the dataset if needed + train_dataset = maybe_unpair_preference_dataset( + train_dataset, args.dataset_num_proc, desc="Unpairing train dataset" + ) + # Apply the chat template if needed + train_dataset = train_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + desc="Applying chat template to train dataset", + ) if eval_dataset is not None: - eval_dataset = eval_dataset.shuffle(seed=args.data_seed) + eval_dataset = eval_dataset.map( + maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset" + ) + eval_dataset = maybe_unpair_preference_dataset( + eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset" + ) + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + desc="Applying chat template to eval dataset", + ) # Tokenize and prepare the training datasets train_dataset = train_dataset.map( From 27f483bed93714d4940e14d76f99f0fdbb3699d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 14:29:41 +0000 Subject: [PATCH 08/16] Update doc --- docs/source/dataset_formats.mdx | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/source/dataset_formats.mdx b/docs/source/dataset_formats.mdx index cc92ec0ff1..794102fb3e 100644 --- a/docs/source/dataset_formats.mdx +++ b/docs/source/dataset_formats.mdx @@ -199,21 +199,21 @@ unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", " Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer. -| Trainer | Expected dataset type | -| ----------------------- | ------------------------------------------------------- | -| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) | -| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | -| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | -| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) | -| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) | -| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) | -| [`NashMDTrainer`] | [Prompt-only](#prompt-only) | -| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | -| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | -| [`PPOTrainer`] | Tokenized language modeling | -| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | -| [`SFTTrainer`] | [Language modeling](#language-modeling) | -| [`XPOTrainer`] | [Prompt-only](#prompt-only) | +| Trainer | Expected dataset type | +| ----------------------- | ------------------------------------------------------------------------------------------------------ | +| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) | +| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) | +| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) | +| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | +| [`NashMDTrainer`] | [Prompt-only](#prompt-only) | +| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | +| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`PPOTrainer`] | Tokenized language modeling | +| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | +| [`SFTTrainer`] | [Language modeling](#language-modeling) | +| [`XPOTrainer`] | [Prompt-only](#prompt-only) | From 7a634185360e784574e4c8b9bb07c9ac97cee70b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 16:10:06 +0000 Subject: [PATCH 09/16] KTO doc rewrite --- docs/source/kto_trainer.mdx | 182 ++++++++++++++++++++---------------- 1 file changed, 103 insertions(+), 79 deletions(-) diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx index 7c6433be43..9200162e34 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.mdx @@ -2,109 +2,133 @@ [![](https://img.shields.io/badge/All_models-KTO-blue)](https://huggingface.co/models?other=kto,trl) -TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for aligning language models with binary feedback data (e.g., upvote/downvote), as described in the [paper](https://huggingface.co/papers/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela. -For a full example have a look at [`examples/scripts/kto.py`]. +## Overview -Depending on how good your base model is, you may or may not need to do SFT before KTO. -This is different from standard RLHF and DPO, which always require SFT. -You can also train with imbalanced data (more chosen than rejected examples, or vice-versa), but you will need to adjust hyperparameters accordingly (see below). +Kahneman-Tversky Optimization (KTO) was introduced in [KTO: Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306) by [Kawin Ethayarajh](https://huggingface.co/kawine), [Winnie Xu](https://huggingface.co/xwinxu), [Niklas Muennighoff](https://huggingface.co/Muennighoff), Dan Jurafsky, [Douwe Kiela](https://huggingface.co/douwekiela). -## Expected dataset type -The KTO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns: +The abstract from the paper is the following: -- `prompt` -- `completion` -- `label` +> Kahneman & Tversky's prospect theory tells us that humans perceive random variables in a biased but well-defined manner; for example, humans are famously loss-averse. We show that objectives for aligning LLMs with human feedback implicitly incorporate many of these biases -- the success of these objectives (e.g., DPO) over cross-entropy minimization can partly be ascribed to them being human-aware loss functions (HALOs). However, the utility functions these methods attribute to humans still differ from those in the prospect theory literature. Using a Kahneman-Tversky model of human utility, we propose a HALO that directly maximizes the utility of generations instead of maximizing the log-likelihood of preferences, as current methods do. We call this approach Kahneman-Tversky Optimization (KTO), and it matches or exceeds the performance of preference-based methods at scales from 1B to 30B. Crucially, KTO does not need preferences -- only a binary signal of whether an output is desirable or undesirable for a given input. This makes it far easier to use in the real world, where preference data is scarce and expensive. -for example: +The official code can be found in [ContextualAI/HALOs](https://github.com/ContextualAI/HALOs). +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Younes Belkada](https://huggingface.co/ybelkada), [Lewis Tunstall](https://huggingface.co/lewtun) and Pablo Vicente. + +## Quick start + +This example demonstrates how to train a model using the KTO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [KTO Mix 14k](https://huggingface.co/datasets/trl-lib/kto-mix-14k). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_kto.py +from datasets import load_dataset +from trl import KTOConfig, KTOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +train_dataset = load_dataset("trl-lib/kto-mix-14k", split="train") + +training_args = KTOConfig(output_dir="Qwen2-0.5B-KTO", logging_steps=10) +trainer = KTOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() ``` -kto_dataset_dict = { - "prompt": [ - "Hey, hello", - "How are you", - "What is your name?", - "What is your name?", - "Which is the best programming language?", - "Which is the best programming language?", - "Which is the best programming language?", - ], - "completion": [ - "hi nice to meet you", - "leave me alone", - "I don't have a name", - "My name is Mary", - "Python", - "C++", - "Java", - ], - "label": [ - True, - False, - False, - True, - True, - False, - False, - ], -} + +Execute the script using the following command: + +```bash +accelerate launch train_kto.py ``` -where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`). -A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. -In theory, the dataset must contain at least one desirable and one undesirable completion; however, some people have had success running KTO on _only_ desirable or undesirable data (in the latter case, it is best to use a conservative learning rate). +Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time. +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/orpo-qwen2-reward-margin.png) -## Expected model format -The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [TRL Chat CLI](clis#chat-interface). -## Using the `KTOTrainer` +
$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-KTO
+<quentin_gallouedec>:
+What is the best programming language?
 
-For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response. 
+<trl-lib/Qwen2-0.5B-KTO>:
+The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:                                                                                  
 
-The `beta` refers to the hyperparameter that controls how quickly the loss saturates, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
+Here are some other factors to consider when choosing a programming language for a project:
 
-The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
-By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.
+ 1 JavaScript: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.                                                                   
+ 2 Java: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.                                                                                                                                                            
+ 3 C++: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.                                                                                                                                         
+ 4 Python: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.   
+
- -Every choice of `beta` has a maximum learning rate it will tolerate before learning degenerates. For the default `beta = 0.1', this learning rate is `1e-6` for most models. The lower the beta is, the lower your learning rate should be. In general, we strongly recommend a learning rate between `5e-7` and `5e-6`. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, use more epochs. - +## Expected dataset format - -Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor. - - -```py -training_args = KTOConfig( - beta=0.1, - desirable_weight=1.0, - undesirable_weight=1.0, - learning_rate=5e-7, -) - -kto_trainer = KTOTrainer( - model, - ref_model, - args=training_args, - train_dataset=train_dataset, - processing_class=tokenizer, -) -``` -After this one can then call: +KTO requires an [unpaired preference dataset](dataset_formats#unpaired-preference). Alternatively, you can provide a *paired* preference dataset (also known simply as a *preference dataset*). In this case, the trainer will automatically convert it to an unpaired format by separating the chosen and rejected responses, assigning `label = True` to the chosen completions and `label = False` to the rejected ones. + +The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +In theory, the dataset should contain at least one chosen and one rejected completion. However, some users have successfully run KTO using *only* chosen or only rejected data. If using only rejected data, it is advisable to adopt a conservative learning rate. + +## Example script -```py -kto_trainer.train() +We provide an example script to train a model using the KTO method. The script is available in [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) + +To test the KTO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/kto-mix-14k), run the following command: + +```bash +accelerate launch examples/scripts/kto.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/kto-mix-14k \ + --num_train_epochs 1 \ + --logging_steps 25 \ + --output_dir Qwen2-0.5B-KTO ``` +## Usage tips + ### For Mixture of Experts Models: Enabling the auxiliary loss MOEs are the most efficient if the load is about equally distributed between experts. -To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config. + + +### Batch size recommendations + +Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor. + +### Learning rate recommendations + +Each choice of `beta` has a maximum learning rate it can tolerate before learning performance degrades. For the default setting of `beta = 0.1`, the learning rate should typically not exceed `1e-6` for most models. As `beta` decreases, the learning rate should also be reduced accordingly. In general, we strongly recommend keeping the learning rate between `5e-7` and `5e-6`. Even with small datasets, we advise against using a learning rate outside this range. Instead, opt for more epochs to achieve better results. + +### Imbalanced data + +The `desirable_weight` and `undesirable_weight` of the [`KTOConfig`] refer to the weights placed on the losses for desirable/positive and undesirable/negative examples. +By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3. + +## Logged metrics + +While training and evaluating we record the following reward metrics: -This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig). -To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001). +- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta +- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta +- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards +- `logps/chosen`: ... +- `logps/rejected`: ... +- `logits/chosen` +- `logits/rejected` +- `kl`: the KL divergence between the policy model and the reference model ## KTOTrainer From b9f9ce27f0c525132865794208a5a27b14889db6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 16:10:17 +0000 Subject: [PATCH 10/16] fix orpo doc --- docs/source/orpo_trainer.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/orpo_trainer.md b/docs/source/orpo_trainer.md index 628383c5bd..02d0b9c86b 100644 --- a/docs/source/orpo_trainer.md +++ b/docs/source/orpo_trainer.md @@ -4,7 +4,7 @@ ## Overview -Odds Ratio Preference Optimization (ORPO) wa introduced in [ORPO: Monolithic Preference Optimization without Reference Model](https://huggingface.co/papers/2403.07691) by [Jiwoo Hong](https://huggingface.co/JW17), [Noah Lee](https://huggingface.co/nlee-208), and [James Thorne](https://huggingface.co/j6mes). +Odds Ratio Preference Optimization (ORPO) was introduced in [ORPO: Monolithic Preference Optimization without Reference Model](https://huggingface.co/papers/2403.07691) by [Jiwoo Hong](https://huggingface.co/JW17), [Noah Lee](https://huggingface.co/nlee-208), and [James Thorne](https://huggingface.co/j6mes). The abstract from the paper is the following: @@ -95,7 +95,7 @@ accelerate launch examples/scripts/orpo.py \ --dataset_name trl-lib/ultrafeedback_binarized \ --num_train_epochs 1 \ --logging_steps 25 \ - --output_dir Qwen2-0.5B-DPO + --output_dir Qwen2-0.5B-ORPO ``` ## Usage tips From 3bfcd4b3ff5f64b8fc3b71c6aab5b674785ee06b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 16:24:33 +0000 Subject: [PATCH 11/16] add other dataset config names in test --- tests/test_kto_trainer.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index 6154aeeecf..d3ef30aa06 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -42,17 +42,17 @@ def setUp(self): @parameterized.expand( [ - ["gpt2", "kto", True, True], - ["gpt2", "kto", True, False], - ["gpt2", "kto", False, True], - ["gpt2", "kto", False, False], - ["gpt2", "apo_zero_unpaired", True, True], - ["gpt2", "apo_zero_unpaired", True, False], - ["gpt2", "apo_zero_unpaired", False, True], - ["gpt2", "apo_zero_unpaired", False, False], + ("gpt2", "standard_preference", "kto", True, True), + ("t5", "standard_implicit_prompt_preference", "kto", True, False), + ("gpt2", "standard_unpaired_preference", "kto", False, True), + ("t5", "conversational_preference", "kto", False, False), + ("gpt2", "conversational_implicit_prompt_preference", "apo_zero_unpaired", True, True), + ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", True, False), + ("gpt2", "standard_unpaired_preference", "apo_zero_unpaired", False, True), + ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", False, False), ] ) - def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset): + def test_kto_trainer(self, name, config_name, loss_type, pre_compute, eval_dataset): with tempfile.TemporaryDirectory() as tmp_dir: training_args = KTOConfig( output_dir=tmp_dir, @@ -68,7 +68,7 @@ def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset): report_to="none", ) - dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) if name == "gpt2": model = self.model From 9e977aff66e602f6b34c00ddccae93c1672d1592 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 16:34:39 +0000 Subject: [PATCH 12/16] update doc image --- docs/source/kto_trainer.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx index 9200162e34..00b4f733cb 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.mdx @@ -51,7 +51,7 @@ accelerate launch train_kto.py Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time. -![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/orpo-qwen2-reward-margin.png) +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kto-qwen2-reward-margin.png) To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [TRL Chat CLI](clis#chat-interface). From 97702d3d502fad3156b53e69b26c310075793474 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 16:36:46 +0000 Subject: [PATCH 13/16] fix links in doc --- docs/source/kto_trainer.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx index 00b4f733cb..4376fa2c53 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.mdx @@ -74,7 +74,7 @@ Here are some other factors to consider when choosing a programming language for KTO requires an [unpaired preference dataset](dataset_formats#unpaired-preference). Alternatively, you can provide a *paired* preference dataset (also known simply as a *preference dataset*). In this case, the trainer will automatically convert it to an unpaired format by separating the chosen and rejected responses, assigning `label = True` to the chosen completions and `label = False` to the rejected ones. -The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. +The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. In theory, the dataset should contain at least one chosen and one rejected completion. However, some users have successfully run KTO using *only* chosen or only rejected data. If using only rejected data, it is advisable to adopt a conservative learning rate. From 531441aaad58ef98f97825c54968bdb3241f2cf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 16:39:17 +0000 Subject: [PATCH 14/16] Update reward and log probability metrics in KTOTrainer doc --- docs/source/kto_trainer.mdx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx index 4376fa2c53..dc881f9577 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.mdx @@ -124,10 +124,10 @@ While training and evaluating we record the following reward metrics: - `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta - `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta - `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards -- `logps/chosen`: ... -- `logps/rejected`: ... -- `logits/chosen` -- `logits/rejected` +- `logps/chosen`: the mean log probabilities of the chosen completions +- `logps/rejected`: the mean log probabilities of the rejected completions +- `logits/chosen`: the mean logits of the chosen completions +- `logits/rejected`: the mean logits of the rejected completions - `kl`: the KL divergence between the policy model and the reference model ## KTOTrainer From 50e8e9775d40cb08d19e1383be7a0e7494e88b6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 18 Oct 2024 17:44:41 +0000 Subject: [PATCH 15/16] skip enc-dec test --- tests/test_kto_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index d3ef30aa06..d5a094a66e 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -43,13 +43,13 @@ def setUp(self): @parameterized.expand( [ ("gpt2", "standard_preference", "kto", True, True), - ("t5", "standard_implicit_prompt_preference", "kto", True, False), + # ("t5", "standard_implicit_prompt_preference", "kto", True, False), # KTO broken for enc-dec ("gpt2", "standard_unpaired_preference", "kto", False, True), - ("t5", "conversational_preference", "kto", False, False), + # ("t5", "conversational_preference", "kto", False, False), ("gpt2", "conversational_implicit_prompt_preference", "apo_zero_unpaired", True, True), - ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", True, False), + # ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", True, False), ("gpt2", "standard_unpaired_preference", "apo_zero_unpaired", False, True), - ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", False, False), + # ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", False, False), ] ) def test_kto_trainer(self, name, config_name, loss_type, pre_compute, eval_dataset): From 58f64b8b7ab8d4c5912de825f1b772056a1214b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 24 Oct 2024 11:39:05 +0200 Subject: [PATCH 16/16] Update docs/source/kto_trainer.mdx Co-authored-by: lewtun --- docs/source/kto_trainer.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx index dc881f9577..1ed6a33613 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.mdx @@ -49,7 +49,7 @@ Execute the script using the following command: accelerate launch train_kto.py ``` -Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time. +Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time. ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kto-qwen2-reward-margin.png)