Skip to content

Commit

Permalink
Add the pre-function of handling long prompt and Update the context p…
Browse files Browse the repository at this point in the history
…rocessor doc (zilliztech#395)

Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG authored May 29, 2023
1 parent 873fca7 commit 03a2787
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 10 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
intersphinx_mapping = {
"torch": ("https://pytorch.org/docs/stable/", None),
"numpy": ("https://numpy.org/devdocs/", None),
"python": ("https://docs.python.org/3", None),
"python": ("https://docs.python.org/3.8/", None),
}

autodoc_member_order = "bysource"
Expand Down
1 change: 1 addition & 0 deletions gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
kwargs,
extra_param=context.get("pre_embedding_func", None),
prompts=chat_cache.config.prompts,
cache_config=chat_cache.config,
)
if isinstance(pre_embedding_res, tuple):
pre_store_data = pre_embedding_res[0]
Expand Down
6 changes: 5 additions & 1 deletion gptcache/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class Config:
:type similarity_threshold: float
:param prompts: optional, if the request content will remove the prompt string when the request contains the prompt list
:type prompts: Optional[List[str]]
:param template: optional, if the request content will remove the template string and only keep the parameter value in the template
:type template: Optional[str]
Example:
.. code-block:: python
Expand All @@ -26,7 +28,8 @@ def __init__(
self,
log_time_func: Optional[Callable[[str, float], None]] = None,
similarity_threshold: float = 0.8,
prompts: Optional[List[str]] = None
prompts: Optional[List[str]] = None,
template: Optional[str] = None,
):
if similarity_threshold < 0 or similarity_threshold > 1:
raise CacheError(
Expand All @@ -35,3 +38,4 @@ def __init__(
self.log_time_func = log_time_func
self.similarity_threshold = similarity_threshold
self.prompts = prompts
self.template = template
21 changes: 20 additions & 1 deletion gptcache/processor/context/concat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,26 @@


class ConcatContextProcess(ContextProcess):
"""A concat context processor simply concat the context
"""A concat context processor simply concat the context.
Generally used with rwkv embedding, because rwkv can input almost infinitely long
Example:
.. code-block:: python
from gptcache.manager import manager_factory
from gptcache.processor.context.concat_context import ConcatContextProcess
context_process = ConcatContextProcess()
rwkv_embedding = Rwkv()
data_manager = manager_factory(
"sqlite,faiss",
vector_params={"dimension": rwkv_embedding.dimension},
)
cache.init(
pre_embedding_func=context_process.pre_process,
embedding_func=rwkv_embedding.to_embeddings,
data_manager=data_manager,
)
"""

content: str = ""
Expand Down
7 changes: 7 additions & 0 deletions gptcache/processor/context/selective_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ class SelectiveContextProcess(ContextProcess):
more details: https://github.com/liyucheng09/Selective_Context
Example:
.. code-block:: python
from gptcache.processor.context.selective_context import SelectiveContextProcess
context_process = SelectiveContextProcess()
cache.init(pre_embedding_func=context_process.pre_process)
"""

content: str = ""
Expand Down
8 changes: 8 additions & 0 deletions gptcache/processor/context/summarization_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ class SummarizationContextProcess(ContextProcess):
:type tokenizer: transformers.PreTrainedTokenizer
:param target_length: The length of the summarized text.
:type target_length: int
Example:
.. code-block:: python
from gptcache.processor.context.summarization_context import SummarizationContextProcess
context_process = SummarizationContextProcess()
cache.init(pre_embedding_func=context_process.pre_process)
"""
def __init__(self, summarizer=transformers.pipeline("summarization", model="facebook/bart-large-cnn"),
tokenizer=None, target_length=512):
Expand Down
98 changes: 94 additions & 4 deletions gptcache/processor/pre.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import string
from typing import Dict, Any


Expand Down Expand Up @@ -47,19 +48,108 @@ def last_content_without_prompt(data: Dict[str, Any], **params: Dict[str, Any])
return new_content_str


def _get_pattern_value(pattern_str: str, value_str: str):
literal_text_arr = []
field_name_arr = []
for literal_text, field_name, _, _ in string.Formatter().parse(pattern_str):
literal_text_arr.append(literal_text)
if field_name is not None:
field_name_arr.append(
field_name if field_name else str(len(field_name_arr))
)

pattern_values = {}
last_end = 0
for i, literal_text in enumerate(literal_text_arr):
start = value_str.find(literal_text, last_end)
if i == len(literal_text_arr) - 1:
end = len(value_str)
else:
end = value_str.find(literal_text_arr[i + 1], start + 1)
if start == -1 or end == -1:
break
start += len(literal_text)
pattern_values[field_name_arr[i]] = value_str[start:end]
last_end = end
return pattern_values


def last_content_without_template(data: Dict[str, Any], **params: Dict[str, Any]) -> Any:
"""get the last content's template values of the message list without template content.
When considering a cache agent or chain, the majority of the content consists of template content,
while the essential information is simply a list of parameters within the template.
In this way, the cache key is composed of a string made up of all the parameter values in the list.
WARNING: Two parameters without intervals cannot appear in the template,
for example: template = "{foo}{hoo}" is not supported,
but template = "{foo}:{hoo}" is supported
:param data: the user llm request data
:type data: Dict[str, Any]
:Example with str template:
.. code-block:: python
from gptcache import Config
from gptcache.processor.pre import last_content_without_template
template_obj = "tell me a joke about {subject}"
prompt = template_obj.format(subject="animal")
value = last_content_without_template(
data={"messages": [{"content": prompt}]}, cache_config=Config(template=template_obj)
)
print(value)
# ['animal']
:Example with langchain template:
.. code-block:: python
from langchain import PromptTemplate
from gptcache import Config
from gptcache.processor.pre import last_content_without_template
template_obj = PromptTemplate.from_template("tell me a joke about {subject}")
prompt = template_obj.format(subject="animal")
value = last_content_without_template(
data={"messages": [{"content": prompt}]},
cache_config=Config(template=template_obj.template),
)
print(value)
# ['animal']
NOTE: At present, only the simple PromptTemplate in langchain is supported.
For ChatPromptTemplate, it needs to be adjusted according to the template array.
If you need to use it, you need to pass in the final dialog template yourself.
The reason why it cannot be advanced is that ChatPromptTemplate
does not provide a method to directly return the template string.
"""
last_content_str = data.get("messages")[-1]["content"]
cache_config = params.get("cache_config", None)
if not (cache_config and cache_config.template):
return last_content_str

pattern_value = _get_pattern_value(cache_config.template, last_content_str)
return str(list(pattern_value.values()))


def all_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
""" get all content of the message list
"""get all content of the message list
:param data: the user llm request data
:type data: Dict[str, Any]
Example:
:Example:
.. code-block:: python
from gptcache.processor.pre import all_content
content = all_content({"messages": [{"content": "foo1"}, {"content": "foo2"}]})
# content = "foo1\nfoo2"
content = all_content(
{"messages": [{"content": "foo1"}, {"content": "foo2"}]}
)
# content = "foo1\\nfoo2"
"""
s = ""
messages = data.get("messages")
Expand Down
55 changes: 52 additions & 3 deletions tests/unit_tests/adapter/test_langchain_models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import os
import random
from unittest.mock import patch

from gptcache import Cache
from gptcache import Cache, Config
from gptcache.adapter import openai
from gptcache.adapter.api import init_similar_cache, get
from gptcache.adapter.langchain_models import LangChainLLMs, LangChainChat, _cache_msg_data_convert
from gptcache.processor.pre import get_prompt
from gptcache.processor.pre import get_prompt, last_content_without_template
from gptcache.utils import import_pydantic, import_langchain
from gptcache.utils.response import get_message_from_openai_answer

import_pydantic()
import_langchain()

from langchain import OpenAI
from langchain import OpenAI, PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage

Expand Down Expand Up @@ -102,3 +106,48 @@ def test_langchain_chats():

answer = chat(messages=question, cache_obj=llm_cache)
assert answer == _cache_msg_data_convert(msg).generations[0].message


def test_last_content_without_template():
string_prompt = PromptTemplate.from_template("tell me a joke about {subject}")
template = string_prompt.template
cache_obj = Cache()
data_dir = str(random.random())
init_similar_cache(data_dir=data_dir, cache_obj=cache_obj, pre_func=last_content_without_template, config=Config(template=template))

subject_str = "animal"
expect_answer = "this is a joke"

with patch("openai.ChatCompletion.create") as mock_create:
datas = {
"choices": [
{
"message": {"content": expect_answer, "role": "assistant"},
"finish_reason": "stop",
"index": 0,
}
],
"created": 1677825464,
"id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD",
"model": "gpt-3.5-turbo-0301",
"object": "chat.completion.chunk",
}
mock_create.return_value = datas

response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": string_prompt.format(subject=subject_str)},
],
cache_obj=cache_obj,
)
assert get_message_from_openai_answer(response) == expect_answer, response

cache_obj.flush()

init_similar_cache(data_dir=data_dir, cache_obj=cache_obj)

cache_res = get(str([subject_str]), cache_obj=cache_obj)
print(str([subject_str]))
assert cache_res == expect_answer, cache_res

0 comments on commit 03a2787

Please sign in to comment.