forked from dottxt-ai/outlines
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create base classes for the models and model type adapters
- Loading branch information
Showing
3 changed files
with
90 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Literal, Union | ||
|
||
|
||
class ModelTypeAdapter(ABC): | ||
"""Base class for all model type adapters. | ||
A type adapter instance must be given as a value to the `type_adapter` | ||
attribute when instantiating a model. | ||
The type adapter is responsible for formatting the input and output types | ||
passed to the model to match the specific format expected by the | ||
associated model. | ||
""" | ||
|
||
@abstractmethod | ||
def format_input(self, model_input): | ||
"""Format the user input to the expected format of the model. | ||
For API-based models, it typically means creating the `messages` | ||
argument passed to the client. For local models, it can mean casting | ||
the input from str to list for instance. | ||
This method is also used to validate that the input type provided by | ||
the user is supported by the model. | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def format_output_type(self, output_type): | ||
"""Format the output type to the expected format of the model. | ||
For API-based models, this typically means creating a `response_format` | ||
argument. For local models, it means formatting the logits processor to | ||
create the object type expected by the model. | ||
""" | ||
... | ||
|
||
|
||
class Model(ABC): | ||
"""Base class for all models. | ||
This class defines a shared `__call__` method that can be used to call the | ||
model directly. | ||
All models inheriting from this class must define a `type_adapter` | ||
attribute of type `ModelTypeAdapter`. The methods of the `type_adapter` | ||
attribute are used in the `generate` method to format the input and output | ||
types received by the model. | ||
""" | ||
|
||
model_type: Union[Literal["api"], Literal["local"]] | ||
type_adapter: ModelTypeAdapter | ||
|
||
def __call__(self, model_input, output_type=None, **inference_kwargs): | ||
"""Call the model. | ||
Users can call the model directly, in which case we will create a | ||
generator instance with the output type provided and call it. | ||
Thus, those commands are equivalent: | ||
```python | ||
generator = Generator(model, Foo) | ||
generator("prompt") | ||
``` | ||
and | ||
```python | ||
model("prompt", Foo) | ||
``` | ||
""" | ||
from outlines.generate import Generator | ||
|
||
return Generator(self, output_type)(model_input, **inference_kwargs) | ||
|
||
@abstractmethod | ||
def generate(self, model_input, output_type=None, **inference_kwargs): | ||
"""Generate a response from the model. | ||
The output_type argument contains a logits processor for local models | ||
while it contains a type (Json, Enum...) for the API-based models. | ||
""" | ||
... |