Skip to content

Commit

Permalink
Create the Dottxt model
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin Picard committed Feb 17, 2025
1 parent c4198ee commit 36e764f
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 2 deletions.
53 changes: 53 additions & 0 deletions docs/reference/models/dottxt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Dottxt

!!! Installation

To be able to use Dottxt in Outlines, you must install the `dottxt` python sdk.

```bash
pip install dottxt
```

You also need to have a Dottxt API key. This API key must either be set as an environment variable called `DOTTXT_API_KEY` or be provided to the `outlines.models.Dottxt` class when instantiating it.

## Generate text

Dottxt only supports constrained generation with the `Json` output type. The input of the generation must be a string. Batch generation is not supported.
Thus, you must always provide an output type.

Before generating text, Dottxt first compiles the output type provided into a schema.
This step happens when you create a `Generator` object.
```python
from outlines.models import Dottxt
from outlines.generate import Generator
from pydantic import BaseModel

class Character(BaseModel):
name: str

model = Dottxt()
generator = Generator(model, Character)
```
You can then use the generator to generate text.
```python
result = generator("Create a character")
```

If you instead call the model directly, this compilation step happens automatically.
```python
from outlines.models import Dottxt
from pydantic import BaseModel, Field

class Character(BaseModel):
name: str

model = Dottxt()
result = model("Create a character", Character)
```

In any case, compilation happens only once for a given output type. The Dottxt API handles storing the compiled schema on their servers.

## Optional parameters

You can provide the same optional parameters you would pass to the `dottxt` sdk's client both during the initialization of the `Dottxt` class and when generating text.
Consult the [dottxt python sdk Github repository](https://github.com/dottxt-ai/dottxt-python) for the full list of parameters.
31 changes: 30 additions & 1 deletion outlines/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Any, Optional, Union, cast, get_args

from outlines.models import APIModel, LlamaCpp, LocalModel
from outlines.models import APIModel, CompiledAPIModel, LlamaCpp, LocalModel
from outlines.processors import CFGLogitsProcessor, RegexLogitsProcessor
from outlines.types import CFG, Choice, Json, List, Regex

Expand Down Expand Up @@ -41,6 +41,33 @@ def __call__(self, prompt, **inference_kwargs):
return self.model.generate(prompt, self.output_type, **inference_kwargs)


@dataclass
class CompiledAPIGenerator:
"""Represents an API-based generator for which the output type is first compiled.
Attributes
----------
model
An instance of a model wrapper.
output_type
The output type.
"""

model: CompiledAPIModel
output_type: Optional[Union[Json, List, Choice, Regex]]

def __post_init__(self):
if isinstance(self.output_type, CFG):
raise NotImplementedError(
"CFG generation is not supported for compiled API-based models"
)
self.compilation_output = self.model.compile_output_type(self.output_type)

def __call__(self, prompt, **inference_kwargs):
return self.model.generate(prompt, self.compilation_output, **inference_kwargs)


@dataclass
class LocalGenerator:
"""Represents a local model-based generator.
Expand Down Expand Up @@ -85,5 +112,7 @@ def Generator(
):
if isinstance(model, APIModel): # type: ignore
return APIGenerator(model, output_type) # type: ignore
elif isinstance(model, CompiledAPIModel):
return CompiledAPIGenerator(model, output_type) # type: ignore
else:
return LocalGenerator(model, output_type) # type: ignore
2 changes: 2 additions & 0 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .anthropic import Anthropic
from .base import Model, ModelTypeAdapter
from .dottxt import Dottxt
from .exllamav2 import ExLlamaV2Model, exl2
from .gemini import Gemini
from .llamacpp import LlamaCpp
Expand All @@ -23,3 +24,4 @@

LocalModel = LlamaCpp
APIModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini]
CompiledAPIModel = Dottxt
96 changes: 96 additions & 0 deletions outlines/models/dottxt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Integration with Dottxt's API."""
import json
from functools import singledispatchmethod
from types import NoneType

from outlines.models.base import Model, ModelTypeAdapter
from outlines.types import Json

__all__ = ["Dottxt"]


class DottxtTypeAdapter(ModelTypeAdapter):
@singledispatchmethod
def format_input(self, model_input):
"""Generate the `messages` argument to pass to the client.
Argument
--------
model_input
The input passed by the user.
, Choice, Regex
"""
raise NotImplementedError(
f"The input type {input} is not available with Dottxt. The only available type is `str`."
)

@format_input.register(str)
def format_str_input(self, model_input: str):
"""Generate the `messages` argument to pass to the client when the user
only passes a prompt.
"""
return model_input

@singledispatchmethod
def format_output_type(self, output_type):
"""Format the output type to pass to the client."""
raise NotImplementedError(
f"The input type {input} is not available with Dottxt."
)

@format_output_type.register(Json)
def format_json_output_type(self, output_type: Json):
"""Format the output type to pass to the client."""
schema = output_type.to_json_schema()
return json.dumps(schema)

@format_output_type.register(NoneType)
def format_none_output_type(self, output_type: None):
"""Format the output type to pass to the client."""
raise NotImplementedError(
"You must provide an output type. Dottxt only supports constrained generation."
)


class Dottxt(Model):
"""Thin wrapper around the `dottxt.client.Dottxt` client.
This wrapper is used to convert the input and output types specified by the
users at a higher level to arguments to the `dottxt.client.Dottxt` client.
"""

def __init__(self, *args, **kwargs):
from dottxt.client import Dottxt

self.client = Dottxt(*args, **kwargs)
self.type_adapter = DottxtTypeAdapter()

def compile_output_type(self, output_type) -> str:
"""Call the Dottxt sdk to create a schema for the output type provided.
The funtion first checks whether a schema already exists for the output.
Otherwise, it creates a new schema. In both cases, it returns the js_id
of the schema to be provided when calling the `create_completion` method.
"""
json_schema = self.type_adapter.format_output_type(output_type)
schema_status = self.client.get_schema_status_by_source(json_schema)
if schema_status is None:
schema_status = self.client.create_schema(json_schema)
if schema_status.status == "in_progress":
schema_status = self.client.poll_schema_status(schema_status.js_id)
if schema_status.status != "complete":
raise ValueError(f"Failed to create the schema: {schema_status}")
return schema_status.js_id

def generate(self, model_input, output_type=None, **inference_kwargs):
prompt = self.type_adapter.format_input(model_input)

completion = self.client.create_completion(
prompt,
output_type,
**inference_kwargs,
)
return completion.data
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ anthropic = ["anthropic"]
gemini = ["google-generativeai"]
llamacpp = ["llama-cpp-python", "transformers", "datasets", "numpy<2"]
exllamav2 = ["exllamav2"]
dottxt = ["dottxt"]
test = [
"pre-commit",
"pytest",
Expand All @@ -76,7 +77,8 @@ test = [
"transformers",
"pillow",
"exllamav2",
"jax"
"jax",
"dottxt"
]
serve = [
"vllm>=0.3.0",
Expand Down Expand Up @@ -154,6 +156,7 @@ module = [
"pycountry.*",
"airportsdata.*",
"outlines_core.*",
"dottxt.*",
]
ignore_missing_imports = true

Expand Down
65 changes: 65 additions & 0 deletions tests/models/test_dottxt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import json
import os

import pytest
from pydantic import BaseModel

from outlines.generate import Generator
from outlines.models.dottxt import Dottxt
from outlines.types import Json


class User(BaseModel):
name: str


@pytest.fixture
def api_key():
"""Get the Dottxt API key from the environment, providing a default value if not found.
This fixture should be used for tests that do not make actual api calls,
but still require to initialize the Dottxt client.
"""
api_key = os.getenv("DOTTXT_API_KEY")
if not api_key:
return "MOCK_API_KEY"
return api_key


def test_dottxt_wrong_init_parameters(api_key):
with pytest.raises(TypeError, match="got an unexpected"):
Dottxt(api_key=api_key, foo=10)


def test_dottxt_wrong_inference_parameters(api_key):
with pytest.raises(TypeError, match="got an unexpected"):
model = Dottxt(api_key=api_key)
model("prompt", Json(User), foo=10)


def test_dottxt_wrong_input_type(api_key):
with pytest.raises(NotImplementedError, match="is not available"):
model = Dottxt(api_key=api_key)
model(["prompt"], Json(User))


def test_dottxt_wrong_output_type(api_key):
with pytest.raises(NotImplementedError, match="must provide an output type"):
model = Dottxt(api_key=api_key)
model("prompt")


@pytest.mark.api_call
def test_dottxt_direct_call(api_key):
model = Dottxt(api_key=api_key)
result = model("Create a user", Json(User))
assert "name" in json.loads(result)


@pytest.mark.api_call
def test_dottxt_generator_call(api_key):
model = Dottxt(api_key=api_key)
generator = Generator(model, Json(User))
result = generator("Create a user")
assert "name" in json.loads(result)

0 comments on commit 36e764f

Please sign in to comment.