Skip to content

Commit

Permalink
💾 Deprecate config in favor of args in PPOTrainer (#2384)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Nov 25, 2024
1 parent 17e8060 commit ee3cbe1
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion examples/research_projects/tools/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def solution():
optimize_cuda_cache=True,
)

ppo_trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
ppo_trainer = PPOTrainer(args=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader)

# text env
Expand Down
2 changes: 1 addition & 1 deletion examples/research_projects/tools/triviaqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ScriptArguments:
seed=script_args.seed,
optimize_cuda_cache=True,
)
ppo_trainer = PPOTrainer(config=config, model=model, tokenizer=tokenizer)
ppo_trainer = PPOTrainer(args=config, model=model, tokenizer=tokenizer)
dataset = load_dataset("mandarjoshi/trivia_qa", "rc", split="train")
local_seed = script_args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime
dataset = dataset.shuffle(local_seed)
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def tokenize(element):
# Training
################
trainer = PPOTrainer(
config=training_args,
args=training_args,
processing_class=tokenizer,
policy=policy,
ref_policy=ref_policy,
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def tokenize(element):
# Training
################
trainer = PPOTrainer(
config=training_args,
args=training_args,
processing_class=tokenizer,
policy=policy,
ref_policy=ref_policy,
Expand Down
13 changes: 7 additions & 6 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,21 @@
from ..core import masked_mean, masked_whiten
from ..models import create_reference_model
from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
from .ppo_config import PPOConfig
from .utils import (
OnlineTrainerState,
batch_generation,
disable_dropout_in_model,
exact_div,
first_true_indices,
forward,
generate_model_card,
get_reward,
peft_module_casting_to_bf16,
prepare_deepspeed,
print_rich_table,
truncate_response,
)
from .ppo_config import PPOConfig
from .utils import generate_model_card, peft_module_casting_to_bf16


if is_peft_available():
Expand Down Expand Up @@ -97,10 +98,11 @@ def forward(self, **kwargs):
class PPOTrainer(Trainer):
_tag_names = ["trl", "ppo"]

@deprecate_kwarg("config", new_name="args", version="0.15.0", raise_if_both_names=True)
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
def __init__(
self,
config: PPOConfig,
args: PPOConfig,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
],
Expand All @@ -122,8 +124,7 @@ def __init__(
"same as `policy`, you must make a copy of it, or `None` if you use peft."
)

self.args = config
args = config
self.args = args
self.processing_class = processing_class
self.policy = policy

Expand Down

0 comments on commit ee3cbe1

Please sign in to comment.