diff --git a/outlines/__init__.py b/outlines/__init__.py index 581f76b79..4aacb6c7e 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -13,6 +13,7 @@ from outlines.templates import Template, prompt from outlines.models import ( + from_dottxt, from_openai, from_transformers, from_gemini, @@ -26,12 +27,14 @@ model_list = [ "from_anthropic", + "from_dottxt", "from_gemini", "from_llamacpp", "from_mlxlm", "from_ollama", "from_openai", - "from_transformersfrom_vllm", + "from_transformers", + "from_vllm", ] __all__ = [ diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index a0b4a85d3..ee50356b0 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -10,7 +10,7 @@ from .anthropic import from_anthropic, Anthropic from .base import Model, ModelTypeAdapter -from .dottxt import Dottxt +from .dottxt import Dottxt, from_dottxt from .exllamav2 import ExLlamaV2Model, exl2 from .gemini import from_gemini, Gemini from .llamacpp import LlamaCpp, from_llamacpp diff --git a/outlines/models/dottxt.py b/outlines/models/dottxt.py index 4ba2f84a2..7132402ec 100644 --- a/outlines/models/dottxt.py +++ b/outlines/models/dottxt.py @@ -3,11 +3,14 @@ import json from functools import singledispatchmethod from types import NoneType -from typing import Optional +from typing import Optional, TYPE_CHECKING from outlines.models.base import Model, ModelTypeAdapter from outlines.types import JsonType +if TYPE_CHECKING: + from dottxt import Dottxt as DottxtClient + __all__ = ["Dottxt"] @@ -63,11 +66,14 @@ class Dottxt(Model): """ - def __init__(self, model_name: Optional[str] = None, *args, **kwargs): - from dottxt.client import Dottxt - - self.client = Dottxt(*args, **kwargs) + def __init__( + self, + client: "Dottxt", + model_name: Optional[str] = None, + model_revision: Optional[str] = None, + ): self.model_name = model_name + self.model_revision = model_revision self.type_adapter = DottxtTypeAdapter() def generate(self, model_input, output_type=None, **inference_kwargs): @@ -76,6 +82,7 @@ def generate(self, model_input, output_type=None, **inference_kwargs): if self.model_name: inference_kwargs["model_name"] = self.model_name + inference_kwargs["model_revision"] = self.model_revision completion = self.client.json( prompt, @@ -83,3 +90,11 @@ def generate(self, model_input, output_type=None, **inference_kwargs): **inference_kwargs, ) return completion.data + + +def from_dottxt( + client: "DottxtClient", + model_name: Optional[str] = None, + model_revision: Optional[str] = None, +): + return Dottxt(client, model_name, model_revision)