From 32c4883ec6c37d13290a2ae1e1d19d88a70b35b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 30 Jul 2024 12:10:45 +0200 Subject: [PATCH 1/3] Add `model_name` property to templates When rendering templates with special variables representing, e.g., BOS tokens we need to know which model the prompt is going to be passed to so as to render the template with the appropriate token. This PR adds a `model_name` attribute to the `Template` class. --- prompts/templates.py | 2 ++ tests/test_templates.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/prompts/templates.py b/prompts/templates.py index 82d5cc3..052e983 100644 --- a/prompts/templates.py +++ b/prompts/templates.py @@ -37,6 +37,7 @@ class Template: template: str signature: inspect.Signature + model: Optional[str] = None registry: Dict[str, Callable] = field(default_factory=dict) def __call__(self, *args, **kwargs) -> str: @@ -83,6 +84,7 @@ def register(self, model_name: str): def wrapper(fn: Callable): tpl = template(fn) + tpl.model = model_name self.registry[model_name] = tpl return tpl diff --git a/tests/test_templates.py b/tests/test_templates.py index 6c2af9c..896a231 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -197,6 +197,10 @@ def simple_prompt_name(query: str): assert callable(simple_prompt) assert callable(simple_prompt["provider/name"]) + assert simple_prompt.model is None + assert simple_prompt_name.model == "provider/name" + assert simple_prompt["provider/name"].model == "provider/name" + assert simple_prompt("test") == "test" assert simple_prompt_name("test") == "name: test" From ec36ba82bdeeeaf5b33c387832b31e45343d5ebe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 30 Jul 2024 12:49:38 +0200 Subject: [PATCH 2/3] Automatically render special EOS and BOS tokens We define Jinja2 `eos` and `bos` global variables that are rendered as EOS and BOS tokens. --- docs/reference/special_tokens.md | 31 +++++++++++++++++++++++ mkdocs.yml | 1 + prompts/templates.py | 42 +++++++++++++++++++++++++++++--- tests/test_templates.py | 22 +++++++++++++++++ 4 files changed, 93 insertions(+), 3 deletions(-) create mode 100644 docs/reference/special_tokens.md diff --git a/docs/reference/special_tokens.md b/docs/reference/special_tokens.md new file mode 100644 index 0000000..b98c736 --- /dev/null +++ b/docs/reference/special_tokens.md @@ -0,0 +1,31 @@ +# Handle special tokens + +Tokens that indicate the beginnning of a sequence, an end of sequence, that +delineate user and assistant turns in a conversation, etc. are model-specific. +This means that one needs to write a new prompt each time they use a new model, +only replacing these special tokens. This is error-prone and leads to duplicated +work. + +`prompts` provides special variables in its templates that allows user to use special tokens in their prompts in a model-agnotic way: + +```python +import prompts + + +@prompts.template +def a_simple_prompt(query: str): + """{{ bos + query + eos }}""" + + +print(a_simple_prompt["mistralai/Mistral-7B-v0.1"]("question")) +# question + +print(a_simple_prompt["google/gemma-2-9b"]("question")) +# question +``` + + +!!! note "Registry" + + The registry is currently limited to a few models. Please [open an issue](https://github.com/outlines-dev/prompts/issues) if you + want to use `prompts` with a model that is not currently in the registry. diff --git a/mkdocs.yml b/mkdocs.yml index d15ce98..9d1f2b8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -75,3 +75,4 @@ nav: - reference/index.md - Prompt template: reference/template.md - Dispatch: reference/dispatch.md + - Special tokens: reference/special_tokens.md diff --git a/prompts/templates.py b/prompts/templates.py index 052e983..393db19 100644 --- a/prompts/templates.py +++ b/prompts/templates.py @@ -1,8 +1,9 @@ import inspect import re +import warnings from dataclasses import dataclass, field from functools import lru_cache -from typing import Callable, Dict, Hashable, Optional, cast +from typing import Callable, Dict, Hashable, Optional, Tuple, cast from jinja2 import Environment, StrictUndefined @@ -29,6 +30,8 @@ class Template: The template to render. signature The prompt function's signature. + model + The model the `Template` is associated with. Defaults to `None`. registry Registry that maps function names to their respective `Template` instances. @@ -50,7 +53,7 @@ def __call__(self, *args, **kwargs) -> str: """ bound_arguments = self.signature.bind(*args, **kwargs) bound_arguments.apply_defaults() - return render(self.template, **bound_arguments.arguments) + return render(self.template, self.model, **bound_arguments.arguments) def __str__(self): return self.template @@ -74,6 +77,7 @@ def __getitem__(self, model_name: str): try: return self.registry[model_name] except KeyError: + self.model = model_name return self def register(self, model_name: str): @@ -140,13 +144,21 @@ def template(fn: Callable) -> Template: @lru_cache -def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str: +def render( + template: str, + model_name: Optional[str] = None, + **values: Optional[Dict[str, Hashable]], +) -> str: r"""Parse a Jinaj2 template and translate it into an Outlines graph. This function removes extra whitespaces and linebreaks from templates to allow users to enter prompts more naturally than if they used Python's constructs directly. See the examples for a detailed explanation. + We also define the `bos` and `eos` special variables which, when used, will + be replaced by the model's BOS and EOS tokens respectively. This allows you + to write prompts that are model-agnostic. + Examples -------- @@ -223,6 +235,8 @@ def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str: ---------- template A string that contains a template written with the Jinja2 syntax. + model_name + The name of the model to which the rendered string will be passed. **values Map from the variables in the template to their value. @@ -245,12 +259,34 @@ def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str: # used to continue to the next line without linebreak. cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) + # Warn the user when the model is not present in the special token registry + if model_name not in SPECIAL_TOKENS: + warnings.warn( + UserWarning( + f"The model {model_name} is not present in the special token registry." + "As a result, EOS and BOS tokens will be rendered as the empty string." + "Please open an issue: https://github.com/outlines-dev/prompts/issues" + "And ask for the model to be added to the registry." + ) + ) + env = Environment( trim_blocks=True, lstrip_blocks=True, keep_trailing_newline=True, undefined=StrictUndefined, ) + env.globals["bos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[0] + env.globals["eos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[1] jinja_template = env.from_string(cleaned_template) return jinja_template.render(**values) + + +# (BOS, EOS) +SPECIAL_TOKENS: Dict[Optional[str], Tuple[str, str]] = { + None: ("", ""), + "google/gemma-2-9b": ("", ""), + "openai-community/gpt2": ("", "<|endoftext|>"), + "mistralai/Mistral-7B-v0.1": ("", ""), +} diff --git a/tests/test_templates.py b/tests/test_templates.py index 896a231..8bd295d 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -183,6 +183,7 @@ def test_only_code(variable): return variable +@pytest.mark.filterwarnings("ignore: The model") def test_dispatch(): @prompts.template @@ -207,3 +208,24 @@ def simple_prompt_name(query: str): assert simple_prompt("test") == "test" assert simple_prompt["gpt2"]("test") == "test" assert simple_prompt["provider/name"]("test") == "name: test" + + +def test_special_tokens(): + + @prompts.template + def simple_prompt(query: str): + """{{ bos + query + eos }}""" + + assert simple_prompt("test") == "test" + assert simple_prompt["openai-community/gpt2"]("test") == "test<|endoftext|>" + assert simple_prompt["mistralai/Mistral-7B-v0.1"]("test") == "test" + + +def test_warn(): + + @prompts.template + def simple_prompt(): + """test""" + + with pytest.warns(UserWarning, match="not present in the special token"): + simple_prompt["non-existent-model"]() From 1a780b6bb20049ef3a61f1494c4c4f94b7ccb337 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 30 Jul 2024 13:43:23 +0200 Subject: [PATCH 3/3] Add documentation for model dispatch --- docs/reference/dispatch.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/reference/dispatch.md b/docs/reference/dispatch.md index 1728e5b..b301341 100644 --- a/docs/reference/dispatch.md +++ b/docs/reference/dispatch.md @@ -1 +1,24 @@ # Model-based prompt dispatching + + +Different models often require different prompts to achieve a given task. They are, in essence, not different prompts in the sense that they are supposed to perform the same operation. In the same way we use `functools.singledispatch` to dispatch a functionality on the type of the first +argument, it can be useful to dispatch the prompt on the model that is being used. + +`prompts` provides a way to dispatch the prompt on the model: + + +```python +import prompts + + +@prompts.template +def a_simple_prompt(query: str): + """{{ query }}""" + +@a_simple_prompt.register("google/gemma-2-9b") +def a_simple_prompt_gemma(query: str): + """{{ query }}""" +``` + +!!! note + Choosing BOS and EOS based on the model is better achieved by using [special variables](special_tokens.md).