diff --git a/examples/research_projects/stack_llama/scripts/rl_training.py b/examples/research_projects/stack_llama/scripts/rl_training.py index 011c00554f..633cd53d05 100644 --- a/examples/research_projects/stack_llama/scripts/rl_training.py +++ b/examples/research_projects/stack_llama/scripts/rl_training.py @@ -20,9 +20,9 @@ from datasets import load_dataset from peft import LoraConfig from tqdm import tqdm -from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline +from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline, set_seed -from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer from trl.core import LengthSampler diff --git a/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py b/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py index d3998c7882..2684d31680 100644 --- a/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py +++ b/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py @@ -25,9 +25,10 @@ HfArgumentParser, RobertaForSequenceClassification, RobertaTokenizer, + set_seed, ) -from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model, set_seed +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model from trl.core import LengthSampler diff --git a/tests/test_core.py b/tests/test_core.py index 88ecf38fad..16a6284753 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -16,7 +16,7 @@ import torch -from trl.core import masked_mean, masked_var, masked_whiten, whiten +from trl.core import masked_mean, masked_var, masked_whiten class CoreTester(unittest.TestCase): @@ -36,6 +36,10 @@ def test_masked_var(self): self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask)) def test_masked_whiten(self): + def whiten(values: torch.Tensor) -> torch.Tensor: + mean, var = torch.mean(values), torch.var(values) + return (values - mean) * torch.rsqrt(var + 1e-8) + whiten_unmasked = whiten(self.test_input_unmasked) whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3] diffs = (whiten_unmasked - whiten_masked).sum() diff --git a/trl/__init__.py b/trl/__init__.py index 2df981580c..5598230781 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -21,7 +21,6 @@ _import_structure = { "scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"], - "core": ["set_seed"], "data_utils": [ "apply_chat_template", "extract_prompt", @@ -115,7 +114,6 @@ _import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"]) if TYPE_CHECKING: - from .core import set_seed from .data_utils import ( apply_chat_template, extract_prompt, diff --git a/trl/core.py b/trl/core.py index f62c8bc414..776ed5bdce 100644 --- a/trl/core.py +++ b/trl/core.py @@ -13,62 +13,14 @@ # limitations under the License. import gc -import random import warnings +from collections.abc import Mapping from contextlib import contextmanager from typing import Optional, Union import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence -from transformers import TopKLogitsWarper, TopPLogitsWarper, is_torch_npu_available, is_torch_xpu_available - - -try: - from collections.abc import Mapping -except ImportError: - from collections.abc import Mapping - - -WANDB_PADDING = -1 - - -def top_k_top_p_filtering( - logits: torch.FloatTensor, - top_k: int = 0, - top_p: float = 1.0, - filter_value: float = -float("Inf"), - min_tokens_to_keep: int = 1, -) -> torch.FloatTensor: - """ - Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. - - Args: - logits: logits distribution shape (batch size, vocabulary size) - top_k (`int`, *optional*, defaults to 0): - If > 0, only keep the top k tokens with highest probability (top-k filtering) - top_p (`float`, *optional*, defaults to 1.0): - If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus - filtering is described in Holtzman et al. (https://huggingface.co/papers/1904.09751) - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimumber of tokens we keep per batch example in the output. - - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 - """ - - if top_k > 0: - logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( - None, logits - ) - - if 0 <= top_p <= 1.0: - logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( - None, logits - ) - - return logits +from transformers import is_torch_npu_available, is_torch_xpu_available def flatten_dict(nested: dict, sep: str = "/") -> dict: @@ -88,52 +40,6 @@ def recurse(nest: dict, prefix: str, into: dict) -> None: return flat -def convert_to_scalar(stats: dict) -> dict: - """ - Converts the stats from a flattened dict to single scalar dicts - """ - tensorboard_stats = {} - for k, v in stats.items(): - # for tensorboard compatibility - arrays and tensors are ignored with tensorboard - # therefore we convert single element tensors to scalars - if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and ( - len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1) - ): - v = v.item() - tensorboard_stats[k] = v - return tensorboard_stats - - -def stack_dicts(stats_dicts: list[dict]) -> dict: - """Stack the values of a dict.""" - results = dict() - for k in stats_dicts[0]: - stats_list = [torch.flatten(d[k]) for d in stats_dicts] - results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING) - return results - - -def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor: - """ - See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 - """ - logp = F.log_softmax(logits, dim=2) - - if not gather: - return logp - logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1) - return logpy - - -def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: - """Whiten values.""" - mean, var = torch.mean(values), torch.var(values) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: """Compute mean of tensor with a masked values.""" if axis is not None: @@ -170,73 +76,6 @@ def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = T return whitened -def clip_by_value(x: torch.Tensor, tensor_min: float, tensor_max: float) -> torch.Tensor: - """ - Tensor extension to torch.clamp - https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713 - """ - clipped = torch.max(torch.min(x, tensor_max), tensor_min) - return clipped - - -def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor: - """Calculate entropy from logits.""" - pd = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1) - return entropy - - -def stats_to_np(stats_dict: dict) -> dict: - """Cast all torch.tensors in dict to numpy arrays.""" - new_dict = dict() - for k, v in stats_dict.items(): - if isinstance(v, torch.Tensor): - new_dict[k] = v.detach().cpu() - if new_dict[k].dtype == torch.bfloat16: - new_dict[k] = new_dict[k].float() - new_dict[k] = new_dict[k].numpy() - else: - new_dict[k] = v - if np.isscalar(new_dict[k]): - new_dict[k] = float(new_dict[k]) - return new_dict - - -def respond_to_batch( - model: nn.Module, queries: list[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0 -) -> torch.LongTensor: - """Sample text from language model.""" - input_ids = queries - for _i in range(txt_len): - # Get Logits - outputs = model(input_ids) - next_token_logits = outputs[0][:, -1, :] - next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - # Sample - probs = F.softmax(next_token_logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1).squeeze(1) - input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1) - return input_ids[:, -txt_len:] - - -def set_seed(seed: int) -> None: - """ - Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`. - - Args: - seed (`int`): The seed to set. - """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if is_torch_xpu_available(): - torch.xpu.manual_seed_all(seed) - elif is_torch_npu_available(): - torch.npu.manual_seed_all(seed) - else: - torch.cuda.manual_seed_all(seed) - - class LengthSampler: """ Samples a length diff --git a/trl/extras/best_of_n_sampler.py b/trl/extras/best_of_n_sampler.py index c0f02152c2..cf7b43ff25 100644 --- a/trl/extras/best_of_n_sampler.py +++ b/trl/extras/best_of_n_sampler.py @@ -15,9 +15,8 @@ from typing import Any, Callable, Optional, Union import torch -from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, set_seed -from ..core import set_seed from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 85a2e4d57c..9759609f56 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# There is a circular import in the PPOTrainer if we let isort sort these from typing import TYPE_CHECKING from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available @@ -21,7 +20,6 @@ _import_structure = { "alignprop_config": ["AlignPropConfig"], "alignprop_trainer": ["AlignPropTrainer"], - "base": ["BaseTrainer"], "bco_config": ["BCOConfig"], "bco_trainer": ["BCOTrainer"], "callbacks": [ @@ -41,8 +39,8 @@ "iterative_sft_trainer": ["IterativeSFTTrainer"], "judges": [ "AllTrueJudge", - "BaseJudge", "BaseBinaryJudge", + "BaseJudge", "BasePairwiseJudge", "BaseRankJudge", "HfPairwiseJudge", @@ -60,23 +58,21 @@ "orpo_trainer": ["ORPOTrainer"], "ppo_config": ["PPOConfig"], "ppo_trainer": ["PPOTrainer"], - "ppov2_config": ["PPOv2Config"], - "ppov2_trainer": ["PPOv2Trainer"], "prm_config": ["PRMConfig"], "prm_trainer": ["PRMTrainer"], "reward_config": ["RewardConfig"], - "reward_trainer": ["RewardTrainer", "compute_accuracy"], + "reward_trainer": ["RewardTrainer"], "rloo_config": ["RLOOConfig"], "rloo_trainer": ["RLOOTrainer"], "sft_config": ["SFTConfig"], "sft_trainer": ["SFTTrainer"], "utils": [ - "AdaptiveKLController", "ConstantLengthDataset", "DataCollatorForCompletionOnlyLM", - "FixedKLController", "RunningMoments", + "compute_accuracy", "disable_dropout_in_model", + "empty_cache", "peft_module_casting_to_bf16", ], "xpo_config": ["XPOConfig"], @@ -93,7 +89,6 @@ if TYPE_CHECKING: from .alignprop_config import AlignPropConfig from .alignprop_trainer import AlignPropTrainer - from .base import BaseTrainer from .bco_config import BCOConfig from .bco_trainer import BCOTrainer from .callbacks import ( @@ -135,17 +130,16 @@ from .prm_config import PRMConfig from .prm_trainer import PRMTrainer from .reward_config import RewardConfig - from .reward_trainer import RewardTrainer, compute_accuracy + from .reward_trainer import RewardTrainer from .rloo_config import RLOOConfig from .rloo_trainer import RLOOTrainer from .sft_config import SFTConfig from .sft_trainer import SFTTrainer from .utils import ( - AdaptiveKLController, ConstantLengthDataset, DataCollatorForCompletionOnlyLM, - FixedKLController, RunningMoments, + compute_accuracy, disable_dropout_in_model, empty_cache, peft_module_casting_to_bf16, diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py index c9cd718c44..7c5018892e 100644 --- a/trl/trainer/alignprop_trainer.py +++ b/trl/trainer/alignprop_trainer.py @@ -22,10 +22,11 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import PyTorchModelHubMixin from transformers import is_wandb_available from ..models import DDPOStableDiffusionPipeline -from . import AlignPropConfig, BaseTrainer +from .alignprop_config import AlignPropConfig from .utils import generate_model_card, get_comet_experiment_url @@ -35,7 +36,7 @@ logger = get_logger(__name__) -class AlignPropTrainer(BaseTrainer): +class AlignPropTrainer(PyTorchModelHubMixin): """ The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/ diff --git a/trl/trainer/base.py b/trl/trainer/base.py deleted file mode 100644 index 7730e6af9a..0000000000 --- a/trl/trainer/base.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from huggingface_hub import PyTorchModelHubMixin - - -class BaseTrainer(PyTorchModelHubMixin): - r""" - Base class for all trainers - this base class implements the basic functions that we - need for a trainer. - - The trainer needs to have the following functions: - - step: takes in a batch of data and performs a step of training - - loss: takes in a batch of data and returns the loss - - compute_rewards: takes in a batch of data and returns the rewards - - _build_models_and_tokenizer: builds the models and tokenizer - - _build_dataset: builds the dataset - Each user is expected to implement their own trainer class that inherits from this base - if they want to use a new training algorithm. - """ - - def __init__(self, config): - self.config = config - - def step(self, *args): - raise NotImplementedError("Not implemented") - - def loss(self, *args): - raise NotImplementedError("Not implemented") - - def compute_rewards(self, *args): - raise NotImplementedError("Not implemented") - - def _save_pretrained(self, save_directory): - raise NotImplementedError("Not implemented") diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py index 01a1d0d5c5..846ab2730e 100644 --- a/trl/trainer/ddpo_trainer.py +++ b/trl/trainer/ddpo_trainer.py @@ -23,10 +23,11 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import PyTorchModelHubMixin from transformers import is_wandb_available from ..models import DDPOStableDiffusionPipeline -from . import BaseTrainer, DDPOConfig +from .ddpo_config import DDPOConfig from .utils import PerPromptStatTracker, generate_model_card, get_comet_experiment_url @@ -37,7 +38,7 @@ logger = get_logger(__name__) -class DDPOTrainer(BaseTrainer): +class DDPOTrainer(PyTorchModelHubMixin): """ The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index bf760570bc..ab9a06e469 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -51,7 +51,6 @@ is_torch_xpu_available, ) -from ..import_utils import is_unsloth_available from ..trainer.model_config import ModelConfig @@ -62,34 +61,6 @@ from peft import LoraConfig, PeftConfig -class AdaptiveKLController: - """ - Adaptive KL controller described in the paper: - https://huggingface.co/papers/1909.08593 - """ - - def __init__(self, init_kl_coef, target, horizon): - self.value = init_kl_coef - self.target = target - self.horizon = horizon - - def update(self, current, n_steps): - target = self.target - proportional_error = np.clip(current / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.horizon - self.value *= mult - - -class FixedKLController: - """Fixed KL controller.""" - - def __init__(self, kl_coef): - self.value = kl_coef - - def update(self, current, n_steps): - pass - - class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): """ Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index' @@ -878,24 +849,6 @@ def peft_module_casting_to_bf16(model): module = module.to(torch.bfloat16) -def trl_sanitze_kwargs_for_tagging(model, tag_names, kwargs=None): - if is_unsloth_available(): - # Unsloth adds a new attribute in the model config `unsloth_version` - # to keep track of models that have been patched with unsloth. - if hasattr(model, "config") and getattr(model.config, "unsloth_version", None) is not None: - tag_names.append("unsloth") - - if kwargs is not None: - if "tags" not in kwargs: - kwargs["tags"] = tag_names - elif "tags" in kwargs and isinstance(kwargs["tags"], list): - kwargs["tags"].extend(tag_names) - elif "tags" in kwargs and isinstance(kwargs["tags"], str): - tag_names.append(kwargs["tags"]) - kwargs["tags"] = tag_names - return kwargs - - def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]: if model_args.load_in_4bit: quantization_config = BitsAndBytesConfig(