From 685aa012edcedaebe61f5960fdfe382f56d5ece3 Mon Sep 17 00:00:00 2001 From: Robin Picard Date: Mon, 20 Jan 2025 01:15:49 +0100 Subject: [PATCH] Add a __call__ method to models --- outlines/models/__init__.py | 1 + outlines/models/anthropic.py | 3 ++- outlines/models/base.py | 10 ++++++++++ outlines/models/gemini.py | 3 ++- outlines/models/llamacpp.py | 3 ++- outlines/models/openai.py | 3 ++- 6 files changed, 19 insertions(+), 4 deletions(-) create mode 100644 outlines/models/base.py diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index 620108bb9..6339a2918 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -9,6 +9,7 @@ from typing import Union from .anthropic import Anthropic +from .base import Model from .exllamav2 import ExLlamaV2Model, exl2 from .gemini import Gemini from .llamacpp import LlamaCpp diff --git a/outlines/models/anthropic.py b/outlines/models/anthropic.py index e55b2e3e2..a29ffe765 100644 --- a/outlines/models/anthropic.py +++ b/outlines/models/anthropic.py @@ -2,6 +2,7 @@ from functools import singledispatchmethod from typing import Union +from outlines.models import Model from outlines.prompts import Vision __all__ = ["Anthropic"] @@ -63,7 +64,7 @@ def format_vision_input(self, model_input: Vision): } -class Anthropic(AnthropicBase): +class Anthropic(Model, AnthropicBase): def __init__(self, model_name: str, *args, **kwargs): from anthropic import Anthropic diff --git a/outlines/models/base.py b/outlines/models/base.py new file mode 100644 index 000000000..e0e18bdd3 --- /dev/null +++ b/outlines/models/base.py @@ -0,0 +1,10 @@ +from abc import ABC + + +class Model(ABC): + """Base class for all models.""" + + def __call__(self, model_input, output_type=None, **inference_kwargs): + from outlines.generate import Generator + + return Generator(self, output_type)(model_input, **inference_kwargs) diff --git a/outlines/models/gemini.py b/outlines/models/gemini.py index 1d6bbfd1e..ea6ae9d61 100644 --- a/outlines/models/gemini.py +++ b/outlines/models/gemini.py @@ -7,6 +7,7 @@ from pydantic import BaseModel from typing_extensions import _TypedDictMeta # type: ignore +from outlines.models import Model from outlines.prompts import Vision from outlines.types import Choice, Json, List @@ -91,7 +92,7 @@ def format_enum_output_type(self, output_type): } -class Gemini(GeminiBase): +class Gemini(Model, GeminiBase): def __init__(self, model_name: str, *args, **kwargs): import google.generativeai as genai diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 0e597c5a7..2d4bb659d 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -2,6 +2,7 @@ import warnings from typing import TYPE_CHECKING, Dict, Iterator, List, Set, Tuple, Union +from outlines.models import Model from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: @@ -93,7 +94,7 @@ def __setstate__(self, state): raise NotImplementedError("Cannot load a pickled llamacpp tokenizer") -class LlamaCpp: +class LlamaCpp(Model): """Wraps a model provided by the `llama-cpp-python` library.""" def __init__(self, model_path: Union[str, "Llama"], **kwargs): diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 0fea288b9..d1ea3dc40 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -5,6 +5,7 @@ from pydantic import BaseModel +from outlines.models import Model from outlines.prompts import Vision from outlines.types import Json @@ -114,7 +115,7 @@ def format_json_output_type(self, output_type: Json): } -class OpenAI(OpenAIBase): +class OpenAI(Model, OpenAIBase): """Thin wrapper around the `openai.OpenAI` client. This wrapper is used to convert the input and output types specified by the