Skip to content

Commit

Permalink
Types: Fix PEP 484 implicit-optional compliance (#1297)
Browse files Browse the repository at this point in the history
This was done automatically with hauntsaninja/no_implicit_optional.
  • Loading branch information
akx authored Jan 31, 2024
1 parent 6f40f20 commit 88685f2
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class ScriptArguments:
def get_stack_exchange_paired(
data_dir: str = "data/rl",
sanity_check: bool = False,
cache_dir: str = None,
cache_dir: Optional[str] = None,
num_proc=24,
) -> Dataset:
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
--lora_alpha=16
"""
from dataclasses import dataclass, field
from typing import Dict
from typing import Dict, Optional

import torch
from datasets import Dataset, load_dataset
Expand Down Expand Up @@ -87,7 +87,7 @@ def extract_anthropic_prompt(prompt_and_response):
return prompt_and_response[: search_term_idx + len(search_term)]


def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: Optional[str] = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
Expand Down
2 changes: 1 addition & 1 deletion trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
return whitened


def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool = None) -> torch.Tensor:
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
if axis is not None:
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
Expand Down
3 changes: 2 additions & 1 deletion trl/environment/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import re
import warnings
from typing import Optional

import torch
from accelerate.utils import extract_model_from_parallel
Expand Down Expand Up @@ -416,7 +417,7 @@ def _generate_batched(
self,
query_tensors,
batch_size: int = 16,
pad_to_multiple_of: int = None,
pad_to_multiple_of: Optional[int] = None,
):
"""
Generate responses for a list of query tensors.
Expand Down
3 changes: 2 additions & 1 deletion trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import os
from copy import deepcopy
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -600,7 +601,7 @@ def compute_reward_score(self, input_ids, attention_mask=None, **kwargs):


def create_reference_model(
model: PreTrainedModelWrapper, num_shared_layers: int = None, pattern: str = None
model: PreTrainedModelWrapper, num_shared_layers: Optional[int] = None, pattern: Optional[str] = None
) -> PreTrainedModelWrapper:
"""
Creates a static reference copy of a model. Note that model will be in `.eval()` mode.
Expand Down
12 changes: 6 additions & 6 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ class DPOTrainer(Trainer):

def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
label_smoothing: float = 0,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid",
args: TrainingArguments = None,
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
padding_value: Optional[int] = None,
Expand All @@ -165,11 +165,11 @@ def __init__(
generate_during_eval: bool = False,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
precompute_ref_log_probs: bool = False,
dataset_num_proc: int = None,
dataset_num_proc: Optional[int] = None,
model_init_kwargs: Optional[Dict] = None,
ref_model_init_kwargs: Optional[Dict] = None,
model_adapter_name: str = None,
ref_adapter_name: str = None,
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
):
if model_init_kwargs is None:
model_init_kwargs = {}
Expand Down Expand Up @@ -585,7 +585,7 @@ def build_tokenized_answer(self, prompt, answer):
attention_mask=answer_attention_mask,
)

def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict:
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
"""Tokenize a single row from a DPO specific dataset.
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
Expand Down
6 changes: 3 additions & 3 deletions trl/trainer/iterative_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class IterativeSFTTrainer(Trainer):

def __init__(
self,
model: PreTrainedModel = None,
args: TrainingArguments = None,
tokenizer: PreTrainedTokenizerBase = None,
model: Optional[PreTrainedModel] = None,
args: Optional[TrainingArguments] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
None,
None,
Expand Down
12 changes: 6 additions & 6 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ class PPOTrainer(BaseTrainer):

def __init__(
self,
config: PPOConfig = None,
model: PreTrainedModelWrapper = None,
config: Optional[PPOConfig] = None,
model: Optional[PreTrainedModelWrapper] = None,
ref_model: Optional[PreTrainedModelWrapper] = None,
tokenizer: PreTrainedTokenizerBase = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
data_collator: Optional[typing.Callable] = None,
Expand Down Expand Up @@ -431,7 +431,7 @@ def _remove_unused_columns(self, dataset: "Dataset"):
def generate(
self,
query_tensor: Union[torch.Tensor, List[torch.Tensor]],
length_sampler: Callable = None,
length_sampler: Optional[Callable] = None,
batch_size: int = 4,
return_prompt: bool = True,
generate_ref_response: bool = False,
Expand Down Expand Up @@ -508,10 +508,10 @@ def _generate_batched(
self,
model: PreTrainedModelWrapper,
query_tensors: List[torch.Tensor],
length_sampler: Callable = None,
length_sampler: Optional[Callable] = None,
batch_size: int = 4,
return_prompt: bool = True,
pad_to_multiple_of: int = None,
pad_to_multiple_of: Optional[int] = None,
remove_padding: bool = True,
**generation_kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class RewardTrainer(Trainer):

def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
args: Optional[RewardConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class SFTTrainer(Trainer):

def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
args: TrainingArguments = None,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
def __init__(
self,
response_template: Union[str, List[int]],
instruction_template: Union[str, List[int]] = None,
instruction_template: Optional[Union[str, List[int]]] = None,
*args,
mlm: bool = False,
ignore_index: int = -100,
Expand Down

0 comments on commit 88685f2

Please sign in to comment.