Skip to content

Commit

Permalink
Dispatch template on model name
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jul 29, 2024
1 parent a9a5a63 commit b7694f1
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 18 deletions.
66 changes: 54 additions & 12 deletions prompts/templates.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,43 @@
import inspect
import re
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, cast
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Callable, Dict, Hashable, Optional, cast

from jinja2 import Environment, StrictUndefined


@dataclass
class Template:
"""Represents a prompt function.
"""Represents a prompt template.
We return a `Prompt` class instead of a simple function so the
template defined in prompt functions can be accessed.
A prompt template is a callable that, given a Jinja2 template and a set of values,
renders the template using those values. It is recommended to instantiate `Temaplate`
using the `template` decorator, which extracts the template from the function's
docstring and its variables from the function's signature.
It is not uncommon that, for the same taks, different models will perform
better with different prompt. Here we thus allow to dispatch to associate a
prompt with a task and dispatch the prompt based on the model being used; a
`Template` instance is thus also a registry that associates model names to
other templates.
TODO: Store variables instead of signature?
Attributes
----------
template
The template to render.
model
The model used for inference.
signature
The prompt function's signature.
registry
Registry that maps function names to their respective `Template`
instances.
"""

template: str
signature: inspect.Signature

def __post_init__(self):
self.parameters: List[str] = list(self.signature.parameters.keys())
registry: Dict[str, Callable] = field(default_factory=dict)

def __call__(self, *args, **kwargs) -> str:
"""Render and return the template.
Expand All @@ -47,6 +54,40 @@ def __call__(self, *args, **kwargs) -> str:
def __str__(self):
return self.template

def __getitem__(self, model_name: str):
"""Get the prompt template corresponding to a model name.
We return the default template when trying to fetch a
template for a model that was not registered.
Parameters
----------
model_name
The name of the model whose prompt template we want to retrieve.
Returns
-------
The template registered for the model name.
"""
try:
return self.registry[model_name]
except KeyError:
return self

def register(self, model_name: str):
"""Register the prompt template, as represented by a prompt function,
for the model name.
"""

def wrapper(fn: Callable):
tpl = template(fn)
self.registry[model_name] = tpl
return tpl

return wrapper


def template(fn: Callable) -> Template:
"""Decorate a function that contains a prompt template.
Expand Down Expand Up @@ -96,7 +137,8 @@ def template(fn: Callable) -> Template:
return Template(template, signature)


def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
@lru_cache
def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str:
r"""Parse a Jinaj2 template and translate it into an Outlines graph.
This function removes extra whitespaces and linebreaks from templates to
Expand Down
34 changes: 28 additions & 6 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_render_jinja():
"""

# Notice the newline after the end of the loop
examples = ["one", "two"]
examples = ("one", "two")
prompt = render(
"""
{% for e in examples %}
Expand All @@ -78,7 +78,7 @@ def test_render_jinja():
assert prompt == "Example: one\nExample: two\n"

# We can remove the newline by cloing with -%}
examples = ["one", "two"]
examples = ("one", "two")
prompt = render(
"""
{% for e in examples %}
Expand All @@ -102,7 +102,7 @@ def test_render_jinja():
assert render(tpl, is_true=False) == "final"

# Ignore leading white spaces
examples = ["one", "two"]
examples = ("one", "two")
prompt = render(
"""
{% for e in examples %}
Expand All @@ -114,7 +114,7 @@ def test_render_jinja():
assert prompt == "one\ntwo\n"

# Do not ignore leading white spaces
examples = ["one", "two"]
examples = ("one", "two")
prompt = render(
"""
{% for e in examples %}
Expand All @@ -132,7 +132,7 @@ def test_tpl(variable):
"""{{variable}} test"""

assert test_tpl.template == "{{variable}} test"
assert test_tpl.parameters == ["variable"]
assert list(test_tpl.signature.parameters.keys()) == ["variable"]

with pytest.raises(TypeError):
test_tpl(v="test")
Expand All @@ -157,7 +157,7 @@ def test_kwarg_tpl(var, other_var="other"):
"""{{var}} and {{other_var}}"""

assert test_kwarg_tpl.template == "{{var}} and {{other_var}}"
assert test_kwarg_tpl.parameters == ["var", "other_var"]
assert list(test_kwarg_tpl.signature.parameters.keys()) == ["var", "other_var"]

p = test_kwarg_tpl("test")
assert p == "test and other"
Expand All @@ -181,3 +181,25 @@ def test_empty(variable):
@prompts.template
def test_only_code(variable):
return variable


def test_dispatch():

@prompts.template
def simple_prompt(query: str):
"""{{ query }}"""

@simple_prompt.register("provider/name")
def simple_prompt_name(query: str):
"""name: {{ query }}"""

assert list(simple_prompt.registry.keys()) == ["provider/name"]
assert callable(simple_prompt)
assert callable(simple_prompt["provider/name"])

assert simple_prompt("test") == "test"
assert simple_prompt_name("test") == "name: test"

assert simple_prompt("test") == "test"
assert simple_prompt["gpt2"]("test") == "test"
assert simple_prompt["provider/name"]("test") == "name: test"

0 comments on commit b7694f1

Please sign in to comment.