diff --git a/docs/reference/models/dottxt.md b/docs/reference/models/dottxt.md new file mode 100644 index 000000000..bfb22a5a2 --- /dev/null +++ b/docs/reference/models/dottxt.md @@ -0,0 +1,49 @@ +# Dottxt + +!!! Installation + + To be able to use Dottxt in Outlines, you must install the `dottxt` python sdk. + + ```bash + pip install dottxt + ``` + + You also need to have a Dottxt API key. This API key must either be set as an environment variable called `DOTTXT_API_KEY` or be provided to the `outlines.models.Dottxt` class when instantiating it. + +## Generate text + +Dottxt only supports constrained generation with the `Json` output type. The input of the generation must be a string. Batch generation is not supported. +Thus, you must always provide an output type. + +You can either create a `Generator` object and call it afterward: +```python +from outlines.models import Dottxt +from outlines.generate import Generator +from pydantic import BaseModel + +class Character(BaseModel): + name: str + +model = Dottxt() +generator = Generator(model, Character) +result = generator("Create a character") +``` + +or call the model directly with the output type: +```python +from outlines.models import Dottxt +from pydantic import BaseModel + +class Character(BaseModel): + name: str + +model = Dottxt() +result = model("Create a character", Character) +``` + +In any case, compilation for a given output type happens only once (the first time it is used to generate text). + +## Optional parameters + +You can provide the same optional parameters you would pass to the `dottxt` sdk's client both during the initialization of the `Dottxt` class and when generating text. +Consult the [dottxt python sdk Github repository](https://github.com/dottxt-ai/dottxt-python) for the full list of parameters. diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index ea228c5c6..5816bbcb1 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -10,6 +10,7 @@ from .anthropic import Anthropic from .base import Model, ModelTypeAdapter +from .dottxt import Dottxt from .exllamav2 import ExLlamaV2Model, exl2 from .gemini import Gemini from .llamacpp import LlamaCpp @@ -25,4 +26,4 @@ ] LocalModel = LlamaCpp -APIModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini, Ollama] +APIModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini, Ollama, Dottxt] diff --git a/outlines/models/dottxt.py b/outlines/models/dottxt.py new file mode 100644 index 000000000..bef605286 --- /dev/null +++ b/outlines/models/dottxt.py @@ -0,0 +1,85 @@ +"""Integration with Dottxt's API.""" +import json +from functools import singledispatchmethod +from types import NoneType + +from outlines.models.base import Model, ModelTypeAdapter +from outlines.types import Json + +__all__ = ["Dottxt"] + + +class DottxtTypeAdapter(ModelTypeAdapter): + @singledispatchmethod + def format_input(self, model_input): + """Generate the `messages` argument to pass to the client. + + Argument + -------- + model_input + The input passed by the user. + + """ + raise NotImplementedError( + f"The input type {input} is not available with Dottxt. The only available type is `str`." + ) + + @format_input.register(str) + def format_str_input(self, model_input: str): + """Generate the `messages` argument to pass to the client when the user + only passes a prompt. + + """ + return model_input + + @singledispatchmethod + def format_output_type(self, output_type): + """Format the output type to pass to the client.""" + raise NotImplementedError( + f"The input type {input} is not available with Dottxt." + ) + + @format_output_type.register(Json) + def format_json_output_type(self, output_type: Json): + """Format the output type to pass to the client.""" + schema = output_type.to_json_schema() + return json.dumps(schema) + + @format_output_type.register(NoneType) + def format_none_output_type(self, output_type: None): + """Format the output type to pass to the client.""" + raise NotImplementedError( + "You must provide an output type. Dottxt only supports constrained generation." + ) + + +class Dottxt(Model): + """Thin wrapper around the `dottxt.client.Dottxt` client. + + This wrapper is used to convert the input and output types specified by the + users at a higher level to arguments to the `dottxt.client.Dottxt` client. + + """ + + def __init__(self, model_name: str = None, *args, **kwargs): + from dottxt.client import Dottxt + + self.client = Dottxt(*args, **kwargs) + self.model_name = model_name + self.type_adapter = DottxtTypeAdapter() + + def generate(self, model_input, output_type=None, **inference_kwargs): + prompt = self.type_adapter.format_input(model_input) + json_schema = self.type_adapter.format_output_type(output_type) + + if self.model_name: + inference_kwargs["model_name"] = self.model_name + + print("HEREEE", prompt, json_schema, inference_kwargs) + + completion = self.client.json( + prompt, + json_schema, + **inference_kwargs, + ) + return completion.data diff --git a/tests/models/test_dottxt.py b/tests/models/test_dottxt.py new file mode 100644 index 000000000..fe1aa9be5 --- /dev/null +++ b/tests/models/test_dottxt.py @@ -0,0 +1,69 @@ +import json +import os + +import pytest +from pydantic import BaseModel + +from outlines.generate import Generator +from outlines.models.dottxt import Dottxt +from outlines.types import Json + + +class User(BaseModel): + first_name: str + last_name: str + user_id: int + + +@pytest.fixture +def api_key(): + """Get the Dottxt API key from the environment, providing a default value if not found. + + This fixture should be used for tests that do not make actual api calls, + but still require to initialize the Dottxt client. + + """ + api_key = os.getenv("DOTTXT_API_KEY") + if not api_key: + return "MOCK_API_KEY" + return api_key + + +def test_dottxt_wrong_init_parameters(api_key): + with pytest.raises(TypeError, match="got an unexpected"): + Dottxt(api_key=api_key, foo=10) + + +def test_dottxt_wrong_output_type(api_key): + with pytest.raises(NotImplementedError, match="must provide an output type"): + model = Dottxt(api_key=api_key) + model("prompt") + + +@pytest.mark.api_call +def test_dottxt_wrong_input_type(api_key): + with pytest.raises(NotImplementedError, match="is not available"): + model = Dottxt(api_key=api_key) + model(["prompt"], Json(User)) + + +@pytest.mark.api_call +def test_dottxt_wrong_inference_parameters(api_key): + with pytest.raises(TypeError, match="got an unexpected"): + model = Dottxt(api_key=api_key) + model("prompt", Json(User), foo=10) + + +@pytest.mark.api_call +def test_dottxt_direct_call(api_key): + model = Dottxt(api_key=api_key, model_name="meta-llama/Llama-3.1-8B-Instruct") + result = model("Create a user", Json(User)) + assert "first_name" in json.loads(result) + + +@pytest.mark.api_call +def test_dottxt_generator_call(api_key): + model = Dottxt(api_key=api_key, model_name="meta-llama/Llama-3.1-8B-Instruct") + generator = Generator(model, Json(User)) + result = generator("Create a user") + assert "first_name" in json.loads(result)