Skip to content

Commit

Permalink
Create base classes for the models and model type adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin Picard authored and rlouf committed Feb 13, 2025
1 parent 5c5c18a commit 0882bcc
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 8 deletions.
10 changes: 5 additions & 5 deletions outlines/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union, cast, get_args

from outlines.models import APIModel, LocalModel
from outlines.models.base import Model
from outlines.types import Choice, Json, List, Regex

from .api import SequenceGenerator
Expand Down Expand Up @@ -30,7 +30,7 @@ class APIGenerator:
"""

model: APIModel
model: Model
output_type: Optional[Union[Json, List, Choice, Regex]] = None

def __call__(self, prompt, **inference_kwargs):
Expand All @@ -53,7 +53,7 @@ class LocalGenerator:
"""

model: LocalModel
model: Model
output_type: Optional[Union[Json, List, Choice, Regex]]

def __post_init__(self):
Expand All @@ -70,10 +70,10 @@ def __call__(self, prompt, **inference_kwargs):


def Generator(
model: Union[LocalModel, APIModel],
model: Model,
output_type: Optional[Union[Json, List, Choice, Regex]] = None,
):
if isinstance(model, APIModel): # type: ignore
if model.model_type == "api": # type: ignore
return APIGenerator(model, output_type) # type: ignore
else:
return LocalGenerator(model, output_type) # type: ignore
4 changes: 1 addition & 3 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Union

from .anthropic import Anthropic
from .base import Model, ModelTypeAdapter
from .exllamav2 import ExLlamaV2Model, exl2
from .gemini import Gemini
from .llamacpp import LlamaCpp
Expand All @@ -19,6 +20,3 @@
from .vllm import VLLM, vllm

LogitsGenerator = Union[Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM]

LocalModel = LlamaCpp
APIModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini]
84 changes: 84 additions & 0 deletions outlines/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from abc import ABC, abstractmethod
from typing import Literal, Union


class ModelTypeAdapter(ABC):
"""Base class for all model type adapters.
A type adapter instance must be given as a value to the `type_adapter`
attribute when instantiating a model.
The type adapter is responsible for formatting the input and output types
passed to the model to match the specific format expected by the
associated model.
"""

@abstractmethod
def format_input(self, model_input):
"""Format the user input to the expected format of the model.
For API-based models, it typically means creating the `messages`
argument passed to the client. For local models, it can mean casting
the input from str to list for instance.
This method is also used to validate that the input type provided by
the user is supported by the model.
"""
...

@abstractmethod
def format_output_type(self, output_type):
"""Format the output type to the expected format of the model.
For API-based models, this typically means creating a `response_format`
argument. For local models, it means formatting the logits processor to
create the object type expected by the model.
"""
...


class Model(ABC):
"""Base class for all models.
This class defines a shared `__call__` method that can be used to call the
model directly.
All models inheriting from this class must define a `type_adapter`
attribute of type `ModelTypeAdapter`. The methods of the `type_adapter`
attribute are used in the `generate` method to format the input and output
types received by the model.
"""

model_type: Union[Literal["api"], Literal["local"]]
type_adapter: ModelTypeAdapter

def __call__(self, model_input, output_type=None, **inference_kwargs):
"""Call the model.
Users can call the model directly, in which case we will create a
generator instance with the output type provided and call it.
Thus, those commands are equivalent:
```python
generator = Generator(model, Foo)
generator("prompt")
```
and
```python
model("prompt", Foo)
```
"""
from outlines.generate import Generator

return Generator(self, output_type)(model_input, **inference_kwargs)

@abstractmethod
def generate(self, model_input, output_type=None, **inference_kwargs):
"""Generate a response from the model.
The output_type argument contains a logits processor for local models
while it contains a type (Json, Enum...) for the API-based models.
"""
...

0 comments on commit 0882bcc

Please sign in to comment.