Skip to content

Commit

Permalink
🙈 Suppress warning for estimating tokens in trainers (#2389)
Browse files Browse the repository at this point in the history
* Suppress warning for estimating tokens in trainer

* Suppress warning for estimating FLOPs in ORPO and Reward trainers
  • Loading branch information
qgallouedec authored Nov 24, 2024
1 parent 672c965 commit 163695e
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 0 deletions.
9 changes: 9 additions & 0 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,15 @@ def make_inputs_require_grad(module, input, output):
self.embedding_func = embedding_func
self.embedding_tokenizer = embedding_tokenizer

# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
# issued.
model.warnings_issued["estimate_tokens"] = True

with PartialState().local_main_process_first():
# Apply the chat template if needed
train_dataset = train_dataset.map(
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,15 @@ def make_inputs_require_grad(module, input, output):

self._stored_metrics = defaultdict(lambda: defaultdict(list))

# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
# that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,15 @@ def make_inputs_require_grad(module, input, output):
self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}
self.dataset_num_proc = args.dataset_num_proc

# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in DPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
# that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,15 @@ def make_inputs_require_grad(module, input, output):
" meaning the auxiliary loss will not be used."
)

# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
# issued.
model.warnings_issued["estimate_tokens"] = True

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
Expand Down
8 changes: 8 additions & 0 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,14 @@ def __init__(
use_cache=False if args.gradient_checkpointing else True,
)

# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
# the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
# that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True

super().__init__(
model=model,
args=args,
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,15 @@ def make_inputs_require_grad(module, input, output):

self._stored_metrics = defaultdict(lambda: defaultdict(list))

# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
# that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,15 @@ def __init__(
else:
self.use_reward_data_collator = False

# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
# "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
# issued.
model.warnings_issued["estimate_tokens"] = True

if "input_ids_chosen" not in train_dataset.column_names:
with PartialState().local_main_process_first():
fn_kwargs = {"tokenizer": processing_class}
Expand Down

0 comments on commit 163695e

Please sign in to comment.