diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 38acd85994..a3d7b1e974 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -47,7 +47,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput -from transformers.utils import check_min_version, is_peft_available +from transformers.utils import is_peft_available from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt @@ -531,10 +531,6 @@ def make_inputs_require_grad(module, input, output): self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) - # num_logits_to_keep is supported since transformers v4.45.0 - if self.use_num_logits_to_keep: - check_min_version("4.45.0") - if self.loss_type == "bco_pair": self.running = RunningMoments(self.accelerator)