Skip to content

Commit

Permalink
Implement AI tools and refine prompt removal
Browse files Browse the repository at this point in the history
  • Loading branch information
ejohb committed Sep 30, 2024
1 parent c309f08 commit cf08026
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 14 deletions.
74 changes: 61 additions & 13 deletions fmtr/tools/ai_tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from datetime import datetime
from peft import PeftConfig, PeftModel
from statistics import mean, stdev
from statistics import mean
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List

Expand Down Expand Up @@ -151,6 +151,8 @@ class BulkInferenceManager:
BATCH_STABLE_RESET = None
BATCH_FACTOR_REDUCTION = None

TOOLS = None

def __init__(self):
"""
Expand Down Expand Up @@ -212,6 +214,7 @@ def encode(self, prompts: List[str]):
ids_input = self.tokenizer.apply_chat_template(
messages,
return_tensors="pt",
tools=self.TOOLS,
padding=True,
return_attention_mask=True,
return_dict=True
Expand Down Expand Up @@ -243,15 +246,20 @@ def generate(self, prompts, **params):

batch_encoding = self.encode(prompts).to(self.model.device)

ids_input, ids_attention = batch_encoding.data['input_ids'], batch_encoding.data['attention_mask']

ids_output = self.model.generate(
pad_token_id=self.tokenizer.eos_token_id,
**batch_encoding,
input_ids=ids_input,
attention_mask=ids_attention,
**params
)
ids_output = ids_output.to(self.DEVICE_INACTIVE)

batcher.batch_complete()
yield prompts, ids_output

ids_output = self.remove_prompt_ids(ids_input, ids_output)

yield ids_output

except RuntimeError as exception:
if "CUDA out of memory" in str(exception):
Expand All @@ -261,15 +269,25 @@ def generate(self, prompts, **params):
raise

self.deactivate()
logger.info(f'Generation complete.')

def remove_prompt_ids(self, ids_input, ids_output):
"""
Outputs contain the prompt & the prompt is left-padded, so we can remove it by removing the width of the inputs
"""
width = ids_input.shape[-1]
ids_output = ids_output[:, width:]
return ids_output

def decode(self, prompts, ids_output):
def decode(self, ids_output):
"""
Decode outputs to text
"""
texts_prompts = self.tokenizer.batch_decode(ids_output, skip_special_tokens=True)
texts = [text_prompt.removeprefix(prompt).strip() for prompt, text_prompt in zip(prompts, texts_prompts)]
texts = self.tokenizer.batch_decode(ids_output, skip_special_tokens=True)
return texts

def get_outputs(self, prompts: List[str], **params):
Expand All @@ -281,22 +299,22 @@ def get_outputs(self, prompts: List[str], **params):

params = params or dict(do_sample=False)

for prompts_batch, ids_output in self.generate(prompts, **params):
texts = self.decode(prompts_batch, ids_output)
for ids_output in self.generate(prompts, **params):
texts = self.decode(ids_output)

lengths = [len(text) // 5 for text in texts]
msg = f'Text statistics: {min(lengths)=} {max(lengths)=} {mean(lengths)=} {stdev(lengths)=}.'
msg = f'Text statistics: {min(lengths)=} {max(lengths)=} {mean(lengths)=}.'
logger.info(msg)

yield from texts

def get_output(self, prompt, **params):
def get_output(self, prompt, **kwargs):
"""
Get a singleton output
"""
outputs = self.get_outputs([prompt], **params)
outputs = self.get_outputs([prompt], **kwargs)
output = next(iter(outputs))
return output

Expand All @@ -318,6 +336,36 @@ def tst():
return data


def tst_tool():
"""
Test Tool usage
"""

def get_current_weather(location: str, format: str):
"""
Get the current weather
Args:
location: The city and state, e.g. San Francisco, CA
format: The temperature unit to use. Infer this from the users location. (choices: ["celsius", "fahrenheit"])
"""
return "It's 25 degrees and sunny!"

class BiTools(BulkInferenceManager):
TOOLS = [get_current_weather]

prompt = "What's the weather like in Paris?"
prompts = [prompt]
manager = BiTools()
gen = manager.get_outputs(prompts, max_new_tokens=200, do_sample=True, temperature=1.2, top_p=0.5, top_k=50)
texts = list(gen)
return texts




if __name__ == '__main__':
texts = tst()
texts = tst_tool()
texts
2 changes: 1 addition & 1 deletion fmtr/tools/version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.0
0.9.1

0 comments on commit cf08026

Please sign in to comment.