From 038a70adc62830ca39bd22f37197d5a098d35bc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 30 Jul 2024 12:49:38 +0200 Subject: [PATCH] Automatically render special EOS and BOS tokens We define Jinja2 `eos` and `bos` global variables that are rendered as EOS and BOS tokens. --- docs/reference/special_tokens.md | 31 +++++++++++++++++++++++ mkdocs.yml | 1 + prompts/templates.py | 42 +++++++++++++++++++++++++++++--- tests/test_templates.py | 22 +++++++++++++++++ 4 files changed, 93 insertions(+), 3 deletions(-) create mode 100644 docs/reference/special_tokens.md diff --git a/docs/reference/special_tokens.md b/docs/reference/special_tokens.md new file mode 100644 index 0000000..b98c736 --- /dev/null +++ b/docs/reference/special_tokens.md @@ -0,0 +1,31 @@ +# Handle special tokens + +Tokens that indicate the beginnning of a sequence, an end of sequence, that +delineate user and assistant turns in a conversation, etc. are model-specific. +This means that one needs to write a new prompt each time they use a new model, +only replacing these special tokens. This is error-prone and leads to duplicated +work. + +`prompts` provides special variables in its templates that allows user to use special tokens in their prompts in a model-agnotic way: + +```python +import prompts + + +@prompts.template +def a_simple_prompt(query: str): + """{{ bos + query + eos }}""" + + +print(a_simple_prompt["mistralai/Mistral-7B-v0.1"]("question")) +# question + +print(a_simple_prompt["google/gemma-2-9b"]("question")) +# question +``` + + +!!! note "Registry" + + The registry is currently limited to a few models. Please [open an issue](https://github.com/outlines-dev/prompts/issues) if you + want to use `prompts` with a model that is not currently in the registry. diff --git a/mkdocs.yml b/mkdocs.yml index d15ce98..9d1f2b8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -75,3 +75,4 @@ nav: - reference/index.md - Prompt template: reference/template.md - Dispatch: reference/dispatch.md + - Special tokens: reference/special_tokens.md diff --git a/prompts/templates.py b/prompts/templates.py index 052e983..393db19 100644 --- a/prompts/templates.py +++ b/prompts/templates.py @@ -1,8 +1,9 @@ import inspect import re +import warnings from dataclasses import dataclass, field from functools import lru_cache -from typing import Callable, Dict, Hashable, Optional, cast +from typing import Callable, Dict, Hashable, Optional, Tuple, cast from jinja2 import Environment, StrictUndefined @@ -29,6 +30,8 @@ class Template: The template to render. signature The prompt function's signature. + model + The model the `Template` is associated with. Defaults to `None`. registry Registry that maps function names to their respective `Template` instances. @@ -50,7 +53,7 @@ def __call__(self, *args, **kwargs) -> str: """ bound_arguments = self.signature.bind(*args, **kwargs) bound_arguments.apply_defaults() - return render(self.template, **bound_arguments.arguments) + return render(self.template, self.model, **bound_arguments.arguments) def __str__(self): return self.template @@ -74,6 +77,7 @@ def __getitem__(self, model_name: str): try: return self.registry[model_name] except KeyError: + self.model = model_name return self def register(self, model_name: str): @@ -140,13 +144,21 @@ def template(fn: Callable) -> Template: @lru_cache -def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str: +def render( + template: str, + model_name: Optional[str] = None, + **values: Optional[Dict[str, Hashable]], +) -> str: r"""Parse a Jinaj2 template and translate it into an Outlines graph. This function removes extra whitespaces and linebreaks from templates to allow users to enter prompts more naturally than if they used Python's constructs directly. See the examples for a detailed explanation. + We also define the `bos` and `eos` special variables which, when used, will + be replaced by the model's BOS and EOS tokens respectively. This allows you + to write prompts that are model-agnostic. + Examples -------- @@ -223,6 +235,8 @@ def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str: ---------- template A string that contains a template written with the Jinja2 syntax. + model_name + The name of the model to which the rendered string will be passed. **values Map from the variables in the template to their value. @@ -245,12 +259,34 @@ def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str: # used to continue to the next line without linebreak. cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) + # Warn the user when the model is not present in the special token registry + if model_name not in SPECIAL_TOKENS: + warnings.warn( + UserWarning( + f"The model {model_name} is not present in the special token registry." + "As a result, EOS and BOS tokens will be rendered as the empty string." + "Please open an issue: https://github.com/outlines-dev/prompts/issues" + "And ask for the model to be added to the registry." + ) + ) + env = Environment( trim_blocks=True, lstrip_blocks=True, keep_trailing_newline=True, undefined=StrictUndefined, ) + env.globals["bos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[0] + env.globals["eos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[1] jinja_template = env.from_string(cleaned_template) return jinja_template.render(**values) + + +# (BOS, EOS) +SPECIAL_TOKENS: Dict[Optional[str], Tuple[str, str]] = { + None: ("", ""), + "google/gemma-2-9b": ("", ""), + "openai-community/gpt2": ("", "<|endoftext|>"), + "mistralai/Mistral-7B-v0.1": ("", ""), +} diff --git a/tests/test_templates.py b/tests/test_templates.py index 896a231..8bd295d 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -183,6 +183,7 @@ def test_only_code(variable): return variable +@pytest.mark.filterwarnings("ignore: The model") def test_dispatch(): @prompts.template @@ -207,3 +208,24 @@ def simple_prompt_name(query: str): assert simple_prompt("test") == "test" assert simple_prompt["gpt2"]("test") == "test" assert simple_prompt["provider/name"]("test") == "name: test" + + +def test_special_tokens(): + + @prompts.template + def simple_prompt(query: str): + """{{ bos + query + eos }}""" + + assert simple_prompt("test") == "test" + assert simple_prompt["openai-community/gpt2"]("test") == "test<|endoftext|>" + assert simple_prompt["mistralai/Mistral-7B-v0.1"]("test") == "test" + + +def test_warn(): + + @prompts.template + def simple_prompt(): + """test""" + + with pytest.warns(UserWarning, match="not present in the special token"): + simple_prompt["non-existent-model"]()