From 2b65abd267608b754f471863fd33ab7746b46481 Mon Sep 17 00:00:00 2001 From: Robin Picard Date: Tue, 11 Feb 2025 21:02:53 +0100 Subject: [PATCH] Implement base model for Anthropic, Gemini, Openai and Llamaccp --- outlines/models/anthropic.py | 25 ++++++++++-- outlines/models/gemini.py | 17 ++++---- outlines/models/llamacpp.py | 74 ++++++++++++++++++++++++---------- outlines/models/openai.py | 23 ++++++----- tests/models/test_anthropic.py | 7 ++++ tests/models/test_gemini.py | 7 ++++ tests/models/test_llamacpp.py | 5 +++ tests/models/test_openai.py | 7 ++++ 8 files changed, 125 insertions(+), 40 deletions(-) diff --git a/outlines/models/anthropic.py b/outlines/models/anthropic.py index e55b2e3e2..830c0cf77 100644 --- a/outlines/models/anthropic.py +++ b/outlines/models/anthropic.py @@ -2,12 +2,23 @@ from functools import singledispatchmethod from typing import Union +from outlines.models.base import Model, ModelTypeAdapter from outlines.prompts import Vision __all__ = ["Anthropic"] -class AnthropicBase: +class AnthropicTypeAdapter(ModelTypeAdapter): + """Type adapter for the Anthropic clients. + + `AnthropicTypeAdapter` is responsible for preparing the arguments to + Anthropic's `messages.create` methods: the input (prompt and possibly + image). + Anthropic does not support defining the output type, so + `format_output_type` is not implemented. + + """ + @singledispatchmethod def format_input(self, model_input): """Generate the `messages` argument to pass to the client. @@ -62,18 +73,26 @@ def format_vision_input(self, model_input: Vision): ] } + def format_output_type(self, output_type): + """Not implemented for Anthropic.""" + raise NotImplementedError( + f"The output type {output_type} is not available with Anthropic." + ) + -class Anthropic(AnthropicBase): +class Anthropic(Model): def __init__(self, model_name: str, *args, **kwargs): from anthropic import Anthropic self.client = Anthropic(*args, **kwargs) self.model_name = model_name + self.model_type = "api" + self.type_adapter = AnthropicTypeAdapter() def generate( self, model_input: Union[str, Vision], output_type=None, **inference_kwargs ): - messages = self.format_input(model_input) + messages = self.type_adapter.format_input(model_input) if output_type is not None: raise NotImplementedError( diff --git a/outlines/models/gemini.py b/outlines/models/gemini.py index 1d6bbfd1e..5ab6a515c 100644 --- a/outlines/models/gemini.py +++ b/outlines/models/gemini.py @@ -7,17 +7,18 @@ from pydantic import BaseModel from typing_extensions import _TypedDictMeta # type: ignore +from outlines.models.base import Model, ModelTypeAdapter from outlines.prompts import Vision from outlines.types import Choice, Json, List __all__ = ["Gemini"] -class GeminiBase: - """Base class for the Gemini clients. +class GeminiTypeAdapter(ModelTypeAdapter): + """Type adapter for the Gemini clients. - `GeminiBase` is responsible for preparing the arguments to Gemini's - `generate_contents` methods: the input (prompt and possibly image), as well + `GeminiTypeAdapter` is responsible for preparing the arguments to Gemini's + `generate_content` methods: the input (prompt and possibly image), as well as the output type (only JSON). """ @@ -91,11 +92,13 @@ def format_enum_output_type(self, output_type): } -class Gemini(GeminiBase): +class Gemini(Model): def __init__(self, model_name: str, *args, **kwargs): import google.generativeai as genai self.client = genai.GenerativeModel(model_name, *args, **kwargs) + self.model_type = "api" + self.type_adapter = GeminiTypeAdapter() def generate( self, @@ -105,9 +108,9 @@ def generate( ): import google.generativeai as genai - contents = self.format_input(model_input) + contents = self.type_adapter.format_input(model_input) generation_config = genai.GenerationConfig( - **self.format_output_type(output_type) + **self.type_adapter.format_output_type(output_type) ) completion = self.client.generate_content( generation_config=generation_config, **contents, **inference_kwargs diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 0e597c5a7..6d8cd1ca5 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -1,7 +1,9 @@ import pickle import warnings +from functools import singledispatchmethod from typing import TYPE_CHECKING, Dict, Iterator, List, Set, Tuple, Union +from outlines.models.base import Model, ModelTypeAdapter from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: @@ -93,7 +95,49 @@ def __setstate__(self, state): raise NotImplementedError("Cannot load a pickled llamacpp tokenizer") -class LlamaCpp: +class LlamaCppTypeAdapter(ModelTypeAdapter): + """Type adapter for the `llama-cpp-python` library. + + `LlamaCppTypeAdapter` is responsible for preparing the arguments to + `llama-cpp-python`'s `Llama.__call__` method: the input (a string prompt), + as well as the logits processor (an instance of `LogitsProcessorList`). + + """ + + @singledispatchmethod + def format_input(self, model_input): + """Generate the prompt argument to pass to the model. + + Argument + -------- + model_input + The input passed by the user. + + """ + raise NotImplementedError( + f"The input type {input} is not available. " + "The `llama-cpp-python` library does not support batch inference. " + ) + + @format_input.register(str) + def format_str_input(self, model_input: str): + return model_input + + def format_output_type(self, output_type): + """Generate the logits processor argument to pass to the model. + + Argument + -------- + output_type + The logits processor provided. + + """ + from llama_cpp import LogitsProcessorList + + return LogitsProcessorList([output_type]) + + +class LlamaCpp(Model): """Wraps a model provided by the `llama-cpp-python` library.""" def __init__(self, model_path: Union[str, "Llama"], **kwargs): @@ -115,6 +159,8 @@ def __init__(self, model_path: Union[str, "Llama"], **kwargs): self.model = Llama(model_path, **kwargs) self.tokenizer = LlamaCppTokenizer(self.model) + self.model_type = "local" + self.type_adapter = LlamaCppTypeAdapter() @classmethod def from_pretrained(cls, repo_id, filename, **kwargs): @@ -134,7 +180,7 @@ def from_pretrained(cls, repo_id, filename, **kwargs): model = Llama.from_pretrained(repo_id, filename, **kwargs) return cls(model) - def generate(self, prompt: str, logits_processor, **inference_kwargs) -> str: + def generate(self, model_input, logits_processor, **inference_kwargs): """Generate text using `llama-cpp-python`. Arguments @@ -152,16 +198,9 @@ def generate(self, prompt: str, logits_processor, **inference_kwargs) -> str: The generated text. """ - from llama_cpp import LogitsProcessorList - - if not isinstance(prompt, str): - raise NotImplementedError( - "The `llama-cpp-python` library does not support batch inference." - ) - completion = self.model( - prompt, - logits_processor=LogitsProcessorList([logits_processor]), + self.type_adapter.format_input(model_input), + logits_processor=self.type_adapter.format_output_type(logits_processor), **inference_kwargs, ) result = completion["choices"][0]["text"] @@ -171,7 +210,7 @@ def generate(self, prompt: str, logits_processor, **inference_kwargs) -> str: return result def stream( - self, prompt: str, logits_processor, **inference_kwargs + self, model_input, logits_processor, **inference_kwargs ) -> Iterator[str]: """Stream text using `llama-cpp-python`. @@ -190,16 +229,9 @@ def stream( A generator that return strings. """ - from llama_cpp import LogitsProcessorList - - if not isinstance(prompt, str): - raise NotImplementedError( - "The `llama-cpp-python` library does not support batch inference." - ) - generator = self.model( - prompt, - logits_processor=LogitsProcessorList([logits_processor]), + self.type_adapter.format_input(model_input), + logits_processor=self.type_adapter.format_output_type(logits_processor), stream=True, **inference_kwargs, ) diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 0fea288b9..f082dbc4c 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -5,18 +5,19 @@ from pydantic import BaseModel +from outlines.models.base import Model, ModelTypeAdapter from outlines.prompts import Vision from outlines.types import Json __all__ = ["OpenAI"] -class OpenAIBase: - """Base class for the OpenAI clients. +class OpenAITypeAdapter(ModelTypeAdapter): + """Type adapter for the OpenAI clients. - `OpenAI` base is responsible for preparing the arguments to OpenAI's - `completions.create` methods: the input (prompt and possibly image), as well - as the output type (only JSON). + `OpenAITypeAdapter` is responsible for preparing the arguments to OpenAI's + `completions.create` methods: the input (prompt and possibly image), as + well as the output type (only JSON). """ @@ -64,7 +65,7 @@ def format_vision_input(self, model_input: Vision): { "type": "image_url", "image_url": { - "url": f"data:{model_input.image_format};base64,{model_input.image_str}" + "url": f"data:{model_input.image_format};base64,{model_input.image_str}" # noqa: E702 }, }, ], @@ -114,7 +115,7 @@ def format_json_output_type(self, output_type: Json): } -class OpenAI(OpenAIBase): +class OpenAI(Model): """Thin wrapper around the `openai.OpenAI` client. This wrapper is used to convert the input and output types specified by the @@ -127,6 +128,8 @@ def __init__(self, model_name: str, *args, **kwargs): self.client = OpenAI(*args, **kwargs) self.model_name = model_name + self.model_type = "api" + self.type_adapter = OpenAITypeAdapter() def generate( self, @@ -134,8 +137,8 @@ def generate( output_type: Optional[Union[type[BaseModel], str]] = None, **inference_kwargs, ): - messages = self.format_input(model_input) - response_format = self.format_output_type(output_type) + messages = self.type_adapter.format_input(model_input) + response_format = self.type_adapter.format_output_type(output_type) result = self.client.chat.completions.create( model=self.model_name, **messages, **response_format, **inference_kwargs ) @@ -156,3 +159,5 @@ def __init__(self, deployment_name: str, *args, **kwargs): self.client = AzureOpenAI(*args, **kwargs) self.model_name = deployment_name + self.model_type = "api" + self.type_adapter = OpenAITypeAdapter() diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 44a176084..1d5b5c533 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -48,6 +48,13 @@ def test_anthropic_simple_call(): assert isinstance(result, str) +@pytest.mark.api_call +def test_anthropic_direct_call(): + model = Anthropic(MODEL_NAME) + result = model("Respond with one word. Not more.") + assert isinstance(result, str) + + @pytest.mark.api_call def test_anthropic_simple_vision(): model = Anthropic(MODEL_NAME) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 4718b12d0..fb7a79def 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -33,6 +33,13 @@ def test_gemini_simple_call(): assert isinstance(result, str) +@pytest.mark.api_call +def test_gemini_direct_call(): + model = Gemini(MODEL_NAME) + result = model("Respond with one word. Not more.") + assert isinstance(result, str) + + @pytest.mark.api_call def test_gemini_simple_vision(): model = Gemini(MODEL_NAME) diff --git a/tests/models/test_llamacpp.py b/tests/models/test_llamacpp.py index ffb0110cd..e9af3cd2a 100644 --- a/tests/models/test_llamacpp.py +++ b/tests/models/test_llamacpp.py @@ -31,6 +31,11 @@ def test_llamacpp_simple(model): assert isinstance(result, str) +def test_llamacpp_call(model): + result = model("Respond with one word. Not more.") + assert isinstance(result, str) + + def test_llamacpp_regex(model): regex_str = Regex(r"[0-9]").to_regex() logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index d2b8d1a40..d7d074573 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -66,6 +66,13 @@ def test_openai_simple_call(): assert isinstance(result, str) +@pytest.mark.api_call +def test_openai_direct_call(): + model = OpenAI(MODEL_NAME) + result = model("Respond with one word. Not more.") + assert isinstance(result, str) + + @pytest.mark.api_call def test_openai_simple_vision(): model = OpenAI(MODEL_NAME)