Skip to content

Commit

Permalink
Implement base model for Anthropic, Gemini, Openai and Llamaccp
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin Picard authored and rlouf committed Feb 13, 2025
1 parent 0882bcc commit 2b65abd
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 40 deletions.
25 changes: 22 additions & 3 deletions outlines/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,23 @@
from functools import singledispatchmethod
from typing import Union

from outlines.models.base import Model, ModelTypeAdapter
from outlines.prompts import Vision

__all__ = ["Anthropic"]


class AnthropicBase:
class AnthropicTypeAdapter(ModelTypeAdapter):
"""Type adapter for the Anthropic clients.
`AnthropicTypeAdapter` is responsible for preparing the arguments to
Anthropic's `messages.create` methods: the input (prompt and possibly
image).
Anthropic does not support defining the output type, so
`format_output_type` is not implemented.
"""

@singledispatchmethod
def format_input(self, model_input):
"""Generate the `messages` argument to pass to the client.
Expand Down Expand Up @@ -62,18 +73,26 @@ def format_vision_input(self, model_input: Vision):
]
}

def format_output_type(self, output_type):
"""Not implemented for Anthropic."""
raise NotImplementedError(
f"The output type {output_type} is not available with Anthropic."
)


class Anthropic(AnthropicBase):
class Anthropic(Model):
def __init__(self, model_name: str, *args, **kwargs):
from anthropic import Anthropic

self.client = Anthropic(*args, **kwargs)
self.model_name = model_name
self.model_type = "api"
self.type_adapter = AnthropicTypeAdapter()

def generate(
self, model_input: Union[str, Vision], output_type=None, **inference_kwargs
):
messages = self.format_input(model_input)
messages = self.type_adapter.format_input(model_input)

if output_type is not None:
raise NotImplementedError(
Expand Down
17 changes: 10 additions & 7 deletions outlines/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
from pydantic import BaseModel
from typing_extensions import _TypedDictMeta # type: ignore

from outlines.models.base import Model, ModelTypeAdapter
from outlines.prompts import Vision
from outlines.types import Choice, Json, List

__all__ = ["Gemini"]


class GeminiBase:
"""Base class for the Gemini clients.
class GeminiTypeAdapter(ModelTypeAdapter):
"""Type adapter for the Gemini clients.
`GeminiBase` is responsible for preparing the arguments to Gemini's
`generate_contents` methods: the input (prompt and possibly image), as well
`GeminiTypeAdapter` is responsible for preparing the arguments to Gemini's
`generate_content` methods: the input (prompt and possibly image), as well
as the output type (only JSON).
"""
Expand Down Expand Up @@ -91,11 +92,13 @@ def format_enum_output_type(self, output_type):
}


class Gemini(GeminiBase):
class Gemini(Model):
def __init__(self, model_name: str, *args, **kwargs):
import google.generativeai as genai

self.client = genai.GenerativeModel(model_name, *args, **kwargs)
self.model_type = "api"
self.type_adapter = GeminiTypeAdapter()

def generate(
self,
Expand All @@ -105,9 +108,9 @@ def generate(
):
import google.generativeai as genai

contents = self.format_input(model_input)
contents = self.type_adapter.format_input(model_input)
generation_config = genai.GenerationConfig(
**self.format_output_type(output_type)
**self.type_adapter.format_output_type(output_type)
)
completion = self.client.generate_content(
generation_config=generation_config, **contents, **inference_kwargs
Expand Down
74 changes: 53 additions & 21 deletions outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pickle
import warnings
from functools import singledispatchmethod
from typing import TYPE_CHECKING, Dict, Iterator, List, Set, Tuple, Union

from outlines.models.base import Model, ModelTypeAdapter
from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
Expand Down Expand Up @@ -93,7 +95,49 @@ def __setstate__(self, state):
raise NotImplementedError("Cannot load a pickled llamacpp tokenizer")


class LlamaCpp:
class LlamaCppTypeAdapter(ModelTypeAdapter):
"""Type adapter for the `llama-cpp-python` library.
`LlamaCppTypeAdapter` is responsible for preparing the arguments to
`llama-cpp-python`'s `Llama.__call__` method: the input (a string prompt),
as well as the logits processor (an instance of `LogitsProcessorList`).
"""

@singledispatchmethod
def format_input(self, model_input):
"""Generate the prompt argument to pass to the model.
Argument
--------
model_input
The input passed by the user.
"""
raise NotImplementedError(
f"The input type {input} is not available. "
"The `llama-cpp-python` library does not support batch inference. "
)

@format_input.register(str)
def format_str_input(self, model_input: str):
return model_input

def format_output_type(self, output_type):
"""Generate the logits processor argument to pass to the model.
Argument
--------
output_type
The logits processor provided.
"""
from llama_cpp import LogitsProcessorList

return LogitsProcessorList([output_type])


class LlamaCpp(Model):
"""Wraps a model provided by the `llama-cpp-python` library."""

def __init__(self, model_path: Union[str, "Llama"], **kwargs):
Expand All @@ -115,6 +159,8 @@ def __init__(self, model_path: Union[str, "Llama"], **kwargs):
self.model = Llama(model_path, **kwargs)

self.tokenizer = LlamaCppTokenizer(self.model)
self.model_type = "local"
self.type_adapter = LlamaCppTypeAdapter()

@classmethod
def from_pretrained(cls, repo_id, filename, **kwargs):
Expand All @@ -134,7 +180,7 @@ def from_pretrained(cls, repo_id, filename, **kwargs):
model = Llama.from_pretrained(repo_id, filename, **kwargs)
return cls(model)

def generate(self, prompt: str, logits_processor, **inference_kwargs) -> str:
def generate(self, model_input, logits_processor, **inference_kwargs):
"""Generate text using `llama-cpp-python`.
Arguments
Expand All @@ -152,16 +198,9 @@ def generate(self, prompt: str, logits_processor, **inference_kwargs) -> str:
The generated text.
"""
from llama_cpp import LogitsProcessorList

if not isinstance(prompt, str):
raise NotImplementedError(
"The `llama-cpp-python` library does not support batch inference."
)

completion = self.model(
prompt,
logits_processor=LogitsProcessorList([logits_processor]),
self.type_adapter.format_input(model_input),
logits_processor=self.type_adapter.format_output_type(logits_processor),
**inference_kwargs,
)
result = completion["choices"][0]["text"]
Expand All @@ -171,7 +210,7 @@ def generate(self, prompt: str, logits_processor, **inference_kwargs) -> str:
return result

def stream(
self, prompt: str, logits_processor, **inference_kwargs
self, model_input, logits_processor, **inference_kwargs
) -> Iterator[str]:
"""Stream text using `llama-cpp-python`.
Expand All @@ -190,16 +229,9 @@ def stream(
A generator that return strings.
"""
from llama_cpp import LogitsProcessorList

if not isinstance(prompt, str):
raise NotImplementedError(
"The `llama-cpp-python` library does not support batch inference."
)

generator = self.model(
prompt,
logits_processor=LogitsProcessorList([logits_processor]),
self.type_adapter.format_input(model_input),
logits_processor=self.type_adapter.format_output_type(logits_processor),
stream=True,
**inference_kwargs,
)
Expand Down
23 changes: 14 additions & 9 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@

from pydantic import BaseModel

from outlines.models.base import Model, ModelTypeAdapter
from outlines.prompts import Vision
from outlines.types import Json

__all__ = ["OpenAI"]


class OpenAIBase:
"""Base class for the OpenAI clients.
class OpenAITypeAdapter(ModelTypeAdapter):
"""Type adapter for the OpenAI clients.
`OpenAI` base is responsible for preparing the arguments to OpenAI's
`completions.create` methods: the input (prompt and possibly image), as well
as the output type (only JSON).
`OpenAITypeAdapter` is responsible for preparing the arguments to OpenAI's
`completions.create` methods: the input (prompt and possibly image), as
well as the output type (only JSON).
"""

Expand Down Expand Up @@ -64,7 +65,7 @@ def format_vision_input(self, model_input: Vision):
{
"type": "image_url",
"image_url": {
"url": f"data:{model_input.image_format};base64,{model_input.image_str}"
"url": f"data:{model_input.image_format};base64,{model_input.image_str}" # noqa: E702
},
},
],
Expand Down Expand Up @@ -114,7 +115,7 @@ def format_json_output_type(self, output_type: Json):
}


class OpenAI(OpenAIBase):
class OpenAI(Model):
"""Thin wrapper around the `openai.OpenAI` client.
This wrapper is used to convert the input and output types specified by the
Expand All @@ -127,15 +128,17 @@ def __init__(self, model_name: str, *args, **kwargs):

self.client = OpenAI(*args, **kwargs)
self.model_name = model_name
self.model_type = "api"
self.type_adapter = OpenAITypeAdapter()

def generate(
self,
model_input: Union[str, Vision],
output_type: Optional[Union[type[BaseModel], str]] = None,
**inference_kwargs,
):
messages = self.format_input(model_input)
response_format = self.format_output_type(output_type)
messages = self.type_adapter.format_input(model_input)
response_format = self.type_adapter.format_output_type(output_type)
result = self.client.chat.completions.create(
model=self.model_name, **messages, **response_format, **inference_kwargs
)
Expand All @@ -156,3 +159,5 @@ def __init__(self, deployment_name: str, *args, **kwargs):

self.client = AzureOpenAI(*args, **kwargs)
self.model_name = deployment_name
self.model_type = "api"
self.type_adapter = OpenAITypeAdapter()
7 changes: 7 additions & 0 deletions tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def test_anthropic_simple_call():
assert isinstance(result, str)


@pytest.mark.api_call
def test_anthropic_direct_call():
model = Anthropic(MODEL_NAME)
result = model("Respond with one word. Not more.")
assert isinstance(result, str)


@pytest.mark.api_call
def test_anthropic_simple_vision():
model = Anthropic(MODEL_NAME)
Expand Down
7 changes: 7 additions & 0 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def test_gemini_simple_call():
assert isinstance(result, str)


@pytest.mark.api_call
def test_gemini_direct_call():
model = Gemini(MODEL_NAME)
result = model("Respond with one word. Not more.")
assert isinstance(result, str)


@pytest.mark.api_call
def test_gemini_simple_vision():
model = Gemini(MODEL_NAME)
Expand Down
5 changes: 5 additions & 0 deletions tests/models/test_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def test_llamacpp_simple(model):
assert isinstance(result, str)


def test_llamacpp_call(model):
result = model("Respond with one word. Not more.")
assert isinstance(result, str)


def test_llamacpp_regex(model):
regex_str = Regex(r"[0-9]").to_regex()
logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer)
Expand Down
7 changes: 7 additions & 0 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def test_openai_simple_call():
assert isinstance(result, str)


@pytest.mark.api_call
def test_openai_direct_call():
model = OpenAI(MODEL_NAME)
result = model("Respond with one word. Not more.")
assert isinstance(result, str)


@pytest.mark.api_call
def test_openai_simple_vision():
model = OpenAI(MODEL_NAME)
Expand Down

0 comments on commit 2b65abd

Please sign in to comment.