diff --git a/examples/ppo_redemption.py b/examples/ppo_redemption.py new file mode 100644 index 000000000..84435b225 --- /dev/null +++ b/examples/ppo_redemption.py @@ -0,0 +1,82 @@ +# Generates positive movie reviews by tuning a pretrained model on IMDB dataset +# with a sentiment reward function +import json +import os +import sys +from typing import List + +import torch +from datasets import load_dataset +from transformers import pipeline + +import trlx +from trlx.data.default_configs import TRLConfig, default_ppo_config + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + + +def get_negative_score(scores): + return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"] + + +def main(hparams={}): + # Merge sweep config with default config if given + config = TRLConfig.update(default_ppo_config().to_dict(), hparams) + config.method.cliprange_reward = False + config.method.gen_kwargs["max_new_tokens"] = 70 + config.method.gen_kwargs["temperature"] = 0.3 + config.train.total_steps = 20000 + config.train.checkpoint_interval = 10000000 + # config.method.init_kl_coef = 0 + + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 + + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device=device, + ) + + def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], model_tok, **kwargs) -> List[float]: + # Reward positively for initially negative then positive review + # Reward functions should never receive padded text except for a singel EOS at the end + # Reward function should return token rewards for just the response + first_halves = [".".join(sample.split(".")[: len(sample.split(".")) // 2]) for sample in samples] + negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves))) + second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2 :]) for sample in samples] + positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves))) + text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)] + tok_scores = [] + for sample, prompt, response, text_score in zip(samples, prompts, outputs, text_scores): + toks = model_tok(response).input_ids + tok_score = [0] * len(toks) + # Hacky way of assigning intermediate score + tok_score[len(tok_score) // 2] = text_score[0] + tok_score[-1] = text_score[1] + tok_scores.append(tok_score) + return tok_scores + + # Take few words off of movies reviews as prompts + imdb = load_dataset("imdb", split="train+test") + prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] + + trlx.train( + reward_fn=dense_reward_fn, + prompts=prompts, + eval_prompts=["I don't know much about Hungarian underground"] * 256, + config=config, + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 5277d7010..3acee97ab 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -49,11 +49,13 @@ def default_ppo_config(): ref_mean=None, ref_std=None, cliprange_reward=10, + num_train_sequences=1, gen_kwargs=dict( max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True, + num_return_sequences=1, ), ), ) diff --git a/trlx/data/ppo_types.py b/trlx/data/ppo_types.py index 375d7d3ab..8cbe654a4 100644 --- a/trlx/data/ppo_types.py +++ b/trlx/data/ppo_types.py @@ -33,6 +33,7 @@ class PPORLElement: logprobs: TensorType["response_size"] values: TensorType["response_size"] rewards: TensorType["response_size"] + loss_mask: TensorType["response_size"] @dataclass @@ -54,6 +55,9 @@ class PPORLBatch: :param rewards: A batch of rewards :type rewards: torch.Tensor + + :param loss_masks: A mask for tokens during the loss computation + :type loss_masks: torch.Tensor """ query_tensors: TensorType["batch_size", "query_size"] @@ -61,3 +65,4 @@ class PPORLBatch: logprobs: TensorType["batch_size", "response_size"] values: TensorType["batch_size", "response_size"] rewards: TensorType["batch_size", "response_size"] + loss_masks: TensorType["batch_size", "response_size"] diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 82d3ec637..261f86553 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -112,6 +112,9 @@ class PPOConfig(MethodConfig): :param gen_experience_kwargs: if this is not None, then the experience is generated using this :type gen_experience_kwargs: Dict[str, Any] + + :param num_train_sequences: top_k of n sampled sequences from prompt + :type num_train_sequences: int """ ppo_epochs: int @@ -131,12 +134,15 @@ class PPOConfig(MethodConfig): cliprange_reward: float gen_kwargs: dict gen_experience_kwargs: Optional[dict] = None + num_train_sequences: int = 1 + dist_ref_model: bool = False def get_advantages_and_returns( self, values: TensorType["batch_size", "response_size"], rewards: TensorType["batch_size", "response_size"], response_length: int, + mask: TensorType["batch_size", "response_size"], use_whitening: Optional[bool] = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """Function that computes advantages and returns from rewards and values. @@ -168,7 +174,7 @@ def get_advantages_and_returns( advantages = torch.stack(advantages_reversed[::-1], dim=1) returns = advantages + values if use_whitening: - advantages = whiten(advantages) + advantages = whiten(advantages, mask) return advantages.detach(), returns def loss( diff --git a/trlx/pipeline/ppo_pipeline.py b/trlx/pipeline/ppo_pipeline.py index 7bcfebedc..1f8921a0f 100644 --- a/trlx/pipeline/ppo_pipeline.py +++ b/trlx/pipeline/ppo_pipeline.py @@ -47,6 +47,7 @@ def ppo_collate_fn(padding_side: str, pad_token_id: int, elems: Iterable[PPORLEl padding_value=0.0, batch_first=True, ), + pad_sequence([elem.loss_mask for elem in elems], batch_first=True, padding_value=0.0), ) diff --git a/trlx/trainer/__init__.py b/trlx/trainer/__init__.py index 8e0d239df..1e53938d5 100644 --- a/trlx/trainer/__init__.py +++ b/trlx/trainer/__init__.py @@ -41,6 +41,7 @@ def __init__( logit_mask=None, stop_sequences=None, train_mode=False, + inference_pipeline=None, ): self.store: BaseRolloutStore = None self.config = config diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 5c82335c0..56aa295fc 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -4,6 +4,7 @@ import sys from abc import abstractmethod from contextlib import contextmanager +from copy import copy from time import time from typing import Dict, List, Optional, Tuple @@ -13,6 +14,7 @@ from ray.air import session from rich.console import Console from rich.table import Table +from torch.nn.utils.rnn import pad_sequence from transformers import AutoTokenizer import trlx.utils.logging as logging @@ -25,6 +27,7 @@ get_git_tag, get_optimizer_class, get_scheduler_class, + remove_bos, significant, ) from trlx.utils.modeling import ( @@ -67,12 +70,15 @@ def __init__(self, config, **kwargs): # noqa: C901 self.opt = self.setup_optimizer() self.scheduler = self.setup_scheduler() - self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path) + self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path, use_fast=True) self.tokenizer.padding_side = config.tokenizer.padding_side self.tokenizer.truncation_side = config.tokenizer.truncation_side self.tokenizer.sep_token = "" - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = "<|padding|>" + if config.model.model_arch_type != "seq2seq": + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + # Add eos token to self.stop_sequences + self.stop_sequences.append(self.tokenizer.eos_token) script_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0] if not isinstance(config.model.model_path, str): @@ -141,6 +147,12 @@ def __init__(self, config, **kwargs): # noqa: C901 else: self.generate_sweep_kwarg = (k, v) + # Setup inference pipeline for model inference during training + if kwargs.get("inference_pipeline", None): + self.inference_pipeline = kwargs.get("inference_pipeline")( + self.model, self.accelerator.device, self.tokenizer + ) + def setup_model(self): """ Returns a model derived from an instance's TRLConfig @@ -201,82 +213,173 @@ def decode( prompts: List[torch.LongTensor], samples: List[torch.LongTensor], prompt_sizes: torch.LongTensor = None, - append_eos_token: bool = False, - ) -> Tuple[List[str], List[str], List[str]]: + append_eos_token: bool = True, + ) -> Tuple[List[str], List[str], List[str], List[torch.LongTensor], List[torch.LongTensor], List[torch.LongTensor]]: """ - Decode tensor generations into lists of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]) + Decode tensor generations with stopping criteria into lists of strings: + (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]), + and remove padding from samples, responses. Note prompts maybe sometimes be right padded, as well as samples """ if prompt_sizes is None: # Assuming prompts were left-padded prompt_sizes = [prompts.shape[1]] * len(prompts) str_samples, str_prompts, str_outputs = [], [], [] + tok_samples, tok_prompts, tok_outputs = [], [], [] for prompt, sample, prompt_size in zip(prompts, samples, prompt_sizes): if self.config.model.model_arch_type == "seq2seq": output_start_ix = 0 else: output_start_ix = prompt_size - str_prompt = self.tokenizer.decode(prompt[:prompt_size], skip_special_tokens=True) - str_output = self.tokenizer.decode(sample[output_start_ix:], skip_special_tokens=True) - # Trim outputs up to `self.stop_sequences` if any are present - trimmed = False - if self.stop_sequences: - for stop in self.stop_sequences: - stop_ix = str_output.find(stop) - if stop_ix >= 0: - str_output = str_output[:stop_ix].rstrip() - trimmed = True - - # Recover the last if it was present in the original sample - # or add one if it was trimmed with `self.stop_sequences`. - # When a generation ended due to `max_new_tokens` exhaustion, - # only then or token would not be present in the original sample at the end - if append_eos_token and ( - trimmed or sample[-1] == self.tokenizer.eos_token_id or sample[-1] == self.tokenizer.pad_token_id - ): - str_output += self.tokenizer.eos_token + # We must decode by skipping padding in the middle with skip_special_tokens + tok_prompt = prompt[:prompt_size].cpu() + str_prompt = self.tokenizer.decode(tok_prompt, skip_special_tokens=True) + + tok_output = sample[output_start_ix:] + # Get prefix corresponding to text + EOS token tag + end = tok_output.not_equal(self.tokenizer.eos_token_id).sum() + 1 + tok_output = tok_output[:end].cpu() + str_output = self.tokenizer.decode(tok_output, skip_special_tokens=True) + # Add first EOS token back + str_output += self.tokenizer.eos_token str_prompts.append(str_prompt) str_outputs.append(str_output) + tok_prompts.append(tok_prompt) + tok_outputs.append(tok_output) if self.config.model.model_arch_type == "seq2seq": sample = str_prompt + self.tokenizer.sep_token + str_output + tok_sample = torch.cat( + [tok_prompt, torch.tensor(self.tokenizer.sep_token_id, dtype=torch.long).view((1,)), tok_output] + ) else: sample = str_prompt + str_output + tok_sample = torch.cat([tok_prompt, tok_output], dim=0) str_samples.append(sample) + tok_samples.append(tok_sample) - return str_samples, str_prompts, str_outputs + return str_samples, str_prompts, str_outputs, tok_samples, tok_prompts, tok_outputs - def generate(self, input_ids, attention_mask=None, **kwargs): - """Wraps hf's `generate` adding some specific method's defaults""" - input_ids = input_ids.to(self.accelerator.device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.accelerator.device) - if self.generate_experience_kwargs is not None: - kwargs = dict(self.generate_experience_kwargs, **kwargs) - else: - kwargs = dict(self.generate_kwargs, **kwargs) + def batched_generate(self, input_ids, attention_mask, chunk_size, generate_kwargs): + # Chunk input_ids and attention_mask + input_ids = input_ids.split(chunk_size, dim=0) + attention_mask = attention_mask.split(chunk_size, dim=0) if attention_mask is not None else None - with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs + all_tok_samples = [] + all_loss_masks = [] + for chunk_idx in range(len(input_ids)): + input_ids_chunk = input_ids[chunk_idx].to(self.accelerator.device) + attention_mask_chunk = ( + attention_mask[chunk_idx].to(self.accelerator.device) if attention_mask is not None else None ) - def generate_eval(self, input_ids, attention_mask=None, **kwargs): - """Wraps hf's `generate` adding some specific method's defaults""" - input_ids = input_ids.to(self.accelerator.device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.accelerator.device) + # Sample trajectory from model + with torch.no_grad(): + tok_samples = ( + self.accelerator.unwrap_model(self.model) + .generate(input_ids=input_ids_chunk, attention_mask=attention_mask_chunk, **generate_kwargs) + .cpu() + ) - kwargs = dict(self.generate_kwargs, **kwargs) + # Decode and apply stopping criteria + stopped_tok_samples = [] + loss_masks = [] + for tok_prompt, tok_sample in zip(input_ids_chunk, tok_samples): + tok_output = tok_sample[tok_prompt.shape[0] :] + str_output = self.tokenizer.decode(tok_output) + if self.stop_sequences: + for stop in self.stop_sequences: + stop_ix = str_output.find(stop) + if stop_ix >= 0: + str_output = str_output[:stop_ix].rstrip() + stopped_tok_output = self.tokenizer(str_output, return_tensors="pt").input_ids[0] + # Remove BOS tok from output if present + stopped_tok_output = remove_bos(stopped_tok_output, self.tokenizer) + # Concat prompt, output to get sample + stopped_tok_sample = torch.cat([tok_prompt.cpu(), stopped_tok_output], dim=0) + stopped_tok_samples.append(stopped_tok_sample) + # Loss mask will be applied to entire output sequence + mask = torch.ones_like(stopped_tok_output) + loss_masks.append(mask) + + all_tok_samples += stopped_tok_samples + all_loss_masks += loss_masks + + return all_tok_samples, all_loss_masks + + def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): + """Wraps hf's `generate` adding some specific method's defaults""" + # Update generate kwargs with default generate_kwargs + generate_kwargs = copy(self.generate_kwargs) + generate_kwargs.update(kwargs) + + # Update max_new_tokens to respect max_seq_length + prompt_length = input_ids.shape[1] + if generate_kwargs.get("max_new_tokens") is not None: + generate_kwargs["max_new_tokens"] = min( + max(self.max_length - prompt_length, 0), generate_kwargs["max_new_tokens"] + ) + else: + generate_kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) + + # Set default num_return_sequences value + if generate_kwargs.get("num_return_sequences") is None: + generate_kwargs["num_return_sequences"] = 1 + + # Repeat input_ids, attention_mask by 'num_return_sequences' times + num_return_sequences = generate_kwargs.pop( + "num_return_sequences" + ) # Pop to hide from unwrapped_model.generate call + input_ids = input_ids.repeat_interleave(num_return_sequences, dim=0) + attention_mask = ( + attention_mask.repeat_interleave(num_return_sequences, dim=0) if attention_mask is not None else None + ) - with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs + chunk_size = input_ids.shape[0] if chunk_size is None else chunk_size + + # Check for inference pipeline + use_inference_pipeline = generate_kwargs.get("use_inference_pipeline") + if use_inference_pipeline is not None: + generate_kwargs.pop("use_inference_pipeline") # Pop to hide from generator kwargs + + if use_inference_pipeline: + assert hasattr(self, "inference_pipeline") + # Remove all but temperature and sampling parameters from gen_kwargs + generate_kwargs = {k: v for k, v in generate_kwargs.items() if k == "do_sample" or k == "temperature"} + data = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "prompt": self.tokenizer.batch_decode(input_ids, skip_special_tokens=True), + "gen_kwargs": len(input_ids) * [generate_kwargs], + } + data = self.inference_pipeline(data) + # NOTE: Inference pipeline should ensure tok_samples have same shapes as loss_masks + all_tok_samples = data["tok_samples"] + all_loss_masks = data["loss_mask"] + else: + # If no inference pipeline then do batched generation locally + all_tok_samples, all_loss_masks = self.batched_generate( + input_ids, attention_mask, chunk_size, generate_kwargs ) + # Convert to list of 1-d tensors to be padded + # Add extra EOS token w/ loss_mask[-1] == 0 to prevent empty responses + all_tok_samples = [ + torch.cat([tok_sample, torch.tensor(self.tokenizer.eos_token_id).view((1,)).long()], dim=0) + for tok_sample in all_tok_samples + ] + all_loss_masks = [torch.cat([m, torch.tensor(0).view((1,))], dim=0) for m in all_loss_masks] + + # Concat/pad samples + all_tok_samples = pad_sequence( + all_tok_samples, batch_first=True, padding_value=self.tokenizer.eos_token_id + ).squeeze() + all_loss_masks = pad_sequence(all_loss_masks, batch_first=True, padding_value=0.0) + return all_tok_samples, all_loss_masks + def save_pretrained(self, directory: Optional[str] = None, **kwargs): """Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for later use. @@ -376,12 +479,25 @@ def evaluate(self): # noqa: C901 generate_time = time() for i_prompt, prompts in enumerate(self.eval_dataloader): metadata = {k: v for k, v in prompts.items() if k != "input_ids" and k != "attention_mask"} + chunk_size = self.config.method.chunk_size if hasattr(self.config.method, "chunk_size") else None if self.generate_sweep_kwarg: - samples = self.generate_eval( - prompts["input_ids"], prompts["attention_mask"], **{gen_sweep_arg: gen_sweep_value} + samples, _ = self.generate( + prompts["input_ids"], + prompts["attention_mask"], + chunk_size=chunk_size, + **{gen_sweep_arg: gen_sweep_value}, ) else: - samples = self.generate_eval(prompts["input_ids"], prompts["attention_mask"]) + samples, _ = self.generate(prompts["input_ids"], prompts["attention_mask"], chunk_size=chunk_size) + samples = samples.to(self.accelerator.device) + + # Repeat prompts, metadata num_return_sequence times + num_return_sequences = 1 + if self.generate_kwargs.get("num_return_sequences") is not None: + num_return_sequences = self.generate_kwargs["num_return_sequences"] + prompts["input_ids"] = prompts["input_ids"].repeat_interleave(num_return_sequences, dim=0) + prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(num_return_sequences, dim=0) + metadata = {k: self.repeat_interleave(v, num_return_sequences) for k, v in metadata.items()} # TODO(reciprocated): this should be moved into `decode` # but that needs to be synced with indexing in `make_experience` @@ -396,9 +512,9 @@ def evaluate(self): # noqa: C901 pad_index=self.tokenizer.pad_token_id, ) ) - all_samples.extend(samples.tolist()) - all_prompts.extend(prompts.tolist()) - all_prompt_sizes.extend(prompt_sizes.tolist()) + all_samples.extend(samples) + all_prompts.extend(prompts) + all_prompt_sizes.extend(prompt_sizes) metadata = gather_dict(metadata, self.accelerator.gradient_state) all_metadata.append(metadata) @@ -414,7 +530,9 @@ def evaluate(self): # noqa: C901 stats["time/generate"] = time() - generate_time if self.accelerator.is_main_process: - str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes) + str_samples, str_prompts, str_outputs, tok_samples, tok_prompts, tok_outputs = self.decode( + all_prompts, all_samples, all_prompt_sizes + ) columns = ["prompt", "output"] columns_data = [str_prompts, str_outputs] @@ -427,10 +545,25 @@ def evaluate(self): # noqa: C901 # in online setting, compute the reward for validation if self.reward_fn: logger.info("Computing rewards") - rewards = torch.tensor( - self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata), - dtype=float, + rewards = self.reward_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + tok_samples=tok_samples, + tok_prompts=tok_prompts, + tok_outputs=tok_outputs, + model_tok=self.tokenizer, + **metadata, ) + # Remove kl terms from reward + if hasattr(self, "dist_ref_model") and self.dist_ref_model: + rewards = [[r[0] for r in reward] for reward in rewards] + if type(rewards[0]) is torch.Tensor: + rewards = torch.tensor([reward.sum().item() for reward in rewards], dtype=float) + elif type(rewards[0]) is list: + rewards = torch.tensor([sum(reward) for reward in rewards]) + else: + rewards = torch.tensor(rewards) mean_reward = rewards.mean().item() columns.append("reward") if not isinstance(rewards, list): @@ -442,7 +575,16 @@ def evaluate(self): # noqa: C901 if self.metric_fn: logger.info("Computing metrics") metric_time = time() - metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata) + metrics = self.metric_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + tok_samples=tok_samples, + tok_prompts=tok_prompts, + tok_outputs=tok_outputs, + model_tok=self.tokenizer, + **metadata, + ) stats["time/metric"] = time() - metric_time mean_metrics = { @@ -633,6 +775,15 @@ def learn(self): # noqa: C901 self.post_epoch_callback() tbar.close() + @staticmethod + def repeat_interleave(l, n): + if type(l) is torch.Tensor: + l = l.repeat_interleave(n, dim=0) + elif type(l) is list: + l = [[s] * n for s in l] + l = [item for sublist in l for item in sublist] + return l + @abstractmethod def get_arch(self, config: TRLConfig): """Returns a specific wrapper of the decoder architecture""" diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index a3af9aa3f..a61912d18 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -4,9 +4,10 @@ from time import time from typing import Callable, List +import numpy as np import torch -import torch.nn.functional as F import transformers +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -67,8 +68,9 @@ def __init__(self, config: TRLConfig, **kwargs): self.store.clear_history() # Clear the rollout store - # Set up a reference model when hydra heads are not used - if not hasattr(self.model, "frozen_head") and not self.model.peft_type: + # Setup a reference model when hydra heads and distributed ref model are not used + self.dist_ref_model = config.method.dist_ref_model + if not hasattr(self.model, "frozen_head") and not self.model.peft_type and not self.dist_ref_model: self.ref_model = self.get_arch(self.config) self.ref_model.to(self.accelerator.device) self.ref_model.eval() @@ -130,9 +132,14 @@ def loss(self, batch: PPORLBatch): old_logprobs = batch.logprobs.to(self.accelerator.device) old_values = batch.values.to(self.accelerator.device) old_rewards = batch.rewards.to(self.accelerator.device) + loss_masks = batch.loss_masks.to(self.accelerator.device) response_length = old_rewards.shape[1] - advantages, returns = self.config.method.get_advantages_and_returns(old_values, old_rewards, response_length) + # TODO: loss_mask should affect advantages if discount < 1, GAE cannot be used (lam=1) + # NOTE: Rewards from KL on masked tokens should already be zeroed out + advantages, returns = self.config.method.get_advantages_and_returns( + old_values, old_rewards, response_length, loss_masks + ) if self.config.model.model_arch_type == "seq2seq": input_ids = query_tensors @@ -160,14 +167,12 @@ def loss(self, batch: PPORLBatch): logprobs, values_pred, mask = ( logprobs[:, start:end], values_pred[:, start:end], - mask[:, start + 1 : end + 1], + mask[:, start:end], ) else: tokens = torch.cat((query_tensors, response_tensors), dim=1) attention_mask = tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device) - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - outputs = self.model(tokens, attention_mask, return_dict=True, position_ids=position_ids) + outputs = self.model(tokens, attention_mask, return_dict=True) logits = outputs.logits values_pred = outputs.value values_pred = values_pred[:, :-1] @@ -178,7 +183,7 @@ def loss(self, batch: PPORLBatch): logprobs, values_pred, mask = ( logprobs[:, start:end], values_pred[:, start:end], - attention_mask[:, start + 1 : end + 1], + loss_masks, ) loss, stats = self.config.method.loss( @@ -228,13 +233,15 @@ def prepare_learning(self): self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=False) + self.make_experience(self.config.method.num_rollouts) + self.n_updates_per_batch = self.config.method.ppo_epochs self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) def add_prompt_pipeline(self, pipeline: PromptPipeline): """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" - prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) + prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=False) prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) self.prompt_iterator = infinite_dataloader(prompt_dataloader) @@ -265,6 +272,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ppo_rl_elements = [] accumulated_stats = [] + # Require chunk_size * num_train_sequences divides num_rollouts + assert num_rollouts % (self.config.method.chunk_size * self.config.method.num_train_sequences) == 0 + while len(ppo_rl_elements) < num_rollouts: stats = {} # Get next batch in prompt dataset @@ -273,12 +283,24 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() # Generate samples from the language model (similar to using HuggingFace `generate` method) - samples = self.generate(batch["input_ids"], batch["attention_mask"]) + samples = self.generate( + batch["input_ids"], + batch["attention_mask"], + chunk_size=self.config.method.chunk_size, + **self.generate_experience_kwargs, + ) stats["time/rollout_generate"] = time() - rollout_generate_time - prompt_tensors = batch.input_ids - device = samples.device + num_return_sequences = ( + self.generate_experience_kwargs["num_return_sequences"] + if self.generate_experience_kwargs.get("num_return_sequences") is not None + else 1 + ) + prompt_tensors = batch.input_ids.repeat_interleave(num_return_sequences, dim=0) + # Prepare samples for reward_fn via broadcast to rank 0 + device = self.accelerator.device + samples = samples.to(device) prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) padded_samples = self.accelerator.pad_across_processes( samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False @@ -289,61 +311,102 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq gathered_samples = self.accelerator.gather(padded_samples) gathered_prompts = self.accelerator.gather(padded_prompts) gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) - metadata = gather_dict({k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}) + metadata = gather_dict( + { + k: self.repeat_interleave(v, num_return_sequences) + for k, v in batch.items() + if k != "input_ids" and k != "attention_mask" + } + ) if self.accelerator.is_main_process: - all_str_samples, all_str_prompts, all_str_outputs = self.decode( - gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True - ) + ( + all_str_samples, + all_str_prompts, + all_str_outputs, + all_tok_samples, + all_tok_prompts, + all_tok_outputs, + ) = self.decode(gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True) rollout_score_time = time() - all_scores = torch.tensor( - self.reward_fn( - samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, **metadata - ), - dtype=torch.float, - device=device, + # reward_fn should return list of rewards at each token per sample + # NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed) + # NOTE: reward_fn can optionally also compute the ref_logits. + # In this case size will be [batch_size, response_length, 2] + all_scores = self.reward_fn( + samples=all_str_samples, + prompts=all_str_prompts, + outputs=all_str_outputs, + tok_samples=all_tok_samples, + tok_prompts=all_tok_prompts, + tok_outputs=all_tok_outputs, + model_tok=self.tokenizer, + **metadata, ) + all_scores = [ + torch.tensor(score, dtype=torch.float, device=device).view( + -1, + ) + for score in all_scores + ] + # Pad -np.inf reward on the ends + all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-np.inf) + max_len = torch.tensor( + len(all_scores[0]) / (1 + int(self.dist_ref_model)), dtype=torch.long, device=device + ) + stats["time/rollout_score"] = time() - rollout_score_time - all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1).unbind()) + all_scores = list( + all_scores.reshape(self.accelerator.num_processes, len(samples), max_len, -1).unbind() + ) else: all_scores = None + max_len = torch.tensor(0, dtype=torch.long, device=device) if torch.distributed.is_initialized(): - scores = torch.empty(len(samples), device=device) + torch.distributed.broadcast(max_len, 0) + # Allocate extra space if scores include ref_logits + scores = torch.empty((len(samples), max_len, 1 + int(self.dist_ref_model)), device=device) torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() - str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) - - # Pad the sample outputs - outputs = self.tokenizer(str_outputs).input_ids - if self.config.model.model_arch_type == "seq2seq": - # add to the start of the output - for i in range(len(outputs)): - outputs[i] = [self.tokenizer.pad_token_id] + outputs[i] - - outputs = list(map(torch.LongTensor, outputs)) - maxsize = max(map(len, outputs)) - outputs = [ - F.pad( - output, - (0, maxsize - len(output)), - value=self.tokenizer.pad_token_id, - ) - for output in outputs - ] - sample_outputs = torch.vstack(outputs).to(device) + # Remove ref_logits from scores if present + if self.dist_ref_model: + all_ref_logprobs = scores[:, :, 1] + scores = scores[:, :, 0] + else: + all_ref_logprobs = None + scores = scores.squeeze(-1) + scores_mask = scores != -np.inf + # Remove infs so mask can be used if self.config.method.cliprange_reward: scores = torch.clip(scores, -self.config.method.cliprange_reward, self.config.method.cliprange_reward) + # Best-of-N Sampling. + train_indices = self.get_topk_indices( + input_tensor=scores_mask * scores, + window_size=num_return_sequences, + k=self.config.method.num_train_sequences, + device=device, + ) + scores = scores[train_indices] + scores_mask = scores_mask[train_indices] + samples = samples[train_indices] + prompt_tensors = prompt_tensors[train_indices] + loss_masks = loss_masks[train_indices.cpu()] + if all_ref_logprobs is not None: + all_ref_logprobs = all_ref_logprobs[train_indices] + # store statistics of the initial rollout as reference if self.ref_mean is None: - self.ref_mean, self.ref_std = scores.mean(), scores.std() - all_scores_mean, all_scores_std = self.running_moments.update(scores) + self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum( + dim=1 + ).std() + all_scores_mean, all_scores_std = self.running_moments.update(scores, scores_mask) stats["rollout_scores/mean"] = all_scores_mean.item() stats["rollout_scores/std"] = all_scores_std.item() stats["rollout_scores/running_mean"] = self.running_moments.mean.item() @@ -354,7 +417,33 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq elif self.config.method.scale_reward == "ref": scores /= self.ref_std + # Only use these samples, prompts, outputs to compute ppo stats + _, _, _, tok_samples, tok_prompts, tok_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) + tok_prompts = torch.stack(tok_prompts, dim=0) + + # Pad the sample outputs + # outputs = self.tokenizer(str_outputs).input_ids + # TODO: Why is this here? Should this be a sep token? + if self.config.model.model_arch_type == "seq2seq": + # add to the start of the output + for i in range(len(tok_outputs)): + tok_outputs[i] = [self.tokenizer.pad_token_id] + outputs[i].tolist() + + if self.config.model.model_arch_type == "seq2seq": + attention_mask = sample_outputs != self.tokenizer.pad_token_id + start = 0 + else: + # NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response + start = tok_prompts.shape[1] - 1 + + padded_tok_samples = pad_sequence(tok_samples, batch_first=True, padding_value=self.tokenizer.pad_token_id) + padded_tok_outputs = pad_sequence(tok_outputs, batch_first=True, padding_value=self.tokenizer.pad_token_id) + # Remove extra padding from loss masks (may occur if using BoN sampling) + loss_masks = loss_masks[:, : padded_tok_outputs.shape[1]] + attention_mask = padded_tok_samples.not_equal(self.tokenizer.pad_token_id).long() + # Precompute logprobs, values + # TODO: Come back to seq2seq if self.config.model.model_arch_type == "seq2seq": attention_mask = batch.attention_mask.to(device) prompt_tensors = batch.input_ids.to(device) @@ -386,88 +475,178 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq return_dict=True, ).logits else: + values_chunks = [] + log_probs_chunks = [] + ref_logprobs_chunks = [] all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - with torch.no_grad(): - logits, *_, values = self.model( - all_tokens, attention_mask=attention_mask, position_ids=position_ids - ) - # TODO(dahoas): When hydra model works need to also support generation on hydra head - if hasattr(self.model, "frozen_head") or self.model.peft_type: - ref_logits = self.model.forward_hydra( - all_tokens, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - ).logits + all_tokens_chunks = torch.chunk(all_tokens, chunks=self.config.method.gen_chunk_size, dim=0) + attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.gen_chunk_size, dim=0) + position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.gen_chunk_size, dim=0) + for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip( + all_tokens_chunks, attention_mask_chunks, position_ids_chunks + ): + all_tokens_chunk = all_tokens_chunk.to(device) + attention_mask_chunk = attention_mask_chunk.to(device) + position_ids_chunk = position_ids_chunk.to(device) + with torch.no_grad(): + logits, *_, values = self.model( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids_chunk, + ) + # If all_ref_logits is not None they have already been generated during call to reward_fn + if all_ref_logprobs is None: + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids_chunk, + return_dict=True, + ).logits + elif hasattr(self, "ref_model"): + ref_logits = self.ref_model( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids_chunk, + return_dict=True, + ).logits + ref_logits = ref_logits.to(device) + # If no ref model is provided then we compute no kl penalty + else: + ref_logits = logits.clone() + + if self.config.model.model_arch_type == "seq2seq": + logprobs = logprobs_of_labels(logits[:, start:-1, :], all_tokens_chunk[:, start + 1 :]) + ref_logprobs = logprobs_of_labels(ref_logits[:, start:-1, :], all_tokens_chunk[:, start + 1 :]) else: - ref_logits = self.ref_model( - all_tokens, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - ).logits - ref_logits = ref_logits.to(device) - - if self.config.model.model_arch_type == "seq2seq": - logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) + # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled + # So need to index at start = prompt_tensors.shape[1] - 1 which is + # the logprob corresponding to the first sampled token + # Indexing ends at -1 because the last logprob corresponds to an unsampled token + logprobs = logprobs_of_labels(logits[:, start:-1, :], all_tokens_chunk[:, start + 1 :]) + if all_ref_logprobs is None: + ref_logprobs = logprobs_of_labels( + ref_logits[:, start:-1, :], all_tokens_chunk[:, start + 1 :] + ) + + values_chunks.append(values.cpu()) + log_probs_chunks.append(logprobs.cpu()) + if all_ref_logprobs is None: + ref_logprobs_chunks.append(ref_logprobs.cpu()) + + # Remove values before v[start] (this is the value of the state before any tokens are sampled) + # and remove the last value v[-1] (this is a terminal state after all tokens have been generated with value 0) + values = torch.cat(values_chunks, dim=0)[:, start:-1] + logprobs = torch.cat(log_probs_chunks, dim=0) + attention_mask = attention_mask[:, start:].cpu() + + if all_ref_logprobs is None: + ref_logprobs = torch.cat(ref_logprobs_chunks, dim=0) + # all_ref_logprobs returned from reward already has prompt prefix removed else: - logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) - - n_samples: int = samples.shape[0] + # Remove (some) padding from distributed communication + # So arithmetic with logprobs can be done + ref_logprobs = all_ref_logprobs[:, : logprobs.shape[1]].cpu() # Estimate the KL divergence between the model and reference model - if self.config.model.model_arch_type == "seq2seq": - attention_mask = sample_outputs != self.tokenizer.pad_token_id - start = 0 - else: - start = prompt_tensors.shape[1] - 1 - - log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] + # NOTE: nan is interfering with kl estimates since 0 * nan = 0 + # Convert inf padding terms in ref_logprobs to number removable with attention mask mult + if logprobs.shape[1] != loss_masks.shape[1]: + raise ValueError( + f"Shape mismatch between logprobs and loss_masks:\n\ + logprobs: {logprobs.shape}\n\ + ref_logprobs: {ref_logprobs.shape}\n\ + loss_masks: {loss_masks.shape}\n\ + padded_tok_outputs: {padded_tok_outputs.shape}" + ) + log_ratio = (logprobs - torch.nan_to_num(ref_logprobs)) * loss_masks kl = log_ratio.exp() - 1 - log_ratio mean_kl_per_token = kl.mean() mean_kl = kl.sum(1).mean() + kl_penalties = self.kl_ctl.value * -log_ratio.cpu() - logprobs = logprobs.cpu() - ref_logprobs = ref_logprobs.cpu() - prompt_tensors = prompt_tensors.cpu() - sample_outputs = sample_outputs.cpu() - values = values.cpu()[:, :-1] + n_samples = padded_tok_samples.shape[0] + rollout_count = 0 # Get the logprobs and values, for tokens that are not padding, - # from the start of the prompt up to the token, while also including the latter + # from the end of the prompt up to the token, while also including the latter # (these are taken from the student model and not the reference model) - ends = start + attention_mask[:, start:].sum(1) + 1 - all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] - all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] - - kl_penalty = self.kl_ctl.value * -log_ratio.cpu() - kl_penalty = [xs[start : ends[ix]] for ix, xs in enumerate(kl_penalty)] - - rollout_count = 0 - + ends = attention_mask[:, 1:].sum(1) + 1 for sample_idx in range(n_samples): - rewards = kl_penalty[sample_idx] - rewards[-1] += scores[sample_idx].cpu() + value = values[sample_idx, : ends[sample_idx]] + logprob = logprobs[sample_idx, : ends[sample_idx]] + kl_penalty = kl_penalties[sample_idx, : ends[sample_idx]] + loss_mask = loss_masks[sample_idx, : ends[sample_idx]] + query_tensor = tok_prompts[sample_idx] + response_tensor = tok_outputs[sample_idx] + if ( + len(value) != len(logprob) + or len(logprob) != len(kl_penalty) + or len(kl_penalty) != len(response_tensor) + or len(response_tensor) != len(loss_mask) + ): + raise ValueError( + f"Length mismatch between value, logprob, kl, and response_tensor:\n\ + Value: {value.shape}, {value}\n\ + Logprob: {logprob.shape}, {logprob}\n\ + KL: {kl_penalty.shape}, {kl_penalty}\n\ + end: {ends[sample_idx]}\n\ + Response: {response_tensor.shape}, {response_tensor}, \ + {self.tokenizer.decode(response_tensor)}\n\ + Loss mask: {loss_mask.shape}, {loss_mask}" + ) + + # Zero out terms on masked tokens + kl_penalty = loss_mask * kl_penalty + + # Then add in rewards + if scores.shape[1] == 1: + # NOTE: Final reward given at EOS token following HHH practice + score = scores[sample_idx][0].cpu() + kl_penalty[-1] += score + rewards = kl_penalty + else: + score = scores[sample_idx] + score_right_padding = torch.sum(scores_mask[sample_idx]) + score = score[:score_right_padding].cpu() + if len(score) != len(kl_penalty): + raise ValueError( + f"Length mismatch between score and kl penalty:\n\ + Logprob: {logprob.shape}, {logprob}\n\ + kl_penalty: {kl_penalty.shape}, {kl_penalty}\n\ + Score: {score.shape}, {score}" + ) + rewards = kl_penalty + score + + if kl_penalty.isnan().any() or score.isnan().any(): + raise ValueError( + f"nan in tensor:\n\ + KL: {kl_penalty}\n\ + Score: {score}\n\ + logprob: {logprob}\n\ + ref logprob: {ref_logprobs[sample_idx][:ends[sample_idx]-1]}\n\ + mask: {attention_mask[sample_idx]}\n\ + kl ctl: {self.kl_ctl.value}" + ) ppo_rl_elements.append( PPORLElement( - query_tensor=prompt_tensors[sample_idx], - response_tensor=sample_outputs[sample_idx], - logprobs=all_logprobs[sample_idx], - values=all_values[sample_idx], + query_tensor=query_tensor, + response_tensor=response_tensor, + logprobs=logprob, + values=value, rewards=rewards, + loss_mask=loss_mask, ) ) rollout_count += 1 if torch.distributed.is_initialized(): - torch.distributed.all_reduce(mean_kl, torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce(mean_kl.to(self.accelerator.device), torch.distributed.ReduceOp.AVG) stats["time/rollout_time"] = clock.tick() stats["policy/sqrt_kl"] = torch.sqrt(mean_kl).item() @@ -485,3 +664,17 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # Push samples and rewards to trainer's rollout storage self.push_to_store(ppo_rl_elements) + + @staticmethod + def get_topk_indices(input_tensor, window_size: int, k: int, device="cpu"): + # Sum the scores along dim 1 + input_tensor = input_tensor.sum(1).unsqueeze(1) + # Use unfold to create the sliding windows + unfolded = input_tensor.unfold(0, window_size, window_size) + # Find the topk values and indices along the unfolded dimension + _, indices = torch.topk(unfolded, k, dim=2) + # Adjust indices to be relative to original tensor + indices = indices.squeeze(1) + torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to( + device + ).unsqueeze(1) + return indices.reshape(-1) diff --git a/trlx/trlx.py b/trlx/trlx.py index 7fbce94f4..445976c4e 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -21,6 +21,7 @@ def train( # noqa: C901 prompts: Optional[List[str]] = None, eval_prompts: Optional[List[str]] = None, metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None, + inference_pipeline=None, config: Optional[TRLConfig] = None, stop_sequences: Optional[List[str]] = [], ): @@ -80,6 +81,7 @@ def train( # noqa: C901 config=config, reward_fn=reward_fn, metric_fn=metric_fn, + inference_pipeline=inference_pipeline, stop_sequences=stop_sequences, **config.train.trainer_kwargs, ) diff --git a/trlx/utils/__init__.py b/trlx/utils/__init__.py index 784ec4437..b690cfe21 100644 --- a/trlx/utils/__init__.py +++ b/trlx/utils/__init__.py @@ -248,3 +248,12 @@ def infinite_dataloader(dataloader: Iterable, sampler=None) -> Iterable: epoch += 1 yield from dataloader + + +# Text processing utils + + +def remove_bos(toks, tokenizer): + if hasattr(tokenizer, "bos_token") and len(toks) > 0 and toks[0].item() == self.tokenizer.bos_token_id: + toks = toks[1:] + return toks diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 47688f553..29d6bf502 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -1,5 +1,5 @@ import functools -from typing import Dict, MutableMapping, Tuple, Union +from typing import Dict, MutableMapping, Optional, Tuple, Union import accelerate import numpy as np @@ -175,11 +175,13 @@ def hf_get_num_hidden_layers(config: transformers.PretrainedConfig) -> int: return findattr(config, num_hidden_layers_attrs) -def get_global_statistics(xs: torch.Tensor, group=None) -> Tuple[float, float, int]: +def get_global_statistics(xs: torch.Tensor, mask=None, group=None) -> Tuple[float, float, int]: """ Computes element-wise mean and variance of the tensor across processes """ - sum_and_count = torch.tensor([xs.sum(), xs.numel()], device=xs.device) + if mask is None: + mask = torch.ones_like(xs) + sum_and_count = torch.tensor([xs.sum(), mask.sum()], device=xs.device) dist.all_reduce(sum_and_count, dist.ReduceOp.SUM, group=group) global_sum, count = sum_and_count global_mean = global_sum / count @@ -190,15 +192,16 @@ def get_global_statistics(xs: torch.Tensor, group=None) -> Tuple[float, float, i return global_mean, global_var, count -def whiten(xs: torch.Tensor, shift_mean=True, distributed=True, group=None) -> torch.Tensor: +def whiten(xs: torch.Tensor, mask: torch.Tensor, shift_mean=True, distributed=True, group=None) -> torch.Tensor: """Whitens values""" if distributed and dist.is_initialized(): - mean, var, _ = get_global_statistics(xs, group=group) + mean, var, _ = get_global_statistics(xs, mask=mask, group=group) else: var, mean = torch.var_mean(xs) whitened = (xs - mean) * torch.rsqrt(var + 1e-8) if not shift_mean: + # TODO: Why not whitened += mean*torch.rsqrt(var+1e-8)? whitened += mean return whitened @@ -276,8 +279,11 @@ def __init__(self): self.var = 1 self.count = 1e-24 - def update(self, xs: torch.Tensor) -> Tuple[float, float]: + def update(self, xs: torch.Tensor, xs_mask: Optional[torch.Tensor] = None) -> Tuple[float, float]: """Updates running moments from batch's moments computed across ranks""" + if xs_mask is None: + xs_mask = torch.ones_like(xs) + xs = torch.sum(xs * xs_mask, dim=1) if dist.is_initialized(): xs_mean, xs_var, xs_count = get_global_statistics(xs) else: