From b7694f1d11d77e795f8c178411aa977a9e2b8a32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 26 Jul 2024 13:17:41 +0200 Subject: [PATCH] Dispatch template on model name --- prompts/templates.py | 66 +++++++++++++++++++++++++++++++++-------- tests/test_templates.py | 34 +++++++++++++++++---- 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/prompts/templates.py b/prompts/templates.py index 9e9c12f..82d5cc3 100644 --- a/prompts/templates.py +++ b/prompts/templates.py @@ -1,36 +1,43 @@ import inspect import re -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, cast +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Callable, Dict, Hashable, Optional, cast from jinja2 import Environment, StrictUndefined @dataclass class Template: - """Represents a prompt function. + """Represents a prompt template. - We return a `Prompt` class instead of a simple function so the - template defined in prompt functions can be accessed. + A prompt template is a callable that, given a Jinja2 template and a set of values, + renders the template using those values. It is recommended to instantiate `Temaplate` + using the `template` decorator, which extracts the template from the function's + docstring and its variables from the function's signature. + + It is not uncommon that, for the same taks, different models will perform + better with different prompt. Here we thus allow to dispatch to associate a + prompt with a task and dispatch the prompt based on the model being used; a + `Template` instance is thus also a registry that associates model names to + other templates. - TODO: Store variables instead of signature? Attributes ---------- template The template to render. - model - The model used for inference. signature The prompt function's signature. + registry + Registry that maps function names to their respective `Template` + instances. """ template: str signature: inspect.Signature - - def __post_init__(self): - self.parameters: List[str] = list(self.signature.parameters.keys()) + registry: Dict[str, Callable] = field(default_factory=dict) def __call__(self, *args, **kwargs) -> str: """Render and return the template. @@ -47,6 +54,40 @@ def __call__(self, *args, **kwargs) -> str: def __str__(self): return self.template + def __getitem__(self, model_name: str): + """Get the prompt template corresponding to a model name. + + We return the default template when trying to fetch a + template for a model that was not registered. + + Parameters + ---------- + model_name + The name of the model whose prompt template we want to retrieve. + + Returns + ------- + The template registered for the model name. + + """ + try: + return self.registry[model_name] + except KeyError: + return self + + def register(self, model_name: str): + """Register the prompt template, as represented by a prompt function, + for the model name. + + """ + + def wrapper(fn: Callable): + tpl = template(fn) + self.registry[model_name] = tpl + return tpl + + return wrapper + def template(fn: Callable) -> Template: """Decorate a function that contains a prompt template. @@ -96,7 +137,8 @@ def template(fn: Callable) -> Template: return Template(template, signature) -def render(template: str, **values: Optional[Dict[str, Any]]) -> str: +@lru_cache +def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str: r"""Parse a Jinaj2 template and translate it into an Outlines graph. This function removes extra whitespaces and linebreaks from templates to diff --git a/tests/test_templates.py b/tests/test_templates.py index 6be76af..6c2af9c 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -67,7 +67,7 @@ def test_render_jinja(): """ # Notice the newline after the end of the loop - examples = ["one", "two"] + examples = ("one", "two") prompt = render( """ {% for e in examples %} @@ -78,7 +78,7 @@ def test_render_jinja(): assert prompt == "Example: one\nExample: two\n" # We can remove the newline by cloing with -%} - examples = ["one", "two"] + examples = ("one", "two") prompt = render( """ {% for e in examples %} @@ -102,7 +102,7 @@ def test_render_jinja(): assert render(tpl, is_true=False) == "final" # Ignore leading white spaces - examples = ["one", "two"] + examples = ("one", "two") prompt = render( """ {% for e in examples %} @@ -114,7 +114,7 @@ def test_render_jinja(): assert prompt == "one\ntwo\n" # Do not ignore leading white spaces - examples = ["one", "two"] + examples = ("one", "two") prompt = render( """ {% for e in examples %} @@ -132,7 +132,7 @@ def test_tpl(variable): """{{variable}} test""" assert test_tpl.template == "{{variable}} test" - assert test_tpl.parameters == ["variable"] + assert list(test_tpl.signature.parameters.keys()) == ["variable"] with pytest.raises(TypeError): test_tpl(v="test") @@ -157,7 +157,7 @@ def test_kwarg_tpl(var, other_var="other"): """{{var}} and {{other_var}}""" assert test_kwarg_tpl.template == "{{var}} and {{other_var}}" - assert test_kwarg_tpl.parameters == ["var", "other_var"] + assert list(test_kwarg_tpl.signature.parameters.keys()) == ["var", "other_var"] p = test_kwarg_tpl("test") assert p == "test and other" @@ -181,3 +181,25 @@ def test_empty(variable): @prompts.template def test_only_code(variable): return variable + + +def test_dispatch(): + + @prompts.template + def simple_prompt(query: str): + """{{ query }}""" + + @simple_prompt.register("provider/name") + def simple_prompt_name(query: str): + """name: {{ query }}""" + + assert list(simple_prompt.registry.keys()) == ["provider/name"] + assert callable(simple_prompt) + assert callable(simple_prompt["provider/name"]) + + assert simple_prompt("test") == "test" + assert simple_prompt_name("test") == "name: test" + + assert simple_prompt("test") == "test" + assert simple_prompt["gpt2"]("test") == "test" + assert simple_prompt["provider/name"]("test") == "name: test"