forked from dottxt-ai/outlines
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Robin Picard
committed
Feb 17, 2025
1 parent
c4198ee
commit 414bd00
Showing
6 changed files
with
278 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Dottxt | ||
|
||
!!! Installation | ||
|
||
To be able to use Dottxt in Outlines, you must install the `dottxt` python sdk. | ||
|
||
```bash | ||
pip install dottxt | ||
``` | ||
|
||
You also need to have a Dottxt API key. This API key must either be set as an environment variable called `DOTTXT_API_KEY` or be provided to the `outlines.models.Dottxt` class when instantiating it. | ||
|
||
## Generate text | ||
|
||
Dottxt only supports constrained generation with the `Json` output type. The input of the generation must be a string. Batch generation is not supported. | ||
Thus, you must always provide an output type. | ||
|
||
Before generating text, Dottxt first compiles the output type provided into a schema. | ||
This step happens when you create a `Generator` object. | ||
```python | ||
from outlines.models import Dottxt | ||
from outlines.generate import Generator | ||
from pydantic import BaseModel | ||
|
||
class Character(BaseModel): | ||
name: str | ||
|
||
model = Dottxt() | ||
generator = Generator(model, Character) | ||
``` | ||
You can then use the generator to generate text. | ||
```python | ||
result = generator("Create a character") | ||
``` | ||
|
||
If you instead call the model directly, this compilation step happens automatically. | ||
```python | ||
from outlines.models import Dottxt | ||
from pydantic import BaseModel, Field | ||
|
||
class Character(BaseModel): | ||
name: str | ||
|
||
model = Dottxt() | ||
result = model("Create a character", Character) | ||
``` | ||
|
||
In any case, compilation happens only once for a given output type. The Dottxt API handles storing the compiled schema on their servers. | ||
|
||
## Optional parameters | ||
|
||
You can provide the same optional parameters you would pass to the `dottxt` sdk's client both during the initialization of the `Dottxt` class and when generating text. | ||
Consult the [dottxt python sdk Github repository](https://github.com/dottxt-ai/dottxt-python) for the full list of parameters. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
"""Integration with Dottxt's API.""" | ||
import json | ||
from functools import singledispatchmethod | ||
from types import NoneType | ||
|
||
from outlines.models.base import Model, ModelTypeAdapter | ||
from outlines.types import Json | ||
|
||
__all__ = ["Dottxt"] | ||
|
||
DEFAULT_DOTTXT_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" | ||
|
||
|
||
class DottxtTypeAdapter(ModelTypeAdapter): | ||
@singledispatchmethod | ||
def format_input(self, model_input): | ||
"""Generate the `messages` argument to pass to the client. | ||
Argument | ||
-------- | ||
model_input | ||
The input passed by the user. | ||
""" | ||
raise NotImplementedError( | ||
f"The input type {input} is not available with Dottxt. The only available type is `str`." | ||
) | ||
|
||
@format_input.register(str) | ||
def format_str_input(self, model_input: str): | ||
"""Generate the `messages` argument to pass to the client when the user | ||
only passes a prompt. | ||
""" | ||
return model_input | ||
|
||
@singledispatchmethod | ||
def format_output_type(self, output_type): | ||
"""Format the output type to pass to the client.""" | ||
raise NotImplementedError( | ||
f"The input type {input} is not available with Dottxt." | ||
) | ||
|
||
@format_output_type.register(Json) | ||
def format_json_output_type(self, output_type: Json): | ||
"""Format the output type to pass to the client.""" | ||
schema = output_type.to_json_schema() | ||
return json.dumps(schema) | ||
|
||
@format_output_type.register(NoneType) | ||
def format_none_output_type(self, output_type: None): | ||
"""Format the output type to pass to the client.""" | ||
raise NotImplementedError( | ||
"You must provide an output type. Dottxt only supports constrained generation." | ||
) | ||
|
||
|
||
class Dottxt(Model): | ||
"""Thin wrapper around the `dottxt.client.Dottxt` client. | ||
This wrapper is used to convert the input and output types specified by the | ||
users at a higher level to arguments to the `dottxt.client.Dottxt` client. | ||
""" | ||
|
||
def __init__(self, model_name: str = DEFAULT_DOTTXT_MODEL_NAME, *args, **kwargs): | ||
from dottxt.client import Dottxt | ||
|
||
self.client = Dottxt(*args, **kwargs) | ||
self.model_name = model_name | ||
self.type_adapter = DottxtTypeAdapter() | ||
|
||
def compile_output_type(self, output_type) -> str: | ||
"""Call the Dottxt sdk to create a schema for the output type provided. | ||
The funtion first checks whether a schema already exists for the output. | ||
Otherwise, it creates a new schema. In both cases, it returns the js_id | ||
of the schema to be provided when calling the `create_completion` method. | ||
""" | ||
json_schema = self.type_adapter.format_output_type(output_type) | ||
schema_status_response = self.client.get_schema_status_by_source(json_schema) | ||
# the sdk can return None (no schema found), a single schema object | ||
# or a list of schema objects (if the schema has been compiled for multiple models) | ||
if isinstance(schema_status_response, list): | ||
schema_status = next( | ||
( | ||
element | ||
for element in schema_status_response | ||
if element.model_name == self.model_name | ||
), | ||
None, | ||
) | ||
else: | ||
schema_status = schema_status_response | ||
if schema_status is None: | ||
schema_status = self.client.create_schema(json_schema, self.model_name) | ||
if schema_status.status == "in_progress": | ||
schema_status = self.client.poll_schema_status(schema_status.js_id) | ||
if schema_status.status != "complete": | ||
raise ValueError(f"Failed to create the schema: {schema_status}") | ||
return schema_status.js_id | ||
|
||
def generate(self, model_input, output_type=None, **inference_kwargs): | ||
"""Call the Dottxt API to generate text. | ||
Here the `output_type` argument corresponds to the `js_id` | ||
""" | ||
prompt = self.type_adapter.format_input(model_input) | ||
|
||
completion = self.client.create_completion( | ||
prompt, | ||
output_type, | ||
**inference_kwargs, | ||
) | ||
return completion.data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import json | ||
import os | ||
|
||
import pytest | ||
from pydantic import BaseModel | ||
|
||
from outlines.generate import Generator | ||
from outlines.models.dottxt import Dottxt | ||
from outlines.types import Json | ||
|
||
|
||
class User(BaseModel): | ||
name: str | ||
|
||
|
||
@pytest.fixture | ||
def api_key(): | ||
"""Get the Dottxt API key from the environment, providing a default value if not found. | ||
This fixture should be used for tests that do not make actual api calls, | ||
but still require to initialize the Dottxt client. | ||
""" | ||
api_key = os.getenv("DOTTXT_API_KEY") | ||
if not api_key: | ||
return "MOCK_API_KEY" | ||
return api_key | ||
|
||
|
||
def test_dottxt_wrong_init_parameters(api_key): | ||
with pytest.raises(TypeError, match="got an unexpected"): | ||
Dottxt(api_key=api_key, foo=10) | ||
|
||
|
||
def test_dottxt_wrong_inference_parameters(api_key): | ||
with pytest.raises(TypeError, match="got an unexpected"): | ||
model = Dottxt(api_key=api_key) | ||
model("prompt", Json(User), foo=10) | ||
|
||
|
||
def test_dottxt_wrong_input_type(api_key): | ||
with pytest.raises(NotImplementedError, match="is not available"): | ||
model = Dottxt(api_key=api_key) | ||
model(["prompt"], Json(User)) | ||
|
||
|
||
def test_dottxt_wrong_output_type(api_key): | ||
with pytest.raises(NotImplementedError, match="must provide an output type"): | ||
model = Dottxt(api_key=api_key) | ||
model("prompt") | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_dottxt_output_type_compilation(api_key): | ||
model = Dottxt(api_key=api_key) | ||
js_id = model.compile_output_type(Json(User)) | ||
assert isinstance(js_id, str) | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_dottxt_direct_call(api_key): | ||
model = Dottxt(api_key=api_key, model_name="meta-llama/Llama-3.1-8B-Instruct") | ||
result = model("Create a user", Json(User)) | ||
assert "name" in json.loads(result) | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_dottxt_generator_call(api_key): | ||
model = Dottxt(api_key=api_key, model_name="mistralai/Mistral-7B-Instruct-v0.2") | ||
generator = Generator(model, Json(User)) | ||
result = generator("Create a user") | ||
assert "name" in json.loads(result) |