diff --git a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py index d21ecd3d4b..b173694239 100644 --- a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py @@ -78,7 +78,7 @@ class ScriptArguments: def get_stack_exchange_paired( data_dir: str = "data/rl", sanity_check: bool = False, - cache_dir: str = None, + cache_dir: Optional[str] = None, num_proc=24, ) -> Dataset: """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format. diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index f0031f2881..1a08f50818 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -50,7 +50,7 @@ --lora_alpha=16 """ from dataclasses import dataclass, field -from typing import Dict +from typing import Dict, Optional import torch from datasets import Dataset, load_dataset @@ -87,7 +87,7 @@ def extract_anthropic_prompt(prompt_and_response): return prompt_and_response[: search_term_idx + len(search_term)] -def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset: +def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: Optional[str] = None) -> Dataset: """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format. The dataset is converted to a dictionary with the following structure: diff --git a/trl/core.py b/trl/core.py index d4ecba81da..1a0e8761a6 100644 --- a/trl/core.py +++ b/trl/core.py @@ -113,7 +113,7 @@ def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: return whitened -def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool = None) -> torch.Tensor: +def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: """Compute mean of tensor with a masked values.""" if axis is not None: return (values * mask).sum(axis=axis) / mask.sum(axis=axis) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index 25f44ae935..58b61fd17f 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -14,6 +14,7 @@ import re import warnings +from typing import Optional import torch from accelerate.utils import extract_model_from_parallel @@ -416,7 +417,7 @@ def _generate_batched( self, query_tensors, batch_size: int = 16, - pad_to_multiple_of: int = None, + pad_to_multiple_of: Optional[int] = None, ): """ Generate responses for a list of query tensors. diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py index eb639543ad..f6d4e86bba 100644 --- a/trl/models/modeling_base.py +++ b/trl/models/modeling_base.py @@ -15,6 +15,7 @@ import logging import os from copy import deepcopy +from typing import Optional import torch import torch.nn as nn @@ -600,7 +601,7 @@ def compute_reward_score(self, input_ids, attention_mask=None, **kwargs): def create_reference_model( - model: PreTrainedModelWrapper, num_shared_layers: int = None, pattern: str = None + model: PreTrainedModelWrapper, num_shared_layers: Optional[int] = None, pattern: Optional[str] = None ) -> PreTrainedModelWrapper: """ Creates a static reference copy of a model. Note that model will be in `.eval()` mode. diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 3a27c50889..e4f77a64b5 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -139,12 +139,12 @@ class DPOTrainer(Trainer): def __init__( self, - model: Union[PreTrainedModel, nn.Module, str] = None, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, beta: float = 0.1, label_smoothing: float = 0, loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid", - args: TrainingArguments = None, + args: Optional[TrainingArguments] = None, data_collator: Optional[DataCollator] = None, label_pad_token_id: int = -100, padding_value: Optional[int] = None, @@ -165,11 +165,11 @@ def __init__( generate_during_eval: bool = False, compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, precompute_ref_log_probs: bool = False, - dataset_num_proc: int = None, + dataset_num_proc: Optional[int] = None, model_init_kwargs: Optional[Dict] = None, ref_model_init_kwargs: Optional[Dict] = None, - model_adapter_name: str = None, - ref_adapter_name: str = None, + model_adapter_name: Optional[str] = None, + ref_adapter_name: Optional[str] = None, ): if model_init_kwargs is None: model_init_kwargs = {} @@ -585,7 +585,7 @@ def build_tokenized_answer(self, prompt, answer): attention_mask=answer_attention_mask, ) - def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict: + def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict: """Tokenize a single row from a DPO specific dataset. At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py index 006b02ad51..cd231939c0 100644 --- a/trl/trainer/iterative_sft_trainer.py +++ b/trl/trainer/iterative_sft_trainer.py @@ -60,9 +60,9 @@ class IterativeSFTTrainer(Trainer): def __init__( self, - model: PreTrainedModel = None, - args: TrainingArguments = None, - tokenizer: PreTrainedTokenizerBase = None, + model: Optional[PreTrainedModel] = None, + args: Optional[TrainingArguments] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( None, None, diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 2f53bbdbfb..537f77c8d9 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -142,10 +142,10 @@ class PPOTrainer(BaseTrainer): def __init__( self, - config: PPOConfig = None, - model: PreTrainedModelWrapper = None, + config: Optional[PPOConfig] = None, + model: Optional[PreTrainedModelWrapper] = None, ref_model: Optional[PreTrainedModelWrapper] = None, - tokenizer: PreTrainedTokenizerBase = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None, optimizer: Optional[torch.optim.Optimizer] = None, data_collator: Optional[typing.Callable] = None, @@ -431,7 +431,7 @@ def _remove_unused_columns(self, dataset: "Dataset"): def generate( self, query_tensor: Union[torch.Tensor, List[torch.Tensor]], - length_sampler: Callable = None, + length_sampler: Optional[Callable] = None, batch_size: int = 4, return_prompt: bool = True, generate_ref_response: bool = False, @@ -508,10 +508,10 @@ def _generate_batched( self, model: PreTrainedModelWrapper, query_tensors: List[torch.Tensor], - length_sampler: Callable = None, + length_sampler: Optional[Callable] = None, batch_size: int = 4, return_prompt: bool = True, - pad_to_multiple_of: int = None, + pad_to_multiple_of: Optional[int] = None, remove_padding: bool = True, **generation_kwargs, ): diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 5a466f8489..45babf3bfe 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -55,7 +55,7 @@ class RewardTrainer(Trainer): def __init__( self, - model: Union[PreTrainedModel, nn.Module] = None, + model: Optional[Union[PreTrainedModel, nn.Module]] = None, args: Optional[RewardConfig] = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index e5164c4e38..9f103e31c0 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -121,8 +121,8 @@ class SFTTrainer(Trainer): def __init__( self, - model: Union[PreTrainedModel, nn.Module, str] = None, - args: TrainingArguments = None, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + args: Optional[TrainingArguments] = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index e20927a095..fc91e4d4b2 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -82,7 +82,7 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): def __init__( self, response_template: Union[str, List[int]], - instruction_template: Union[str, List[int]] = None, + instruction_template: Optional[Union[str, List[int]]] = None, *args, mlm: bool = False, ignore_index: int = -100,