From 0ce3b659280fc61bb1c6b18b6c35a2b1d6f30703 Mon Sep 17 00:00:00 2001 From: anch0vy Date: Mon, 28 Oct 2024 19:49:35 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=8C=20Fix=20type=20hint=20in=20`LogCom?= =?UTF-8?q?pletionsCallback`=20(#2285)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update callbacks.py for fix small python type error * Update callbacks.py --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/callbacks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index eefe83ce96..24e33a7566 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -364,9 +364,9 @@ class LogCompletionsCallback(WandbCallback): column containing the prompts for generating completions. generation_config (`GenerationConfig`, *optional*): The generation config to use for generating completions. - num_prompts (`int`, *optional*): + num_prompts (`int` or `None`, *optional*): The number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset. - freq (`int`, *optional*): + freq (`int` or `None`, *optional*): The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`. """ @@ -374,8 +374,8 @@ def __init__( self, trainer: Trainer, generation_config: Optional[GenerationConfig] = None, - num_prompts: int = None, - freq: int = None, + num_prompts: Optional[int] = None, + freq: Optional[int] = None, ): super().__init__() self.trainer = trainer