diff --git a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py index 290e74b5b..67863bf7d 100755 --- a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py +++ b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py @@ -51,14 +51,14 @@ def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): for i in tqdm(range(len(prompts))): key = tokenizer.decode( - tokenizer(prompts[i], truncation=True, max_length=max_length)["input_ids"], + tokenizer(prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], skip_special_tokens=True, ) # get prompt like trlx's prompt prompt_label[key.strip()] = summaries[i] for i in tqdm(range(len(val_prompts))): key = tokenizer.decode( - tokenizer(val_prompts[i], truncation=True, max_length=max_length)["input_ids"], + tokenizer(val_prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], skip_special_tokens=True, ) # get prompt like trlx's prompt prompt_label[key.strip()] = val_summaries[i] diff --git a/examples/summarize_rlhf/trlx_gptj_text_summarization.py b/examples/summarize_rlhf/trlx_gptj_text_summarization.py index 001e157f0..3d9e3c5f3 100755 --- a/examples/summarize_rlhf/trlx_gptj_text_summarization.py +++ b/examples/summarize_rlhf/trlx_gptj_text_summarization.py @@ -67,12 +67,13 @@ def get_prompt_dataset(prompts, max_length): prompts[i].split("TL;DR:")[0], truncation=True, max_length=max_length - 5, # to make sure "TL;DR" dont get truncated + add_special_tokens=False, )["input_ids"], skip_special_tokens=True, ).strip() tmp = tmp + "\nTL;DR:" tmp = tokenizer.decode( - tokenizer(tmp, truncation=True, max_length=max_length)["input_ids"], + tokenizer(tmp, truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], skip_special_tokens=True, ).strip() formatted_prompts.append(tmp) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index bd70ad79c..dfd09dea8 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -3,7 +3,7 @@ import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader -from transformers import DataCollatorWithPadding +from transformers import DataCollatorWithPadding, PreTrainedTokenizer from trlx.data.ilql_types import ILQLBatch, ILQLElement from trlx.pipeline import BasePipeline, BaseRolloutStore, register_datapipeline @@ -23,7 +23,8 @@ def tokenize_dialogue(dialogue: Union[str, List[str]], tokenizer, max_length=204 ctx_length = max_length if tokenizer.truncation_side == "left": for phrase in reversed(dialogue): - tokens = tokenizer(phrase).input_ids[-ctx_length:] + # Manually added BOS and EOS above so we don't want to add special tokens here + tokens = tokenizer(phrase, add_special_tokens=False).input_ids[-ctx_length:] ctx_length -= len(tokens) out.insert(0, tokens) if ctx_length == 0: @@ -38,7 +39,8 @@ def tokenize_dialogue(dialogue: Union[str, List[str]], tokenizer, max_length=204 elif tokenizer.truncation_side == "right": for phrase in dialogue: - tokens = tokenizer(phrase).input_ids[:ctx_length] + # Manually added BOS and EOS above so we don't want to add special tokens here + tokens = tokenizer(phrase, add_special_tokens=False).input_ids[:ctx_length] ctx_length -= len(tokens) out.append(tokens) if ctx_length == 0: @@ -52,13 +54,20 @@ class PromptPipeline(BasePipeline): Tokenizes prompts, unless they are already tokenized, and truncates them to `max_prompt_length` from the right """ - def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer=None): + def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer: PreTrainedTokenizer): super().__init__() - model_inputs = tokenizer(prompts, truncation=True, padding=False, max_length=max_prompt_length) - prompts = model_inputs["input_ids"] + + model_inputs = tokenizer( + prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=False + ) + + prompts_tokens = model_inputs["input_ids"] attention_mask = model_inputs["attention_mask"] + self.tokenizer = tokenizer - self.prompts = [{"input_ids": prompt, "attention_mask": mask} for prompt, mask in zip(prompts, attention_mask)] + self.prompts = [ + {"input_ids": tokens, "attention_mask": mask} for tokens, mask in zip(prompts_tokens, attention_mask) + ] def __getitem__(self, ix: int): return self.prompts[ix] diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index a6acaccda..09b2c6eee 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -3,7 +3,7 @@ import sys from abc import abstractmethod from time import time -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Tuple import ray import torch @@ -175,25 +175,6 @@ def setup_scheduler(self): scheduler = scheduler_class(self.opt, **self.config.scheduler.kwargs) return scheduler - def tokenize(self, text: Union[Sequence[str], Sequence[torch.LongTensor]]): - """ - Tokenize a batch of text after adding bos token to each of the samples - """ - if isinstance(text[0], torch.LongTensor): - return text - - text = [self.tokenizer.bos_token + txt for txt in text] - return self.tokenizer( - text, - truncation=True, - max_length=self.config.seq_length, - return_tensors="pt", - # NOTE: We manually add special tokens (bos) above so we set this False - # to avoid models that automatically add special tokens (e.g. OPT) - # adding them twice more. - add_special_tokens=False, - ) - def decode( self, prompts: List[torch.LongTensor], diff --git a/trlx/trainer/accelerate_ilql_trainer.py b/trlx/trainer/accelerate_ilql_trainer.py index 1cab072d2..231b2c059 100644 --- a/trlx/trainer/accelerate_ilql_trainer.py +++ b/trlx/trainer/accelerate_ilql_trainer.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Sequence, Union, cast +from typing import Optional, cast import numpy as np import torch @@ -43,22 +43,6 @@ def get_arch(self, config): num_layers_unfrozen=config.model.num_layers_unfrozen, ) - def tokenize(self, texts: Union[Sequence[str], Sequence[torch.LongTensor]]): - if isinstance(texts[0], torch.LongTensor): - return texts - - tokenized = self.tokenizer( - [self.tokenizer.bos_token + x + self.tokenizer.eos_token for x in texts], - max_length=self.max_length, - truncation=True, - # NOTE: We manually add special tokens (bos) above so we set this False - # to avoid models that automatically add special tokens (e.g. OPT) - # adding them twice more. - add_special_tokens=False, - ) - input_ids = list(map(torch.as_tensor, tokenized.input_ids)) - return input_ids - def post_backward_callback(self): if self.iter_count % self.config.method.steps_for_target_q_sync == 0: self.accelerator.unwrap_model(self.model).sync_target_q_heads() diff --git a/trlx/trainer/nemo_ilql_trainer.py b/trlx/trainer/nemo_ilql_trainer.py index 04729bd50..c58cc3249 100644 --- a/trlx/trainer/nemo_ilql_trainer.py +++ b/trlx/trainer/nemo_ilql_trainer.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Iterable, Sequence, Union, cast +from typing import Iterable, Sequence, cast import numpy as np import torch @@ -156,22 +156,6 @@ def __init__( if stop_sequences is not None and len(stop_sequences) > 0: logging.warning(f"Ignoring stop_sequences {stop_sequences=}") - def tokenize(self, texts: Union[Sequence[str], Sequence[torch.LongTensor]]): - if isinstance(texts[0], torch.LongTensor): - return texts - - tokenized = self.tokenizer( - [self.tokenizer.bos_token + x + self.tokenizer.eos_token for x in texts], - max_length=self.max_length, - truncation=True, - # NOTE: We manually add special tokens (bos) above so we set this False - # to avoid models that automatically add special tokens (e.g. OPT) - # adding them twice more. - add_special_tokens=False, - ) - input_ids = list(map(torch.as_tensor, tokenized.input_ids)) - return input_ids - def learn(self): def collate_fn(elems: Iterable[ILQLElement]): batch = ilql_collate_fn(elems)