forked from dottxt-ai/outlines
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Robin Picard
committed
Feb 21, 2025
1 parent
b56acd1
commit 026cb0c
Showing
4 changed files
with
205 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Dottxt | ||
|
||
!!! Installation | ||
|
||
To be able to use Dottxt in Outlines, you must install the `dottxt` python sdk. | ||
|
||
```bash | ||
pip install dottxt | ||
``` | ||
|
||
You also need to have a Dottxt API key. This API key must either be set as an environment variable called `DOTTXT_API_KEY` or be provided to the `outlines.models.Dottxt` class when instantiating it. | ||
|
||
## Generate text | ||
|
||
Dottxt only supports constrained generation with the `Json` output type. The input of the generation must be a string. Batch generation is not supported. | ||
Thus, you must always provide an output type. | ||
|
||
You can either create a `Generator` object and call it afterward: | ||
```python | ||
from outlines.models import Dottxt | ||
from outlines.generate import Generator | ||
from pydantic import BaseModel | ||
|
||
class Character(BaseModel): | ||
name: str | ||
|
||
model = Dottxt() | ||
generator = Generator(model, Character) | ||
result = generator("Create a character") | ||
``` | ||
|
||
or call the model directly with the output type: | ||
```python | ||
from outlines.models import Dottxt | ||
from pydantic import BaseModel | ||
|
||
class Character(BaseModel): | ||
name: str | ||
|
||
model = Dottxt() | ||
result = model("Create a character", Character) | ||
``` | ||
|
||
In any case, compilation for a given output type happens only once (the first time it is used to generate text). | ||
|
||
## Optional parameters | ||
|
||
You can provide the same optional parameters you would pass to the `dottxt` sdk's client both during the initialization of the `Dottxt` class and when generating text. | ||
Consult the [dottxt python sdk Github repository](https://github.com/dottxt-ai/dottxt-python) for the full list of parameters. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""Integration with Dottxt's API.""" | ||
import json | ||
from functools import singledispatchmethod | ||
from types import NoneType | ||
|
||
from outlines.models.base import Model, ModelTypeAdapter | ||
from outlines.types import Json | ||
|
||
__all__ = ["Dottxt"] | ||
|
||
|
||
class DottxtTypeAdapter(ModelTypeAdapter): | ||
@singledispatchmethod | ||
def format_input(self, model_input): | ||
"""Generate the `messages` argument to pass to the client. | ||
Argument | ||
-------- | ||
model_input | ||
The input passed by the user. | ||
""" | ||
raise NotImplementedError( | ||
f"The input type {input} is not available with Dottxt. The only available type is `str`." | ||
) | ||
|
||
@format_input.register(str) | ||
def format_str_input(self, model_input: str): | ||
"""Generate the `messages` argument to pass to the client when the user | ||
only passes a prompt. | ||
""" | ||
return model_input | ||
|
||
@singledispatchmethod | ||
def format_output_type(self, output_type): | ||
"""Format the output type to pass to the client.""" | ||
raise NotImplementedError( | ||
f"The input type {input} is not available with Dottxt." | ||
) | ||
|
||
@format_output_type.register(Json) | ||
def format_json_output_type(self, output_type: Json): | ||
"""Format the output type to pass to the client.""" | ||
schema = output_type.to_json_schema() | ||
return json.dumps(schema) | ||
|
||
@format_output_type.register(NoneType) | ||
def format_none_output_type(self, output_type: None): | ||
"""Format the output type to pass to the client.""" | ||
raise NotImplementedError( | ||
"You must provide an output type. Dottxt only supports constrained generation." | ||
) | ||
|
||
|
||
class Dottxt(Model): | ||
"""Thin wrapper around the `dottxt.client.Dottxt` client. | ||
This wrapper is used to convert the input and output types specified by the | ||
users at a higher level to arguments to the `dottxt.client.Dottxt` client. | ||
""" | ||
|
||
def __init__(self, model_name: str = None, *args, **kwargs): | ||
from dottxt.client import Dottxt | ||
|
||
self.client = Dottxt(*args, **kwargs) | ||
self.model_name = model_name | ||
self.type_adapter = DottxtTypeAdapter() | ||
|
||
def generate(self, model_input, output_type=None, **inference_kwargs): | ||
prompt = self.type_adapter.format_input(model_input) | ||
json_schema = self.type_adapter.format_output_type(output_type) | ||
|
||
if self.model_name: | ||
inference_kwargs["model_name"] = self.model_name | ||
|
||
print("HEREEE", prompt, json_schema, inference_kwargs) | ||
|
||
completion = self.client.json( | ||
prompt, | ||
json_schema, | ||
**inference_kwargs, | ||
) | ||
return completion.data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import json | ||
import os | ||
|
||
import pytest | ||
from pydantic import BaseModel | ||
|
||
from outlines.generate import Generator | ||
from outlines.models.dottxt import Dottxt | ||
from outlines.types import Json | ||
|
||
|
||
class User(BaseModel): | ||
first_name: str | ||
last_name: str | ||
user_id: int | ||
|
||
|
||
@pytest.fixture | ||
def api_key(): | ||
"""Get the Dottxt API key from the environment, providing a default value if not found. | ||
This fixture should be used for tests that do not make actual api calls, | ||
but still require to initialize the Dottxt client. | ||
""" | ||
api_key = os.getenv("DOTTXT_API_KEY") | ||
if not api_key: | ||
return "MOCK_API_KEY" | ||
return api_key | ||
|
||
|
||
def test_dottxt_wrong_init_parameters(api_key): | ||
with pytest.raises(TypeError, match="got an unexpected"): | ||
Dottxt(api_key=api_key, foo=10) | ||
|
||
|
||
def test_dottxt_wrong_output_type(api_key): | ||
with pytest.raises(NotImplementedError, match="must provide an output type"): | ||
model = Dottxt(api_key=api_key) | ||
model("prompt") | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_dottxt_wrong_input_type(api_key): | ||
with pytest.raises(NotImplementedError, match="is not available"): | ||
model = Dottxt(api_key=api_key) | ||
model(["prompt"], Json(User)) | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_dottxt_wrong_inference_parameters(api_key): | ||
with pytest.raises(TypeError, match="got an unexpected"): | ||
model = Dottxt(api_key=api_key) | ||
model("prompt", Json(User), foo=10) | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_dottxt_direct_call(api_key): | ||
model = Dottxt(api_key=api_key, model_name="meta-llama/Llama-3.1-8B-Instruct") | ||
result = model("Create a user", Json(User)) | ||
assert "first_name" in json.loads(result) | ||
|
||
|
||
@pytest.mark.api_call | ||
def test_dottxt_generator_call(api_key): | ||
model = Dottxt(api_key=api_key, model_name="meta-llama/Llama-3.1-8B-Instruct") | ||
generator = Generator(model, Json(User)) | ||
result = generator("Create a user") | ||
assert "first_name" in json.loads(result) |