Skip to content

Commit

Permalink
Create the Dottxt model
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin Picard committed Feb 21, 2025
1 parent b56acd1 commit 026cb0c
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 1 deletion.
49 changes: 49 additions & 0 deletions docs/reference/models/dottxt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Dottxt

!!! Installation

To be able to use Dottxt in Outlines, you must install the `dottxt` python sdk.

```bash
pip install dottxt
```

You also need to have a Dottxt API key. This API key must either be set as an environment variable called `DOTTXT_API_KEY` or be provided to the `outlines.models.Dottxt` class when instantiating it.

## Generate text

Dottxt only supports constrained generation with the `Json` output type. The input of the generation must be a string. Batch generation is not supported.
Thus, you must always provide an output type.

You can either create a `Generator` object and call it afterward:
```python
from outlines.models import Dottxt
from outlines.generate import Generator
from pydantic import BaseModel

class Character(BaseModel):
name: str

model = Dottxt()
generator = Generator(model, Character)
result = generator("Create a character")
```

or call the model directly with the output type:
```python
from outlines.models import Dottxt
from pydantic import BaseModel

class Character(BaseModel):
name: str

model = Dottxt()
result = model("Create a character", Character)
```

In any case, compilation for a given output type happens only once (the first time it is used to generate text).

## Optional parameters

You can provide the same optional parameters you would pass to the `dottxt` sdk's client both during the initialization of the `Dottxt` class and when generating text.
Consult the [dottxt python sdk Github repository](https://github.com/dottxt-ai/dottxt-python) for the full list of parameters.
3 changes: 2 additions & 1 deletion outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .anthropic import Anthropic
from .base import Model, ModelTypeAdapter
from .dottxt import Dottxt
from .exllamav2 import ExLlamaV2Model, exl2
from .gemini import Gemini
from .llamacpp import LlamaCpp
Expand All @@ -25,4 +26,4 @@
]

LocalModel = LlamaCpp
APIModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini, Ollama]
APIModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini, Ollama, Dottxt]
85 changes: 85 additions & 0 deletions outlines/models/dottxt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Integration with Dottxt's API."""
import json
from functools import singledispatchmethod
from types import NoneType

from outlines.models.base import Model, ModelTypeAdapter
from outlines.types import Json

__all__ = ["Dottxt"]


class DottxtTypeAdapter(ModelTypeAdapter):
@singledispatchmethod
def format_input(self, model_input):
"""Generate the `messages` argument to pass to the client.
Argument
--------
model_input
The input passed by the user.
"""
raise NotImplementedError(
f"The input type {input} is not available with Dottxt. The only available type is `str`."
)

@format_input.register(str)
def format_str_input(self, model_input: str):
"""Generate the `messages` argument to pass to the client when the user
only passes a prompt.
"""
return model_input

@singledispatchmethod
def format_output_type(self, output_type):
"""Format the output type to pass to the client."""
raise NotImplementedError(
f"The input type {input} is not available with Dottxt."
)

@format_output_type.register(Json)
def format_json_output_type(self, output_type: Json):
"""Format the output type to pass to the client."""
schema = output_type.to_json_schema()
return json.dumps(schema)

@format_output_type.register(NoneType)
def format_none_output_type(self, output_type: None):
"""Format the output type to pass to the client."""
raise NotImplementedError(
"You must provide an output type. Dottxt only supports constrained generation."
)


class Dottxt(Model):
"""Thin wrapper around the `dottxt.client.Dottxt` client.
This wrapper is used to convert the input and output types specified by the
users at a higher level to arguments to the `dottxt.client.Dottxt` client.
"""

def __init__(self, model_name: str = None, *args, **kwargs):
from dottxt.client import Dottxt

self.client = Dottxt(*args, **kwargs)
self.model_name = model_name
self.type_adapter = DottxtTypeAdapter()

def generate(self, model_input, output_type=None, **inference_kwargs):
prompt = self.type_adapter.format_input(model_input)
json_schema = self.type_adapter.format_output_type(output_type)

if self.model_name:
inference_kwargs["model_name"] = self.model_name

print("HEREEE", prompt, json_schema, inference_kwargs)

completion = self.client.json(
prompt,
json_schema,
**inference_kwargs,
)
return completion.data
69 changes: 69 additions & 0 deletions tests/models/test_dottxt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import json
import os

import pytest
from pydantic import BaseModel

from outlines.generate import Generator
from outlines.models.dottxt import Dottxt
from outlines.types import Json


class User(BaseModel):
first_name: str
last_name: str
user_id: int


@pytest.fixture
def api_key():
"""Get the Dottxt API key from the environment, providing a default value if not found.
This fixture should be used for tests that do not make actual api calls,
but still require to initialize the Dottxt client.
"""
api_key = os.getenv("DOTTXT_API_KEY")
if not api_key:
return "MOCK_API_KEY"
return api_key


def test_dottxt_wrong_init_parameters(api_key):
with pytest.raises(TypeError, match="got an unexpected"):
Dottxt(api_key=api_key, foo=10)


def test_dottxt_wrong_output_type(api_key):
with pytest.raises(NotImplementedError, match="must provide an output type"):
model = Dottxt(api_key=api_key)
model("prompt")


@pytest.mark.api_call
def test_dottxt_wrong_input_type(api_key):
with pytest.raises(NotImplementedError, match="is not available"):
model = Dottxt(api_key=api_key)
model(["prompt"], Json(User))


@pytest.mark.api_call
def test_dottxt_wrong_inference_parameters(api_key):
with pytest.raises(TypeError, match="got an unexpected"):
model = Dottxt(api_key=api_key)
model("prompt", Json(User), foo=10)


@pytest.mark.api_call
def test_dottxt_direct_call(api_key):
model = Dottxt(api_key=api_key, model_name="meta-llama/Llama-3.1-8B-Instruct")
result = model("Create a user", Json(User))
assert "first_name" in json.loads(result)


@pytest.mark.api_call
def test_dottxt_generator_call(api_key):
model = Dottxt(api_key=api_key, model_name="meta-llama/Llama-3.1-8B-Instruct")
generator = Generator(model, Json(User))
result = generator("Create a user")
assert "first_name" in json.loads(result)

0 comments on commit 026cb0c

Please sign in to comment.