diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0565d0f3c..9390f8769 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,6 +33,7 @@ jobs: run: | curl -fsSL https://ollama.com/install.sh | sh ollama --version + ollama pull tinyllama - name: Set up test environment run: | python -m pip install --upgrade pip @@ -50,7 +51,7 @@ jobs: echo "::set-output name=id::$MATRIX_ID" - name: Run tests run: | - uv run pytest -x --cov=outlines + uv run pytest -x --cov=outlines -m 'not api_call' env: COVERAGE_FILE: .coverage.${{ steps.matrix-id.outputs.id }} - name: Upload coverage data diff --git a/README.md b/README.md index 01d7b9c98..059cf2b30 100644 --- a/README.md +++ b/README.md @@ -80,10 +80,15 @@ Please see [the documentation](https://dottxt-ai.github.io/outlines/latest/refer You can reduce the completion to a choice between multiple possibilities: ``` python +from transformers import AutoModelForCausalLM, AutoTokenizer + import outlines model_name = "HuggingFaceTB/SmolLM2-360M-Instruct" -model = outlines.models.transformers(model_name) +model = outlines.from_transformers( + AutoModelForCausalLM.from_pretrained(model_name), + AutoTokenizer.from_pretrained(model_name) +) # You must apply the chat template tokens to the prompt! # See below for an example. @@ -100,9 +105,7 @@ Text: I really really really want pizza. <|im_start|>assistant """ -generator = outlines.generate.choice(model, ["Pizza", "Pasta", "Salad", "Dessert"]) -answer = generator(prompt) - +answer = model(prompt, outlines.Choice(["Pizza", "Pasta", "Salad", "Dessert"])) # Likely answer: Pizza ``` @@ -117,8 +120,7 @@ class Food(str, Enum): salad = "Salad" dessert = "Dessert" -generator = outlines.generate.choice(model, Food) -answer = generator(prompt) +answer = model(prompt, outlines.Choice(Food)) # Likely answer: Pizza ```` @@ -128,9 +130,14 @@ You can instruct the model to only return integers or floats: ``` python +from transformers import AutoModelForCausalLM, AutoTokenizer import outlines -model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1") +model_name = "WizardLM/WizardMath-7B-V1.1" +model = outlines.from_transformers( + AutoModelForCausalLM.from_pretrained(model_name), + AutoTokenizer.from_pretrained(model_name) +) prompt = "result of 9 + 9 = 18result of 1 + 2 = " answer = outlines.generate.format(model, int)(prompt) @@ -138,8 +145,7 @@ print(answer) # 3 prompt = "sqrt(2)=" -generator = outlines.generate.format(model, float) -answer = generator(prompt, max_tokens=10) +answer = model(prompt, outlines.types.number, max_tokens=10) print(answer) # 1.41421356 ``` @@ -151,9 +157,14 @@ Outlines also comes with fast regex-structured generation. In fact, the `choice` hood: ``` python +from transformers import AutoModelForCausalLM, AutoTokenizer import outlines -model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct") +model_name = "microsoft/Phi-3-mini-4k-instruct" +model = outlines.from_transformers( + AutoModelForCausalLM.from_pretrained(model_name), + AutoTokenizer.from_pretrained(model_name) +) prompt = """ <|im_start|>system You are a helpful assistant. @@ -167,15 +178,13 @@ The IP address of a Google DNS server is """ -generator = outlines.generate.text(model) -unstructured = generator(prompt, max_tokens=30) +unstructured = model(prompt, max_tokens=10) -generator = outlines.generate.regex( - model, - r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)", - sampler=outlines.samplers.greedy(), +structured = model( + prompt, + outlines.Regex(r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"), + max_tokens=30 ) -structured = generator(prompt, max_tokens=30) print(unstructured) # 8.8.8.8 @@ -196,6 +205,7 @@ Outlines users can guide the generation process so the output is *guaranteed* to ```python from enum import Enum from pydantic import BaseModel, constr +from transformers import AutoModelForCausalLM, AutoTokenizer import outlines @@ -222,20 +232,25 @@ class Character(BaseModel): strength: int -model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct") +model_name = "microsoft/Phi-3-mini-4k-instruct" +model = outlines.from_transformers( + AutoModelForCausalLM.from_pretrained(model_name), + AutoTokenizer.from_pretrained(model_name) +) # Construct structured sequence generator -generator = outlines.generate.json(model, Character) # Draw a sample seed = 789001 -character = generator("Give me a character description", seed=seed) +prompt = "Give me a character description" +character = model(prompt, outlines.JsonType(Character), seed=seed) print(repr(character)) # Character(name='Anderson', age=28, armor=, weapon=, strength=8) -character = generator("Give me an interesting character description") +prompt = "Give me an interesting character description" +character = model(prompt, outlines.JsonType(Character), seed=seed) print(repr(character)) # Character(name='Vivian Thr', age=44, armor=, weapon=, strength=125) @@ -248,6 +263,8 @@ The method works with union types, optional types, arrays, nested schemas, etc. Sometimes you just want to be able to pass a JSON Schema instead of a Pydantic model. We've got you covered: ``` python +from transformers import AutoModelForCausalLM, AutoTokenizer + import outlines schema = '''{ @@ -287,9 +304,14 @@ schema = '''{ } }''' -model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct") -generator = outlines.generate.json(model, schema) -character = generator("Give me a character description") +model_name = "microsoft/Phi-3-mini-4k-instruct" +model = outlines.from_transformers( + AutoModelForCausalLM.from_pretrained(model_name), + AutoTokenizer.from_pretrained(model_name) +) + +prompt = "Give me a character description" +character = model(prompt, outlines.JsonType(schema)) ``` ### Using context-free grammars to guide generation @@ -297,6 +319,8 @@ character = generator("Give me a character description") Formal grammars rule the world, and Outlines makes them rule LLMs too. You can pass any context-free grammar in the EBNF format and Outlines will generate an output that is valid to this grammar: ``` python +from transformers import AutoModelForCausalLM, AutoTokenizer + import outlines arithmetic_grammar = """ @@ -313,9 +337,14 @@ arithmetic_grammar = """ %import common.NUMBER """ -model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1") -generator = outlines.generate.cfg(model, arithmetic_grammar) -sequence = generator("Alice had 4 apples and Bob ate 2. Write an expression for Alice's apples:") +model_name = "WizardLM/WizardMath-7B-V1.1" +model = outlines.from_transformers( + AutoModelForCausalLM.from_pretrained(model_name), + AutoTokenizer.from_pretrained(model_name) +) + +prompt = "Alice had 4 apples and Bob ate 2. Write an expression for Alice's apples:" +sequence = model(prompt, outlines.types.Cfg(arithmetic_grammar)) print(sequence) # (8-2) @@ -328,17 +357,24 @@ This was a very simple grammar, and you can use `outlines.generate.cfg` to gener Outlines can infer the structure of the output from the signature of a function. The result is a dictionary, and can be passed directly to the function using the usual dictionary expansion syntax `**`: ```python +from transformers import AutoModelForCausalLM, AutoTokenizer + import outlines def add(a: int, b: int): return a + b -model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1") -generator = outlines.generate.json(model, add) + +model_name = "WizardLM/WizardMath-7B-V1.1" +model = outlines.from_transformers( + AutoModelForCausalLM.from_pretrained(model_name), + AutoTokenizer.from_pretrained(model_name) +) +generator = outlines.generate.json(model, outlines.types.JsonType(add)) result = generator("Return json with two integers named a and b respectively. a is odd and b even.") -print(add(**result)) +# print(add(**result)) # 3 ``` @@ -350,6 +386,8 @@ You can also embed various functions into an enum to generate params: from enum import Enum from functools import partial +from transformers import AutoModelForCausalLM, AutoTokenizer + import outlines @@ -363,11 +401,15 @@ class Operation(Enum): add = partial(add) mul = partial(mul) -model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1") -generator = outlines.generate.json(model, Operation) +model_name = "WizardLM/WizardMath-7B-V1.1" +model = outlines.from_transformers( + AutoModelForCausalLM.from_pretrained(model_name), + AutoTokenizer.from_pretrained(model_name) +) +generator = outlines.generate.json(model, outlines.types.JsonType(Operation)) result = generator("Return json with two float named c and d respectively. c is negative and d greater than 1.0.") -print(result) +#print(result) # {'c': -3.14, 'd': 1.5} ``` @@ -393,6 +435,8 @@ You are a sentiment-labelling assistant. You can then load it and call it with: ``` python +from transformers import AutoModelForCausalLM, AutoTokenizer + import outlines examples = [ diff --git a/docs/cookbook/extract_event_details.py b/docs/cookbook/extract_event_details.py index b51f8d921..a3231cd00 100644 --- a/docs/cookbook/extract_event_details.py +++ b/docs/cookbook/extract_event_details.py @@ -1,11 +1,15 @@ from datetime import datetime +from mlx_lm import load from pydantic import BaseModel, Field -from outlines import generate, models +import outlines +from outlines.generate import Generator +from outlines.types import JsonType + # Load the model -model = models.mlxlm("mlx-community/Hermes-3-Llama-3.1-8B-8bit") +model = outlines.from_mlxlm(*load("mlx-community/Hermes-3-Llama-3.1-8B-8bit")) # Define the event schema using Pydantic @@ -34,7 +38,7 @@ class Event(BaseModel): see you 😘 """ # Create the generator -generator = generate.json(model, Event) +generator = Generator(model, JsonType(Event)) # Extract the event information event = generator(prompt + message) diff --git a/docs/reference/models/anthropic.md b/docs/reference/models/anthropic.md index ffc510f43..d3db67ddb 100644 --- a/docs/reference/models/anthropic.md +++ b/docs/reference/models/anthropic.md @@ -2,44 +2,43 @@ !!! Installation - You need to install the `anthropic` library to be able to use the Anthropic API in Outlines. Or alternatively you can run: - - ```bash - pip install "outlines[anthropic]" - ``` + You need to install the `anthropic` library to be able to use the Anthropic API in Outlines: `pip install anthropic`. ## Anthropic models Outlines supports models available via the Anthropic API, e.g. Claude 3.5 Haiku or Claude 3.5 Sonner. You can initialize the model by passing the model name to `outlines.models.Anthropic`: ```python -from outlines import models +from anthropic import Anthropic +import outlines -model = models.Anthropic("claude-3-5-haiku-latest") -model = models.Anthropic("claude-3-5-sonnet-latest") +model = outlines.from_anthropic(Anthropic("claude-3-5-haiku-latest")) +model = outlines.from_anthropic(Anthropic("claude-3-5-sonnet-latest")) ``` -Check the [Anthropic documentation](https://docs.anthropic.com/en/docs/about-claude/models) for an up-to-date list of available models. You can pass any paramater you would pass to the Anthropic SDK as keyword arguments: - -```python -model = models.Anthropic( - "claude-3.5-haiku-latest", - api_key="" -) -``` +Check the [Anthropic documentation](https://docs.anthropic.com/en/docs/about-claude/models) for an up-to-date list of available models. ## Text generation To generate text using an Anthropic model you need to build a `Generator` object, possibly with the desired output type. You can then call the model by calling the `Generator`. The method accepts every argument that you could pass to the `client.completions.create` function, as keyword arguments: ```python -from outlines import models, Generator +from anthropic import Anthropic +import outlines -model = models.Anthropic("claude-3-5-haiku-latest") +model = outlines.from_anthropic(Anthropic("claude-3-5-haiku-latest")) generator = Generator(model) result = generator("Prompt", max_tokens=1024) + +# Call the model directly +result = model("Prompt", max_tokens=1024) ``` + +!!! Warning + + You must set a value for `max_tokens` with Anthropic models. + See the [Anthropic SDK documentation](https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/resources/messages.py) for the list of available arguments. The Anthropic API currently does not support structured generation. diff --git a/docs/reference/models/gemini.md b/docs/reference/models/gemini.md index 07ec9aa5a..ba83d59ec 100644 --- a/docs/reference/models/gemini.md +++ b/docs/reference/models/gemini.md @@ -2,21 +2,19 @@ !!! Installation - You need to install the `google-generativeai` library to be able to use the Gemini API in Outlines. Or alternatively you can run: + You need to install the `google-generativeai` library to be able to use the Gemini API in Outlines: `pip install google-generativeai` - ```bash - pip install "outlines[gemini]" - ``` ## Gemini models Outlines supports models available via the Gemini API, e.g. Gemini 1.5. You can initialize the model by passing the model name to `outlines.models.Gemini`: ```python -from outlines import models +import outlines +import google.generativeai as genai -model = models.Gemini("gemini-1-5-flash") -model = models.Gemini("gemini-1-5-pro") +model = outlines.from_gemini(genai.GenerativeModel("gemini-1-5-flash")) +model = outlines.from_gemini(genai.GenerativeModel("gemini-1-5-pro")) ``` Check the [Gemini documentation](https://ai.google.dev/gemini-api/docs/models/gemini) for an up-to-date list of available models. @@ -26,11 +24,15 @@ Check the [Gemini documentation](https://ai.google.dev/gemini-api/docs/models/ge To generate text using a Gemini model you need to build a `Generator` object, possibly with the desired output type. You can then call the model by calling the `Generator`. The method accepts every argument that you could pass to the `client.completions.create` function, as keyword arguments: ```python -from outlines import models, Generator +import outlines +import google.generativeai as genai -model = models.Gemini("gemini-1-5-flash") -generator = Generator(model) +model = outlines.from_gemini(genai.GenerativeModel("gemini-1-5-flash")) +generator = outlines.Generator(model) result = generator("Prompt", max_tokens=1024) + +# Call the model directly +result = model("Prompt", max_tokens=1024) ``` ### Structured generation @@ -43,17 +45,22 @@ Outlines provides support for JSON Schema-based structured generation with the G ```python from collections import TypedDict -from outlines import Generator, models -from outlines.types import Json -model = models.Gemini("gemini-1-5-flash") +import google.generativeai as genai + +import outlines +from outlines import Generator +from outlines.types import JsonType + + +model = outlines.from_gemini(genai.GenerativeModel("gemini-1-5-flash")) class Person(TypedDict): first_name: str last_name: str age: int -generator = Generator(model, Json(Person)) +generator = Generator(model, JsonType(Person)) generator("current indian prime minister on january 1st 2023") # Person(first_name='Narendra', last_name='Modi', age=72) ``` @@ -68,10 +75,13 @@ Outlines provides support for multiple-choices structured generation. Enums and ```python from enum import Enum -from outlines import Generator, models + +import google.generativeai as genai + +from outlines import Generator from outlines.types import Choice -model = models.Gemini("gemini-1-5-flash") +model = outlines.from_gemini(genai.GenerativeModel("gemini-1-5-flash")) class Foo(Enum): foo = "Foo" diff --git a/docs/reference/models/llamacpp.md b/docs/reference/models/llamacpp.md index 52e892b76..313f58eb9 100644 --- a/docs/reference/models/llamacpp.md +++ b/docs/reference/models/llamacpp.md @@ -2,49 +2,36 @@ Outlines provides an integration with [Llama.cpp](https://github.com/ggerganov/llama.cpp) using the [llama-cpp-python library][llamacpp]. Llamacpp allows to run quantized models on machines with limited compute. -!!! Note "Installation" +!!! Note "Documentation" - You need to install the `llama-cpp-python` library to use the llama.cpp integration. See the [installation section](#installation) for instructions to install `llama-cpp-python` with CUDA, Metal, ROCm and other backends. To get started quickly you can also run: + To be able to use llama.cpp in Outlines, you must install the `llama-cpp-python` library, `pip install llama-cpp-python` + + Consult the [`llama-cpp-python` documentation](https://llama-cpp-python.readthedocs.io/en/latest) for detailed informations about how to initialize model and the available options. - ```bash - pip install "outlines[llamacpp]" - ``` ## Load the model -To load a model you can use the same interface as you would using `llamap-cpp-python` directly. The default method is to initialize the model by passing the path to the weights on your machine. Assuming [Phi2's weights](https://huggingface.co/TheBloke/phi-2-GGUF) are in the current directory: +You can use `outlines.from_llamacpp` to load a `llama-cpp-python` model. Assuming [Phi2's weights](https://huggingface.co/TheBloke/phi-2-GGUF) are in the current directory: ```python -from outlines import models +from llama_cpp import Llama +import outlines -llm = models.LlamaCpp("./phi-2.Q4_K_M.gguf") +llm = outlines.from_llamacpp(Llama("./phi-2.Q4_K_M.gguf")) ``` You can initialize the model by passing the name of the repository on the HuggingFace Hub, and the filenames (or glob pattern): ```python -from outlines import models +from llama_cpp import Llama +import outlines -model = models.LlamaCpp.from_pretrained("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf") +model = outlines.from_llamacpp(Llama.from_pretrained("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf")) ``` -This will download the model files to the hub cache folder and load the weights in memory. - - -You can pass the same keyword arguments to the model as you would pass in the [llama-ccp-library][llamacpp]: - -```python -from outlines import models - -model = models.LlamaCpp( - "TheBloke/phi-2-GGUF", - "phi-2.Q4_K_M.gguf" - n_ctx=512, # to set the context length value -) -``` +This will download the model files to the hub cache folder and load the weights in memory. See the [llama-cpp-python documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__) for the full list of parameters you can pass to the `Llama` class. -See the [llama-cpp-python documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__) for the full list of parameters. ### Load the model on GPU @@ -55,13 +42,14 @@ See the [llama-cpp-python documentation](https://llama-cpp-python.readthedocs.io To load the model on GPU, pass `n_gpu_layers=-1`: ```python -from outlines import models +from llama_cpp import Llama +import outlines -model = models.LlamaCpp( +model = outlines.from_llamacpp(Llama( "TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf", n_gpu_layers=-1, # to use GPU acceleration -) +)) ``` @@ -71,11 +59,15 @@ model = models.LlamaCpp( To generate text you must first create a `Generator` object by passing the model instance and, possibley, the expected output type: ```python -from outlines import models, generate +import outlines -model = models.LlamaCpp("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf") -generator = Generator(model) +model = outlines.from_llamacpp(Llama( + "TheBloke/phi-2-GGUF", + "phi-2.Q4_K_M.gguf", + n_gpu_layers=-1, # to use GPU acceleration +)) +generator = outlines.Generator(model) ``` You can pass to the generator the same keyword arguments you would pass in `llama-cpp-python`: @@ -91,78 +83,6 @@ tokens = generator.stream("A prompt") ``` -## Installation - -You need to install the `llama-cpp-python` library to use the llama.cpp integration. - -### CPU - -For a *CPU-only* installation run: - -```bash -pip install llama-cpp-python -``` - -!!! Warning - - Do not run this command if you want support for BLAS, Metal or CUDA. Follow the instructions below instead. - -### CUDA - -```bash -CMAKE_ARGS="-DLLAMA_CUDA=on" pip install llama-cpp-python -``` - -It is also possible to install pre-built wheels with CUDA support (Python 3.10 and above): - -```bash -pip install llama-cpp-python \ - --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/ -``` - -Where `` is one of the following, depending on the version of CUDA installed on your system: - -- `cu121` for CUDA 12.1 -- `cu122` for CUDA 12.2 -- `cu123` CUDA 12.3 - -### Metal - -```bash -CMAKE_ARGS="-DLLAMA_METAL=on" pip install llama-cpp-python -``` - -It is also possible to install pre-build wheels with Metal support (Python 3.10 or above, MacOS 11.0 and above): - -```bash -pip install llama-cpp-python \ - --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/metal -``` - -### OpenBLAS - -```bash -CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" pip install llama-cpp-python -``` - -### Other backend - -`llama.cpp` supports many other backends. Refer to the [llama.cpp documentation][llama-cpp-python-install] to use the following backends: - -- CLBast (OpenCL) -- hipBLAS (ROCm) -- Vulkan -- Kompute -- SYCL - - [llamacpp]: https://github.com/abetlen/llama-cpp-python [llama-cpp-python-call]: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__ [llama-cpp-python-install]: https://github.com/abetlen/llama-cpp-python/tree/08b16afe11e7b42adec2fed0a781123383476045?tab=readme-ov-file#supported-backends -[llama-cpp-sampling-params]: https://github.com/ggerganov/llama.cpp/blob/e11a8999b5690f810c2c99c14347f0834e68c524/common/sampling.h#L22 -[mirostat]: https://arxiv.org/abs/2007.14966 -[degeneration]: https://arxiv.org/abs/1904.09751 -[top-k]: https://arxiv.org/abs/1805.04833 -[minimum-p]: https://github.com/ggerganov/llama.cpp/pull/3841 -[locally-typical]: https://arxiv.org/abs/2202.00666 -[tail-free]: https://www.trentonbricken.com/Tail-Free-Sampling diff --git a/docs/reference/models/mlxlm.md b/docs/reference/models/mlxlm.md index d435b9c1f..33953c11f 100644 --- a/docs/reference/models/mlxlm.md +++ b/docs/reference/models/mlxlm.md @@ -4,11 +4,9 @@ Outlines provides an integration with [mlx-lm](https://github.com/ml-explore/mlx !!! Note "Installation" - You need to install the `mlx` and `mlx-lm` libraries on a device which [supports Metal](https://support.apple.com/en-us/102894) to use the mlx-lm integration. To get started quickly you can also run: + You need to install the `mlx` and `mlx-lm` libraries on a device which [supports Metal](https://support.apple.com/en-us/102894) to use the mlx-lm integration: `pip install mlx mlx-lm`. - ```bash - pip install "outlines[mlxlm]" - ``` + Consult the [`mlx-lm` documentation](https://github.com/ml-explore/mlx-examples/tree/main/llms) for detailed informations about how to initialize OpenAI clients and the available options. ## Load the model @@ -16,34 +14,14 @@ Outlines provides an integration with [mlx-lm](https://github.com/ml-explore/mlx You can initialize the model by passing the name of the repository on the HuggingFace Hub. The official repository for mlx-lm supported models is [mlx-community](https://huggingface.co/mlx-community). ```python -from outlines import models +from mlx_lm import load +import outlines -model = models.mlxlm("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit") +model = outlines.from_mlxlm(*load("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit")) ``` This will download the model files to the hub cache folder and load the weights in memory. -The arguments `model_config` and `tokenizer_config` are available to modify loading behavior. For example, per the `mlx-lm` [documentation](https://github.com/ml-explore/mlx-examples/tree/main/llms#supported-models), you must set an eos_token for `qwen/Qwen-7B`. In outlines you may do so via - -``` -model = models.mlxlm( - "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit", - tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True}, -) -``` - -**Main parameters:** - -(Subject to change. Table based on [mlx-lm.load docstring](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py#L429)) - -| Parameters | Type | Description | Default | -|--------------------|--------|--------------------------------------------------------------------------------------------------|---------| -| `tokenizer_config` | `dict` | Configuration parameters specifically for the tokenizer. Defaults to an empty dictionary. | `{}` | -| `model_config` | `dict` | Configuration parameters specifically for the model. Defaults to an empty dictionary. | `{}` | -| `adapter_path` | `str` | Path to the LoRA adapters. If provided, applies LoRA layers to the model. | `None` | -| `lazy` | `bool` | If False, evaluate the model parameters to make sure they are loaded in memory before returning. | `False` | - - ## Generate text You may generate text using the parameters described in the [text generation documentation](../text.md). @@ -51,12 +29,11 @@ You may generate text using the parameters described in the [text generation doc With the loaded model, you can generate text or perform structured generation, e.g. ```python -from outlines import models, generate +from mlx_lm import load +import outlines -model = models.mlxlm("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit") -generator = generate.text(model) - -answer = generator("A prompt", temperature=2.0) +model = outlines.from_mlxlm(*load("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit")) +answer = model("A prompt", temperature=2.0) ``` ## Streaming @@ -64,10 +41,11 @@ answer = generator("A prompt", temperature=2.0) You may creating a streaming iterable with minimal changes ```python -from outlines import models, generate +from mlx_lm import load +import outlines -model = models.mlxlm("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit") -generator = generate.text(model) +model = outlines.from_mlxlm(*load("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit")) +generator = outlines.Generator(model) for token_str in generator.text("A prompt", temperature=2.0): print(token_str) @@ -80,14 +58,13 @@ You may perform structured generation with mlxlm to guarantee your output will m Example: Phone number generation with pattern `"\\+?[1-9][0-9]{7,14}"`: ```python -from outlines import models, generate +from mlx_lm import load +import outlines -model = models.mlxlm("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit") +model = outlines.from_mlxlm(*load("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit")) phone_number_pattern = "\\+?[1-9][0-9]{7,14}" -generator = generate.regex(model, phone_number_pattern) - -model_output = generator("What's Jennys Number?\n") +model_output = model("What's Jennys Number?\n", outlines.Regex(phone_number_pattern)) print(model_output) # '8675309' ``` diff --git a/docs/reference/models/models.md b/docs/reference/models/models.md index eab70f2a1..5d85c67f8 100644 --- a/docs/reference/models/models.md +++ b/docs/reference/models/models.md @@ -4,53 +4,49 @@ title: Models # Models -Outlines supports generation using a number of inference engines (`outlines.models`). Loading a model using outlines follows a similar interface between inference engines: - -```python -import outlines - -model = outlines.models.transformers("microsoft/Phi-3-mini-128k-instruct") -model = outlines.models.transformers_vision("llava-hf/llava-v1.6-mistral-7b-hf") -model = outlines.models.vllm("microsoft/Phi-3-mini-128k-instruct") -model = outlines.models.llamacpp( - "microsoft/Phi-3-mini-4k-instruct-gguf", "Phi-3-mini-4k-instruct-q4.gguf" -) -model = outlines.models.exllamav2("bartowski/Phi-3-mini-128k-instruct-exl2") -model = outlines.models.mlxlm("mlx-community/Phi-3-mini-4k-instruct-4bit") - -model = outlines.models.openai( - "gpt-4o-mini", - api_key=os.environ["OPENAI_API_KEY"] -) -``` - - -# Feature Matrix -| | [Transformers](transformers.md) | [Transformers Vision](transformers_vision.md) | [vLLM](vllm.md) | [llama.cpp](llamacpp.md) | [ExLlamaV2](exllamav2.md) | [MLXLM](mlxlm.md) | [OpenAI](openai.md)* | -|-------------------|--------------|---------------------|------|-----------|-----------|-------|---------| -| **Device** | | | | | | | | -| Cuda | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | N/A | -| Apple Silicon | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | N/A | -| x86 / AMD64 | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | N/A | -| **Sampling** | | | | | | | | -| Greedy | ✅ | ✅ | ✅ | ✅* | ✅ | ✅ | ❌ | -| Multinomial | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| Multiple Samples | ✅ | ✅ | | ❌ | | ❌ | ✅ | -| Beam Search | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | -| **Generation** | | | | | | | | -| Batch | ✅ | ✅ | ✅ | ❌ | ? | ❌ | ❌ | -| Stream | ✅ | ❌ | ❌ | ✅ | ? | ✅ | ❌ | -| Text | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Structured** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| JSON Schema | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| Choice | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| Regex | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| Grammar | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | - - -## Caveats - -- OpenAI doesn't support structured generation due to limitations in their API and server implementation. -- `outlines.generate` ["Structured"](../generation/generation.md) includes methods such as `outlines.generate.regex`, `outlines.generate.json`, `outlines.generate.cfg`, etc. -- MLXLM only supports Apple Silicon. -- llama.cpp greedy sampling available via multinomial with `temperature = 0.0`. +This section provides detailed information about using structured generation with a number of inference engines. + +## Supported providers + +### Open Source models + +Open source models offer more flexibility for structured generation as we have control over the sampling loop: + +- [transformers](transformers.md) - Run open-source models locally. +- [llama-cpp-python](llamacpp.md) - Python bindings for llama.cpp. +- [mlx-lm](mlxlm.md) - Run open-source models on Metal hardware. +- [vllm](vlm.md) - Run open-source models on the vLLM engine. +- [exllamaV2](exllamav2.md) + +Ollama only supports Json Schema: + +- [ollama](ollama.md) - Python client library for Ollama. + +### Cloud AI providers + +OpenAI has recently integrated [structured outputs][structured-outputs] in its API, and only JSON Schema-based structured generation is available. Google's Gemini API supports both Json Schema and multiple choices: + +- [OpenAI](openai.md) - GPT-4o, o1, o3-mini and other OpenAI models. +- [Azure OpenAI](openai.md) - Microsoft's Azure-hosted OpenAI models. +- [Gemini](gemini.md) - Run Google's Gemini model. + + +## Structure generation coverage + +Integrations differ in their coverage of structured generation. Here is a summary: + + +| | [Transformers](transformers.md) | [vLLM](vllm.md) | [llama.cpp](llamacpp.md) | [ExLlamaV2](exllamav2.md) | [MLXLM](mlxlm.md) | [OpenAI](openai.md) | [Gemini](gemini.md) +|-------------------|--------------|------|-----------|-----------|-------|---------|-------| +| **Supported HW** | | | | | | | | +| CUDA | ✅ | ✅ | ✅ | ✅ | ❌ | N/A | N/A | +| Apple Silicon | ✅ | ❌ | ✅ | ✅ | ✅ | N/A | N/A | +| x86 / AMD64 | ✅ | ❌ | ✅ | ✅ | ❌ | N/A | N/A | +| **Structure** | | | | | | | | +| JSON Schema | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Choice | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | +| Regex | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | +| Grammar | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | + + +[structured-outputs]: https://platform.openai.com/docs/guides/structured-outputs diff --git a/docs/reference/models/ollama.md b/docs/reference/models/ollama.md index 0d2961832..4cc3bcd2d 100644 --- a/docs/reference/models/ollama.md +++ b/docs/reference/models/ollama.md @@ -7,6 +7,8 @@ - To download Ollama: https://ollama.com/download - To install the ollama python sdk: `pip install ollama` + Consult the [`ollama` documentation](https://github.com/ollama/ollama-python) for detailed informations about how to initialize models and the available options. + ## Ollama models You must provide a model name when instantiating the `outlines.models.Ollama` class. This model must be available on your system. diff --git a/docs/reference/models/openai.md b/docs/reference/models/openai.md index eae8dde39..354994861 100644 --- a/docs/reference/models/openai.md +++ b/docs/reference/models/openai.md @@ -1,56 +1,34 @@ # OpenAI and compatible APIs -!!! Installation +!!! Note "Documentation" - You need to install the `openai` library to be able to use the OpenAI API in Outlines. Or alternatively you can run: + To be able to use OpenAI in Outlines, you must install the OpenAI Python SDK with `pip install openai` - ```bash - pip install "outlines[openai]" - ``` + Consult the [OpenAI SDK documentation](https://github.com/openai/openai-python) for detailed informations about how to initialize OpenAI clients and the available options. ## OpenAI models Outlines supports models available via the OpenAI Chat API, e.g. GPT-4o, ChatGPT and GPT-4. You can initialize the model by passing the model name to `outlines.models.OpenAI`: ```python -from outlines import models +from openai import OpenAI +import outlines +# OpenAI models +client = OpenAI() +model = outlines.from_openai(client, "gpt-4o-mini") +model = outlines.from_openai(client, "gpt-4o") -model = models.OpenAI("gpt-4o-mini") -model = models.OpenAI("gpt-4o") -``` - -Check the [OpenAI documentation](https://platform.openai.com/docs/models/gpt-4o) for an up-to-date list of available models. You can pass any parameter you would pass to `openai.OpenAI` as keyword arguments: - -```python -import os -from outlines import models - - -model = models.OpenAI( - "gpt-4o-mini", - api_key=os.environ["OPENAI_API_KEY"] -) -``` - -Refer to the [OpenAI SDK's code](https://github.com/openai/openai-python/blob/54a5911f5215148a0bdeb10e2bcfb84f635a75b9/src/openai/_client.py) for an up-to-date list of the initialization parameters. - -## Azure OpenAI models - -Outlines also supports Azure OpenAI models: - -```python -from outlines import models - - -model = models.AzureOpenAI( - "azure-deployment-name", +# OpenAI models deployed on Azure +client = AzureOpenAI( api_version="2024-07-18", azure_endpoint="https://example-endpoint.openai.azure.com", ) +model = outlines.from_openai(client, "azure-deployment-name") ``` -You can pass any parameter you would pass to `openai.AzureOpenAI`. You can consult the [OpenAI SDK's code](https://github.com/openai/openai-python/blob/54a5911f5215148a0bdeb10e2bcfb84f635a75b9/src/openai/lib/azure.py) for an up-to-date list. +Check the [OpenAI documentation](https://platform.openai.com/docs/models/gpt-4o) for an up-to-date list of available models, and the [OpenAI SDK's code](https://github.com/openai/openai-python/blob/54a5911f5215148a0bdeb10e2bcfb84f635a75b9/src/openai/_client.py) for an up-to-date list of the initialization parameters. + ## Advanced configuration @@ -59,17 +37,17 @@ For more advanced configuration option, such as support proxy, please consult th ```python from openai import AsyncOpenAI, DefaultHttpxClient -from outlines import models -from outlines.models.openai import OpenAIConfig - +import outlines client = AsyncOpenAI( + api_key="my key", base_url="http://my.test.server.example.com:8083", http_client=DefaultHttpxClient( proxies="http://my.test.proxy.example.com", transport=httpx.HTTPTransport(local_address="0.0.0.0"), ), ) +model = outlines.from_openai(client, "model_name") ``` ## Models that follow the OpenAI standard @@ -79,27 +57,25 @@ Outlines supports models that follow the OpenAI standard. You will need to initi ```python import os from openai import AsyncOpenAI -from outlines import models -from outlines.models.openai import OpenAIConfig - +import outlines -model = models.OpenAI( - "model_name", +client = AsyncOpenAI( api_key=os.environ.get("PROVIDER_KEY"), - base_url="http://other.provider.server.com" + base_url="http://other.provider.server.com", ) +model = outlines.from_openai(client, "model_name") ``` -## Text generation +## Calling the model -To generate text using an OpenAI model you need to build a `Generator` object, possibly with the desired output type. You can then call the model by calling the `Generator`. The method accepts every argument that you could pass to the `client.completions.create` function, as keyword arguments: +You can call the model directly. The method accepts every argument that you could pass to the `client.completions.create` function, as keyword arguments: ```python -from outlines import models, Generator +from openai import OpenAI +import outlines -model = models.OpenAI("gpt-4o-mini") -generator = Generator(model) -result = generator("Prompt", seed=10) +model = outlines.from_openai(OpenAI(), "gpt-4o-mini") +result = model("Prompt", seed=10) ``` See the [OpenAI SDK documentation](https://github.com/openai/openai-python/blob/6974a981aec1814b5abba429a8ea21be9ac58538/src/openai/types/completion_create_params.py#L13) for the list of available arguments. @@ -109,11 +85,13 @@ See the [OpenAI SDK documentation](https://github.com/openai/openai-python/blob/ Outlines provides support for [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs/json-mode). Currently only JSON-Schema is supported: ```python +from openai import OpenAI from pydantic import BaseModel -from outlines import models, Generator -from outlines.types import Json -model = models.OpenAI("gpt-4o-mini") +import outlines +from outlines.types import JsonType + +model = outlines.from_openai(OpenAI(), "gpt-4o-mini") class Person(BaseModel): first_name: str diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md index f24591436..724a724c3 100644 --- a/docs/reference/models/transformers.md +++ b/docs/reference/models/transformers.md @@ -1,92 +1,34 @@ # Transformers +!!! Note "Documentation" -!!! Installation + To be able to use Transformers models in Outlines, you must install the `transformers` library, `pip install transformers` - You need to install the `transformer` library to be able to use these models in Outlines, or alternatively: - - ```bash - pip install "outlines[transformers]" - ``` + Consult the [`transformers` documentation](https://huggingface.co/docs/transformers/en/index) for detailed informations about how to initialize models and the available options. ## Create a `Transformers` model -The only mandatory argument to instantiate a `Transformers` model is the name of the model to use. -```python -from outlines import models - -model = models.Transformers("microsoft/Phi-3-mini-4k-instruct") -``` - -The model name must be a valid `transformers` model name. You can find a list of all in the HuggingFace library [here](https://huggingface.co/models). - -When instantiating a `Transformers` model as such, the class creates a model from the transformers libray using the class `AutoModelForCausalLM` by default (`transformers.AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)`). - -You can also provide keyword arguments in an optional `model_kwargs` parameter. Those will be passed to the `from_pretrained` method of the model class. One such argument is `device_map`, which allows you to specify the device on which the model will be loaded. - -For instance: -```python -from outlines import models - -model = models.Transformers("microsoft/Phi-3-mini-4k-instruct", model_kwargs={"device_map": "cuda"}) -``` - -## Alternative model classes - -If the model you want to use is not compatible with `AutoModelForCausalLM`, you must provide a value for the `model_class` parameter. This value must be a valid `transformers` model class. - -For instance: -```python -from outlines import models -from transformers import AutoModelForSeq2SeqLM - -model = models.Transformers("facebook/bart-large", model_class=AutoModelForSeq2SeqLM) -``` - -When you instantiate a `Transformers` model, the class also creates a `Tokenizer` instance from the `AutoTokenizer` class. You can provide keyword arguments in an optional `tokenizer_kwargs` parameter. Those will be passed to the `from_pretrained` method of the tokenizer class as such: `tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs)`. - -Similarly, if your model is not compatible with `AutoTokenizer`, you must provide a value for the `tokenizer_class` parameter. +You can use `outlines.from_transformers` to load a `transformers` model and tokenizer: ```python +from transformers import AutoModelForCausalLM, AutoTokenizer from outlines import models -from transformers import T5ForConditionalGeneration, T5Tokenizer -model_pile_t5 = models.Transformers( - model_name="EleutherAI/pile-t5-large", - model_class=T5ForConditionalGeneration, - tokenizer_class=T5Tokenizer +model_name = "microsoft/Phi-3-mini-4k-instruct" +model = models.from_transformers( + AutoModelForCausalLM.from_pretrained(model_name), + AutoTokenizer.from_pretrained(model_name) ) ``` -### Mamba - -[Mamba](https://github.com/state-spaces/mamba) is a transformers alternative which employs memory efficient, linear-time decoding. - -To use Mamba with outlines you must first install the necessary requirements: -``` -pip install causal-conv1d>=1.2.0 mamba-ssm torch transformers -``` - -Then you can create an `Mamba` Outlines model via: -```python -from outlines import models - -model = models.Mamba("state-spaces/mamba-2.8b-hf", model_kwargs={"device_map": "cuda"}, tokenizer_kwargs={"padding_side": "left"}) -``` - -Alternatively, you can use the `Transformers` class to create an `Mamba` model by providing the appropriate `model_class` and `tokenizer_class` arguments. - -Read [`transformers`'s documentation](https://huggingface.co/docs/transformers/en/model_doc/mamba) for more information. - -### Encoder-Decoder Models - -You can use encoder-decoder (seq2seq) models like T5 and BART with Outlines. +The model name must be a valid `transformers` model name. You can find a list of all in the HuggingFace library [here](https://huggingface.co/models). We currently support `CausalLM`, `Seq2Seq`, `Mamba` and vision models. Be cautious with model selection though, some models such as `t5-base` don't include certain characters (`{`) and you may get an error when trying to perform structured generation. ## Use the model to generate text Once you have created a `Transformers` model, you can use it to generate text by calling the instance of the model. + ```python model("Hello, how are you?") ``` diff --git a/docs/reference/models/vllm.md b/docs/reference/models/vllm.md index 8789b588e..1636410a4 100644 --- a/docs/reference/models/vllm.md +++ b/docs/reference/models/vllm.md @@ -3,11 +3,9 @@ !!! Note "Installation" - You need to install the `vllm` library to use the vLLM integration. See the [installation section](#installation) for instructions to install vLLM for CPU or ROCm. To get started you can also run: + You need to install the `vllm` library to use the vLLM integration: `pip install vllm`. The default installation only works on machines with a GPU, follow the [installation section][vllm-install-cpu] for instructions to install vLLM for CPU or ROCm. - ```bash - pip install "outlines[vllm]" - ``` + Consult the [vLLM documentation][vllm-docs] for detailed informations about how to initialize OpenAI clients and the available options. ## Load the model @@ -15,106 +13,19 @@ Outlines supports models available via vLLM's offline batched inference interfac ```python -from outlines import models - -model = models.vllm("microsoft/Phi-3-mini-4k-instruct") +import outlines +from vllm import LLM +model = outlines.from_vllm(LLM("microsoft/Phi-3-mini-4k-instruct")) ``` -Or alternatively: - -```python -import vllm -from outlines import models - -llm = vllm.LLM("microsoft/Phi-3-mini-4k-instruct") -model = models.VLLM(llm) -``` - - Models are loaded from the [HuggingFace hub](https://huggingface.co/). !!! Warning "Device" - The default installation of vLLM only allows to load models on GPU. See the [installation instructions](#installation) to run models on CPU. - - -You can pass any parameter that you would normally pass to `vllm.LLM`, as keyword arguments: - -```python -from outlines import models - -model = models.vllm( - "microsoft/Phi-3-mini-4k-instruct", - trust_remote_code=True, - gpu_memory_utilization=0.7 -) -``` - -**Main parameters:** - -| **Parameters** | **Type** | **Description** | **Default** | -|----------------|:---------|:----------------|:------------| -| `tokenizer_mode`| `str` | "auto" will use the fast tokenizer if available and "slow" will always use the slow tokenizer. | `auto` -| `trust_remote_code`| `bool` | Trust remote code when downloading the model and tokenizer. | `False` | -| `tensor_parallel_size`| `int` | The number of GPUs to use for distributed execution with tensor parallelism.| `1` | -| `dtype`| `str` | The data type for the model weights and activations. Currently, we support `float32`, `float16`, and `bfloat16`. If `auto`, we use the `torch_dtype` attribute specified in the model config file. However, if the `torch_dtype` in the config is `float32`, we will use `float16` instead.| `auto` | -| `quantization`| `Optional[str]` | The method used to quantize the model weights. Currently, we support "awq", "gptq" and "squeezellm". If None, we first check the `quantization_config` attribute in the model config file. If that is None, we assume the model weights are not quantized and use `dtype` to determine the data type of the weights.| `None` | -| `revision`| `Optional[str]` | The specific model version to use. It can be a branch name, a tag name, or a commit id.| `None` | -| `tokenizer_revision`| `Optional[str]`| The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id.| `None` | -| `gpu_memory_utilization`| `float` | The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. Higher values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors.| `0.9` | -| `swap_space`| `int` | The size (GiB) of CPU memory per GPU to use as swap space. This can be used for temporarily storing the states of the requests when their `best_of` sampling parameters are larger than 1. If all requests will have `best_of=1`, you can safely set this to 0. Otherwise, too small values may cause out-of-memory (OOM) errors.| 4 | -| `enforce_eager`| `bool` | Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid.| `False` | -| `enable_lora` | `bool` | Whether to enable loading LoRA adapters | `False` | - -See the [vLLM code](https://github.com/vllm-project/vllm/blob/8f44facdddcf3c704f7d6a2719b6e85efc393449/vllm/entrypoints/llm.py#L72) for a list of all the available parameters. - -### Use quantized models - -vLLM supports AWQ, GPTQ and SqueezeLLM quantized models: + The default installation of vLLM only allows to load models on GPU. See the [installation instructions][vllm-install-cpu] to run models on CPU. -```python -from outlines import models - -model = models.vllm("TheBloke/Llama-2-7B-Chat-AWQ", quantization="awq") -model = models.vllm("TheBloke/Mistral-7B-Instruct-v0.2-GPTQ", quantization="gptq") -model = models.vllm("https://huggingface.co/squeeze-ai-lab/sq-llama-30b-w4-s5", quantization="squeezellm") -``` - -!!! Warning "Dependencies" - - To use AWQ model you need to install the autoawq library `pip install autoawq`. - - To use GPTQ models you need to install the autoGTPQ and optimum libraries `pip install auto-gptq optimum`. - - -### Multi-GPU usage - -To run multi-GPU inference with vLLM you need to set the `tensor_parallel_size` argument to the number of GPUs available when initializing the model. For instance to run inference on 2 GPUs: - - -```python -from outlines import models - -model = models.vllm( - "microsoft/Phi-3-mini-4k-instruct" - tensor_parallel_size=2 -) -``` - -### Load LoRA adapters - -You can load LoRA adapters and alternate between them dynamically: - -```python -from outlines import models - -model = models.vllm("facebook/opt-350m", enable_lora=True) -model.load_lora("ybelkaa/opt-350m-lora") # Load LoRA adapter -model.load_lora(None) # Unload LoRA adapter -``` - ## Generate text In addition to the parameters described in the [text generation section](../text.md) you can pass an instance of `SamplingParams` directly to any generator via the `sampling_params` keyword argument: @@ -124,112 +35,16 @@ from vllm.sampling_params import SamplingParams from outlines import models, generate -model = models.vllm("microsoft/Phi-3-mini-4k-instruct") -generator = generate.text(model) - +model = outlines.from_vllm(LLM("microsoft/Phi-3-mini-4k-instruct")) params = SamplingParams(n=2, frequency_penalty=1., min_tokens=2) -answer = generator("A prompt", sampling_params=params) +answer = model("A prompt", sampling_params=params) ``` -This also works with generators built with `generate.regex`, `generate.json`, `generate.cfg`, `generate.format` and `generate.choice`. - -!!! Note - - The values passed via the `SamplingParams` instance supersede the other arguments to the generator or the samplers. - -**`SamplingParams` attributes:** - -| Parameters | Type | Description | Default | -|:-----------|------------------|:-----------------------|---------| -| `n` | `int` | Number of output sequences to return for the given prompt. | `1` | -| `best_of` | `Optional[int]` | Number of output sequences that are generated from the prompt. From these `best_of` sequences, the top `n` sequences are returned. `best_of` must be greater than or equal to `n`. This is treated as the beam width when `use_beam_search` is True. By default, `best_of` is set to `n`. | `None` | -| `presence_penalty` | `float` | Float that penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.| `0.0` | -| `frequency_penalty` | `float` | Float that penalizes new tokens based on their frequency in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens. | `0.0` -| `repetition_penalty` | `float` | Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to repeat tokens. | `1.0` | -| `temperature` | `float` | Float that controls the randomness of the sampling. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling. | `1.0` | -| `top_p` | `float` | Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to 1 to consider all tokens. | `1.0` | -| `top_k` | `int` | Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens. | `-1` | -| `min_p` |`float` | Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this. | `0.0` | -| `seed` | `Optional[int]` | Random seed to use for the generation. | `None` | -| `use_beam_search` | `bool` | Whether to use beam search instead of sampling. | `False` | -| `length_penalty` | `float` | Float that penalizes sequences based on their length. Used in beam search. | `1.0` | -| `early_stopping` | `Union[bool, str]` | Controls the stopping condition for beam search. It accepts the following values: `True`, where the generation stops as soon as there are `best_of` complete candidates; `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm). | `False` | -| `stop` | `Optional[Union[str, List[str]]]` | List of strings that stop the generation when they are generated. The returned output will not contain the stop strings. | `None` | -| `stop_token_ids` | `Optional[List[int]]` | List of tokens that stop the generation when they are generated. The returned output will contain the stop tokens unless the stop tokens are special tokens. | `None` | -| `include_stop_str_in_output` | `bool` | Whether to include the stop strings in output text. Defaults to False. | `False` | -| `ignore_eos` | `bool` | Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. | `False` | -| `max_tokens` | `int` | Maximum number of tokens to generate per output sequence. | `16` | -| `min_tokens` | `int` | Minimum number of tokens to generate per output sequence before EOS or stop_token_ids can be generated | `0` | -| `skip_special_tokens` | `bool` | Whether to skip special tokens in the output. | `True` | -| `spaces_between_special_tokens` | `bool` | Whether to add spaces between special tokens in the output. Defaults to True. | `True` | - -### Streaming - !!! Warning Streaming is not available for the offline vLLM integration. - -## Installation - -By default the vLLM library is installed with pre-commpiled C++ and CUDA binaries and will only run on GPU: - -```python -pip install vllm -``` - -### CPU - -You need to have the `gcc` compiler installed on your system. Then you will need to install vLLM from source. First clone the repository: - -```bash -git clone https://github.com/vllm-project/vllm.git -cd vllm -``` - -Install the Python packages needed for the installation: - -```bash -pip install --upgrade pip -pip install wheel packaging ninja setuptools>=49.4.0 numpy -pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu -``` - -and finally run: - -```bash -VLLM_TARGET_DEVICE=cpu python setup.py install -``` - -See the [vLLM documentation][vllm-install-cpu] for more details, alternative installation methods (Docker) and performance tips. - -### ROCm - - -You will need to install vLLM from source. First install Pytorch on ROCm: - -```bash -pip install torch==2.2.0.dev20231206+rocm5.7 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # tested version -``` - -You will then need to install flash attention for ROCm following [these instructions][rocm-flash-attention]. You can then install `xformers=0.0.23` and apply the patches needed to adapt Flash Attention for ROCm: - -```bash -pip install xformers==0.0.23 --no-deps -bash patch_xformers.rocm.sh -``` - -And finally build vLLM: - -```bash -cd vllm -pip install -U -r requirements-rocm.txt -python setup.py install # This may take 5-10 minutes. -``` - -See the [vLLM documentation][vllm-install-rocm] for alternative installation methods (Docker). - - +[vllm-docs]:https://docs.vllm.ai/en/latest/ [vllm-install-cpu]: https://docs.vllm.ai/en/latest/getting_started/cpu-installation.html [vllm-install-rocm]: https://docs.vllm.ai/en/latest/getting_started/amd-installation.html [rocm-flash-attention]: https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support diff --git a/examples/babyagi.py b/examples/babyagi.py index 0a7a0b13b..2b647999b 100644 --- a/examples/babyagi.py +++ b/examples/babyagi.py @@ -4,14 +4,16 @@ The original repo can be found at https://github.com/yoheinakajima/babyagi """ + from collections import deque from typing import Deque, List +from openai import OpenAI + import outlines -import outlines.models as models -model = models.openai("gpt-4o-mini") -complete = outlines.generate.text(model) +model = outlines.from_openai(OpenAI(), "gpt-4o-mini") +complete = outlines.Generator(model) ################# diff --git a/examples/bentoml/service.py b/examples/bentoml/service.py index 98d46894a..884904a73 100644 --- a/examples/bentoml/service.py +++ b/examples/bentoml/service.py @@ -1,7 +1,7 @@ import typing as t import bentoml -from import_model import BENTO_MODEL_TAG +from import_model import BENTO_MODEL_TAG, MODEL_ID DEFAULT_SCHEMA = """{ "title": "Character", @@ -55,13 +55,17 @@ class Outlines: def __init__(self) -> None: import torch + from transformers import AutoModelForCausalLM, AutoTokenizer import outlines - self.model = outlines.models.transformers( - self.bento_model_ref.path, - device="cuda", - model_kwargs={"torch_dtype": torch.float16}, + self.model = outlines.from_transformers( + AutoTokenizer.from_pretrained(MODEL_ID), + AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ) ) @bentoml.api diff --git a/examples/cerebrium/main.py b/examples/cerebrium/main.py index b61cdb6e5..5ea1b61f3 100644 --- a/examples/cerebrium/main.py +++ b/examples/cerebrium/main.py @@ -1,6 +1,11 @@ import outlines +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = outlines.from_transformers( + AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2"), + AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2"), +) -model = outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.2") schema = { "title": "Character", @@ -29,14 +34,13 @@ }, } -generator = outlines.generate.json(model, schema) - def generate( prompt: str = "Amiri, a 53 year old warrior woman with a sword and leather armor.", ): - character = generator( - f"[INST]Give me a character description. Describe {prompt}.[/INST]" + character = model( + f"[INST]Give me a character description. Describe {prompt}.[/INST]", + outlines.JsonType(schema), ) print(character) diff --git a/examples/cfg.py b/examples/cfg.py index 99edf35f1..22fdf7ed0 100644 --- a/examples/cfg.py +++ b/examples/cfg.py @@ -1,5 +1,7 @@ -import outlines.generate as generate -import outlines.models as models +from transformers import AutoModelForCausalLM, AutoTokenizer + +import outlines +from outlines.types import Cfg nlamb_grammar = r""" start: sentence @@ -75,11 +77,16 @@ %ignore WS """ -model = models.transformers("hf-internal-testing/tiny-random-gpt2") +model_name = "hf-internal-testing/tiny-random-gpt2" +model = outlines.from_transformers( + AutoModelForCausalLM.from_pretrained(model_name), + AutoTokenizer.from_pretrained(model_name), +) + batch_size = 10 for grammar in [nlamb_grammar, calc_grammar, dyck_grammar, json_grammar]: - generator = generate.cfg(model, grammar, max_tokens=model.model.config.n_positions) - sequences = generator([" "] * batch_size) + generator = outlines.Generator(model, Cfg(grammar)) + sequences = generator([" "] * batch_size, max_tokens=model.model.config.n_positions) for seq in sequences: try: parse = generator.fsm.parser.parse(seq) diff --git a/examples/dating_profile.py b/examples/dating_profile.py index 504ec943d..29bc78e60 100644 --- a/examples/dating_profile.py +++ b/examples/dating_profile.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, conlist import outlines -from outlines import models class QuestionChoice(str, Enum): @@ -103,25 +102,27 @@ def dating_profile_prompt(description: str, examples: list[Example]): # Below requires ~13GB of GPU memory # https://huggingface.co/mosaicml/mpt-7b-8k-instruct # Motivation: Reasonably large model that fits on a single GPU and has been fine-tuned for a larger context window +model_name = "mosaicml/mpt-7b-8k-instruct" config = transformers.AutoConfig.from_pretrained( "mosaicml/mpt-7b-8k-instruct", trust_remote_code=True ) config.init_device = "meta" -model = models.transformers( - model_name="mosaicml/mpt-7b-8k-instruct", - device="cuda", - model_kwargs={ - "config": config, - "trust_remote_code": True, - "torch_dtype": torch.bfloat16, - "device_map": {"": 0}, - }, +model = outlines.from_transformers( + transformers.AutoModelForCausalLM.from_pretrained( + model_name, + device="cuda", + config=config, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + device_map={"": 0}, + ), + transformers.AutoTokenizer.from_pretrained(model_name), ) new_description = "I'm a laid-back lawyer who spends a lot of his free-time gaming. I work in a corporate office, but ended up here after the start-up I cofounded got acquired, so still play ping pong with my cool coworkers every day. I have a bar at home where I make cocktails, which is great for entertaining friends. I secretly like to wear suits and get a new one tailored every few months. I also like weddings because I get to wear those suits, and it's a good excuse for a date. I watch the latest series because I'm paying, with my hard-earned money, for every streaming service." prompt = dating_profile_prompt(description=new_description, examples=samples) -profile = outlines.generate.json(model, DatingProfile)(prompt) # type: ignore +profile = model(prompt, outlines.JsonType(DatingProfile)) # type: ignore print(profile) # Sample generated profiles diff --git a/examples/llamacpp_example.py b/examples/llamacpp_example.py index 22d0da3ba..cf64a934a 100644 --- a/examples/llamacpp_example.py +++ b/examples/llamacpp_example.py @@ -1,6 +1,7 @@ from enum import Enum from pydantic import BaseModel, constr +from llama_cpp import Llama import outlines @@ -30,10 +31,10 @@ class Character(BaseModel): if __name__ == "__main__": # curl -L -o mistral-7b-instruct-v0.2.Q5_K_M.gguf https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q5_K_M.gguf - model = outlines.models.llamacpp("./mistral-7b-instruct-v0.2.Q5_K_M.gguf") + model = outlines.from_llamacpp(Llama("./mistral-7b-instruct-v0.2.Q5_K_M.gguf")) # Construct structured sequence generator - generator = outlines.generate.json(model, Character) + generator = outlines.Generator(model, outlines.JsonType(Character)) # Draw a sample seed = 789005 diff --git a/examples/math_generate_code.py b/examples/math_generate_code.py index 7eb1651a7..8fd6a22c2 100644 --- a/examples/math_generate_code.py +++ b/examples/math_generate_code.py @@ -1,6 +1,7 @@ """Example from https://dust.tt/spolu/a/d12ac33169""" + import outlines -import outlines.models as models +import openai examples = [ {"question": "What is 37593 * 67?", "code": "37593 * 67"}, @@ -35,7 +36,7 @@ def execute_code(code): prompt = answer_with_code_prompt(question, examples) -model = models.openai("gpt-4o-mini") -answer = outlines.generate.text(model)(prompt) +model = outlines.from_openai(openai.OpenAI(), "gpt-4o-mini") +answer = model(prompt) result = execute_code(answer) print(f"It takes Carla {result:.0f} minutes to download the file.") diff --git a/examples/meta_prompting.py b/examples/meta_prompting.py index cba18b5fe..e05487001 100644 --- a/examples/meta_prompting.py +++ b/examples/meta_prompting.py @@ -9,10 +9,14 @@ https://arxiv.org/abs/2102.07350. """ + import argparse +import openai import outlines -import outlines.models as models + + +client = openai.OpenAI() def split_into_steps(question, model_name: str): @@ -22,16 +26,15 @@ def solve(question): Rephrase : : as a true or false statement, identify an Object, relationship and subject """ - model = models.openai(model_name) - generator = outlines.generate.text(model) + model = outlines.from_openai(client, model_name) prompt = solve(question) - answer = generator(prompt, 500) + answer = model(prompt, 500) prompt += ( answer + "\n what is the only option that displays the same type of relationship as : :?" ) - answer = generator(prompt, 500) + answer = model(prompt, 500) completed = prompt + answer return completed @@ -49,13 +52,12 @@ def determine_goal(question): def solve(memory): """{{memory}}. Let's begin.""" - model = models.openai(model_name) - generator = outlines.generate.text(model) + model = outlines.from_openai(client, model_name) prompt = determine_goal(question) - answer = generator(prompt, stop_at=["."]) + answer = model(prompt, stop_at=["."]) prompt = solve(prompt + answer) - answer = generator(prompt, max_tokens=500) + answer = model(prompt, max_tokens=500) completed = prompt + answer return completed @@ -89,13 +91,12 @@ def get_answer(question, expert, memory): {{question}} """ - model = models.openai(model_name) - generator = outlines.generate.text(model) + model = outlines.from_openai(client, model_name) prompt = find_expert(question) - expert = generator(prompt, stop_at=['"']) + expert = model(prompt, stop_at=['"']) prompt = get_answer(question, expert, prompt + expert) - answer = generator(prompt, max_tokens=500) + answer = model(prompt, max_tokens=500) completed = prompt + answer return completed @@ -117,13 +118,12 @@ def get_answer(expert, memory): For instance, {{expert}} would answer """ - model = models.openai(model_name) - generator = outlines.generate.text(model) + model = outlines.from_openai(client, model_name) prompt = find_expert(question) - expert = generator(prompt, stop_at=["\n", "."]) + expert = model(prompt, stop_at=["\n", "."]) prompt = get_answer(expert, prompt + expert) - answer = generator(prompt, max_tokens=500) + answer = model(prompt, max_tokens=500) completed = prompt + answer return completed diff --git a/examples/modal_example.py b/examples/modal_example.py index 34d4ee606..44b66e15c 100644 --- a/examples/modal_example.py +++ b/examples/modal_example.py @@ -4,7 +4,7 @@ outlines_image = modal.Image.debian_slim(python_version="3.11").pip_install( - "outlines==0.0.37", + "outlines==1.0.0", "transformers==4.38.2", "datasets==2.18.0", "accelerate==0.27.2", @@ -12,9 +12,11 @@ def import_model(): - import outlines + from transformers import AutoModelForCausalLM, AutoTokenizer - outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.2") + model_id = "mistralai/Mistral-7B-Instruct-v0.2" + _ = AutoTokenizer.from_pretrained(model_id) + _ = AutoModelForCausalLM.from_pretrained(model_id) outlines_image = outlines_image.run_function(import_model) @@ -63,12 +65,17 @@ def generate( prompt: str = "Amiri, a 53 year old warrior woman with a sword and leather armor.", ): import outlines + from transformers import AutoModelForCausalLM, AutoTokenizer - model = outlines.models.transformers("mistralai/Mistral-7B-v0.1", device="cuda") + model_id = "mistralai/Mistral-7B-Instruct-v0.2" + model = outlines.from_transformers( + tokenizer=AutoTokenizer.from_pretrained(model_id), + model=AutoModelForCausalLM.from_pretrained(model_id, device="cuda"), + ) - generator = outlines.generate.json(model, schema) - character = generator( - f"[INST]Give me a character description. Describe {prompt}.[/INST]" + character = model( + f"[INST]Give me a character description. Describe {prompt}.[/INST]", + outlines.JsonType(schema), ) print(character) diff --git a/examples/parsing.py b/examples/parsing.py index a10da4ebe..9e08eae14 100644 --- a/examples/parsing.py +++ b/examples/parsing.py @@ -1,4 +1,5 @@ """An example illustrating parser-based masking.""" + import math import time from copy import copy diff --git a/examples/pick_odd_one_out.py b/examples/pick_odd_one_out.py index 6cd9f1daf..4feb3ae21 100644 --- a/examples/pick_odd_one_out.py +++ b/examples/pick_odd_one_out.py @@ -9,8 +9,9 @@ arXiv preprint arXiv:2212.06094. """ + +import openai import outlines -import outlines.models as models @outlines.prompt @@ -31,7 +32,7 @@ def build_ooo_prompt(options): options = ["sea", "mountains", "plains", "sock"] -model = models.openai("gpt-4o-mini") +model = outlines.from_openai(openai.OpenAI(), "gpt-4o-mini") gen_text = outlines.generate.text(model) gen_choice = outlines.generate.choice(model, options) diff --git a/examples/react.py b/examples/react.py index 34b3c6eb2..136755e14 100644 --- a/examples/react.py +++ b/examples/react.py @@ -10,11 +10,12 @@ .. [2] Yao, S., Zhao, J., Yu, D., Du, N., Shafran, I., Narasimhan, K., & Cao, Y. (2022). React: Synergizing reasoning and acting in language models. arXiv preprint arXiv:2210.03629. """ + +from openai import OpenAI import requests # type: ignore import outlines -import outlines.generate as generate -import outlines.models as models +from outlines import Generator, Choice @outlines.prompt @@ -46,11 +47,11 @@ def search_wikipedia(query: str): prompt = build_reAct_prompt("Where is Apple Computers headquarted? ") -model = models.openai("gpt-4o-mini") +model = outlines.from_openai(OpenAI(), "gpt-4o-mini") -mode_generator = generate.choice(model, choices=["Tho", "Act"]) -action_generator = generate.choice(model, choices=["Search", "Finish"]) -text_generator = generate.text(model) +mode_generator = Generator(model, Choice(["Tho", "Act"])) +action_generator = Generator(model, Choice(["Search", "Finish"])) +text_generator = Generator(model) for i in range(1, 10): mode = mode_generator(prompt, max_tokens=128) diff --git a/examples/sampling.ipynb b/examples/sampling.ipynb index bcbcca1e2..d6b50e33d 100644 --- a/examples/sampling.ipynb +++ b/examples/sampling.ipynb @@ -85,6 +85,7 @@ " },\n", "]\n", "\n", + "\n", "@text.prompt\n", "def few_shot_prompt(question, examples):\n", " \"\"\"\n", @@ -96,6 +97,7 @@ " A:\n", " \"\"\"\n", "\n", + "\n", "# Prompt functions can be partially evaluated like any other function\n", "gsm8k_prompt = ft.partial(few_shot_prompt, examples=examples)" ] @@ -155,31 +157,40 @@ " digits.append(digit)\n", " except AttributeError:\n", " print(f\"Could not parse the completion: '{answer}'\")\n", - " \n", + "\n", " unique_digits, counts = np.unique(digits, return_counts=True)\n", " return {d: c for d, c in zip(unique_digits, counts)}\n", "\n", + "\n", "def plot_counts(counts):\n", - " fig = plt.figure(figsize=(12,8))\n", + " fig = plt.figure(figsize=(12, 8))\n", " ax = fig.add_subplot(111)\n", - " \n", + "\n", " bar = ax.bar(counts.keys(), counts.values())\n", " ax.spines[[\"right\", \"top\", \"left\"]].set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", - " \n", + "\n", " for rect in bar:\n", " height = rect.get_height()\n", - " plt.text(rect.get_x() + rect.get_width() / 2.0, height, f'{height:.0f}', ha='center', va='bottom', fontsize=20)\n", - " \n", + " plt.text(\n", + " rect.get_x() + rect.get_width() / 2.0,\n", + " height,\n", + " f\"{height:.0f}\",\n", + " ha=\"center\",\n", + " va=\"bottom\",\n", + " fontsize=20,\n", + " )\n", + "\n", " ax.set_xticks(list(counts.keys()))\n", " ax.set_xlabel(\"Answer\")\n", "\n", + "\n", "def entropy(counts):\n", " counts = np.array(list(counts.values()))\n", " probs = counts / np.sum(counts)\n", " log_probs = np.log(probs)\n", - " return - np.sum(probs * log_probs)" + " return -np.sum(probs * log_probs)" ] }, { diff --git a/examples/self_consistency.py b/examples/self_consistency.py index f1bbe2a18..e251f6411 100644 --- a/examples/self_consistency.py +++ b/examples/self_consistency.py @@ -1,9 +1,9 @@ import re import numpy as np +import openai import outlines -import outlines.models as models examples = [ { @@ -55,8 +55,8 @@ def few_shots(question, examples): """ -model = models.openai("gpt-4o-mini") -generator = outlines.generate.text(model) +model = outlines.from_openai(openai.OpenAI(), "gpt-4o-mini") +generator = outlines.Generator(model) prompt = few_shots(question, examples) answers = generator(prompt, samples=10) @@ -78,5 +78,5 @@ def few_shots(question, examples): answer_value = [key for key, value in results.items() if value == max_count][0] total_count = sum(results.values()) print( - f"The most likely answer is {answer_value} ({max_count/total_count*100}% consensus)" + f"The most likely answer is {answer_value} ({max_count / total_count * 100}% consensus)" ) diff --git a/examples/simulation_based_inference.ipynb b/examples/simulation_based_inference.ipynb index e6b999582..6ad52e58e 100644 --- a/examples/simulation_based_inference.ipynb +++ b/examples/simulation_based_inference.ipynb @@ -61,7 +61,9 @@ "metadata": {}, "outputs": [], "source": [ - "result = requests.get(\"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl\")\n", + "result = requests.get(\n", + " \"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl\"\n", + ")\n", "lines = result.iter_lines()" ] }, @@ -121,8 +123,10 @@ " A:\n", " \"\"\"\n", "\n", + "\n", "model = models.text_completion.openai(\"text-davinci-003\", max_tokens=128)\n", "\n", + "\n", "# TODO: This could largely benefit from vectorization in #52\n", "def one_train_example(problem, example_set):\n", " example_ids = random.choices(range(0, len(example_set)), k=5)\n", @@ -216,7 +220,7 @@ "\n", "example_ids, counts = np.unique(samples, return_counts=True)\n", "\n", - "fig = plt.figure(figsize=(12,8))\n", + "fig = plt.figure(figsize=(12, 8))\n", "ax = fig.add_subplot(111)\n", "ax.bar(example_ids, counts)\n", "\n", @@ -224,7 +228,7 @@ "\n", "ax.set_xticks(range(10))\n", "ax.set_xlabel(\"Example #\")\n", - "ax.set_ylabel(\"Counts\")\n" + "ax.set_ylabel(\"Counts\")" ] }, { diff --git a/outlines/__init__.py b/outlines/__init__.py index df20a5dd9..4aacb6c7e 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -5,18 +5,49 @@ import outlines.models import outlines.processors import outlines.types +from outlines.types import Choice, Regex, JsonType from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache from outlines.function import Function +from outlines.generate import Generator from outlines.templates import Template, prompt +from outlines.models import ( + from_dottxt, + from_openai, + from_transformers, + from_gemini, + from_anthropic, + from_ollama, + from_llamacpp, + from_mlxlm, + from_vllm, +) + + +model_list = [ + "from_anthropic", + "from_dottxt", + "from_gemini", + "from_llamacpp", + "from_mlxlm", + "from_ollama", + "from_openai", + "from_transformers", + "from_vllm", +] + __all__ = [ "clear_cache", "disable_cache", "get_cache", "Function", + "Generator", + "JsonType", + "Cfg", + "Regex", "prompt", "Template", "vectorize", "grammars", -] +] + model_list diff --git a/outlines/function.py b/outlines/function.py index aff21d68f..d7f6dec68 100644 --- a/outlines/function.py +++ b/outlines/function.py @@ -4,7 +4,7 @@ import requests -from outlines import generate, models +import outlines if TYPE_CHECKING: from outlines.generate.api import SequenceGenerator @@ -37,8 +37,14 @@ def from_github(cls, program_path: str, function_name: str = "fn"): def init_generator(self): """Load the model and initialize the generator.""" - model = models.transformers(self.model_name) - self.generator = generate.json(model, self.schema) + from transformers import AutoModelForCausalLM, AutoTokenizer + + model = outlines.from_transformers( + AutoModelForCausalLM.from_pretrained(self.model_name), + AutoTokenizer.from_pretrained(self.model_name), + ) + + self.generator = outlines.Generator(model, outlines.JsonType(self.schema)) def __call__(self, *args, **kwargs): """Call the function. diff --git a/outlines/generate/__init__.py b/outlines/generate/__init__.py index 2fb99a030..0d4f3c8e3 100644 --- a/outlines/generate/__init__.py +++ b/outlines/generate/__init__.py @@ -3,7 +3,7 @@ from outlines.models import APIModel, LlamaCpp, LocalModel from outlines.processors import CFGLogitsProcessor, RegexLogitsProcessor -from outlines.types import CFG, Choice, Json, List, Regex +from outlines.types import CFG, Choice, JsonType, List, Regex from .api import SequenceGenerator from .cfg import cfg @@ -29,7 +29,7 @@ class APIGenerator: """ model: APIModel - output_type: Optional[Union[Json, List, Choice, Regex]] = None + output_type: Optional[Union[JsonType, List, Choice, Regex]] = None def __post_init__(self): if isinstance(self.output_type, CFG): @@ -58,7 +58,7 @@ class LocalGenerator: """ model: LocalModel - output_type: Optional[Union[Json, List, Choice, Regex]] + output_type: Optional[Union[JsonType, List, Choice, Regex]] def __post_init__(self): if self.output_type is None: @@ -81,7 +81,7 @@ def __call__(self, prompt, **inference_kwargs): def Generator( model: Union[LocalModel, APIModel], - output_type: Optional[Union[Json, List, Choice, Regex, CFG]] = None, + output_type: Optional[Union[JsonType, List, Choice, Regex, CFG]] = None, ): if isinstance(model, APIModel): # type: ignore return APIGenerator(model, output_type) # type: ignore diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index 7adf0d0b5..ee50356b0 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -8,22 +8,26 @@ from typing import Union -from .anthropic import Anthropic +from .anthropic import from_anthropic, Anthropic from .base import Model, ModelTypeAdapter -from .dottxt import Dottxt +from .dottxt import Dottxt, from_dottxt from .exllamav2 import ExLlamaV2Model, exl2 -from .gemini import Gemini -from .llamacpp import LlamaCpp -from .mlxlm import MLXLM, mlxlm -from .ollama import Ollama -from .openai import AzureOpenAI, OpenAI -from .transformers import Mamba, Transformers, TransformerTokenizer -from .transformers_vision import TransformersVision -from .vllm import VLLM, vllm +from .gemini import from_gemini, Gemini +from .llamacpp import LlamaCpp, from_llamacpp +from .mlxlm import MLXLM, from_mlxlm +from .ollama import Ollama, from_ollama +from .openai import from_openai, OpenAI +from .transformers import ( + Transformers, + TransformerTokenizer, + TransformersVision, + from_transformers, +) +from .vllm import VLLM, from_vllm LogitsGenerator = Union[ Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM, Ollama ] -LocalModel = Union[LlamaCpp, Transformers] -APIModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini, Ollama, Dottxt] +LocalModel = Union[LlamaCpp, Transformers, MLXLM, VLLM] +APIModel = Union[OpenAI, Anthropic, Gemini, Ollama, Dottxt] diff --git a/outlines/models/anthropic.py b/outlines/models/anthropic.py index 7084215d4..c7ed702dd 100644 --- a/outlines/models/anthropic.py +++ b/outlines/models/anthropic.py @@ -1,12 +1,15 @@ """Integration with Anthropic's API.""" from functools import singledispatchmethod -from typing import Union +from typing import Union, TYPE_CHECKING from outlines.models.base import Model, ModelTypeAdapter from outlines.templates import Vision -__all__ = ["Anthropic"] +if TYPE_CHECKING: + from anthropic import Anthropic as AnthropicClient + +__all__ = ["Anthropic", "from_anthropic"] class AnthropicTypeAdapter(ModelTypeAdapter): @@ -82,12 +85,9 @@ def format_output_type(self, output_type): class Anthropic(Model): - def __init__(self, model_name: str, *args, **kwargs): - from anthropic import Anthropic - - self.client = Anthropic(*args, **kwargs) + def __init__(self, client: "AnthropicClient", model_name: str): + self.client = client self.model_name = model_name - self.model_type = "api" self.type_adapter = AnthropicTypeAdapter() def generate( @@ -106,3 +106,7 @@ def generate( **inference_kwargs, ) return completion.content[0].text + + +def from_anthropic(client: "AnthropicClient", model_name: str) -> Anthropic: + return Anthropic(client, model_name) diff --git a/outlines/models/dottxt.py b/outlines/models/dottxt.py index b7fb3722b..576a4c4cd 100644 --- a/outlines/models/dottxt.py +++ b/outlines/models/dottxt.py @@ -3,10 +3,13 @@ import json from functools import singledispatchmethod from types import NoneType -from typing import Optional +from typing import Optional, TYPE_CHECKING from outlines.models.base import Model, ModelTypeAdapter -from outlines.types import Json +from outlines.types import JsonType + +if TYPE_CHECKING: + from dottxt import Dottxt as DottxtClient __all__ = ["Dottxt"] @@ -41,8 +44,8 @@ def format_output_type(self, output_type): f"The input type {input} is not available with Dottxt." ) - @format_output_type.register(Json) - def format_json_output_type(self, output_type: Json): + @format_output_type.register(JsonType) + def format_json_output_type(self, output_type: JsonType): """Format the output type to pass to the client.""" schema = output_type.to_json_schema() return json.dumps(schema) @@ -63,11 +66,15 @@ class Dottxt(Model): """ - def __init__(self, model_name: Optional[str] = None, *args, **kwargs): - from dottxt.client import Dottxt - - self.client = Dottxt(*args, **kwargs) + def __init__( + self, + client: "Dottxt", + model_name: Optional[str] = None, + model_revision: str = "", + ): + self.client = client self.model_name = model_name + self.model_revision = model_revision self.type_adapter = DottxtTypeAdapter() def generate(self, model_input, output_type=None, **inference_kwargs): @@ -76,6 +83,7 @@ def generate(self, model_input, output_type=None, **inference_kwargs): if self.model_name: inference_kwargs["model_name"] = self.model_name + inference_kwargs["model_revision"] = self.model_revision completion = self.client.json( prompt, @@ -83,3 +91,11 @@ def generate(self, model_input, output_type=None, **inference_kwargs): **inference_kwargs, ) return completion.data + + +def from_dottxt( + client: "DottxtClient", + model_name: Optional[str] = None, + model_revision: str = "", +): + return Dottxt(client, model_name, model_revision) diff --git a/outlines/models/gemini.py b/outlines/models/gemini.py index 401bd401a..f0fd8d78a 100644 --- a/outlines/models/gemini.py +++ b/outlines/models/gemini.py @@ -3,16 +3,20 @@ from enum import EnumMeta from functools import singledispatchmethod from types import NoneType -from typing import Optional, Union +from typing import Optional, Union, TYPE_CHECKING from pydantic import BaseModel from typing_extensions import _TypedDictMeta # type: ignore from outlines.models.base import Model, ModelTypeAdapter from outlines.templates import Vision -from outlines.types import Choice, Json, List +from outlines.types import JsonType, Choice, List -__all__ = ["Gemini"] + +if TYPE_CHECKING: + from google.generativeai import GenerativeModel as GeminiClient + +__all__ = ["Gemini", "from_gemini"] class GeminiTypeAdapter(ModelTypeAdapter): @@ -69,7 +73,7 @@ def format_list_output_type(self, output_type): def format_none_output_type(self, output_type): return {} - @format_output_type.register(Json) + @format_output_type.register(JsonType) def format_json_output_type(self, output_type): """Gemini only accepts Pydantic models and TypeDicts to define the JSON structure.""" if issubclass(output_type.definition, BaseModel): @@ -94,17 +98,14 @@ def format_enum_output_type(self, output_type): class Gemini(Model): - def __init__(self, model_name: str, *args, **kwargs): - import google.generativeai as genai - - self.client = genai.GenerativeModel(model_name, *args, **kwargs) - self.model_type = "api" + def __init__(self, client: "GeminiClient"): + self.client = client self.type_adapter = GeminiTypeAdapter() def generate( self, model_input: Union[str, Vision], - output_type: Optional[Union[Json, EnumMeta]] = None, + output_type: Optional[Union[JsonType, EnumMeta]] = None, **inference_kwargs, ): import google.generativeai as genai @@ -118,3 +119,7 @@ def generate( ) return completion.text + + +def from_gemini(client: "GeminiClient"): + return Gemini(client) diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 8823f3358..3657b669c 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -141,46 +141,14 @@ def format_output_type(self, output_type): class LlamaCpp(Model): """Wraps a model provided by the `llama-cpp-python` library.""" - def __init__(self, model_path: Union[str, "Llama"], **kwargs): + def __init__(self, model: "Llama"): from llama_cpp import Llama - if isinstance(model_path, Llama): - self.model = model_path - else: - # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved - if "tokenizer" not in kwargs: - warnings.warn( - "The pre-tokenizer in `llama.cpp` handles unicode improperly " - + "(https://github.com/ggerganov/llama.cpp/pull/5613)\n" - + "Outlines may raise a `RuntimeError` when building the regex index.\n" - + "To circumvent this error when using `models.llamacpp()` you may pass the argument" - + "`tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained()`\n" - ) - - self.model = Llama(model_path, **kwargs) - + self.model = model self.tokenizer = LlamaCppTokenizer(self.model) self.model_type = "local" self.type_adapter = LlamaCppTypeAdapter() - @classmethod - def from_pretrained(cls, repo_id, filename, **kwargs): - """Download the model weights from Hugging Face and create a `Llama` instance""" - from llama_cpp import Llama - - # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved - if "tokenizer" not in kwargs: - warnings.warn( - "The pre-tokenizer in `llama.cpp` handles unicode improperly " - + "(https://github.com/ggerganov/llama.cpp/pull/5613)\n" - + "Outlines may raise a `RuntimeError` when building the regex index.\n" - + "To circumvent this error when using `models.llamacpp()` you may pass the argument" - + "`tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained()`\n" - ) - - model = Llama.from_pretrained(repo_id, filename, **kwargs) - return cls(model) - def generate(self, model_input, logits_processor, **inference_kwargs): """Generate text using `llama-cpp-python`. @@ -257,3 +225,7 @@ def token_generator() -> Iterator[str]: return return token_generator() + + +def from_llamacpp(model: "Llama"): + return LlamaCpp(model) diff --git a/outlines/models/mlxlm.py b/outlines/models/mlxlm.py index 843107d66..380fb0150 100644 --- a/outlines/models/mlxlm.py +++ b/outlines/models/mlxlm.py @@ -12,6 +12,9 @@ from outlines.processors import OutlinesLogitsProcessor +__all__ = ["MLXLM", "from_mlxlm"] + + class MLXLM: """ Represents an `mlx_lm` model @@ -196,52 +199,5 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]: unprocessed_input_ids = new_token_single -def mlxlm( - model_name: str, - tokenizer_config: dict = {}, - model_config: dict = {}, - adapter_path: Optional[str] = None, - lazy: bool = False, -): - """Instantiate a model from the `mlx_lm` library and its tokenizer. - - Signature adapted from - https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422 - - Parameters - ---------- - Args: - path_or_hf_repo (Path): The path or the huggingface repository to load the model from. - tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. - Defaults to an empty dictionary. - model_config(dict, optional): Configuration parameters specifically for the model. - Defaults to an empty dictionary. - adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers - to the model. Default: ``None``. - lazy (bool): If False eval the model parameters to make sure they are - loaded in memory before returning, otherwise they will be loaded - when needed. Default: ``False`` - - Returns - ------- - A `MLXLM` model instance. - - """ - try: - import mlx.core as mx - import mlx_lm - except ImportError: - raise ImportError( - "The `mlx_lm` library needs to be installed in order to use `mlx_lm` models." - ) - if not mx.metal.is_available(): - raise RuntimeError("You cannot use `mlx_lm` without Apple Silicon (Metal)") - - model, tokenizer = mlx_lm.load( - model_name, - tokenizer_config=tokenizer_config, - model_config=model_config, - adapter_path=adapter_path, - lazy=lazy, - ) +def from_mlxlm(model: "nn.Module", tokenizer: "PreTrainedTokenizer") -> MLXLM: return MLXLM(model, tokenizer) diff --git a/outlines/models/ollama.py b/outlines/models/ollama.py index c97460200..2eec91879 100644 --- a/outlines/models/ollama.py +++ b/outlines/models/ollama.py @@ -1,9 +1,15 @@ from functools import singledispatchmethod from types import NoneType -from typing import Iterator +from typing import Iterator, TYPE_CHECKING from outlines.models.base import Model, ModelTypeAdapter -from outlines.types import Json +from outlines.types import JsonType + +if TYPE_CHECKING: + from ollama import Client as OllamaClient + + +__all__ = ["Ollama", "from_ollama"] class OllamaTypeAdapter(ModelTypeAdapter): @@ -46,8 +52,8 @@ def format_output_type(self, output_type): def format_none_output_type(self, output_type: None): return "" - @format_output_type.register(Json) - def format_json_output_type(self, output_type: Json): + @format_output_type.register(JsonType) + def format_json_output_type(self, output_type: JsonType): return output_type.to_json_schema() @@ -59,21 +65,11 @@ class Ollama(Model): """ - def __init__(self, model_name: str, *args, **kwargs): - from ollama import Client - - self.client = Client(*args, **kwargs) + def __init__(self, client: "OllamaClient", model_name: str): + self.client = client self.model_name = model_name self.type_adapter = OllamaTypeAdapter() - @classmethod - def from_pretrained(cls, model_name: str, *args, **kwargs): - """Download the model weights from Ollama and create a `Ollama` instance.""" - from ollama import pull - - pull(model_name) - return cls(model_name, *args, **kwargs) - def generate(self, model_input, output_type=None, **kwargs) -> str: response = self.client.generate( model=self.model_name, @@ -93,3 +89,7 @@ def stream(self, model_input, output_type=None, **kwargs) -> Iterator[str]: ) for chunk in response: yield chunk.response + + +def from_ollama(client: "OllamaClient", model_name: str) -> Ollama: + return Ollama(client, model_name) diff --git a/outlines/models/openai.py b/outlines/models/openai.py index b809742e7..6038ad1a1 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -1,16 +1,20 @@ """Integration with OpenAI's API.""" + from functools import singledispatchmethod from types import NoneType -from typing import Optional, Union +from typing import Optional, Union, TYPE_CHECKING from pydantic import BaseModel from outlines.models.base import Model, ModelTypeAdapter from outlines.templates import Vision -from outlines.types import Json +from outlines.types import JsonType + +if TYPE_CHECKING: + from openai import OpenAI as OpenAIClient, AzureOpenAI as AzureOpenAIClient -__all__ = ["OpenAI"] +__all__ = ["OpenAI", "from_openai"] class OpenAITypeAdapter(ModelTypeAdapter): @@ -92,8 +96,8 @@ def format_none_output_type(self, _: None): """ return {} - @format_output_type.register(Json) - def format_json_output_type(self, output_type: Json): + @format_output_type.register(JsonType) + def format_json_output_type(self, output_type: JsonType): """Generate the `response_format` argument to the client when the user specified a `Json` output type. @@ -124,12 +128,13 @@ class OpenAI(Model): """ - def __init__(self, model_name: str, *args, **kwargs): + def __init__( + self, client: Union["OpenAIClient", "AzureOpenAIClient"], model_name: str + ): from openai import OpenAI - self.client = OpenAI(*args, **kwargs) + self.client = client self.model_name = model_name - self.model_type = "api" self.type_adapter = OpenAITypeAdapter() def generate( @@ -147,18 +152,7 @@ def generate( return result.choices[0].message.content -class AzureOpenAI(OpenAI): - """Thin wrapper around the `openai.AzureOpenAI` client. - - This wrapper is used to convert the input and output types specified by the - users at a higher level to arguments to the `openai.AzureOpenAI` client. - - """ - - def __init__(self, deployment_name: str, *args, **kwargs): - from openai import AzureOpenAI - - self.client = AzureOpenAI(*args, **kwargs) - self.model_name = deployment_name - self.model_type = "api" - self.type_adapter = OpenAITypeAdapter() +def from_openai( + client: Union["OpenAIClient", "AzureOpenAIClient"], model_name: str +) -> OpenAI: + return OpenAI(client, model_name) diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 3d87443f6..b65d0cefb 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -6,12 +6,10 @@ if TYPE_CHECKING: import torch - from transformers import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, PreTrainedModel -__all__ = ["Transformers", "Mamba"] - -KVCacheType = Tuple[Tuple["torch.DoubleTensor", "torch.DoubleTensor"], ...] +__all__ = ["Transformers", "from_transformers"] def get_llama_tokenizer_types(): @@ -168,55 +166,31 @@ class Transformers(Model): def __init__( self, - model_name: str, - model_class=None, - model_kwargs: dict = {}, - tokenizer_class=None, - tokenizer_kwargs: dict = {}, + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", ): - """Create a Transformers model instance + """Create a Transformers model instance. + + `outlines` supports `PreTrainedModelForCausalLM`, + `PreTrainedMambaForCausalLM`, `PreTrainedModelForSeq2Seq` and any model + that implements the `transformers` model API. Parameters: ---------- - model_name - The name of the transformers model to use; - model_class - The Transformers model class from which to create the model. - If not provided,`AutoModelForCausalLM` will be used. - If you gave the name of a non-causal language model, - you must provide a value for this parameter. - model_kwargs - A dictionary of keyword arguments to pass to the `from_pretrained` - method of the model class. - tokenizer_class - The Transformers tokenizer class from which to create the tokenizer. - If not provided,`AutoTokenizer` will be used. - If you gave the name of a model that is not compatible with `AutoTokenizer`, - you must provide a value for this parameter. - tokenizer_kwargs - A dictionary of keyword arguments to pass to the `from_pretrained` - method of the tokenizer class. + model + A `PreTrainedModel`, or any model that is compatible with the + `transformers` API for models. + tokenizer + A `PreTrainedTokenizer`, or any tokenizer that is compatible with + the `transformers` API for tokenizers. """ - if model_class is None or tokenizer_class is None: - try: - from transformers import AutoModelForCausalLM, AutoTokenizer - except ImportError: - raise ImportError( - "The `transformers` library needs to be installed in order to use `transformers` models." - ) - if model_class is None: - model_class = AutoModelForCausalLM - if tokenizer_class is None: - tokenizer_class = AutoTokenizer - self.model = model_class.from_pretrained(model_name, **model_kwargs) - tokenizer_kwargs.setdefault("padding_side", "left") - self.tokenizer = TransformerTokenizer( - tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs) - ) + tokenizer.padding_size = "left" + self.model = model + self.tokenizer = TransformerTokenizer(tokenizer) self.type_adapter = TransformersTypeAdapter() - def generate(self, model_input, output_type, **inference_kwargs): + def _prepare_model_inputs(self, model_input, output_type): prompts = self.type_adapter.format_input(model_input) input_ids, attention_mask = self.tokenizer.encode(prompts) inputs = { @@ -224,6 +198,10 @@ def generate(self, model_input, output_type, **inference_kwargs): "attention_mask": attention_mask.to(self.model.device), } + return prompts, inputs + + def generate(self, model_input, output_type, **inference_kwargs): + prompts, inputs = self._prepare_model_inputs(model_input, output_type) logits_processor = self.type_adapter.format_output_type(output_type) generated_ids = self._generate_output_seq( @@ -231,7 +209,7 @@ def generate(self, model_input, output_type, **inference_kwargs): ) # if single str input, convert to a 1D outputt - if isinstance(model_input, str): + if isinstance(prompts, str): generated_ids = generated_ids.squeeze(0) return self._decode_generation(generated_ids) @@ -278,41 +256,125 @@ def _decode_generation(self, generated_ids: "torch.Tensor"): ) -class Mamba(Transformers): - """Represents a Mamba model.""" +class TransformersVisionTypeAdapter(ModelTypeAdapter): + """Type adapter for TransformersVision models.""" - def __init__( - self, - model_name: str, - model_kwargs: dict = {}, - tokenizer_kwargs: dict = {}, - ): - """ - Create a Mamba model instance + @singledispatchmethod + def format_input(self, model_input): + """Generate the prompt argument to pass to the model. + + Argument + -------- + model_input + The input passed by the user. - Parameters: - ---------- - model_name - The name of the transformers model to use. It will be passed to - the `from_pretrained` method of the `MambaForCausalLM` class. - model_kwargs - A dictionary of keyword arguments to pass to the `from_pretrained` - method of the `MambaForCausalLM` class. - tokenizer_kwargs - A dictionary of keyword arguments to pass to the `from_pretrained` - method of the `AutoTokenizer` class. """ - try: - from transformers import MambaForCausalLM + raise NotImplementedError( + f"The input type {input} is not available. Please provide a " + "dictionary with the following format: " + "{'prompts': Union[str, List[str]], 'images': Union[Any, List[Any]]}" + "Make sure the number of image tags in the prompts is equal to the " + "number of images provided." + ) - except ImportError: - raise ImportError( - "The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba." + @format_input.register(dict) + def format_list_input(self, model_input): + if "prompts" not in model_input or "images" not in model_input: + raise ValueError( + "The input must contain the following keys: 'prompts' and 'images'." ) + return model_input["prompts"], model_input["images"] - return super().__init__( - model_name=model_name, - model_class=MambaForCausalLM, - model_kwargs=model_kwargs, - tokenizer_kwargs=tokenizer_kwargs, + def format_output_type(self, output_type): + """Generate the logits processor argument to pass to the model. + + Argument + -------- + output_type + The logits processor provided. + + """ + from transformers import LogitsProcessorList + + if output_type is not None: + return LogitsProcessorList([output_type]) + return None + + +class TransformersVision(Transformers): + """Represents a `transformers` model with a vision processor.""" + + def __init__(self, model: "PreTrainedModel", processor): + """Create a TransformersVision model instance + + We rely on the `__init__` method of the `Transformers` class to handle + most of the initialization and then add elements specific to vision + models. + + """ + self.processor = processor + self.processor.padding_side = "left" + self.processor.pad_token = "[PAD]" + + tokenizer: "PreTrainedTokenizer" = self.processor.tokenizer + + super().__init__(model, tokenizer) + + self.type_adapter = TransformersVisionTypeAdapter() + + def _prepare_model_inputs(self, model_input, output_type): + prompts, images = self.type_adapter.format_input(model_input) + inputs = self.processor( + text=prompts, images=images, padding=True, return_tensors="pt" + ).to(self.model.device) + + return prompts, inputs + + +def from_transformers( + model: "PreTrainedModel", + tokenizer_or_processor: "PreTrainedTokenizer", +): + from transformers import ( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + Blip2Processor, + LlavaProcessor, + IdeficsProcessor, + CLIPProcessor, + Qwen2_5_VLProcessor, + Qwen2VLProcessor, + NougatProcessor, + LlavaNextProcessor, + PixtralProcessor, + PaliGemmaProcessor, + ) + + vision_processors = ( + Blip2Processor, + LlavaProcessor, + IdeficsProcessor, + CLIPProcessor, + Qwen2_5_VLProcessor, + Qwen2VLProcessor, + NougatProcessor, + LlavaNextProcessor, + PixtralProcessor, + PaliGemmaProcessor, + ) + + if isinstance( + tokenizer_or_processor, (PreTrainedTokenizer, PreTrainedTokenizerFast) + ): + tokenizer = tokenizer_or_processor + return Transformers(model, tokenizer) + elif isinstance(tokenizer_or_processor, vision_processors): + processor = tokenizer_or_processor + return TransformersVision(model, processor) + else: + raise ValueError( + "We could determine whether the model passed to `from_transformers`" + + " is a text-2-text of vision language model. If you passed a" + + " vision language model please open an issue pasting the error" + + " with the name of the vision language model." ) diff --git a/outlines/models/transformers_vision.py b/outlines/models/transformers_vision.py index b1c1c1de4..e69de29bb 100644 --- a/outlines/models/transformers_vision.py +++ b/outlines/models/transformers_vision.py @@ -1,146 +0,0 @@ -from functools import singledispatchmethod - -from outlines.models import Transformers -from outlines.models.base import ModelTypeAdapter - - -class TransformersVisionTypeAdapter(ModelTypeAdapter): - """Type adapter for TransformersVision models.""" - - @singledispatchmethod - def format_input(self, model_input): - """Generate the prompt argument to pass to the model. - - Argument - -------- - model_input - The input passed by the user. - - """ - raise NotImplementedError( - f"The input type {input} is not available. Please provide a " - "dictionary with the following format: " - "{'prompts': Union[str, List[str]], 'images': Union[Any, List[Any]]}" - "Make sure the number of image tags in the prompts is equal to the " - "number of images provided." - ) - - @format_input.register(dict) - def format_list_input(self, model_input): - if "prompts" not in model_input or "images" not in model_input: - raise ValueError( - "The input must contain the following keys: 'prompts' and 'images'." - ) - return model_input["prompts"], model_input["images"] - - def format_output_type(self, output_type): - """Generate the logits processor argument to pass to the model. - - Argument - -------- - output_type - The logits processor provided. - - """ - from transformers import LogitsProcessorList - - if output_type is not None: - return LogitsProcessorList([output_type]) - return None - - -class TransformersVision(Transformers): - """Represents a `transformers` model with a vision processor.""" - - def __init__( - self, - model_name: str, - model_class, - model_kwargs: dict = {}, - tokenizer_class=None, - tokenizer_kwargs: dict = {}, - processor_class=None, - processor_kwargs: dict = {}, - ): - """Create a TransformersVision model instance - - We rely on the `__init__` method of the `Transformers` class to handle - most of the initialization and then add elements specific to vision - models. - - Parameters - ---------- - model_name - The name of the transformers model to use; - model_class - The Transformers model class from which to create the model. - model_kwargs - A dictionary of keyword arguments to pass to the `from_pretrained` - method of the model class. - tokenizer_class - The Transformers tokenizer class from which to create the tokenizer. - If not provided,`AutoTokenizer` will be used. - If you gave the name of a model that is not compatible with `AutoTokenizer`, - you must provide a value for this parameter. - tokenizer_kwargs - A dictionary of keyword arguments to pass to the `from_pretrained` - method of the tokenizer class. - processor_class - The Transformers processor class from which to create the processor. - If not provided,`AutoProcessor` will be used. - If you gave the name of a model that is not compatible with `AutoProcessor`, - you must provide a value for this parameter. - processor_kwargs - A dictionary of keyword arguments to pass to the `from_pretrained` - method of the processor class. - - """ - if processor_class is None: - try: - from transformers import AutoProcessor - - processor_class = AutoProcessor - except ImportError: - raise ImportError( - "The `transformers` library needs to be installed in order to use `transformers` models." - ) - - processor_kwargs.setdefault("padding_side", "left") - processor_kwargs.setdefault("pad_token", "[PAD]") - self.processor = processor_class.from_pretrained(model_name, **processor_kwargs) - - if tokenizer_class is None and getattr(self.processor, "tokenizer", None): - tokenizer_class = type(self.processor.tokenizer) - - super().__init__( - model_name, - model_class, - model_kwargs, - tokenizer_class, - tokenizer_kwargs, - ) - - self.type_adapter = TransformersVisionTypeAdapter() - - def generate(self, model_input, output_type, **inference_kwargs): - prompts, images = self.type_adapter.format_input(model_input) - - inputs = self.processor( - text=prompts, images=images, padding=True, return_tensors="pt" - ).to(self.model.device) - logits_processor = self.type_adapter.format_output_type(output_type) - - generated_ids = self._generate_output_seq( - prompts, inputs, logits_processor=logits_processor, **inference_kwargs - ) - - # if single str input, convert to a 1D outputt - if isinstance(prompts, str): - generated_ids = generated_ids.squeeze(0) - - return self._decode_generation(generated_ids) - - def stream(self, model_input, output_type, **inference_kwargs): - raise NotImplementedError( - "Streaming is not implemented for TransformersVision models." - ) diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py index b9b035d1f..8c709a5e2 100644 --- a/outlines/models/vllm.py +++ b/outlines/models/vllm.py @@ -9,6 +9,9 @@ from vllm.sampling_params import SamplingParams +__all__ = ["VLLM", "from_vllm"] + + class VLLM: """Represents a vLLM model. @@ -168,22 +171,7 @@ def load_lora(self, adapter_path: Optional[str]): self.lora_request = LoRARequest(adapter_path, 1, adapter_path) -def vllm(model_name: str, **vllm_model_params): - """Load a vLLM model. - - Parameters - --------- - model_name - The name of the model to load from the HuggingFace hub. - vllm_model_params - vLLM-specific model parameters. See the vLLM code for the full list: - https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py - - """ - from vllm import LLM - - model = LLM(model_name, **vllm_model_params) - +def from_vllm(model: "LLM") -> VLLM: return VLLM(model) diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index 0788d9dca..bbab3fde9 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -59,7 +59,7 @@ @dataclass -class Json: +class JsonType: """Represents a JSON object. The structure of JSON object can be defined using a JSON Schema @@ -73,7 +73,7 @@ class Json: """ - definition: Union[str, dict] + definition: Union[str, dict, type[BaseModel]] whitespace_pattern: str = " " def to_json_schema(self): diff --git a/pyproject.toml b/pyproject.toml index 9b6d5174b..2c5382e79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,16 +47,6 @@ dependencies = [ dynamic = ["version"] [project.optional-dependencies] -vllm = ["vllm", "transformers", "numpy<2"] -transformers = ["transformers", "accelerate", "datasets", "numpy<2"] -mlxlm = ["mlx-lm", "datasets"] -openai = ["openai"] -anthropic = ["anthropic"] -gemini = ["google-generativeai"] -llamacpp = ["llama-cpp-python", "transformers", "datasets", "numpy<2"] -exllamav2 = ["exllamav2"] -ollama = ["ollama"] -dottxt = ["dottxt"] test = [ "pre-commit", "pytest", diff --git a/tests/fsm/test_cfg_guide.py b/tests/fsm/test_cfg_guide.py index d92afa625..b925fefea 100644 --- a/tests/fsm/test_cfg_guide.py +++ b/tests/fsm/test_cfg_guide.py @@ -4,7 +4,8 @@ import pytest from transformers import AutoTokenizer -from outlines import grammars, models +from outlines import grammars +from outlines.models.transformers import TransformerTokenizer from outlines.fsm.guide import CFGGuide @@ -341,12 +342,12 @@ def decode(self, token_ids): @pytest.fixture(scope="session") def tokenizer_sentencepiece_gpt2(): - return models.TransformerTokenizer(AutoTokenizer.from_pretrained("gpt2")) + return TransformerTokenizer(AutoTokenizer.from_pretrained("gpt2")) @pytest.fixture(scope="session") def tokenizer_sentencepiece_llama1(): - return models.TransformerTokenizer( + return TransformerTokenizer( AutoTokenizer.from_pretrained( "trl-internal-testing/tiny-random-LlamaForCausalLM" ) @@ -355,14 +356,14 @@ def tokenizer_sentencepiece_llama1(): @pytest.fixture(scope="session") def tokenizer_tiktoken_llama3(): - return models.TransformerTokenizer( + return TransformerTokenizer( AutoTokenizer.from_pretrained("yujiepan/llama-3-tiny-random") ) @pytest.fixture(scope="session") def tokenizer_character_level_byt5(): - return models.TransformerTokenizer( + return TransformerTokenizer( AutoTokenizer.from_pretrained("google/byt5-small") ) diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index aaa4509bf..1bf3f2a41 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -4,8 +4,8 @@ import pytest +import outlines import outlines.generate as generate -import outlines.models as models import outlines.samplers as samplers ########################################## @@ -22,11 +22,12 @@ def model_llamacpp(tmp_path_factory): filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", verbose=False, ) - return models.LlamaCpp(llm) + return outlines.from_llamacpp(llm) @pytest.fixture(scope="session") def model_exllamav2(tmp_path_factory): + from outlines.models.exllamav2 import exl2 from huggingface_hub import snapshot_download tmp_dir = tmp_path_factory.mktemp("model_download") @@ -35,7 +36,7 @@ def model_exllamav2(tmp_path_factory): cache_dir=tmp_dir, ) - return models.exl2( + return exl2( model_path=model_path, cache_q4=True, paged=False, @@ -44,58 +45,79 @@ def model_exllamav2(tmp_path_factory): @pytest.fixture(scope="session") def model_mlxlm(tmp_path_factory): - return models.mlxlm("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit") + from mlx_lm import load + + return outlines.from_mlxlm(*load("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")) @pytest.fixture(scope="session") def model_mlxlm_phi3(tmp_path_factory): - return models.mlxlm("mlx-community/Phi-3-mini-4k-instruct-4bit") + from mlx_lm import load + + return outlines.from_mlxlm(*load("mlx-community/Phi-3-mini-4k-instruct-4bit")) @pytest.fixture(scope="session") def model_transformers_random(tmp_path_factory): - return models.Transformers("hf-internal-testing/tiny-random-gpt2") + from transformers import AutoModelForCausalLM, AutoTokenizer + + return outlines.from_transformers( + AutoModelForCausalLM.fromt_pretrained("hf-internal-testing/tiny-random-gpt2"), + AutoTokenizer.fromt_pretrained("hf-internal-testing/tiny-random-gpt2"), + ) @pytest.fixture(scope="session") def model_transformers_opt125m(tmp_path_factory): - return models.Transformers("facebook/opt-125m") + from transformers import AutoModelForCausalLM, AutoTokenizer + + return outlines.from_transformers( + AutoModelForCausalLM.fromt_pretrained("facebook/opt-125m"), + AutoTokenizer.fromt_pretrained("facebook/opt-125m"), + ) @pytest.fixture(scope="session") def model_mamba(tmp_path_factory): - return models.Mamba(model_name="state-spaces/mamba-130m-hf") + from transformers import MambaModel, AutoTokenizer + + return outlines.from_transformers( + MambaModel.from_pretrained(model_name="state-spaces/mamba-130m-hf"), + AutoTokenizer.from_pretrained(model_name="state-spaces/mamba-130m-hf"), + ) @pytest.fixture(scope="session") def model_bart(tmp_path_factory): - from transformers import AutoModelForSeq2SeqLM + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer - return models.Transformers( - "facebook/bart-base", model_class=AutoModelForSeq2SeqLM + return outlines.from_transformers( + AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base"), + AutoTokenizer.from_pretrained("facebook/bart-base"), ) @pytest.fixture(scope="session") def model_transformers_vision(tmp_path_factory): import torch - from transformers import LlavaNextForConditionalGeneration + from transformers import LlavaNextForConditionalGeneration, AutoTokenizer - return models.transformers_vision( - "llava-hf/llava-v1.6-mistral-7b-hf", - model_class=LlavaNextForConditionalGeneration, - device="cuda", - model_kwargs=dict( + return outlines.from_transformers( + LlavaNextForConditionalGeneration.from_pretrained( + "llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.bfloat16, load_in_4bit=True, - low_cpu_mem_usage=True, + low_mem_cpu_usage=True, ), + AutoTokenizer.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf"), ) @pytest.fixture(scope="session") def model_vllm(tmp_path_factory): - return models.vllm("facebook/opt-125m", gpu_memory_utilization=0.1) + from vllm import LLM + + return outlines.from_vllm(LLM("facebook/opt-125m", gpu_memory_utilization=0.1)) # TODO: exllamav2 failing in main, address in https://github.com/dottxt-ai/outlines/issues/808 @@ -281,7 +303,7 @@ def test_generate_json(request, model_fixture, sample_schema): # TODO: add support for genson in the Regex type of v1.0 -#def test_integrate_genson_generate_json(request): +# def test_integrate_genson_generate_json(request): # from genson import SchemaBuilder # # builder = SchemaBuilder() diff --git a/tests/generate/test_integration_vllm.py b/tests/generate/test_integration_vllm.py index d812d3cf2..7b9acc240 100644 --- a/tests/generate/test_integration_vllm.py +++ b/tests/generate/test_integration_vllm.py @@ -5,16 +5,18 @@ import torch from pydantic import BaseModel, constr -try: - from vllm.sampling_params import SamplingParams -except ImportError: - pass - +import outlines import outlines.generate as generate import outlines.grammars as grammars import outlines.models as models import outlines.samplers as samplers +try: + from vllm import LLM + from vllm.sampling_params import SamplingParams +except ImportError: + pass + pytestmark = pytest.mark.skipif( not torch.cuda.is_available(), reason="vLLM models can only be run on GPU." ) @@ -22,7 +24,7 @@ @pytest.fixture(scope="module") def model(): - return models.vllm("gpt2", gpu_memory_utilization=0.5) + return outlines.from_vllm(LLM("gpt2", gpu_memory_utilization=0.5)) @pytest.mark.parametrize( diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 3aabba53d..8aef257cc 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1,23 +1,20 @@ import io +from anthropic import Anthropic as AnthropicClient import PIL import pytest import requests +import outlines from outlines.models.anthropic import Anthropic from outlines.templates import Vision MODEL_NAME = "claude-3-haiku-20240307" -def test_anthropic_wrong_init_parameters(): - with pytest.raises(TypeError, match="got an unexpected"): - Anthropic(MODEL_NAME, foo=10) - - def test_anthropic_wrong_inference_parameters(): with pytest.raises(TypeError, match="got an unexpected"): - model = Anthropic(MODEL_NAME) + model = Anthropic(AnthropicClient(), MODEL_NAME) model.generate("prompt", foo=10, max_tokens=1024) @@ -27,37 +24,45 @@ def __init__(self, foo): self.foo = foo with pytest.raises(NotImplementedError, match="is not available"): - model = Anthropic(MODEL_NAME) + model = Anthropic(AnthropicClient(), MODEL_NAME) model.generate(Foo("prompt")) +def test_init_from_client(): + client = AnthropicClient() + model = outlines.from_anthropic(client, MODEL_NAME) + assert isinstance(model, Anthropic) + assert model.client == client + + def test_anthropic_wrong_output_type(): class Foo: def __init__(self, foo): self.foo = foo with pytest.raises(NotImplementedError, match="is not available"): - model = Anthropic(MODEL_NAME) + model = Anthropic(AnthropicClient(), MODEL_NAME) model.generate("prompt", Foo(1)) @pytest.mark.api_call def test_anthropic_simple_call(): - model = Anthropic(MODEL_NAME) + model = Anthropic(AnthropicClient(), MODEL_NAME) result = model.generate("Respond with one word. Not more.", max_tokens=1024) assert isinstance(result, str) +@pytest.mark.xfail(reason="Anthropic requires the `max_tokens` parameter to be set") @pytest.mark.api_call def test_anthropic_direct_call(): - model = Anthropic(MODEL_NAME) + model = Anthropic(AnthropicClient(), MODEL_NAME) result = model("Respond with one word. Not more.", max_tokens=1024) assert isinstance(result, str) @pytest.mark.api_call def test_anthropic_simple_vision(): - model = Anthropic(MODEL_NAME) + model = Anthropic(AnthropicClient(), MODEL_NAME) url = "https://raw.githubusercontent.com/dottxt-ai/outlines/refs/heads/main/docs/assets/images/logo.png" r = requests.get(url, stream=True) diff --git a/tests/models/test_dottxt.py b/tests/models/test_dottxt.py index fe1aa9be5..6c3ea612b 100644 --- a/tests/models/test_dottxt.py +++ b/tests/models/test_dottxt.py @@ -4,9 +4,12 @@ import pytest from pydantic import BaseModel +from dottxt.client import Dottxt as DottxtClient + +import outlines from outlines.generate import Generator from outlines.models.dottxt import Dottxt -from outlines.types import Json +from outlines.types import JsonType class User(BaseModel): @@ -31,39 +34,50 @@ def api_key(): def test_dottxt_wrong_init_parameters(api_key): with pytest.raises(TypeError, match="got an unexpected"): - Dottxt(api_key=api_key, foo=10) + client = DottxtClient(api_key=api_key) + Dottxt(client, foo=10) def test_dottxt_wrong_output_type(api_key): with pytest.raises(NotImplementedError, match="must provide an output type"): - model = Dottxt(api_key=api_key) + client = DottxtClient(api_key=api_key) + model = Dottxt(client) model("prompt") +def test_dottxt_init_from_client(api_key): + client = DottxtClient(api_key=api_key) + model = outlines.from_dottxt(client) + assert isinstance(model, Dottxt) + assert model.client == client @pytest.mark.api_call def test_dottxt_wrong_input_type(api_key): with pytest.raises(NotImplementedError, match="is not available"): - model = Dottxt(api_key=api_key) - model(["prompt"], Json(User)) + client = DottxtClient(api_key=api_key) + model = Dottxt(client) + model(["prompt"], JsonType(User)) @pytest.mark.api_call def test_dottxt_wrong_inference_parameters(api_key): with pytest.raises(TypeError, match="got an unexpected"): - model = Dottxt(api_key=api_key) - model("prompt", Json(User), foo=10) + client = DottxtClient(api_key=api_key) + model = Dottxt(client) + model("prompt", JsonType(User), foo=10) @pytest.mark.api_call def test_dottxt_direct_call(api_key): - model = Dottxt(api_key=api_key, model_name="meta-llama/Llama-3.1-8B-Instruct") - result = model("Create a user", Json(User)) + client = DottxtClient(api_key=api_key) + model = Dottxt(client, model_name="meta-llama/Llama-3.1-8B-Instruct") + result = model("Create a user", JsonType(User)) assert "first_name" in json.loads(result) @pytest.mark.api_call def test_dottxt_generator_call(api_key): - model = Dottxt(api_key=api_key, model_name="meta-llama/Llama-3.1-8B-Instruct") - generator = Generator(model, Json(User)) + client = DottxtClient(api_key=api_key) + model = Dottxt(client, model_name="meta-llama/Llama-3.1-8B-Instruct") + generator = Generator(model, JsonType(User)) result = generator("Create a user") assert "first_name" in json.loads(result) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 72e488084..cc54bd71a 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -7,42 +7,46 @@ import requests from pydantic import BaseModel from typing_extensions import TypedDict +import google.generativeai as genai +import outlines from outlines.models.gemini import Gemini from outlines.templates import Vision -from outlines.types import Choice, Json, List +from outlines.types import Choice, JsonType, List MODEL_NAME = "gemini-1.5-flash-latest" -def test_gemini_wrong_init_parameters(): - with pytest.raises(TypeError, match="got an unexpected"): - Gemini(MODEL_NAME, foo=10) - - def test_gemini_wrong_inference_parameters(): with pytest.raises(TypeError, match="got an unexpected"): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) model.generate("prompt", foo=10) +def test_gemini_init_from_client(): + client = genai.GenerativeModel(MODEL_NAME) + model = outlines.from_gemini(client) + assert isinstance(model, Gemini) + assert model.client == client + + @pytest.mark.api_call def test_gemini_simple_call(): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) result = model.generate("Respond with one word. Not more.") assert isinstance(result, str) @pytest.mark.api_call def test_gemini_direct_call(): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) result = model("Respond with one word. Not more.") assert isinstance(result, str) @pytest.mark.api_call def test_gemini_simple_vision(): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) url = "https://raw.githubusercontent.com/dottxt-ai/outlines/refs/heads/main/docs/assets/images/logo.png" r = requests.get(url, stream=True) @@ -55,12 +59,12 @@ def test_gemini_simple_vision(): @pytest.mark.api_call def test_gemini_simple_pydantic(): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) class Foo(BaseModel): bar: int - result = model.generate("foo?", Json(Foo)) + result = model.generate("foo?", JsonType(Foo)) assert isinstance(result, str) assert "bar" in json.loads(result) @@ -68,7 +72,7 @@ class Foo(BaseModel): @pytest.mark.xfail(reason="Vision models do not work with structured outputs.") @pytest.mark.api_call def test_gemini_simple_vision_pydantic(): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) url = "https://raw.githubusercontent.com/dottxt-ai/outlines/refs/heads/main/docs/assets/images/logo.png" r = requests.get(url, stream=True) @@ -86,7 +90,7 @@ class Logo(BaseModel): @pytest.mark.xfail(reason="Gemini seems to be unable to follow nested schemas.") @pytest.mark.api_call def test_gemini_nested_pydantic(): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) class Bar(BaseModel): fu: str @@ -95,7 +99,7 @@ class Foo(BaseModel): sna: int bar: Bar - result = model.generate("foo?", Json(Foo)) + result = model.generate("foo?", JsonType(Foo)) assert isinstance(result, str) assert "sna" in json.loads(result) assert "bar" in json.loads(result) @@ -107,7 +111,7 @@ class Foo(BaseModel): ) @pytest.mark.api_call def test_gemini_simple_json_schema_dict(): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) schema = { "properties": {"bar": {"title": "Bar", "type": "integer"}}, @@ -115,7 +119,7 @@ def test_gemini_simple_json_schema_dict(): "title": "Foo", "type": "object", } - result = model.generate("foo?", Json(schema)) + result = model.generate("foo?", JsonType(schema)) assert isinstance(result, str) assert "bar" in json.loads(result) @@ -125,29 +129,29 @@ def test_gemini_simple_json_schema_dict(): ) @pytest.mark.api_call def test_gemini_simple_json_schema_string(): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) schema = "{'properties': {'bar': {'title': 'Bar', 'type': 'integer'}}, 'required': ['bar'], 'title': 'Foo', 'type': 'object'}" - result = model.generate("foo?", Json(schema)) + result = model.generate("foo?", JsonType(schema)) assert isinstance(result, str) assert "bar" in json.loads(result) @pytest.mark.api_call def test_gemini_simple_typed_dict(): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) class Foo(TypedDict): bar: int - result = model.generate("foo?", Json(Foo)) + result = model.generate("foo?", JsonType(Foo)) assert isinstance(result, str) assert "bar" in json.loads(result) @pytest.mark.api_call def test_gemini_simple_choice_enum(): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) class Foo(Enum): bar = "Bar" @@ -160,7 +164,7 @@ class Foo(Enum): @pytest.mark.api_call def test_gemini_simple_choice_list(): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) choices = ["Foo", "Bar"] result = model.generate("foo?", Choice(choices)) @@ -170,12 +174,12 @@ def test_gemini_simple_choice_list(): @pytest.mark.api_call def test_gemini_simple_list_pydantic(): - model = Gemini(MODEL_NAME) + model = Gemini(genai.GenerativeModel(MODEL_NAME)) class Foo(BaseModel): bar: int - result = model.generate("foo?", List(Json(Foo))) + result = model.generate("foo?", List(JsonType(Foo))) assert isinstance(json.loads(result), list) assert isinstance(json.loads(result)[0], dict) assert "bar" in json.loads(result)[0] diff --git a/tests/models/test_llamacpp.py b/tests/models/test_llamacpp.py index ffb0110cd..225ec558d 100644 --- a/tests/models/test_llamacpp.py +++ b/tests/models/test_llamacpp.py @@ -2,27 +2,33 @@ from enum import Enum import pytest +from llama_cpp import Llama from pydantic import BaseModel from outlines.models import LlamaCpp from outlines.processors import RegexLogitsProcessor -from outlines.types import Choice, Json, Regex +from outlines.types import Choice, JsonType, Regex def test_load_model(): - model = LlamaCpp.from_pretrained( - repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF", - filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", + model = LlamaCpp( + Llama.from_pretrained( + repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF", + filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", + ) ) assert isinstance(model, LlamaCpp) + assert isinstance(model.model, Llama) @pytest.fixture(scope="session") def model(tmp_path_factory): - return LlamaCpp.from_pretrained( - repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF", - filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", + return LlamaCpp( + Llama.from_pretrained( + repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF", + filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", + ) ) @@ -32,7 +38,7 @@ def test_llamacpp_simple(model): def test_llamacpp_regex(model): - regex_str = Regex(r"[0-9]").to_regex() + regex_str = Regex(r"[0-9]").pattern logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) result = model.generate("Respond with one word. Not more.", logits_processor) assert isinstance(result, str) @@ -42,7 +48,7 @@ def test_llamacpp_json(model): class Foo(BaseModel): bar: str - regex_str = Json(Foo).to_regex() + regex_str = JsonType(Foo).to_regex() logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) result = model.generate( "foo? Respond with one word.", logits_processor, max_tokens=1000 @@ -85,7 +91,7 @@ def test_llamacpp_stream_simple(model): def test_llamacpp_stream_regex(model): - regex_str = Regex(r"[0-9]").to_regex() + regex_str = Regex(r"[0-9]").pattern logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) generator = model.stream("Respond with one word. Not more.", logits_processor) @@ -97,7 +103,7 @@ def test_llamacpp_stream_json(model): class Foo(BaseModel): bar: int - regex_str = Json(Foo).to_regex() + regex_str = JsonType(Foo).to_regex() logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) generator = model.stream("foo?", logits_processor) diff --git a/tests/models/test_mlxlm.py b/tests/models/test_mlxlm.py index 20e59da81..10280b756 100644 --- a/tests/models/test_mlxlm.py +++ b/tests/models/test_mlxlm.py @@ -1,9 +1,10 @@ import pytest -from outlines.models.mlxlm import mlxlm +import outlines from outlines.models.transformers import TransformerTokenizer try: + import mlx_lm import mlx.core as mx HAS_MLX = mx.metal.is_available() @@ -14,18 +15,21 @@ TEST_MODEL = "mlx-community/SmolLM-135M-Instruct-4bit" +@pytest.fixture(scope="session") +def model(tmp_path_factory): + model, tokenizer = mlx_lm.load(TEST_MODEL) + return outlines.from_mlxlm(model, tokenizer) + + @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") -def test_mlxlm_model(): - model = mlxlm(TEST_MODEL) +def test_mlxlm_model(model): assert hasattr(model, "model") assert hasattr(model, "tokenizer") assert isinstance(model.tokenizer, TransformerTokenizer) @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") -def test_mlxlm_tokenizer(): - model = mlxlm(TEST_MODEL) - +def test_mlxlm_tokenizer(model): # Test single string encoding/decoding test_text = "Hello, world!" token_ids = mx.array(model.mlx_tokenizer.encode(test_text)) @@ -33,10 +37,9 @@ def test_mlxlm_tokenizer(): @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") -def test_mlxlm_generate(): +def test_mlxlm_generate(model): from outlines.generate.api import GenerationParameters, SamplingParameters - model = mlxlm(TEST_MODEL) prompt = "Write a haiku about programming:" # Test with basic generation parameters @@ -54,10 +57,9 @@ def test_mlxlm_generate(): @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") -def test_mlxlm_stream(): +def test_mlxlm_stream(model): from outlines.generate.api import GenerationParameters, SamplingParameters - model = mlxlm(TEST_MODEL) prompt = "Count from 1 to 5:" gen_params = GenerationParameters(max_tokens=20, stop_at=None, seed=None) @@ -83,9 +85,7 @@ def test_mlxlm_stream(): @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") -def test_mlxlm_errors(): - model = mlxlm(TEST_MODEL) - +def test_mlxlm_errors(model): # Test batch inference (should raise NotImplementedError) with pytest.raises(NotImplementedError): from outlines.generate.api import GenerationParameters, SamplingParameters diff --git a/tests/models/test_ollama.py b/tests/models/test_ollama.py index ed05f3e21..68afa5640 100644 --- a/tests/models/test_ollama.py +++ b/tests/models/test_ollama.py @@ -2,36 +2,39 @@ from enum import Enum import pytest +from ollama import Client as OllamaClient from pydantic import BaseModel +import outlines from outlines.models import Ollama -from outlines.types import Choice, Json +from outlines.types import Choice, JsonType MODEL_NAME = "tinyllama" +CLIENT = OllamaClient() -def test_pull_model(): - model = Ollama.from_pretrained(MODEL_NAME) - assert isinstance(model, Ollama) - - -def test_ollama_wrong_init_parameters(): +def test_wrong_inference_parameters(): with pytest.raises(TypeError, match="got an unexpected"): - Ollama(MODEL_NAME, foo=10) + Ollama(CLIENT, MODEL_NAME).generate( + "Respond with one word. Not more.", None, foo=10 + ) -def test_wrong_inference_parameters(): - with pytest.raises(TypeError, match="got an unexpected"): - Ollama(MODEL_NAME).generate("Respond with one word. Not more.", None, foo=10) +def test_init_from_client(): + model = outlines.from_ollama(CLIENT, MODEL_NAME) + assert isinstance(model, Ollama) + assert model.client == CLIENT def test_ollama_simple(): - result = Ollama(MODEL_NAME).generate("Respond with one word. Not more.", None) + result = Ollama(CLIENT, MODEL_NAME).generate( + "Respond with one word. Not more.", None + ) assert isinstance(result, str) def test_ollama_direct(): - result = Ollama(MODEL_NAME)("Respond with one word. Not more.", None) + result = Ollama(CLIENT, MODEL_NAME)("Respond with one word. Not more.", None) assert isinstance(result, str) @@ -39,7 +42,9 @@ def test_ollama_json(): class Foo(BaseModel): foo: str - result = Ollama(MODEL_NAME)("Respond with one word. Not more.", Json(Foo)) + result = Ollama(CLIENT, MODEL_NAME)( + "Respond with one word. Not more.", JsonType(Foo) + ) assert isinstance(result, str) assert "foo" in json.loads(result) @@ -50,16 +55,16 @@ class Foo(Enum): foor = "Foo" with pytest.raises(NotImplementedError, match="is not available"): - Ollama(MODEL_NAME).generate("foo?", Choice(Foo)) + Ollama(CLIENT, MODEL_NAME).generate("foo?", Choice(Foo)) def test_ollama_wrong_input_type(): with pytest.raises(NotImplementedError, match="is not available"): - Ollama(MODEL_NAME).generate(["foo?", "bar?"], None) + Ollama(CLIENT, MODEL_NAME).generate(["foo?", "bar?"], None) def test_ollama_stream(): - model = Ollama(MODEL_NAME) + model = Ollama(CLIENT, MODEL_NAME) generator = model.stream("Write a sentence about a cat.") assert isinstance(next(generator), str) @@ -68,8 +73,8 @@ def test_ollama_stream_json(): class Foo(BaseModel): foo: str - model = Ollama(MODEL_NAME) - generator = model.stream("Create a character.", Json(Foo)) + model = Ollama(CLIENT, MODEL_NAME) + generator = model.stream("Create a character.", JsonType(Foo)) generated_text = [] for text in generator: generated_text.append(text) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 25f399402..57ed323de 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -5,11 +5,13 @@ import PIL import pytest import requests +from openai import OpenAI as OpenAIClient from pydantic import BaseModel +import outlines from outlines.models.openai import OpenAI from outlines.templates import Vision -from outlines.types import Json +from outlines.types import JsonType MODEL_NAME = "gpt-4o-mini-2024-07-18" @@ -28,14 +30,16 @@ def api_key(): return api_key -def test_openai_wrong_init_parameters(api_key): - with pytest.raises(TypeError, match="got an unexpected"): - OpenAI(MODEL_NAME, api_key=api_key, foo=10) +def test_init_from_client(api_key): + client = OpenAIClient(api_key=api_key) + model = outlines.from_openai(client, "gpt-4o") + assert isinstance(model, OpenAI) + assert model.client == client def test_openai_wrong_inference_parameters(api_key): with pytest.raises(TypeError, match="got an unexpected"): - model = OpenAI(MODEL_NAME, api_key=api_key) + model = OpenAI(OpenAIClient(api_key=api_key), MODEL_NAME) model.generate("prompt", foo=10) @@ -45,7 +49,7 @@ def __init__(self, foo): self.foo = foo with pytest.raises(NotImplementedError, match="is not available"): - model = OpenAI(MODEL_NAME, api_key=api_key) + model = OpenAI(OpenAIClient(api_key=api_key), MODEL_NAME) model.generate(Foo("prompt")) @@ -55,27 +59,27 @@ def __init__(self, foo): self.foo = foo with pytest.raises(NotImplementedError, match="is not available"): - model = OpenAI(MODEL_NAME, api_key=api_key) + model = OpenAI(OpenAIClient(api_key=api_key), MODEL_NAME) model.generate("prompt", Foo(1)) @pytest.mark.api_call def test_openai_simple_call(): - model = OpenAI(MODEL_NAME) + model = OpenAI(OpenAIClient(), MODEL_NAME) result = model.generate("Respond with one word. Not more.") assert isinstance(result, str) @pytest.mark.api_call def test_openai_direct_call(): - model = OpenAI(MODEL_NAME) + model = OpenAI(OpenAIClient(), MODEL_NAME) result = model("Respond with one word. Not more.") assert isinstance(result, str) @pytest.mark.api_call def test_openai_simple_vision(): - model = OpenAI(MODEL_NAME) + model = OpenAI(OpenAIClient(), MODEL_NAME) url = "https://raw.githubusercontent.com/dottxt-ai/outlines/refs/heads/main/docs/assets/images/logo.png" r = requests.get(url, stream=True) @@ -88,19 +92,19 @@ def test_openai_simple_vision(): @pytest.mark.api_call def test_openai_simple_pydantic(): - model = OpenAI(MODEL_NAME) + model = OpenAI(OpenAIClient(), MODEL_NAME) class Foo(BaseModel): bar: int - result = model.generate("foo?", Json(Foo)) + result = model.generate("foo?", JsonType(Foo)) assert isinstance(result, str) assert "bar" in json.loads(result) @pytest.mark.api_call def test_openai_simple_vision_pydantic(): - model = OpenAI(MODEL_NAME) + model = OpenAI(OpenAIClient(), MODEL_NAME) url = "https://raw.githubusercontent.com/dottxt-ai/outlines/refs/heads/main/docs/assets/images/logo.png" r = requests.get(url, stream=True) @@ -110,20 +114,22 @@ def test_openai_simple_vision_pydantic(): class Logo(BaseModel): name: int - result = model.generate(Vision("What does this logo represent?", image), Json(Logo)) + result = model.generate( + Vision("What does this logo represent?", image), JsonType(Logo) + ) assert isinstance(result, str) assert "name" in json.loads(result) @pytest.mark.api_call def test_openai_simple_json_schema(): - model = OpenAI(MODEL_NAME) + model = OpenAI(OpenAIClient(), MODEL_NAME) class Foo(BaseModel): bar: int schema = json.dumps(Foo.model_json_schema()) - result = model.generate("foo?", Json(schema)) + result = model.generate("foo?", JsonType(schema)) assert isinstance(result, str) assert "bar" in json.loads(result) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 249893ff9..f2fb90424 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -3,10 +3,11 @@ import pytest from pydantic import BaseModel -from transformers import AutoModelForSeq2SeqLM +import transformers -from outlines.models.transformers import Mamba, Transformers -from outlines.types import Choice, Json, Regex +import outlines +from outlines.models.transformers import Transformers +from outlines.types import Choice, JsonType, Regex TEST_MODEL = "erwanf/gpt2-mini" TEST_MODEL_SEQ2SEQ = "hf-internal-testing/tiny-random-t5" @@ -14,42 +15,36 @@ def test_transformers_instantiate_simple(): - model = Transformers(TEST_MODEL) + model = outlines.from_transformers( + transformers.AutoModelForCausalLM.from_pretrained(TEST_MODEL), + transformers.AutoTokenizer.from_pretrained(TEST_MODEL), + ) assert isinstance(model, Transformers) -def test_transformers_instantiate_wrong_kwargs(): - with pytest.raises(TypeError): - Transformers(TEST_MODEL, model_kwargs={"foo": "bar"}) - - def test_transformers_instantiate_other_model_class(): - model = Transformers( - model_name=TEST_MODEL_SEQ2SEQ, model_class=AutoModelForSeq2SeqLM + model = outlines.from_transformers( + transformers.AutoModelForSeq2SeqLM.from_pretrained(TEST_MODEL_SEQ2SEQ), + transformers.AutoTokenizer.from_pretrained(TEST_MODEL), ) assert isinstance(model, Transformers) def test_transformers_instantiate_mamba(): - model = Mamba( - model_name=TEST_MODEL_MAMBA, + model = outlines.from_transformers( + transformers.MambaForCausalLM.from_pretrained(TEST_MODEL_MAMBA), + transformers.AutoTokenizer.from_pretrained(TEST_MODEL), ) - assert isinstance(model, Mamba) assert isinstance(model, Transformers) -def test_transformers_instantiate_tokenizer_kwargs(): - model = Transformers( - TEST_MODEL, - tokenizer_kwargs={"additional_special_tokens": ["", ""]} - ) - assert "" in model.tokenizer.special_tokens - assert "" in model.tokenizer.special_tokens - - @pytest.fixture def model(): - return Transformers(TEST_MODEL) + model = outlines.from_transformers( + transformers.AutoModelForCausalLM.from_pretrained(TEST_MODEL), + transformers.AutoTokenizer.from_pretrained(TEST_MODEL), + ) + return model def test_transformers_simple(model): @@ -82,7 +77,7 @@ def test_transformers_json(model): class Character(BaseModel): name: str - result = model("Create a character with a name.", Json(Character)) + result = model("Create a character with a name.", JsonType(Character)) assert "name" in result diff --git a/tests/models/test_transformers_vision.py b/tests/models/test_transformers_vision.py index 44ca4e050..57dedcb69 100644 --- a/tests/models/test_transformers_vision.py +++ b/tests/models/test_transformers_vision.py @@ -11,10 +11,12 @@ CLIPModel, CLIPProcessor, LlavaForConditionalGeneration, + AutoProcessor, ) -from outlines.models.transformers_vision import TransformersVision -from outlines.types import Choice, Json, Regex +import outlines +from outlines.models.transformers import TransformersVision +from outlines.types import Choice, JsonType, Regex TEST_MODEL = "trl-internal-testing/tiny-LlavaForConditionalGeneration" TEST_CLIP_MODEL = "openai/clip-vit-base-patch32" @@ -35,22 +37,16 @@ def img_from_url(url): @pytest.fixture def model(): - return TransformersVision( - model_name=TEST_MODEL, model_class=LlavaForConditionalGeneration + return outlines.from_transformers( + LlavaForConditionalGeneration.from_pretrained(TEST_MODEL), + AutoProcessor.from_pretrained(TEST_MODEL), ) def test_transformers_vision_instantiate_simple(): - model = TransformersVision( - model_name=TEST_MODEL, - model_class=Blip2ForConditionalGeneration, - ) - assert isinstance(model, TransformersVision) - - -def test_transformers_vision_instantiate_other_processor_class(): - model = TransformersVision( - model_name=TEST_CLIP_MODEL, model_class=CLIPModel, processor_class=CLIPProcessor + model = outlines.from_transformers( + Blip2ForConditionalGeneration.from_pretrained(TEST_MODEL), + AutoProcessor.from_pretrained(TEST_MODEL), ) assert isinstance(model, TransformersVision) @@ -121,7 +117,7 @@ class Foo(BaseModel): result = model( {"prompts": "Give a name to this animal.", "images": images[0]}, - Json(Foo), + JsonType(Foo), ) assert "name" in result diff --git a/tests/test_function.py b/tests/test_function.py index 8bb12976e..42e121c60 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -1,6 +1,9 @@ +from typing import Annotated + +import json import pytest import responses -from pydantic import BaseModel +from pydantic import BaseModel, Field from requests.exceptions import HTTPError import outlines @@ -15,14 +18,16 @@ def test_template(text: str): """{{ text }}""" class Foo(BaseModel): - id: int + id: Annotated[str, Field(max_length=5)] fn = Function(test_template, Foo, "hf-internal-testing/tiny-random-GPTJForCausalLM") assert fn.generator is None result = fn("test") - assert isinstance(result, BaseModel) + Foo.parse_raw(result) + assert isinstance(json.loads(result), dict) + assert "id" in json.loads(result) def test_download_from_github_invalid():