From 4b38b60b52e9158bd471905258cca6bec90b04c4 Mon Sep 17 00:00:00 2001 From: Robin Picard Date: Sat, 25 Jan 2025 14:30:41 +0100 Subject: [PATCH] Proposition with a distinct ModelFormatter class --- outlines/models/base.py | 16 +++++++++++++++- outlines/models/openai.py | 13 +++++++------ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/outlines/models/base.py b/outlines/models/base.py index e0e18bdd3..a18915e73 100644 --- a/outlines/models/base.py +++ b/outlines/models/base.py @@ -1,9 +1,23 @@ -from abc import ABC +from abc import ABC, abstractmethod + + +class ModelFormatter(ABC): + """Base class for all model formatters.""" + + @abstractmethod + def format_input(self, model_input): + ... + + @abstractmethod + def format_output_type(self, output_type): + ... class Model(ABC): """Base class for all models.""" + formatter: ModelFormatter + def __call__(self, model_input, output_type=None, **inference_kwargs): from outlines.generate import Generator diff --git a/outlines/models/openai.py b/outlines/models/openai.py index d1ea3dc40..2955235a4 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -1,18 +1,18 @@ """Integration with OpenAI's API.""" from functools import singledispatchmethod -from types import NoneType from typing import Optional, Union from pydantic import BaseModel from outlines.models import Model +from outlines.models.base import ModelFormatter from outlines.prompts import Vision from outlines.types import Json __all__ = ["OpenAI"] -class OpenAIBase: +class OpenAIFormatter(ModelFormatter): """Base class for the OpenAI clients. `OpenAI` base is responsible for preparing the arguments to OpenAI's @@ -83,7 +83,7 @@ def format_output_type(self, output_type): f"The type {output_type} is not available with OpenAI. The only output type available is `Json`." ) - @format_output_type.register(NoneType) + @format_output_type.register(type(None)) def format_none_output_type(self, _: None): """Generate the `response_format` argument to the client when no output type is specified by the user. @@ -115,7 +115,7 @@ def format_json_output_type(self, output_type: Json): } -class OpenAI(Model, OpenAIBase): +class OpenAI(Model): """Thin wrapper around the `openai.OpenAI` client. This wrapper is used to convert the input and output types specified by the @@ -128,6 +128,7 @@ def __init__(self, model_name: str, *args, **kwargs): self.client = OpenAI(*args, **kwargs) self.model_name = model_name + self.formatter = OpenAIFormatter() def generate( self, @@ -135,8 +136,8 @@ def generate( output_type: Optional[Union[type[BaseModel], str]] = None, **inference_kwargs, ): - messages = self.format_input(model_input) - response_format = self.format_output_type(output_type) + messages = self.formatter.format_input(model_input) + response_format = self.formatter.format_output_type(output_type) result = self.client.chat.completions.create( model=self.model_name, **messages, **response_format, **inference_kwargs )