Skip to content

Commit

Permalink
just need filter stats and fine-tuning logic
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 7, 2023
1 parent 4c2c8f0 commit 723fde7
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 7 deletions.
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,20 @@ toolformer = Toolformer(

data_with_api_calls = toolformer.generate_data_with_api_calls(data)

# complete the filtering and fine tuning step
filtered_data, filtered_data_with_api_calls = toolformer.filter_and_keep_only_first_api_call(data, data_with_api_calls)

data_with_api_responses = toolformer.make_api_calls(filtered_data_with_api_calls)

filtered_results = toolformer.filter_by_api_responses(
filtered_data,
filtered_data_with_api_calls,
data_with_api_responses
)

# then finetune with token ids at
# -> filtered_results.filtered_tokens_without_api_response
# complete this with toolformer.finetune(filtered_results)

```

The main novelty of the paper is defining a fitness score for the outputs from a transformer instructed to insert API calls. The score is used to filter the sampled outputs for finetuning the transformer to make API calls that decreases perplexity of the text that follows it.
Expand Down Expand Up @@ -162,7 +175,7 @@ invoke_tools(function_registry, text)
- [ ] Toolformer should eventually calculate all statistics (how many properly sampled, filtered out by different criterias, the distribution of scores as well as how many were rejected) before the final fine-tuning
- [ ] do end-to-end training in `Toolformer`
- [x] doing the prompting and bootstrapping the data
- [ ] prefiltering of bootstrapped data followed by api calls and then another round of filtering
- [x] prefiltering of bootstrapped data followed by api calls and then another round of filtering
- [ ] keep track of all stats
- [ ] take care of fine-tuning, with the interleaving of datasets + optimizer hyperparams
- [ ] hook up gpt-j
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'toolformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.19',
version = '0.0.20',
license='MIT',
description = 'Toolformer - Pytorch',
author = 'Phil Wang',
Expand Down
109 changes: 105 additions & 4 deletions toolformer_pytorch/toolformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from toolformer_pytorch.prompts import DEFAULT_PROMPT_INPUT_TAG

from beartype import beartype
from beartype.typing import Callable, Optional, Union, List
from beartype.typing import Callable, Optional, Union, List, Tuple

from tqdm import tqdm
from x_clip.tokenizer import tokenizer
Expand Down Expand Up @@ -142,7 +142,7 @@ def replace_fn(

# return original text with the output delimiter and the stringified output

return f'{text_without_end_api_token} {delimiter} {str(out)}{end_api_token}'
return f'{text_without_end_api_token} {delimiter} {str(out)} {end_api_token}'

# main function, which takes a registry of functions, the text in question, and makes all the appropriate api calls and append the output

Expand Down Expand Up @@ -451,11 +451,16 @@ def filter_tokens_with_api_response(
assert all_contains_id(tokens_with_api_response, api_start_token_id)
assert all_contains_id(tokens_with_api_response, api_end_token_id)

# auto set devices

device = next(model.parameters()).device
tokens, tokens_without_api_response, tokens_with_api_response = map(lambda t: t.to(device), (tokens, tokens_without_api_response, tokens_with_api_response))

# get all the logits

with torch.no_grad():
model.eval()
logits, logits_without_api_response, logits_with_api_response = map(model, (tokens, tokens_with_api_response, tokens_without_api_response))
logits, logits_without_api_response, logits_with_api_response = map(model, (tokens, tokens_without_api_response, tokens_with_api_response))

# derive all predicted prob of the actual next token id in sequence

Expand All @@ -472,8 +477,10 @@ def filter_tokens_with_api_response(

# deriving the weighting for the original passage is more tricky
# would need to start counting up from <api> start token location
# this would also assume that the language model perfectly copied the passage over and that both token ids are aligned except for the inserted API call - but this can be done with the custom filtering functions eventually

weight = weight_and_mask_fn(tokens_without_api_response[:, 1:], api_start_token_id) # shift to the left by one since <api> does not exist in the original sequence
weight = weight[:, :probs.shape[-1]]

# get the loss L for all three types of sequences

Expand Down Expand Up @@ -561,9 +568,11 @@ def __init__(
tool: Callable,
api_start_str = ' [',
api_stop_str = ']',
api_response_delimiter = '→',
api_start_id = None,
api_stop_id = None,
teach_tool_prompt: str,
filter_threshold = 1.,
pad_id = 0,
prompt_batch_size = 4,
model_seq_len = 2048,
Expand All @@ -582,12 +591,28 @@ def __init__(

self.tokenizer_encode = tokenizer_encode
self.tokenizer_decode = tokenizer_decode
self.tokenizer_encode_to_tensor = lambda s: torch.tensor(tokenizer_encode(s)).long()

self.filter_threshold = filter_threshold

self.api_start_str = api_start_str
self.api_stop_str = api_stop_str
self.api_response_delimiter = api_response_delimiter

if not exists(api_start_id):
api_start_id = tokenizer_encode(api_start_str)
assert len(api_start_id) == 1
api_start_id = api_start_id[0]

self.api_start_id = api_start_id

if not exists(api_stop_id):
api_stop_id = tokenizer_encode(api_stop_str)
assert len(api_stop_id) == 1
api_stop_id = api_stop_id[0]

self.api_stop_id = api_stop_id

self.pad_id = pad_id

self.tool_id = tool_id
Expand Down Expand Up @@ -638,9 +663,85 @@ def generate_data_with_api_calls(

return prompted_outputs

def filter_and_keep_only_first_api_call(
self,
data,
data_with_api_calls: List[str],
return_excluded = False
):
included = []
excluded = []

api_start_stop_kwargs = dict(api_start = self.api_start_str, api_stop = self.api_stop_str)

has_api_calls_ = partial(has_api_calls, **api_start_stop_kwargs)
replace_all_but_first_ = partial(replace_all_but_first, **api_start_stop_kwargs)

for datum, data_with_api_call in zip(data, data_with_api_calls):
if has_api_calls_(data_with_api_call):
data_with_api_call = replace_all_but_first_(data_with_api_call)
included.append((datum, data_with_api_call))
else:
excluded.append((datum, data_with_api_call))

included = tuple(map(list, zip(*included)))

if not return_excluded:
return included

excluded = tuple(map(list, zip(*excluded)))
return included, excluded

def make_api_calls(
self,
filtered_data_with_api_calls: List[str]
):
invoke_tools_ = partial(
invoke_tools,
self.registry,
api_start = self.api_start_str,
api_stop = self.api_stop_str, delimiter = self.api_response_delimiter
)

data_with_api_responses = []
for data in filtered_data_with_api_calls:
output = invoke_tools_(data)
data_with_api_responses.append(output)

return data_with_api_responses

def filter_by_api_responses(
self,
data: List[str],
data_with_api_calls: List[str],
data_with_api_responses: List[str]
) -> FilteredResults:

to_token_ids = lambda l: pad_sequence([*map(self.tokenizer_encode_to_tensor, l)], batch_first = True, padding_value = self.pad_id)

tokens, tokens_without_api_response, tokens_with_api_response = map(to_token_ids, (data, data_with_api_calls, data_with_api_responses))

filtered_results = filter_tokens_with_api_response(
model = self.model,
tokens = tokens,
tokens_with_api_response = tokens_with_api_response,
tokens_without_api_response = tokens_without_api_response,
filter_threshold = self.filter_threshold,
api_start_token_id = self.api_start_id,
api_end_token_id = self.api_stop_id
)

return filtered_results

def forward(
self,
data: List[str]
):
data_with_api_calls = self.generate_data_with_api_calls(data)
return data_with_api_calls
filtered_data_with_api_calls = self.filter_and_keep_only_first_api_call(data_with_api_calls)

assert len(filtered_data_with_api_calls) > 0, 'your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering'

data_with_responses = self.make_api_calls(filtered_data_with_api_calls)

return data_with_responses

0 comments on commit 723fde7

Please sign in to comment.