From 163695e85c4dfce391eef5a222bc423e3fea4014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sun, 24 Nov 2024 16:55:43 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=99=88=20Suppress=20warning=20for=20estim?= =?UTF-8?q?ating=20tokens=20in=20trainers=20(#2389)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Suppress warning for estimating tokens in trainer * Suppress warning for estimating FLOPs in ORPO and Reward trainers --- trl/trainer/bco_trainer.py | 9 +++++++++ trl/trainer/cpo_trainer.py | 9 +++++++++ trl/trainer/dpo_trainer.py | 9 +++++++++ trl/trainer/kto_trainer.py | 9 +++++++++ trl/trainer/online_dpo_trainer.py | 8 ++++++++ trl/trainer/orpo_trainer.py | 9 +++++++++ trl/trainer/reward_trainer.py | 9 +++++++++ 7 files changed, 62 insertions(+) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index fa901a46f4..287cce8436 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -581,6 +581,15 @@ def make_inputs_require_grad(module, input, output): self.embedding_func = embedding_func self.embedding_tokenizer = embedding_tokenizer + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result, + # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point + # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's + # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been + # issued. + model.warnings_issued["estimate_tokens"] = True + with PartialState().local_main_process_first(): # Apply the chat template if needed train_dataset = train_dataset.map( diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index f51dfa0c34..5ff8ca5885 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -317,6 +317,15 @@ def make_inputs_require_grad(module, input, output): self._stored_metrics = defaultdict(lambda: defaultdict(list)) + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and + # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens + # of the input, floating-point operations will not be computed." To suppress this warning, we set the + # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate + # that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + # Compute that only on the main process for faster data processing. # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 70707b5073..4880c231df 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -432,6 +432,15 @@ def make_inputs_require_grad(module, input, output): self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} self.dataset_num_proc = args.dataset_num_proc + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in DPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and + # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens + # of the input, floating-point operations will not be computed." To suppress this warning, we set the + # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate + # that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + # Compute that only on the main process for faster data processing. # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 790788fb77..629486674e 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -576,6 +576,15 @@ def make_inputs_require_grad(module, input, output): " meaning the auxiliary loss will not be used." ) + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result, + # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point + # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's + # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been + # issued. + model.warnings_issued["estimate_tokens"] = True + # Compute that only on the main process for faster data processing. # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 23ca6dd047..c615cd6349 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -248,6 +248,14 @@ def __init__( use_cache=False if args.gradient_checkpointing else True, ) + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include + # the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens + # of the input, floating-point operations will not be computed." To suppress this warning, we set the + # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate + # that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + super().__init__( model=model, args=args, diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 0e3bbc7421..45fe7009ff 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -307,6 +307,15 @@ def make_inputs_require_grad(module, input, output): self._stored_metrics = defaultdict(lambda: defaultdict(list)) + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and + # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens + # of the input, floating-point operations will not be computed." To suppress this warning, we set the + # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate + # that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + # Compute that only on the main process for faster data processing. # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index f3456cab41..d9281d4e27 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -209,6 +209,15 @@ def __init__( else: self.use_reward_data_collator = False + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result, + # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point + # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's + # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been + # issued. + model.warnings_issued["estimate_tokens"] = True + if "input_ids_chosen" not in train_dataset.column_names: with PartialState().local_main_process_first(): fn_kwargs = {"tokenizer": processing_class}