Skip to content

Commit

Permalink
take the first steps towards an end-to-end solution
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 7, 2023
1 parent a4a44dc commit 127a65b
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 8 deletions.
65 changes: 64 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,65 @@ $ pip install toolformer-pytorch

## Usage

Example usage with the example of giving language models awareness of current date and time.

```python
import torch
from toolformer_pytorch import Toolformer, PaLM

# simple calendar api call - function that returns a string

def Calendar():
import datetime
from calendar import day_name, month_name
now = datetime.datetime.now()
return f'Today is {day_name[now.weekday()]}, {month_name[now.month]} {now.day}, {now.year}.'

# prompt for teaching it to use the Calendar function from above

prompt = f"""
Your task is to add calls to a Calendar API to a piece of text.
The API calls should help you get information required to complete the text.
You can call the API by writing "[Calendar()]"
Here are some examples of API calls:
Input: Today is the first Friday of the year.
Output: Today is the first [Calendar()] Friday of the year.
Input: The president of the United States is Joe Biden.
Output: The president of the United States is [Calendar()] Joe Biden.
Input: [input]
Output:
"""

data = [
"The store is never open on the weekend, so today it is closed.",
"The number of days from now until Christmas is 30",
"The current day of the week is Wednesday."
]

# model - here using PaLM, but any nn.Module that returns logits in the shape (batch, seq, num_tokens) is fine

model = PaLM(
dim = 512,
depth = 2,
heads = 8,
dim_head = 64
).cuda()

# toolformer

toolformer = Toolformer(
model = model,
model_seq_len = 256,
teach_tool_prompt = prompt,
tool_id = 'Calendar',
tool = Calendar
)

data_with_api_calls = toolformer.generate_data_with_api_calls(data)

# complete the filtering and fine tuning step
```

The main novelty of the paper is defining a fitness score for the outputs from a transformer instructed to insert API calls. The score is used to filter the sampled outputs for finetuning the transformer to make API calls that decreases perplexity of the text that follows it.

```python
Expand Down Expand Up @@ -98,10 +157,14 @@ invoke_tools(function_registry, text)

- [x] create custom generate function for palm that can do external API calls
- [x] allow for generating tokens at different cursor indices
- [x] api token (which was left and right brackets in paper) needs to be customizable
- [ ] allow for customizing how to fine handling errors in function name, parameters, or execution and output
- [ ] api token (which was left and right brackets in paper) needs to be customizable
- [ ] Toolformer should eventually calculate all statistics (how many properly sampled, filtered out by different criterias, the distribution of scores as well as how many were rejected) before the final fine-tuning
- [ ] do end-to-end training in `Toolformer`
- [x] doing the prompting and bootstrapping the data
- [ ] prefiltering of bootstrapped data followed by api calls and then another round of filtering
- [ ] keep track of all stats
- [ ] take care of fine-tuning, with the interleaving of datasets + optimizer hyperparams
- [ ] hook up gpt-j
- [ ] test for a simple calculator eval dataset

Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'toolformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.18',
version = '0.0.19',
license='MIT',
description = 'Toolformer - Pytorch',
author = 'Phil Wang',
Expand All @@ -21,7 +21,8 @@
'beartype',
'einops>=0.4',
'torch>=1.6',
'tqdm'
'tqdm',
'x-clip'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
6 changes: 4 additions & 2 deletions toolformer_pytorch/palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from torch import nn, einsum
from einops import rearrange

from x_clip.tokenizer import tokenizer

# helpers

def exists(val):
Expand Down Expand Up @@ -162,7 +164,7 @@ def __init__(
depth,
heads,
dim_head,
ff_mult=4,
ff_mult = 4,
):
super().__init__()
self.layers = nn.ModuleList([])
Expand All @@ -184,8 +186,8 @@ class PaLM(nn.Module):
def __init__(
self,
dim,
num_tokens,
depth,
num_tokens=tokenizer.vocab_size,
dim_head=64,
heads=8,
ff_mult=4,
Expand Down
108 changes: 105 additions & 3 deletions toolformer_pytorch/toolformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn.functional as F
from torch import nn, einsum
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from einops import rearrange, reduce

Expand All @@ -18,6 +19,7 @@
from beartype.typing import Callable, Optional, Union, List

from tqdm import tqdm
from x_clip.tokenizer import tokenizer

# helpers

Expand Down Expand Up @@ -233,6 +235,7 @@ def sample(
select_api_start_id_top_k = 10,
):
device = next(model.parameters()).device
positions = positions.clone()
max_seq_len = seq_len + 1

# validate
Expand Down Expand Up @@ -262,7 +265,7 @@ def sample(
# lengthen the prime to the entire sequence length

remain_iterations = seq_len - prime_length
output = F.pad(prime, (max_seq_len - prime_length, 0), value = 0.)
output = F.pad(prime, (0, max_seq_len - prime_length), value = 0.)

batch_indices = torch.arange(batch_size, device = device)
batch_indices = rearrange(batch_indices, 'b -> b 1')
Expand Down Expand Up @@ -337,7 +340,6 @@ def create_api_token_mask(num_tokens, api_start_token_id):
# remove the last token in output (use as noop placeholder)

output = output[:, :-1]

return output

@beartype
Expand Down Expand Up @@ -511,6 +513,42 @@ def loss_fn(weight, probs):

return ret

# datasets and dataloaders

# for bootstrapping the initial datasets with api calls
# as well as for the final finetuning

@beartype
class PromptDataset(Dataset):
def __init__(
self,
prompt: str,
prompt_input_tag: str,
data: List[str],
tokenizer_encode: Callable
):
self.data = data
self.prompt = prompt
self.prompt_input_tag_regex = re.escape(prompt_input_tag)

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
data_string = self.data[idx]
data_with_prompt = re.sub(self.prompt_input_tag_regex, data_string, self.prompt)
token_ids = tokenizer.encode(data_with_prompt)
return torch.tensor(token_ids).long(), torch.tensor(len(token_ids)).long()

def prompt_collate_fn(data, padding_value = 0):
prompts, prompt_lengths = zip(*data)
prompts = pad_sequence(prompts, batch_first = True, padding_value = padding_value)
return prompts, torch.stack(prompt_lengths)

def PromptDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
collate_fn = partial(prompt_collate_fn, padding_value = padding_value)
return DataLoader(ds, *args, collate_fn = collate_fn, **kwargs)

# classes

@beartype
Expand All @@ -521,12 +559,36 @@ def __init__(
*,
tool_id: str,
tool: Callable,
api_start_str = ' [',
api_stop_str = ']',
api_start_id = None,
api_stop_id = None,
teach_tool_prompt: str,
pad_id = 0,
prompt_batch_size = 4,
model_seq_len = 2048,
tokenizer_encode: Callable = tokenizer.encode,
tokenizer_decode: Callable = tokenizer.decode,
prompt_input_tag: str = DEFAULT_PROMPT_INPUT_TAG,
exclude_filters: dict[str, Callable[[str], bool]] = dict()
):
super().__init__()
self.model = model
self.model_seq_len = model_seq_len

self.teach_tool_prompt = teach_tool_prompt
self.prompt_batch_size = prompt_batch_size
self.prompt_input_tag = prompt_input_tag

self.tokenizer_encode = tokenizer_encode
self.tokenizer_decode = tokenizer_decode

self.api_start_str = api_start_str
self.api_stop_str = api_stop_str

self.api_start_id = api_start_id
self.api_stop_id = api_stop_id
self.pad_id = pad_id

self.tool_id = tool_id
self.tool = tool
Expand All @@ -537,8 +599,48 @@ def __init__(
self.teach_tool_prompt = teach_tool_prompt
self.exclude_filters = exclude_filters

def generate_data_with_api_calls(
self,
data: List[str],
temperature: float = 0.9
) -> List[str]:

dataset = PromptDataset(
data = data,
prompt_input_tag = self.prompt_input_tag,
prompt = self.teach_tool_prompt,
tokenizer_encode = self.tokenizer_encode
)

dl = PromptDataloader(
dataset,
batch_size = self.prompt_batch_size
)

prompted_outputs = []

for prime, positions in dl:

sampled_outputs = sample(
model = self.model,
prime = prime,
positions = positions,
seq_len = self.model_seq_len,
pad_id = self.pad_id,
temperature = temperature
)

for sample_output, position in zip(sampled_outputs, positions):
start_position = position.item()

prompted_output = self.tokenizer_decode(sample_output[start_position:])
prompted_outputs.append(prompted_output)

return prompted_outputs

def forward(
self,
data: List[str]
):
raise NotImplementedError
data_with_api_calls = self.generate_data_with_api_calls(data)
return data_with_api_calls

0 comments on commit 127a65b

Please sign in to comment.