Skip to content

Commit

Permalink
Proposition with a distinct ModelFormatter class
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard committed Jan 25, 2025
1 parent 685aa01 commit 4b38b60
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
16 changes: 15 additions & 1 deletion outlines/models/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
from abc import ABC
from abc import ABC, abstractmethod


class ModelFormatter(ABC):
"""Base class for all model formatters."""

@abstractmethod
def format_input(self, model_input):
...

@abstractmethod
def format_output_type(self, output_type):
...


class Model(ABC):
"""Base class for all models."""

formatter: ModelFormatter

def __call__(self, model_input, output_type=None, **inference_kwargs):
from outlines.generate import Generator

Expand Down
13 changes: 7 additions & 6 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""Integration with OpenAI's API."""
from functools import singledispatchmethod
from types import NoneType
from typing import Optional, Union

from pydantic import BaseModel

from outlines.models import Model
from outlines.models.base import ModelFormatter
from outlines.prompts import Vision
from outlines.types import Json

__all__ = ["OpenAI"]


class OpenAIBase:
class OpenAIFormatter(ModelFormatter):
"""Base class for the OpenAI clients.
`OpenAI` base is responsible for preparing the arguments to OpenAI's
Expand Down Expand Up @@ -83,7 +83,7 @@ def format_output_type(self, output_type):
f"The type {output_type} is not available with OpenAI. The only output type available is `Json`."
)

@format_output_type.register(NoneType)
@format_output_type.register(type(None))
def format_none_output_type(self, _: None):
"""Generate the `response_format` argument to the client when no
output type is specified by the user.
Expand Down Expand Up @@ -115,7 +115,7 @@ def format_json_output_type(self, output_type: Json):
}


class OpenAI(Model, OpenAIBase):
class OpenAI(Model):
"""Thin wrapper around the `openai.OpenAI` client.
This wrapper is used to convert the input and output types specified by the
Expand All @@ -128,15 +128,16 @@ def __init__(self, model_name: str, *args, **kwargs):

self.client = OpenAI(*args, **kwargs)
self.model_name = model_name
self.formatter = OpenAIFormatter()

def generate(
self,
model_input: Union[str, Vision],
output_type: Optional[Union[type[BaseModel], str]] = None,
**inference_kwargs,
):
messages = self.format_input(model_input)
response_format = self.format_output_type(output_type)
messages = self.formatter.format_input(model_input)
response_format = self.formatter.format_output_type(output_type)
result = self.client.chat.completions.create(
model=self.model_name, **messages, **response_format, **inference_kwargs
)
Expand Down

0 comments on commit 4b38b60

Please sign in to comment.