Skip to content

Commit

Permalink
llm cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Mika committed Nov 21, 2024
1 parent 3c86c4e commit 7e7245c
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 14 deletions.
16 changes: 12 additions & 4 deletions test_uralicnlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,16 @@
#print(result)
#print(llm_output)
#llm = get_llm("mistral", open_read(os.path.expanduser("~/.mistralapikey")).read().strip(), model="mistral-embed")
#llm = get_llm("microsoft/Phi-3.5-mini-instruct")
#print(llm.prompt("What is Livonian?"))
llm = get_llm("roneneldan/TinyStories-33M")
#llm.load_cache("cache.bin")
#"microsoft/Phi-3.5-mini-instruct")
prompts = ["What is Livonian?", "Look at this dog", "WTF are you talking about", "Yeah, right"]
for prompt in prompts:
print(llm.prompt(prompt))
llm.embed(prompt)

#llm.save_cache("cache.bin")


#llm = get_llm("claude", open_read(os.path.expanduser("~/.claudeapikey")).read().strip())
#print(llm.prompt("What is Tundra Nenets?"))
Expand All @@ -169,11 +177,11 @@
#print(llm.embed("Super great text to embed"))

#print(llm.embed_endangered("Näʹde täävtõõđi âʹtte peeʹlljid pärnnses täävtõõđi.", "sms", "fin"))
llm = get_llm("google-bert/bert-base-uncased")
#llm = get_llm("google-bert/bert-base-uncased")
texts = ["dogs are funny", "cats play around", "cars go fast", "planes fly around", "parrots like to eat", "eagles soar in the skies", "moon is big", "saturn is a planet"]
endangered_texts = ["Ёртозь ёртовсь кудостонть.", "Теке сялгонзояк те касовксонть арасть.", "Истяяк арсеват.", "Атякштне, кунсолан, сыргойсть омбоцеде.", "Вальмаванть неявить ульцява ардыцят.", "Морат эрзянь моро?"]
#print(semantics.cluster(texts, llm, return_ids=True))
#print(semantics.cluster(texts, llm))
#print(semantics.cluster(texts, llm, hierarchical_clustering=True))
#print(semantics.cluster_endangered(endangered_texts, llm, "myv", "fin"))
print(semantics.cluster_endangered(endangered_texts, llm, "myv", "fin", hierarchical_clustering=True, method="hdbscan"))
#print(semantics.cluster_endangered(endangered_texts, llm, "myv", "fin", hierarchical_clustering=True, method="hdbscan"))
67 changes: 57 additions & 10 deletions uralicNLP/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@

import json

from mikatools import pickle_dump, pickle_load

class ModuleNotInstalled(Exception):
pass

Expand Down Expand Up @@ -72,14 +74,52 @@ def get_llm(llm_name, *args, **kwargs):
class LLM(object):
"""docstring for LLM"""
def __init__(self):
self.cache = False
self._embed_cache_dict = {}
self._prompt_cache_dict = {}
super(LLM, self).__init__()


def _embed_cache(func):
def inner(*args, **kwargs):
self = args[0]
if self.cache and "_".join(args[1:]) in self._embed_cache_dict:
return self._embed_cache_dict["_".join(args[1:])]
else:
r = func(*args, **kwargs)
if self.cache:
self._embed_cache_dict["_".join(args[1:])] = r
return r
return inner

def _prompt_cache(func):
def inner(*args, **kwargs):
self = args[0]
if self.cache and "_".join(args[1:]) in self._prompt_cache_dict:
return self._prompt_cache_dict["_".join(args[1:])]
else:
r = func(*args, **kwargs)
if self.cache:
self._prompt_cache_dict["_".join(args[1:])] = r
return r
return inner


@_prompt_cache
def prompt(self, text):
return self._prompt(text)

def _prompt(self, text):
raise NotImplementedException("LLM does not support prompting")

@_embed_cache
def embed(self, text):
return self._embed(text)

def _embed(self, text):
raise NotImplementedException("LLM does not support embeddings")

@_embed_cache
def embed_endangered(self, text, lang, dict_lang,backend=TinyDictionary):
r = []
for word in tokenize_words(text):
Expand All @@ -96,6 +136,13 @@ def embed_endangered(self, text, lang, dict_lang,backend=TinyDictionary):
text = " ".join(r)
return self.embed(text)

def save_cache(self, file, *args, **kwargs):
pickle_dump([self._embed_cache_dict, self._prompt_cache_dict], file, *args, **kwargs )

def load_cache(self, file, *args, **kwargs):
self.cache = True
self._embed_cache_dict, self._prompt_cache_dict = pickle_load(file, *args, **kwargs)


class ChatGPT(LLM):
"""docstring for ChatGPT"""
Expand All @@ -107,7 +154,7 @@ def __init__(self, api_key, model="gpt-4o"):
raise ModuleNotInstalled("OpenAI Python library is not installed. Run pip install openai. If you do have the library installed, check your API key.")
self.model = model

def prompt(self, prompt, temperature=1):
def _prompt(self, prompt, temperature=1):
chat_completion = self.client.chat.completions.create(
messages=[
{
Expand All @@ -120,7 +167,7 @@ def prompt(self, prompt, temperature=1):
)
return chat_completion.choices[0].message.content

def embed(self, text):
def _embed(self, text):
response = self.client.embeddings.create(input=text, model=self.model)
return response.data[0].embedding

Expand All @@ -136,11 +183,11 @@ def __init__(self, api_key, model="gemini-1.5-flash", task_type="retrieval_docum
self.model_name = model
self.task_type = task_type

def prompt(self, prompt):
def _prompt(self, prompt):
response = self.model.generate_content(prompt)
return response.text

def embed(self, text):
def _embed(self, text):
result = genai.embed_content(model=self.model_name, content=text, task_type=self.task_type)
return result['embedding']

Expand All @@ -153,13 +200,13 @@ def __init__(self, model, max_length=1000, device=-1):
self.embedder = None
self.device = device

def prompt(self, prompt):
def _prompt(self, prompt):
if self.model is None:
self.model = pipeline('text-generation', model = self.model_name, device = self.device)
r = self.model(prompt, max_length=self.max_length, truncation=True)
return " ".join([x['generated_text'] for x in r])

def embed(self, text):
def _embed(self, text):
if self.embedder is None:
self.embedder = pipeline('feature-extraction', model=self.model_name,device = self.device)
r = self.embedder(text, return_tensors="pt")[0].numpy().mean(axis=0)
Expand All @@ -174,11 +221,11 @@ def __init__(self, api_key, model="mistral-small-latest"):
raise ModuleNotInstalled("Mistral library is not installed. Run pip install mistralai. If you do have the library installed, check your API key.")
self.model = model

def prompt(self, prompt):
def _prompt(self, prompt):
r = self.s.chat.complete(model=self.model, messages=[{"content": prompt,"role": "user",}])
return r.choices[0].message.content

def embed(self, text):
def _embed(self, text):
embeddings_batch_response = self.s.embeddings.create(model=self.model, inputs=[text])
return embeddings_batch_response.data[0].embedding

Expand All @@ -194,7 +241,7 @@ def __init__(self, api_key, model="claude-3-5-sonnet-latest", max_length=1024):
self.model = model
self.max_length = max_length

def prompt(self, prompt, temperature=1):
def _prompt(self, prompt, temperature=1):
chat_completion = self.client.messages.create(model=self.model,messages=[{"role": "user", "content": prompt}], max_tokens=self.max_length)
return " ".join([x.text for x in chat_completion.content])

Expand All @@ -209,7 +256,7 @@ def __init__(self, api_key, model="voyage-3"):
raise ModuleNotInstalled("Voyage Python library is not installed. Run pip install voyageai. If you do have the library installed, check your API key.")
self.model = model

def embed(self, text):
def _embed(self, text):
result = self.vo.embed([text], model=self.model, input_type="document")
return result.embeddings[0]

Expand Down

0 comments on commit 7e7245c

Please sign in to comment.