Skip to content

Commit

Permalink
🏚 Remove unused components (#2480)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Dec 19, 2024
1 parent 88ad1a0 commit 8c49ea3
Show file tree
Hide file tree
Showing 11 changed files with 24 additions and 280 deletions.
4 changes: 2 additions & 2 deletions examples/research_projects/stack_llama/scripts/rl_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline
from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline, set_seed

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from trl.core import LengthSampler


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
HfArgumentParser,
RobertaForSequenceClassification,
RobertaTokenizer,
set_seed,
)

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model, set_seed
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model
from trl.core import LengthSampler


Expand Down
6 changes: 5 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from trl.core import masked_mean, masked_var, masked_whiten, whiten
from trl.core import masked_mean, masked_var, masked_whiten


class CoreTester(unittest.TestCase):
Expand All @@ -36,6 +36,10 @@ def test_masked_var(self):
self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask))

def test_masked_whiten(self):
def whiten(values: torch.Tensor) -> torch.Tensor:
mean, var = torch.mean(values), torch.var(values)
return (values - mean) * torch.rsqrt(var + 1e-8)

whiten_unmasked = whiten(self.test_input_unmasked)
whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3]
diffs = (whiten_unmasked - whiten_masked).sum()
Expand Down
2 changes: 0 additions & 2 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

_import_structure = {
"scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"],
"core": ["set_seed"],
"data_utils": [
"apply_chat_template",
"extract_prompt",
Expand Down Expand Up @@ -115,7 +114,6 @@
_import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"])

if TYPE_CHECKING:
from .core import set_seed
from .data_utils import (
apply_chat_template,
extract_prompt,
Expand Down
165 changes: 2 additions & 163 deletions trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,62 +13,14 @@
# limitations under the License.

import gc
import random
import warnings
from collections.abc import Mapping
from contextlib import contextmanager
from typing import Optional, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import TopKLogitsWarper, TopPLogitsWarper, is_torch_npu_available, is_torch_xpu_available


try:
from collections.abc import Mapping
except ImportError:
from collections.abc import Mapping


WANDB_PADDING = -1


def top_k_top_p_filtering(
logits: torch.FloatTensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> torch.FloatTensor:
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
Args:
logits: logits distribution shape (batch size, vocabulary size)
top_k (`int`, *optional*, defaults to 0):
If > 0, only keep the top k tokens with highest probability (top-k filtering)
top_p (`float`, *optional*, defaults to 1.0):
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
filtering is described in Holtzman et al. (https://huggingface.co/papers/1904.09751)
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimumber of tokens we keep per batch example in the output.
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""

if top_k > 0:
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
)

if 0 <= top_p <= 1.0:
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
)

return logits
from transformers import is_torch_npu_available, is_torch_xpu_available


def flatten_dict(nested: dict, sep: str = "/") -> dict:
Expand All @@ -88,52 +40,6 @@ def recurse(nest: dict, prefix: str, into: dict) -> None:
return flat


def convert_to_scalar(stats: dict) -> dict:
"""
Converts the stats from a flattened dict to single scalar dicts
"""
tensorboard_stats = {}
for k, v in stats.items():
# for tensorboard compatibility - arrays and tensors are ignored with tensorboard
# therefore we convert single element tensors to scalars
if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and (
len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1)
):
v = v.item()
tensorboard_stats[k] = v
return tensorboard_stats


def stack_dicts(stats_dicts: list[dict]) -> dict:
"""Stack the values of a dict."""
results = dict()
for k in stats_dicts[0]:
stats_list = [torch.flatten(d[k]) for d in stats_dicts]
results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING)
return results


def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor:
"""
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
"""
logp = F.log_softmax(logits, dim=2)

if not gather:
return logp
logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
return logpy


def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
"""Whiten values."""
mean, var = torch.mean(values), torch.var(values)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened


def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
if axis is not None:
Expand Down Expand Up @@ -170,73 +76,6 @@ def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = T
return whitened


def clip_by_value(x: torch.Tensor, tensor_min: float, tensor_max: float) -> torch.Tensor:
"""
Tensor extension to torch.clamp
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
"""
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
return clipped


def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
"""Calculate entropy from logits."""
pd = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
return entropy


def stats_to_np(stats_dict: dict) -> dict:
"""Cast all torch.tensors in dict to numpy arrays."""
new_dict = dict()
for k, v in stats_dict.items():
if isinstance(v, torch.Tensor):
new_dict[k] = v.detach().cpu()
if new_dict[k].dtype == torch.bfloat16:
new_dict[k] = new_dict[k].float()
new_dict[k] = new_dict[k].numpy()
else:
new_dict[k] = v
if np.isscalar(new_dict[k]):
new_dict[k] = float(new_dict[k])
return new_dict


def respond_to_batch(
model: nn.Module, queries: list[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
) -> torch.LongTensor:
"""Sample text from language model."""
input_ids = queries
for _i in range(txt_len):
# Get Logits
outputs = model(input_ids)
next_token_logits = outputs[0][:, -1, :]
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
# Sample
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
return input_ids[:, -txt_len:]


def set_seed(seed: int) -> None:
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`.
Args:
seed (`int`): The seed to set.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if is_torch_xpu_available():
torch.xpu.manual_seed_all(seed)
elif is_torch_npu_available():
torch.npu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)


class LengthSampler:
"""
Samples a length
Expand Down
3 changes: 1 addition & 2 deletions trl/extras/best_of_n_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from typing import Any, Callable, Optional, Union

import torch
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, set_seed

from ..core import set_seed
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper


Expand Down
18 changes: 6 additions & 12 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# There is a circular import in the PPOTrainer if we let isort sort these
from typing import TYPE_CHECKING

from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available
Expand All @@ -21,7 +20,6 @@
_import_structure = {
"alignprop_config": ["AlignPropConfig"],
"alignprop_trainer": ["AlignPropTrainer"],
"base": ["BaseTrainer"],
"bco_config": ["BCOConfig"],
"bco_trainer": ["BCOTrainer"],
"callbacks": [
Expand All @@ -41,8 +39,8 @@
"iterative_sft_trainer": ["IterativeSFTTrainer"],
"judges": [
"AllTrueJudge",
"BaseJudge",
"BaseBinaryJudge",
"BaseJudge",
"BasePairwiseJudge",
"BaseRankJudge",
"HfPairwiseJudge",
Expand All @@ -60,23 +58,21 @@
"orpo_trainer": ["ORPOTrainer"],
"ppo_config": ["PPOConfig"],
"ppo_trainer": ["PPOTrainer"],
"ppov2_config": ["PPOv2Config"],
"ppov2_trainer": ["PPOv2Trainer"],
"prm_config": ["PRMConfig"],
"prm_trainer": ["PRMTrainer"],
"reward_config": ["RewardConfig"],
"reward_trainer": ["RewardTrainer", "compute_accuracy"],
"reward_trainer": ["RewardTrainer"],
"rloo_config": ["RLOOConfig"],
"rloo_trainer": ["RLOOTrainer"],
"sft_config": ["SFTConfig"],
"sft_trainer": ["SFTTrainer"],
"utils": [
"AdaptiveKLController",
"ConstantLengthDataset",
"DataCollatorForCompletionOnlyLM",
"FixedKLController",
"RunningMoments",
"compute_accuracy",
"disable_dropout_in_model",
"empty_cache",
"peft_module_casting_to_bf16",
],
"xpo_config": ["XPOConfig"],
Expand All @@ -93,7 +89,6 @@
if TYPE_CHECKING:
from .alignprop_config import AlignPropConfig
from .alignprop_trainer import AlignPropTrainer
from .base import BaseTrainer
from .bco_config import BCOConfig
from .bco_trainer import BCOTrainer
from .callbacks import (
Expand Down Expand Up @@ -135,17 +130,16 @@
from .prm_config import PRMConfig
from .prm_trainer import PRMTrainer
from .reward_config import RewardConfig
from .reward_trainer import RewardTrainer, compute_accuracy
from .reward_trainer import RewardTrainer
from .rloo_config import RLOOConfig
from .rloo_trainer import RLOOTrainer
from .sft_config import SFTConfig
from .sft_trainer import SFTTrainer
from .utils import (
AdaptiveKLController,
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
FixedKLController,
RunningMoments,
compute_accuracy,
disable_dropout_in_model,
empty_cache,
peft_module_casting_to_bf16,
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/alignprop_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import PyTorchModelHubMixin
from transformers import is_wandb_available

from ..models import DDPOStableDiffusionPipeline
from . import AlignPropConfig, BaseTrainer
from .alignprop_config import AlignPropConfig
from .utils import generate_model_card, get_comet_experiment_url


Expand All @@ -35,7 +36,7 @@
logger = get_logger(__name__)


class AlignPropTrainer(BaseTrainer):
class AlignPropTrainer(PyTorchModelHubMixin):
"""
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
Expand Down
Loading

0 comments on commit 8c49ea3

Please sign in to comment.