Skip to content

Commit

Permalink
Initialize Dottxt from client instance and model name
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 24, 2025
1 parent 7195218 commit 8f539f6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
5 changes: 4 additions & 1 deletion outlines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from outlines.templates import Template, prompt

from outlines.models import (
from_dottxt,
from_openai,
from_transformers,
from_gemini,
Expand All @@ -26,12 +27,14 @@

model_list = [
"from_anthropic",
"from_dottxt",
"from_gemini",
"from_llamacpp",
"from_mlxlm",
"from_ollama",
"from_openai",
"from_transformersfrom_vllm",
"from_transformers",
"from_vllm",
]

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .anthropic import from_anthropic, Anthropic
from .base import Model, ModelTypeAdapter
from .dottxt import Dottxt
from .dottxt import Dottxt, from_dottxt
from .exllamav2 import ExLlamaV2Model, exl2
from .gemini import from_gemini, Gemini
from .llamacpp import LlamaCpp, from_llamacpp
Expand Down
25 changes: 20 additions & 5 deletions outlines/models/dottxt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import json
from functools import singledispatchmethod
from types import NoneType
from typing import Optional
from typing import Optional, TYPE_CHECKING

from outlines.models.base import Model, ModelTypeAdapter
from outlines.types import JsonType

if TYPE_CHECKING:
from dottxt import Dottxt as DottxtClient

__all__ = ["Dottxt"]


Expand Down Expand Up @@ -63,11 +66,14 @@ class Dottxt(Model):
"""

def __init__(self, model_name: Optional[str] = None, *args, **kwargs):
from dottxt.client import Dottxt

self.client = Dottxt(*args, **kwargs)
def __init__(
self,
client: "Dottxt",
model_name: Optional[str] = None,
model_revision: Optional[str] = None,
):
self.model_name = model_name
self.model_revision = model_revision
self.type_adapter = DottxtTypeAdapter()

def generate(self, model_input, output_type=None, **inference_kwargs):
Expand All @@ -76,10 +82,19 @@ def generate(self, model_input, output_type=None, **inference_kwargs):

if self.model_name:
inference_kwargs["model_name"] = self.model_name
inference_kwargs["model_revision"] = self.model_revision

completion = self.client.json(
prompt,
json_schema,
**inference_kwargs,
)
return completion.data


def from_dottxt(
client: "DottxtClient",
model_name: Optional[str] = None,
model_revision: Optional[str] = None,
):
return Dottxt(client, model_name, model_revision)

0 comments on commit 8f539f6

Please sign in to comment.