From 81e935a1403f932bf222b5e4a43d0a6ddd81c5e6 Mon Sep 17 00:00:00 2001 From: Jonathan Tow <41410219+jon-tow@users.noreply.github.com> Date: Fri, 10 Feb 2023 11:54:13 -0500 Subject: [PATCH] refactor: remove orchestrator abstraction from API (#289) * refactor: remove orchestrator abstraction from API * Remove orchestrator in GPT-J config * Add `reward_fn` arg to NeMo constructor to match base trainer API * Initial support for `make_experience` in NeMo ILQL * Run pre-commit * Remove unused sampling util --- configs/ilql_config.yml | 1 - configs/nemo_ilql_config.yml | 1 - configs/ppo_config.yml | 1 - configs/ppo_gptj.yml | 1 - configs/sft_config.yml | 1 - configs/test_config.yml | 3 +- docs/source/index.rst | 1 - docs/source/orchestrator.rst | 23 -- docs/source/pipeline.rst | 2 +- .../configs/trlx_ppo_config.yml | 1 - .../randomwalks/configs/ilql_randomwalks.yml | 1 - .../randomwalks/configs/ppo_randomwalks.yml | 1 - .../configs/ppo_config_cnn_daily.yml | 1 - .../configs/ppo_config_summ_gptj.yml | 1 - trlx/data/__init__.py | 17 +- trlx/data/configs.py | 4 - trlx/orchestrator/__init__.py | 46 --- trlx/orchestrator/offline_orchestrator.py | 132 -------- trlx/orchestrator/ppo_orchestrator.py | 290 ------------------ trlx/pipeline/offline_pipeline.py | 39 ++- trlx/trainer/__init__.py | 21 +- trlx/trainer/accelerate_ilql_trainer.py | 78 +++++ trlx/trainer/accelerate_ppo_trainer.py | 277 +++++++++++++++-- trlx/trainer/nemo_ilql_trainer.py | 82 ++++- trlx/trlx.py | 13 +- trlx/utils/__init__.py | 19 -- trlx/utils/loading.py | 16 - 27 files changed, 466 insertions(+), 607 deletions(-) delete mode 100644 docs/source/orchestrator.rst delete mode 100644 trlx/orchestrator/__init__.py delete mode 100644 trlx/orchestrator/offline_orchestrator.py delete mode 100644 trlx/orchestrator/ppo_orchestrator.py diff --git a/configs/ilql_config.yml b/configs/ilql_config.yml index 3a4fc0c02..40c162c70 100644 --- a/configs/ilql_config.yml +++ b/configs/ilql_config.yml @@ -8,7 +8,6 @@ train: eval_interval: 100 pipeline: "PromptPipeline" - orchestrator: "OfflineOrchestrator" trainer: "AccelerateILQLTrainer" seed: 1000 diff --git a/configs/nemo_ilql_config.yml b/configs/nemo_ilql_config.yml index 752424034..1d4cc71e2 100644 --- a/configs/nemo_ilql_config.yml +++ b/configs/nemo_ilql_config.yml @@ -7,7 +7,6 @@ train: eval_interval: 20 pipeline: "PromptPipeline" - orchestrator: "OfflineOrchestrator" trainer: "NeMoILQLTrainer" trainer_kwargs: pretrained_model: "/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/" diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index f8cff096e..92388a2be 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -8,7 +8,6 @@ train: eval_interval: 100 pipeline: "PromptPipeline" - orchestrator: "PPOOrchestrator" trainer: "AcceleratePPOTrainer" model: diff --git a/configs/ppo_gptj.yml b/configs/ppo_gptj.yml index 9706f14fb..0595f7ded 100644 --- a/configs/ppo_gptj.yml +++ b/configs/ppo_gptj.yml @@ -8,7 +8,6 @@ train: eval_interval: 16 pipeline: "PromptPipeline" - orchestrator: "PPOOrchestrator" trainer: "AcceleratePPOTrainer" model: diff --git a/configs/sft_config.yml b/configs/sft_config.yml index a2ebc9603..4b1efe358 100644 --- a/configs/sft_config.yml +++ b/configs/sft_config.yml @@ -8,7 +8,6 @@ train: eval_interval: 100 pipeline: "PromptPipeline" - orchestrator: "PPOOrchestrator" trainer: "AccelerateSFTTrainer" model: diff --git a/configs/test_config.yml b/configs/test_config.yml index b26228eeb..19adb4f6f 100644 --- a/configs/test_config.yml +++ b/configs/test_config.yml @@ -8,7 +8,6 @@ train: eval_interval: 128 # eval interval pipeline: "PromptPipeline" # prompt pipeline to load - orchestrator: "PPOOrchestrator" # orchestrator to load trainer: "AcceleratePPOTrainer" # Name of model trainer to load model: @@ -36,7 +35,7 @@ scheduler: method: name: "ppoconfig" # Name of RL method config num_rollouts: 128 # Number of rollouts to collect per epoch - chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator + chunk_size: 128 # Number of rollouts to collect in one loop ppo_epochs: 4 # Number of ppo epochs init_kl_coef: 0.2 # init kl coefficient target: 6 # target kl coefficient, set None for fixed kl coef diff --git a/docs/source/index.rst b/docs/source/index.rst index 782e29ecc..1b2947593 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,7 +14,6 @@ currently supports training using PPO or ILQL for models up to 20B using Acceler data models - orchestrator configs pipeline examples diff --git a/docs/source/orchestrator.rst b/docs/source/orchestrator.rst deleted file mode 100644 index 0a8a6a059..000000000 --- a/docs/source/orchestrator.rst +++ /dev/null @@ -1,23 +0,0 @@ -.. _orchestrator: - -Orchestrators -******************* - -Orchestrators manage reading data from a pipeline and creating RL data elements (i.e. ``trlx.data.RLElement``) -to push to a models rollout storage. Use the ``trlx.orchestrator.register_orchestrator`` decorator when creating -new orchestrators. - -**General** - -.. autoclass:: trlx.orchestrator.Orchestrator - :members: - -**PPO** - -.. autoclass:: trlx.orchestrator.ppo_orchestrator.PPOOrchestrator - :members: - -**ILQL** - -.. autoclass:: trlx.orchestrator.offline_orchestrator.OfflineOrchestrator - :members: diff --git a/docs/source/pipeline.rst b/docs/source/pipeline.rst index 04d1a8c04..68279d889 100644 --- a/docs/source/pipeline.rst +++ b/docs/source/pipeline.rst @@ -4,7 +4,7 @@ Pipelines ************************ Pipelines are how you read from a dataset with trlX. Rollout stores are how models store experiences created -for them by the orchestrator. It is these experiences in their rollout store that they are trained on. +for them. It is these experiences in their rollout store that they are trained on. **General** diff --git a/examples/experiments/grounded_program_synthesis/configs/trlx_ppo_config.yml b/examples/experiments/grounded_program_synthesis/configs/trlx_ppo_config.yml index ad1c6a282..825050aef 100644 --- a/examples/experiments/grounded_program_synthesis/configs/trlx_ppo_config.yml +++ b/examples/experiments/grounded_program_synthesis/configs/trlx_ppo_config.yml @@ -8,7 +8,6 @@ train: eval_interval: 16 pipeline: "PromptPipeline" - orchestrator: "PPOOrchestrator" trainer: "AcceleratePPOTrainer" model: diff --git a/examples/randomwalks/configs/ilql_randomwalks.yml b/examples/randomwalks/configs/ilql_randomwalks.yml index f4dcf1c88..dd6f84a49 100644 --- a/examples/randomwalks/configs/ilql_randomwalks.yml +++ b/examples/randomwalks/configs/ilql_randomwalks.yml @@ -8,7 +8,6 @@ train: eval_interval: 16 pipeline: "PromptPipeline" - orchestrator: "OfflineOrchestrator" trainer: "AccelerateILQLTrainer" seed: 1000 diff --git a/examples/randomwalks/configs/ppo_randomwalks.yml b/examples/randomwalks/configs/ppo_randomwalks.yml index 6489f4cdd..b4d6e7a6c 100644 --- a/examples/randomwalks/configs/ppo_randomwalks.yml +++ b/examples/randomwalks/configs/ppo_randomwalks.yml @@ -8,7 +8,6 @@ train: eval_interval: 20 pipeline: "PromptPipeline" - orchestrator: "PPOOrchestrator" trainer: "AcceleratePPOTrainer" model: diff --git a/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml b/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml index 8b6c7f33a..2134beadd 100755 --- a/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml +++ b/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml @@ -9,7 +9,6 @@ train: save_best: False pipeline: "PromptPipeline" - orchestrator: "PPOOrchestrator" trainer: "AcceleratePPOTrainer" model: diff --git a/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml b/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml index 812970918..8055a49b5 100755 --- a/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml +++ b/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml @@ -8,7 +8,6 @@ train: eval_interval: 200 pipeline: "PromptPipeline" - orchestrator: "PPOOrchestrator" trainer: "AcceleratePPOTrainer" model: diff --git a/trlx/data/__init__.py b/trlx/data/__init__.py index a1d996c9a..96d46750e 100644 --- a/trlx/data/__init__.py +++ b/trlx/data/__init__.py @@ -1,31 +1,18 @@ from dataclasses import dataclass -from typing import Any, Iterable +from typing import Iterable from torchtyping import TensorType -from . import configs - @dataclass class GeneralElement: """ - General element outputted by data pipeline being read by orchestrator. + General element outputted by a data pipeline """ pass -@dataclass -class SimElement: - """ - Batch element for Gyarados or Gyarados-like similarity scoring model - """ - - content: Any = None - preference: Any = None - score: float = None - - @dataclass class RLElement: """ diff --git a/trlx/data/configs.py b/trlx/data/configs.py index 7a5955bb2..a0a6feaec 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -152,9 +152,6 @@ class TrainConfig: :param pipeline: Pipeline to use for training. One of the registered pipelines present in trlx.pipeline :type pipeline: str - :param orchestrator: Orchestrator to use for training. One of the registered orchestrators present in trlx.orchestrator - :type orchestrator: str - :param trainer: Trainer to use for training. One of the registered trainers present in trlx.trainer :type trainer: str @@ -193,7 +190,6 @@ class TrainConfig: eval_interval: int pipeline: str # One of the pipelines in framework.pipeline - orchestrator: str # One of the orchestrators trainer: str # One of the trainers trainer_kwargs: Dict[str, Any] = field(default_factory=dict) # Extra keyword arguments for the trainer diff --git a/trlx/orchestrator/__init__.py b/trlx/orchestrator/__init__.py deleted file mode 100644 index a7678a85e..000000000 --- a/trlx/orchestrator/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -import sys -from abc import abstractmethod -from typing import Dict - -from trlx.pipeline import BasePipeline -from trlx.trainer import BaseRLTrainer - -# specifies a dictionary of architectures -_ORCH: Dict[str, any] = {} # registry - - -def register_orchestrator(name): - """Decorator used register a CARP architecture - Args: - name: Name of the architecture - """ - - def register_class(cls, name): - _ORCH[name] = cls - setattr(sys.modules[__name__], name, cls) - return cls - - if isinstance(name, str): - name = name.lower() - return lambda c: register_class(c, name) - - cls = name - name = cls.__name__ - register_class(cls, name.lower()) - - return cls - - -@register_orchestrator -class Orchestrator: - def __init__(self, pipeline: BasePipeline, trainer: BaseRLTrainer): - self.pipeline = pipeline - self.trainer = trainer - - @abstractmethod - def make_experience(self): - """ - Draw from pipeline, get action, generate reward - Push to models RolloutStorage - """ - pass diff --git a/trlx/orchestrator/offline_orchestrator.py b/trlx/orchestrator/offline_orchestrator.py deleted file mode 100644 index d426e7aad..000000000 --- a/trlx/orchestrator/offline_orchestrator.py +++ /dev/null @@ -1,132 +0,0 @@ -import os -from typing import List, Union - -import numpy as np -import torch -from rich.console import Console -from rich.table import Table - -import trlx.utils.logging as logging -from trlx.orchestrator import Orchestrator, register_orchestrator -from trlx.pipeline.offline_pipeline import ILQLRolloutStorage - -logger = logging.get_logger(__name__) - - -def tokenize_dialogue( # noqa: C901 - dialogue: Union[str, List[str]], tokenizer, max_length=2048, truncation_side="left" -) -> List[int]: - """ - Tokenize sample with the interleaved form of (prompt_1, output_1, prompt_2, output_2...) - """ - if isinstance(dialogue, str): - dialogue = [tokenizer.bos_token, dialogue] - elif isinstance(dialogue, tuple): - dialogue = list(dialogue) - dialogue[-1] += tokenizer.eos_token - - out = [] - ctx_length = max_length - if tokenizer.truncation_side == "left": - for phrase in reversed(dialogue): - tokens = tokenizer(phrase).input_ids[-ctx_length:] - ctx_length -= len(tokens) - out.insert(0, tokens) - if ctx_length == 0: - break - - # in case of odd number of phrases (possibly due to truncation) - # since the first phrase always has to be a prompt, force it to be - if len(out) % 2 == 1: - if sum(map(len, out)) == max_length: - out[0].pop(0) - out.insert(0, [tokenizer.bos_token_id]) - - elif tokenizer.truncation_side == "right": - for phrase in dialogue: - tokens = tokenizer(phrase).input_ids[:ctx_length] - ctx_length -= len(tokens) - out.append(tokens) - if ctx_length == 0: - break - return out - - -@register_orchestrator -class OfflineOrchestrator(Orchestrator): - """ - Orchestrator that creates a static dataset for offline training - """ - - def __init__(self, trainer): - self.trainer = trainer - - def make_experience(self, samples, rewards, max_length=2048): - """ - Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer - """ - logger.info("Collecting rollouts") - - if self.trainer.tokenizer: - samples = [tokenize_dialogue(s, self.trainer.tokenizer, max_length) for s in samples] - - all_input_ids = [] - all_actions_ixs = [] - all_states_ixs = [] - all_dones = [] - for sample in samples: - length = 0 - all_input_ids.append(torch.tensor(sum(sample, []))) - isoutput = False - actions_ixs = [] - for phrase in sample: - if isoutput: - actions_ixs.append(torch.arange(length - 1, length + len(phrase) - 1)) - - length += len(phrase) - isoutput = not isoutput - - states_ixs = torch.hstack((*actions_ixs, torch.tensor(length - 1))) - all_dones.append(torch.tensor([1] * (len(states_ixs) - 1) + [0], dtype=int)) - all_actions_ixs.append(torch.hstack(actions_ixs)) - all_states_ixs.append(states_ixs) - - if self.trainer.tokenizer and os.environ.get("RANK", "0") == "0": - logger.info("Logging sample example") - prompt = self.trainer.tokenizer.decode(all_input_ids[0][: all_states_ixs[0][1]]) - response = self.trainer.tokenizer.decode(all_input_ids[0][all_states_ixs[0][1] :]) - columns = ["Prompt", "Response", "Reward"] - table = Table(*columns, title="Sample Example", show_lines=True) - table.add_row(prompt, response, str(rewards[0])) - Console().print(table) - - sample_lengths = np.array(list(map(len, all_input_ids))) - output_lengths = np.array(list(map(len, all_actions_ixs))) - prompt_lengths = sample_lengths - output_lengths - returns = torch.tensor(rewards, dtype=float) - - if os.environ.get("RANK", "0") == "0": - logger.info("Logging experience string statistics") - columns = ["Prompt Length", "Output Length", "Sample Length"] - table = Table(*columns, title="Experience String Stats (mean ∈ \[min, max])", show_lines=True) - row = [] - for lengths in [prompt_lengths, output_lengths, sample_lengths]: - row.append(f"{lengths.mean():.2f} ∈ [{min(lengths)}, {max(lengths)}]") - table.add_row(*row) - Console().print(table) - - returns = (returns - returns.mean()) / (returns.std() + 1e-30) - rewards = [torch.zeros(len(x)) for x in all_actions_ixs] - for rs, ret in zip(rewards, returns): - rs[-1] = ret - - attention_mask = [torch.ones(len(x), dtype=int) for x in all_input_ids] - - self.trainer.store = ILQLRolloutStorage( - all_input_ids, - attention_mask, - rewards, - all_states_ixs, - all_actions_ixs, - all_dones, - ) diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py deleted file mode 100644 index 4ab2d7f5d..000000000 --- a/trlx/orchestrator/ppo_orchestrator.py +++ /dev/null @@ -1,290 +0,0 @@ -import os -from time import time -from typing import List - -import ray -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader - -import trlx.utils.logging as logging -from trlx.data.accelerate_base_datatypes import PromptBatch -from trlx.data.ppo_types import PPORLElement -from trlx.orchestrator import Orchestrator, register_orchestrator -from trlx.pipeline import BasePipeline -from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer -from trlx.utils import Clock -from trlx.utils.modeling import RunningMoments, logprobs_of_labels - -logger = logging.get_logger(__name__) - - -@register_orchestrator -class PPOOrchestrator(Orchestrator): - """PPO Orchestrator - - Runs rollouts - generates samples from prompts using the model, calculates - KL divergence against the reference model, and then pushes these - samples/rewards etc to the trainer's store. - - Note this class is intwined with the trainer `AcceleratePPOTrainer` - it - adds an `orch` property to the trainer instance and also sets a `trainer` - property on itself. See the trainer class for more details. - """ - - def __init__( - self, - trainer: AcceleratePPOTrainer, - pipeline: BasePipeline, - chunk_size: int = 512, - ): - """_summary_ - - Args: - trainer: Trainer - pipeline: Dataset - chunk_size: Batch size - """ - self.pipeline = pipeline - self.trainer = trainer - self.chunk_size = chunk_size - - # Create the dataloader (for batches of prompts) - self.pipeline_loader: DataLoader = self.pipeline.create_loader(self.chunk_size, shuffle=True) - self.pipeline_loader = self.trainer.accelerator.prepare_data_loader(self.pipeline_loader) - self.pipeline_iterator = iter(self.pipeline_loader) - - if not hasattr(self.trainer.model, "frozen_head"): - self.ref_model = self.trainer.get_arch(self.trainer.config) - self.ref_model.to(self.trainer.accelerator.device) - - # Set this orchestrator as a property on the trainer, so that the - # trainer can call `make_experience` directly for each epoch. - self.trainer.orch = self - - self.running = RunningMoments() - self.ref_mean = self.trainer.config.method.ref_mean - self.ref_std = self.trainer.config.method.ref_std - - def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa: - """Make experiences - - Takes `num_rollouts` prompts from `pipeline`, samples from the model and - then computes the KL against a reference model. Finally it then appends - PPOElements to trainer's `store`. - - Args: - num_rollouts: Number of rollouts to generate - iter_count: Total number of updates run (i.e. number of updates run - for all batches & epochs) - """ - logger.info("Collecting rollouts") - tbar = logging.tqdm( - total=num_rollouts, - disable=os.environ.get("RANK", 0) != "0", - desc=f"[rollout 0 / {num_rollouts}]", - # Lower progress bar by 1 if we're in WARNING mode or above to avoid hiding high priority progress - # bars (e.g. loss progress in trainers) - position=logging.get_verbosity() >= logging.WARNING, - # Leave progress bar if we're in INFO mode or lower to avoid spamming in suppressed verbosity levels - leave=logging.get_verbosity() < logging.WARNING, - ) - - ppo_rl_elements = [] - stats = {} - clock = Clock() - - while len(ppo_rl_elements) < num_rollouts: - # Get next batch in prompt dataset and refresh if exhausted - try: - batch: PromptBatch = next(self.pipeline_iterator) - except StopIteration: - self.pipeline_iterator = iter(self.pipeline_loader) - batch = next(self.pipeline_iterator) - - exp_generate_time = time() - - # Generate samples from the language model (similar to using - # HuggingFace `generate` method) - samples = self.trainer.generate(**batch) - stats["time/exp_generate"] = time() - exp_generate_time - - prompt_tensors = batch.input_ids - device = samples.device - str_samples, str_prompts, str_outputs = self.trainer.decode(prompt_tensors, samples) - - # Pad the sample outputs - outputs = self.trainer.tokenizer(str_outputs).input_ids - outputs = list(map(torch.LongTensor, outputs)) - maxsize = max(map(len, outputs)) - outputs = [ - F.pad( - output, - (0, maxsize - len(output)), - value=self.trainer.tokenizer.pad_token_id, - ) - for output in outputs - ] - sample_outputs = torch.vstack(outputs).to(device) - - exp_score_time = time() - - scores = torch.tensor( - self.trainer.reward_fn( - samples=str_samples, - prompts=str_prompts, - outputs=str_outputs, - ), - dtype=torch.float, - ).to(device) - stats["time/exp_score"] = time() - exp_score_time - - # store statistics of the initial rollout as reference - if self.ref_mean is None: - self.ref_mean, self.ref_std = scores.mean(), scores.std() - all_scores_mean, all_scores_std = self.running.update(scores) - stats["exp_scores/mean"] = all_scores_mean - stats["exp_scores/std"] = all_scores_std - stats["exp_scores/running_mean"] = self.running.mean - stats["exp_scores/running_std"] = self.running.std - - if self.trainer.config.method.scale_reward == "running": - scores /= self.running.std - elif self.trainer.config.method.scale_reward == "ref": - scores /= self.ref_std - - clip_reward = self.trainer.config.method.cliprange_reward - if clip_reward: - scores = torch.clip(scores, -clip_reward, clip_reward) - - # Precompute logprobs, values - if self.trainer.config.model.model_arch_type == "seq2seq": - attention_mask = batch.attention_mask.to(device) - prompt_tensors = batch.input_ids.to(device) - with torch.no_grad(): - outputs = self.trainer.model( - input_ids=prompt_tensors, - attention_mask=attention_mask, - decoder_input_ids=sample_outputs, - ) - logits = outputs.logits - values = outputs.value - if hasattr(self.trainer.model, "frozen_head"): - ref_logits = self.trainer.model.forward_hydra( - input_ids=prompt_tensors, - attention_mask=attention_mask, - decoder_input_ids=sample_outputs, - ) - else: - ref_logits = self.ref_model( - input_ids=prompt_tensors, - attention_mask=attention_mask, - decoder_input_ids=sample_outputs, - ).logits - else: - all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) - attention_mask = all_tokens.not_equal(self.trainer.tokenizer.pad_token_id).long().to(device) - with torch.no_grad(): - logits, *_, values = self.trainer.model( - all_tokens, - attention_mask=attention_mask, - ) - # TODO(dahoas): When hydra model works need to also support generation on hydra head - if hasattr(self.trainer.model, "frozen_head"): - ref_logits = self.trainer.model.forward_hydra( - all_tokens, - attention_mask=attention_mask, - return_dict=False, - ) - else: - ref_logits, _, *_ = self.ref_model( - all_tokens, - attention_mask=attention_mask, - return_dict=False, - ) - ref_logits = ref_logits.to(device) - - if self.trainer.config.model.model_arch_type == "seq2seq": - logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) - else: - logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) - - n_samples: int = samples.shape[0] - logprobs = logprobs.cpu() - ref_logprobs = ref_logprobs.cpu() - prompt_tensors = prompt_tensors.cpu() - sample_outputs = sample_outputs.cpu() - - # Estimate the KL divergence between the model and reference model - if self.trainer.config.model.model_arch_type == "seq2seq": - # Skip the beginning of sequence token - start = 1 - - # Get the number of non-padding tokens for each sample - # This assumes all padding is on the right side - padding_token: int = 0 - ends = (sample_outputs[:, start:] != padding_token).sum(1) - - # Get the logprobs and values, for tokens that are not padding - # or beginning of sequences tokens. These are from the model - # (not the reference model) - all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] - all_values = [values[ix, start - 1 : ends[ix] - 1] for ix in range(n_samples)] - - kl_divergence_estimate: List[torch.Tensor] = [ - -self.trainer.kl_ctl.value - * ( - logprobs[sample_idx, start : ends[sample_idx]] - - ref_logprobs[sample_idx, start : ends[sample_idx]] - ) - for sample_idx in range(n_samples) - ] - - # Else if not seq2seq (i.e. causal) - else: - values = values.cpu()[:, :-1] - start = prompt_tensors.shape[1] - 1 - ends = start + attention_mask[:, start:].sum(1) - all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] - all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] - - kl_divergence_estimate = -self.trainer.kl_ctl.value * (logprobs - ref_logprobs) - kl_divergence_estimate = [rs[start : ends[ix]] for ix, rs in enumerate(kl_divergence_estimate)] - - rollout_count = 0 - - for sample_idx in range(n_samples): - sample_kl_divergence_estimate = kl_divergence_estimate[sample_idx] - - if len(sample_kl_divergence_estimate) == 0 or len(all_logprobs[sample_idx]) == 0: - continue - - rewards = sample_kl_divergence_estimate - rewards[-1] += scores[sample_idx].cpu() - - ppo_rl_elements.append( - PPORLElement( - query_tensor=prompt_tensors[sample_idx], - response_tensor=sample_outputs[sample_idx], - logprobs=all_logprobs[sample_idx], - values=all_values[sample_idx], - rewards=rewards, - ) - ) - - rollout_count += 1 - exp_time = clock.tick() - tbar.set_description(f"[rollout {len(ppo_rl_elements)} / {num_rollouts}]") - tbar.update(min(rollout_count, num_rollouts)) - tbar.close() - - stats["kl_ctl_value"] = self.trainer.kl_ctl.value - stats["time/exp"] = exp_time - - if not ray.is_initialized(): - self.trainer.accelerator.log(stats, step=iter_count) - - # Push samples and rewards to trainer's rollout storage - self.trainer.push_to_store(ppo_rl_elements) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index 773eabd86..bd70ad79c 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -1,4 +1,4 @@ -from typing import Iterable, List +from typing import Iterable, List, Union import torch from torch.nn.utils.rnn import pad_sequence @@ -9,6 +9,43 @@ from trlx.pipeline import BasePipeline, BaseRolloutStore, register_datapipeline +def tokenize_dialogue(dialogue: Union[str, List[str]], tokenizer, max_length=2048) -> List[int]: # noqa: C901 + """ + Tokenize sample with the interleaved form of (prompt_1, output_1, prompt_2, output_2...) + """ + if isinstance(dialogue, str): + dialogue = [tokenizer.bos_token, dialogue] + elif isinstance(dialogue, tuple): + dialogue = list(dialogue) + dialogue[-1] += tokenizer.eos_token + + out = [] + ctx_length = max_length + if tokenizer.truncation_side == "left": + for phrase in reversed(dialogue): + tokens = tokenizer(phrase).input_ids[-ctx_length:] + ctx_length -= len(tokens) + out.insert(0, tokens) + if ctx_length == 0: + break + + # in case of odd number of phrases (possibly due to truncation) + # since the first phrase always has to be a prompt, force it to be + if len(out) % 2 == 1: + if sum(map(len, out)) == max_length: + out[0].pop(0) + out.insert(0, [tokenizer.bos_token_id]) + + elif tokenizer.truncation_side == "right": + for phrase in dialogue: + tokens = tokenizer(phrase).input_ids[:ctx_length] + ctx_length -= len(tokens) + out.append(tokens) + if ctx_length == 0: + break + return out + + @register_datapipeline class PromptPipeline(BasePipeline): """ diff --git a/trlx/trainer/__init__.py b/trlx/trainer/__init__.py index 68142bab2..e1c469e21 100644 --- a/trlx/trainer/__init__.py +++ b/trlx/trainer/__init__.py @@ -58,17 +58,9 @@ def push_to_store(self, data): self.store.push(data) def add_eval_pipeline(self, eval_pipeline): - """Adds pipeline from with validation prompts""" + """Adds pipeline for validation prompts""" self.eval_pipeline = eval_pipeline - @abstractmethod - def act(self, data: RLElement) -> RLElement: - """ - Given RLElement with state, produce an action and add it to the RLElement. - Orchestrator should call this, get reward and push subsequent RLElement to RolloutStore - """ - pass - @abstractmethod def sample(self, prompts: Iterable[str], length: int, n_samples: int) -> Iterable[str]: """ @@ -113,14 +105,3 @@ def save(self, directory=None): def load(self, directory=None): """Loads a checkpoint created from `save`""" pass - - def intervals(self, steps: int) -> Dict[str, bool]: - """ - Using config and current step number, returns a dict of whether certain things should be done - """ - - return { - "do_log": (steps + 1) % self.config.train.log_interval == 0, - "do_eval": (steps + 1) % self.config.train.eval_interval == 0, - "do_save": (steps + 1) % self.config.train.checkpoint_interval == 0, - } diff --git a/trlx/trainer/accelerate_ilql_trainer.py b/trlx/trainer/accelerate_ilql_trainer.py index 35d0031b6..1cab072d2 100644 --- a/trlx/trainer/accelerate_ilql_trainer.py +++ b/trlx/trainer/accelerate_ilql_trainer.py @@ -1,14 +1,22 @@ +import os from typing import Optional, Sequence, Union, cast +import numpy as np import torch +from rich.console import Console +from rich.table import Table +import trlx.utils.logging as logging from trlx.data.configs import TRLConfig from trlx.data.ilql_types import ILQLBatch +from trlx.pipeline.offline_pipeline import ILQLRolloutStorage, tokenize_dialogue from trlx.trainer import register_trainer from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer from trlx.trainer.nn.ilql_models import CausalLMWithValueHeads, ILQLConfig from trlx.utils import to_device +logger = logging.get_logger(__name__) + @register_trainer class AccelerateILQLTrainer(AccelerateRLTrainer): @@ -95,3 +103,73 @@ def save_pretrained(self, directory: Optional[str] = None): "`AccelerateILQLTrainer` does not currently support automatic saving " "with `transformers.PreTrainedModel.save_pretrained`." ) + + def make_experience(self, samples, rewards, max_length=2048): + """ + Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer + """ + logger.info("Collecting rollouts") + + if self.tokenizer: + samples = [tokenize_dialogue(s, self.tokenizer, max_length) for s in samples] + + all_input_ids = [] + all_actions_ixs = [] + all_states_ixs = [] + all_dones = [] + for sample in samples: + length = 0 + all_input_ids.append(torch.tensor(sum(sample, []))) + isoutput = False + actions_ixs = [] + for phrase in sample: + if isoutput: + actions_ixs.append(torch.arange(length - 1, length + len(phrase) - 1)) + + length += len(phrase) + isoutput = not isoutput + + states_ixs = torch.hstack((*actions_ixs, torch.tensor(length - 1))) + all_dones.append(torch.tensor([1] * (len(states_ixs) - 1) + [0], dtype=int)) + all_actions_ixs.append(torch.hstack(actions_ixs)) + all_states_ixs.append(states_ixs) + + if self.tokenizer and os.environ.get("RANK", "0") == "0": + logger.info("Logging sample example") + prompt = self.tokenizer.decode(all_input_ids[0][: all_states_ixs[0][1]]) + response = self.tokenizer.decode(all_input_ids[0][all_states_ixs[0][1] :]) + columns = ["Prompt", "Response", "Reward"] + table = Table(*columns, title="Sample Example", show_lines=True) + table.add_row(prompt, response, str(rewards[0])) + Console().print(table) + + sample_lengths = np.array(list(map(len, all_input_ids))) + output_lengths = np.array(list(map(len, all_actions_ixs))) + prompt_lengths = sample_lengths - output_lengths + returns = torch.tensor(rewards, dtype=float) + + if os.environ.get("RANK", "0") == "0": + logger.info("Logging experience string statistics") + columns = ["Prompt Length", "Output Length", "Sample Length"] + table = Table(*columns, title="Experience String Stats (mean ∈ \[min, max])", show_lines=True) + row = [] + for lengths in [prompt_lengths, output_lengths, sample_lengths]: + row.append(f"{lengths.mean():.2f} ∈ [{min(lengths)}, {max(lengths)}]") + table.add_row(*row) + Console().print(table) + + returns = (returns - returns.mean()) / (returns.std() + 1e-30) + rewards = [torch.zeros(len(x)) for x in all_actions_ixs] + for rs, ret in zip(rewards, returns): + rs[-1] = ret + + attention_mask = [torch.ones(len(x), dtype=int) for x in all_input_ids] + + self.store = ILQLRolloutStorage( + all_input_ids, + attention_mask, + rewards, + all_states_ixs, + all_actions_ixs, + all_dones, + ) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 69f3c3d68..4b4f31e62 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -1,14 +1,20 @@ import json import os import uuid +from time import time from typing import Callable, List, Optional +import ray import torch +import torch.nn.functional as F from torch.utils.data import DataLoader from transformers import AutoTokenizer +import trlx.utils.logging as logging +from trlx.data.accelerate_base_datatypes import PromptBatch from trlx.data.configs import TRLConfig -from trlx.data.ppo_types import PPORLBatch +from trlx.data.ppo_types import PPORLBatch, PPORLElement +from trlx.pipeline.offline_pipeline import PromptPipeline from trlx.pipeline.ppo_pipeline import PPORolloutStorage from trlx.trainer import register_trainer from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer @@ -18,26 +24,20 @@ FixedKLController, Seq2SeqLMHydraWithValueHead, ) -from trlx.utils.modeling import logprobs_of_labels +from trlx.utils import Clock +from trlx.utils.modeling import RunningMoments, logprobs_of_labels + +logger = logging.get_logger(__name__) @register_trainer class AcceleratePPOTrainer(AccelerateRLTrainer): - """PPO Accelerate Trainer - - Note this class is intwined with `PPOOrchestrator`. The PPO trainer must be - created first and then given as a parameter to the PPO orchestrator. The PPO - orchestrator then adds a `orch` property to the trainer and also sets a - `trainer` property on itself. This broadly has the effect of the - trainer class extending the orchestrator class (and thus having access to - the `orch.make_experience` method that creates rollouts). - """ + """PPO Accelerate Trainer""" reward_fn: Callable[[List[str], List[str], List[str]], List[float]] - tokenizer: AutoTokenizer - def __init__(self, config, **kwargs): + def __init__(self, config: TRLConfig, **kwargs): """PPO Accelerate Trainer initialization Args: @@ -53,11 +53,11 @@ def __init__(self, config, **kwargs): self.log_rollouts = False # Setup the rollout store - # Rollouts contain the prompt & response, log probs, values - # and rewards - from each rollout + # Rollouts contain the prompt & response, log probs, values and rewards - from each rollout self.store = PPORolloutStorage(self.tokenizer.pad_token_id) # Create the rollout store dataloader (for batching up rollouts) + # TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future rollout_loader: DataLoader = self.store.create_loader(self.config.train.batch_size, shuffle=True) # Prepare multi-GPU acceleration @@ -65,8 +65,12 @@ def __init__(self, config, **kwargs): self.model, self.opt, self.scheduler, rollout_loader ) - # Clear the rollout store - self.store.clear_history() + self.store.clear_history() # Clear the rollout store + + # Setup a reference model when hydra heads are not used + if not hasattr(self.model, "frozen_head"): + self.ref_model = self.get_arch(self.config) + self.ref_model.to(self.accelerator.device) # Setup the KL controller # This helps prevent large divergences in the controller (policy) @@ -107,6 +111,11 @@ def __init__(self, config, **kwargs): else: self.generate_experience_kwargs = None + # Setup stats tracker + self.running_moments = RunningMoments() + self.ref_mean = self.config.method.ref_mean + self.ref_std = self.config.method.ref_std + def get_arch(self, config: TRLConfig): """Get the model""" if config.model.model_arch_type == "seq2seq": @@ -202,9 +211,8 @@ def post_epoch_callback(self): if self.log_rollouts: self.store.export_history(location=self.rollout_logging_dir) self.store.clear_history() - self.orch.make_experience( - self.config.method.num_rollouts, self.iter_count - ) # Collect more rollouts for training + # Collect more rollouts for training + self.make_experience(self.config.method.num_rollouts, self.iter_count) def post_backward_callback(self): self.kl_ctl.update(self.approx_kl, n_steps=self.config.train.batch_size) @@ -226,3 +234,232 @@ def save_pretrained(self, directory: Optional[str] = None): directory = f"{self.config.train.checkpoint_dir}/hf_model" self.accelerator.unwrap_model(self.model).base_model.save_pretrained(directory) self.tokenizer.save_pretrained(directory) + + def add_prompt_pipeline(self, pipeline: PromptPipeline): + """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" + prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) + self.prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) + self.prompt_iterator = iter(self.prompt_dataloader) + + def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa: + """Make experiences + + Takes `chunk_size` number of prompts from `prompt_iterator`, samples + from the model and then computes the KL against a reference model. Finally it + then appends PPOElements to trainer's `store`. + + Args: + num_rollouts: Number of rollouts to generate + iter_count: Total number of updates run (i.e. number of updates run for all batches & epochs) + """ + logger.info("Collecting rollouts") + tbar = logging.tqdm( + total=num_rollouts, + disable=os.environ.get("RANK", 0) != "0", + desc=f"[rollout 0 / {num_rollouts}]", + # Lower progress bar by 1 if we're in WARNING mode or above to avoid hiding high priority progress + # bars (e.g. loss progress in trainers) + position=logging.get_verbosity() >= logging.WARNING, + # Leave progress bar if we're in INFO mode or lower to avoid spamming in suppressed verbosity levels + leave=logging.get_verbosity() < logging.WARNING, + ) + + ppo_rl_elements = [] + stats = {} + clock = Clock() + + while len(ppo_rl_elements) < num_rollouts: + # Get next batch in prompt dataset and refresh if exhausted + # TOOD (jon-tow): Make `prompt_dataloader` a cyclic/infinite DataLoader to not require manually + # "refreshing" the contents of the `prompt_iterator` + try: + batch: PromptBatch = next(self.prompt_iterator) + except StopIteration: + self.prompt_iterator = iter(self.prompt_dataloader) + batch = next(self.prompt_iterator) + + exp_generate_time = time() + + # Generate samples from the language model (similar to using HuggingFace `generate` method) + samples = self.generate(**batch) + stats["time/exp_generate"] = time() - exp_generate_time + + prompt_tensors = batch.input_ids + device = samples.device + str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples) + + # Pad the sample outputs + outputs = self.tokenizer(str_outputs).input_ids + outputs = list(map(torch.LongTensor, outputs)) + maxsize = max(map(len, outputs)) + outputs = [ + F.pad( + output, + (0, maxsize - len(output)), + value=self.tokenizer.pad_token_id, + ) + for output in outputs + ] + sample_outputs = torch.vstack(outputs).to(device) + + exp_score_time = time() + + scores = torch.tensor( + self.reward_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + ), + dtype=torch.float, + ).to(device) + stats["time/exp_score"] = time() - exp_score_time + + # store statistics of the initial rollout as reference + if self.ref_mean is None: + self.ref_mean, self.ref_std = scores.mean(), scores.std() + all_scores_mean, all_scores_std = self.running_moments.update(scores) + stats["exp_scores/mean"] = all_scores_mean + stats["exp_scores/std"] = all_scores_std + stats["exp_scores/running_mean"] = self.running_moments.mean + stats["exp_scores/running_std"] = self.running_moments.std + + if self.config.method.scale_reward == "running": + scores /= self.running_moments.std + elif self.config.method.scale_reward == "ref": + scores /= self.ref_std + + clip_reward = self.config.method.cliprange_reward + if clip_reward: + scores = torch.clip(scores, -clip_reward, clip_reward) + + # Precompute logprobs, values + if self.config.model.model_arch_type == "seq2seq": + attention_mask = batch.attention_mask.to(device) + prompt_tensors = batch.input_ids.to(device) + with torch.no_grad(): + outputs = self.model( + input_ids=prompt_tensors, + attention_mask=attention_mask, + decoder_input_ids=sample_outputs, + ) + logits = outputs.logits + values = outputs.value + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + input_ids=prompt_tensors, + attention_mask=attention_mask, + decoder_input_ids=sample_outputs, + ) + else: + ref_logits = self.ref_model( + input_ids=prompt_tensors, + attention_mask=attention_mask, + decoder_input_ids=sample_outputs, + ).logits + else: + all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) + attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) + with torch.no_grad(): + logits, *_, values = self.model( + all_tokens, + attention_mask=attention_mask, + ) + # TODO(dahoas): When hydra model works need to also support generation on hydra head + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + all_tokens, + attention_mask=attention_mask, + return_dict=False, + ) + else: + ref_logits, _, *_ = self.ref_model( + all_tokens, + attention_mask=attention_mask, + return_dict=False, + ) + ref_logits = ref_logits.to(device) + + if self.config.model.model_arch_type == "seq2seq": + logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) + else: + logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) + + n_samples: int = samples.shape[0] + logprobs = logprobs.cpu() + ref_logprobs = ref_logprobs.cpu() + prompt_tensors = prompt_tensors.cpu() + sample_outputs = sample_outputs.cpu() + + # Estimate the KL divergence between the model and reference model + if self.config.model.model_arch_type == "seq2seq": + # Skip the beginning of sequence token + start = 1 + + # Get the number of non-padding tokens for each sample + # This assumes all padding is on the right side + padding_token: int = 0 + ends = (sample_outputs[:, start:] != padding_token).sum(1) + + # Get the logprobs and values, for tokens that are not padding + # or beginning of sequences tokens. These are from the model + # (not the reference model) + all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] + all_values = [values[ix, start - 1 : ends[ix] - 1] for ix in range(n_samples)] + + kl_divergence_estimate: List[torch.Tensor] = [ + -self.kl_ctl.value + * ( + logprobs[sample_idx, start : ends[sample_idx]] + - ref_logprobs[sample_idx, start : ends[sample_idx]] + ) + for sample_idx in range(n_samples) + ] + + # Else if not seq2seq (i.e. causal) + else: + values = values.cpu()[:, :-1] + start = prompt_tensors.shape[1] - 1 + ends = start + attention_mask[:, start:].sum(1) + all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] + all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] + + kl_divergence_estimate = -self.kl_ctl.value * (logprobs - ref_logprobs) + kl_divergence_estimate = [rs[start : ends[ix]] for ix, rs in enumerate(kl_divergence_estimate)] + + rollout_count = 0 + + for sample_idx in range(n_samples): + sample_kl_divergence_estimate = kl_divergence_estimate[sample_idx] + + if len(sample_kl_divergence_estimate) == 0 or len(all_logprobs[sample_idx]) == 0: + continue + + rewards = sample_kl_divergence_estimate + rewards[-1] += scores[sample_idx].cpu() + + ppo_rl_elements.append( + PPORLElement( + query_tensor=prompt_tensors[sample_idx], + response_tensor=sample_outputs[sample_idx], + logprobs=all_logprobs[sample_idx], + values=all_values[sample_idx], + rewards=rewards, + ) + ) + + rollout_count += 1 + exp_time = clock.tick() + tbar.set_description(f"[rollout {len(ppo_rl_elements)} / {num_rollouts}]") + tbar.update(min(rollout_count, num_rollouts)) + tbar.close() + + stats["kl_ctl_value"] = self.kl_ctl.value + stats["time/exp"] = exp_time + + if not ray.is_initialized(): + self.accelerator.log(stats, step=iter_count) + + # Push samples and rewards to trainer's rollout storage + self.push_to_store(ppo_rl_elements) diff --git a/trlx/trainer/nemo_ilql_trainer.py b/trlx/trainer/nemo_ilql_trainer.py index a027bf930..04729bd50 100644 --- a/trlx/trainer/nemo_ilql_trainer.py +++ b/trlx/trainer/nemo_ilql_trainer.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Iterable, Sequence, Union, cast +import numpy as np import torch from nemo.collections.nlp.parts.nlp_overrides import ( GradScaler, @@ -8,7 +9,7 @@ NLPDDPStrategy, PipelineMixedPrecisionPlugin, ) -from nemo.utils import logging +from nemo.utils import get_rank, logging from nemo.utils.exp_manager import StatelessTimer, exp_manager from omegaconf.omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer @@ -16,10 +17,16 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import ( CheckpointConnector, ) +from rich.console import Console +from rich.table import Table from trlx.data.configs import TRLConfig from trlx.data.ilql_types import ILQLBatch, ILQLElement, flatten_dataclass -from trlx.pipeline.offline_pipeline import ILQLRolloutStorage, ilql_collate_fn +from trlx.pipeline.offline_pipeline import ( + ILQLRolloutStorage, + ilql_collate_fn, + tokenize_dialogue, +) from trlx.trainer import register_trainer from trlx.trainer.nemo.gpt import ILQLGPT from trlx.trainer.nn.ilql_models import ILQLConfig @@ -103,6 +110,7 @@ class NeMoILQLTrainer(BaseRLTrainer): def __init__( self, config: TRLConfig, + reward_fn=None, logit_mask=None, metric_fn=None, stop_sequences=None, @@ -208,3 +216,73 @@ def eval_collate(elems): torch.set_float32_matmul_precision("medium") self.trainer.fit(self.model) + + def make_experience(self, samples, rewards, max_length=2048): + """ + Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer + """ + logging.info("Collecting rollouts") + + if self.tokenizer: + samples = [tokenize_dialogue(s, self.tokenizer, max_length) for s in samples] + + all_input_ids = [] + all_actions_ixs = [] + all_states_ixs = [] + all_dones = [] + for sample in samples: + length = 0 + all_input_ids.append(torch.tensor(sum(sample, []))) + isoutput = False + actions_ixs = [] + for phrase in sample: + if isoutput: + actions_ixs.append(torch.arange(length - 1, length + len(phrase) - 1)) + + length += len(phrase) + isoutput = not isoutput + + states_ixs = torch.hstack((*actions_ixs, torch.tensor(length - 1))) + all_dones.append(torch.tensor([1] * (len(states_ixs) - 1) + [0], dtype=int)) + all_actions_ixs.append(torch.hstack(actions_ixs)) + all_states_ixs.append(states_ixs) + + if get_rank.is_global_rank_zero(): + logging.info("Logging sample example") + prompt = self.tokenizer.decode(all_input_ids[0][: all_states_ixs[0][1]]) + response = self.tokenizer.decode(all_input_ids[0][all_states_ixs[0][1] :]) + columns = ["Prompt", "Response", "Reward"] + table = Table(*columns, title="Sample Example", show_lines=True) + table.add_row(prompt, response, str(rewards[0])) + Console().print(table) + + sample_lengths = np.array(list(map(len, all_input_ids))) + output_lengths = np.array(list(map(len, all_actions_ixs))) + prompt_lengths = sample_lengths - output_lengths + returns = torch.tensor(rewards, dtype=float) + + if get_rank.is_global_rank_zero(): + logging.info("Logging experience string statistics") + columns = ["Prompt Length", "Output Length", "Sample Length"] + table = Table(*columns, title="Experience String Stats (mean ∈ \[min, max])", show_lines=True) + row = [] + for lengths in [prompt_lengths, output_lengths, sample_lengths]: + row.append(f"{lengths.mean():.2f} ∈ [{min(lengths)}, {max(lengths)}]") + table.add_row(*row) + Console().print(table) + + returns = (returns - returns.mean()) / (returns.std() + 1e-30) + rewards = [torch.zeros(len(x)) for x in all_actions_ixs] + for rs, ret in zip(rewards, returns): + rs[-1] = ret + + attention_mask = [torch.ones(len(x), dtype=int) for x in all_input_ids] + + self.store = ILQLRolloutStorage( + all_input_ids, + attention_mask, + rewards, + all_states_ixs, + all_actions_ixs, + all_dones, + ) diff --git a/trlx/trlx.py b/trlx/trlx.py index 6a45b655d..10e3621d1 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -4,7 +4,7 @@ from trlx.data.configs import TRLConfig from trlx.utils import set_seed -from trlx.utils.loading import get_orchestrator, get_pipeline, get_trainer +from trlx.utils.loading import get_pipeline, get_trainer def train( # noqa: C901 @@ -89,8 +89,12 @@ def train( # noqa: C901 eval_prompts = prompts[:batch_size] pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer) - orch = get_orchestrator(config.train.orchestrator)(trainer, pipeline, chunk_size=config.method.chunk_size) - orch.make_experience(config.method.num_rollouts) + trainer.add_prompt_pipeline(pipeline) + + if eval_prompts is None: + eval_prompts = prompts[:batch_size] + + trainer.make_experience(config.method.num_rollouts) # Offline training from the collected samples (e.g. SFT, ILQL) elif samples: @@ -102,8 +106,7 @@ def train( # noqa: C901 eval_prompts = [trainer.tokenizer.bos_token] * batch_size if rewards: - orch = get_orchestrator(config.train.orchestrator)(trainer) - orch.make_experience(samples, rewards, config.train.seq_length) + trainer.make_experience(samples, rewards, config.train.seq_length) else: trainer.store = get_pipeline(config.train.pipeline)(samples, max_prompt_length, trainer.tokenizer) diff --git a/trlx/utils/__init__.py b/trlx/utils/__init__.py index cc081fcf7..4f33d1502 100644 --- a/trlx/utils/__init__.py +++ b/trlx/utils/__init__.py @@ -12,7 +12,6 @@ import torch from accelerate import Accelerator from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR -from torchtyping import TensorType def print_rank_0(*message): @@ -141,9 +140,6 @@ def get_scheduler_class(name: SchedulerName): raise ValueError(f"`{name}` is not a supported scheduler. " f"Supported schedulers are: {supported_schedulers}") -# Stats - - class Clock: """ Helper object for keeping track of time for computations. @@ -185,21 +181,6 @@ def get_stat(self, n_samp: int = 1000, reset: bool = False): return sec_per_samp * n_samp -# Sampling - - -def topk_mask(xs: TensorType["Batch", "Vocab"], k: int): - """ - Takes batched distribution over tokens and masks out scores for tokens - that are not in the top k for that distribution. - """ - - # Get topk per distribution - # For each dist, getting last value gives k-th largest - mintop = torch.topk(xs, k)[0][:, -1].unsqueeze(-1) - return torch.where(xs < mintop, -np.inf * torch.ones_like(xs), xs) - - def tree_map(f, tree: Any) -> Any: """ Apply function f to all leaves in tree diff --git a/trlx/utils/loading.py b/trlx/utils/loading.py index a71608da3..cdb49926c 100644 --- a/trlx/utils/loading.py +++ b/trlx/utils/loading.py @@ -1,10 +1,5 @@ from typing import Callable -# Register load orchestrators via module import -from trlx.orchestrator import _ORCH -from trlx.orchestrator.offline_orchestrator import OfflineOrchestrator -from trlx.orchestrator.ppo_orchestrator import PPOOrchestrator - # Register load pipelines via module import from trlx.pipeline import _DATAPIPELINE from trlx.pipeline.offline_pipeline import PromptPipeline @@ -48,14 +43,3 @@ def get_pipeline(name: str) -> Callable: return _DATAPIPELINE[name] else: raise Exception("Error: Trying to access a pipeline that has not been registered") - - -def get_orchestrator(name: str) -> Callable: - """ - Return constructor for specified orchestrator - """ - name = name.lower() - if name in _ORCH: - return _ORCH[name] - else: - raise Exception("Error: Trying to access an orchestrator that has not been registered")