Skip to content

Commit

Permalink
Set add_special_tokens=False to not add EOS unexpectedly (#287)
Browse files Browse the repository at this point in the history
Do not `add_special_tokens` when preprocessing prompts.  This changes performance (reward) for the better for non-GPT models (which don't add EOS by default) due to not ending the prompt with `<endoftext>`. 

Although adding `<endofcontext>` (EOC, from HH) could be something to add in trhe future.
  • Loading branch information
cat-state authored Feb 10, 2023
1 parent 81e935a commit b91da7b
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 64 deletions.
4 changes: 2 additions & 2 deletions examples/summarize_daily_cnn/t5_summarize_daily_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):

for i in tqdm(range(len(prompts))):
key = tokenizer.decode(
tokenizer(prompts[i], truncation=True, max_length=max_length)["input_ids"],
tokenizer(prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"],
skip_special_tokens=True,
) # get prompt like trlx's prompt
prompt_label[key.strip()] = summaries[i]

for i in tqdm(range(len(val_prompts))):
key = tokenizer.decode(
tokenizer(val_prompts[i], truncation=True, max_length=max_length)["input_ids"],
tokenizer(val_prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"],
skip_special_tokens=True,
) # get prompt like trlx's prompt
prompt_label[key.strip()] = val_summaries[i]
Expand Down
3 changes: 2 additions & 1 deletion examples/summarize_rlhf/trlx_gptj_text_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ def get_prompt_dataset(prompts, max_length):
prompts[i].split("TL;DR:")[0],
truncation=True,
max_length=max_length - 5, # to make sure "TL;DR" dont get truncated
add_special_tokens=False,
)["input_ids"],
skip_special_tokens=True,
).strip()
tmp = tmp + "\nTL;DR:"
tmp = tokenizer.decode(
tokenizer(tmp, truncation=True, max_length=max_length)["input_ids"],
tokenizer(tmp, truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"],
skip_special_tokens=True,
).strip()
formatted_prompts.append(tmp)
Expand Down
23 changes: 16 additions & 7 deletions trlx/pipeline/offline_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from transformers import DataCollatorWithPadding, PreTrainedTokenizer

from trlx.data.ilql_types import ILQLBatch, ILQLElement
from trlx.pipeline import BasePipeline, BaseRolloutStore, register_datapipeline
Expand All @@ -23,7 +23,8 @@ def tokenize_dialogue(dialogue: Union[str, List[str]], tokenizer, max_length=204
ctx_length = max_length
if tokenizer.truncation_side == "left":
for phrase in reversed(dialogue):
tokens = tokenizer(phrase).input_ids[-ctx_length:]
# Manually added BOS and EOS above so we don't want to add special tokens here
tokens = tokenizer(phrase, add_special_tokens=False).input_ids[-ctx_length:]
ctx_length -= len(tokens)
out.insert(0, tokens)
if ctx_length == 0:
Expand All @@ -38,7 +39,8 @@ def tokenize_dialogue(dialogue: Union[str, List[str]], tokenizer, max_length=204

elif tokenizer.truncation_side == "right":
for phrase in dialogue:
tokens = tokenizer(phrase).input_ids[:ctx_length]
# Manually added BOS and EOS above so we don't want to add special tokens here
tokens = tokenizer(phrase, add_special_tokens=False).input_ids[:ctx_length]
ctx_length -= len(tokens)
out.append(tokens)
if ctx_length == 0:
Expand All @@ -52,13 +54,20 @@ class PromptPipeline(BasePipeline):
Tokenizes prompts, unless they are already tokenized, and truncates them to `max_prompt_length` from the right
"""

def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer=None):
def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer: PreTrainedTokenizer):
super().__init__()
model_inputs = tokenizer(prompts, truncation=True, padding=False, max_length=max_prompt_length)
prompts = model_inputs["input_ids"]

model_inputs = tokenizer(
prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=False
)

prompts_tokens = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]

self.tokenizer = tokenizer
self.prompts = [{"input_ids": prompt, "attention_mask": mask} for prompt, mask in zip(prompts, attention_mask)]
self.prompts = [
{"input_ids": tokens, "attention_mask": mask} for tokens, mask in zip(prompts_tokens, attention_mask)
]

def __getitem__(self, ix: int):
return self.prompts[ix]
Expand Down
21 changes: 1 addition & 20 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from abc import abstractmethod
from time import time
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Tuple

import ray
import torch
Expand Down Expand Up @@ -175,25 +175,6 @@ def setup_scheduler(self):
scheduler = scheduler_class(self.opt, **self.config.scheduler.kwargs)
return scheduler

def tokenize(self, text: Union[Sequence[str], Sequence[torch.LongTensor]]):
"""
Tokenize a batch of text after adding bos token to each of the samples
"""
if isinstance(text[0], torch.LongTensor):
return text

text = [self.tokenizer.bos_token + txt for txt in text]
return self.tokenizer(
text,
truncation=True,
max_length=self.config.seq_length,
return_tensors="pt",
# NOTE: We manually add special tokens (bos) above so we set this False
# to avoid models that automatically add special tokens (e.g. OPT)
# adding them twice more.
add_special_tokens=False,
)

def decode(
self,
prompts: List[torch.LongTensor],
Expand Down
18 changes: 1 addition & 17 deletions trlx/trainer/accelerate_ilql_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional, Sequence, Union, cast
from typing import Optional, cast

import numpy as np
import torch
Expand Down Expand Up @@ -43,22 +43,6 @@ def get_arch(self, config):
num_layers_unfrozen=config.model.num_layers_unfrozen,
)

def tokenize(self, texts: Union[Sequence[str], Sequence[torch.LongTensor]]):
if isinstance(texts[0], torch.LongTensor):
return texts

tokenized = self.tokenizer(
[self.tokenizer.bos_token + x + self.tokenizer.eos_token for x in texts],
max_length=self.max_length,
truncation=True,
# NOTE: We manually add special tokens (bos) above so we set this False
# to avoid models that automatically add special tokens (e.g. OPT)
# adding them twice more.
add_special_tokens=False,
)
input_ids = list(map(torch.as_tensor, tokenized.input_ids))
return input_ids

def post_backward_callback(self):
if self.iter_count % self.config.method.steps_for_target_q_sync == 0:
self.accelerator.unwrap_model(self.model).sync_target_q_heads()
Expand Down
18 changes: 1 addition & 17 deletions trlx/trainer/nemo_ilql_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Iterable, Sequence, Union, cast
from typing import Iterable, Sequence, cast

import numpy as np
import torch
Expand Down Expand Up @@ -156,22 +156,6 @@ def __init__(
if stop_sequences is not None and len(stop_sequences) > 0:
logging.warning(f"Ignoring stop_sequences {stop_sequences=}")

def tokenize(self, texts: Union[Sequence[str], Sequence[torch.LongTensor]]):
if isinstance(texts[0], torch.LongTensor):
return texts

tokenized = self.tokenizer(
[self.tokenizer.bos_token + x + self.tokenizer.eos_token for x in texts],
max_length=self.max_length,
truncation=True,
# NOTE: We manually add special tokens (bos) above so we set this False
# to avoid models that automatically add special tokens (e.g. OPT)
# adding them twice more.
add_special_tokens=False,
)
input_ids = list(map(torch.as_tensor, tokenized.input_ids))
return input_ids

def learn(self):
def collate_fn(elems: Iterable[ILQLElement]):
batch = ilql_collate_fn(elems)
Expand Down

0 comments on commit b91da7b

Please sign in to comment.