Skip to content

Commit

Permalink
Create the Ollama model
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin Picard authored and rlouf committed Feb 20, 2025
1 parent 70955ee commit b56acd1
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 2 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Ollama
run: |
curl -fsSL https://ollama.com/install.sh | sh
ollama --version
- name: Set up test environment
run: |
python -m pip install --upgrade pip
Expand Down
72 changes: 72 additions & 0 deletions docs/reference/models/ollama.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Ollama

!!! Installation

To be able to use Ollama in Outlines, you must install both Ollama and the `ollama` python sdk.

- To download Ollama: https://ollama.com/download
- To install the ollama python sdk: `pip install ollama`

## Ollama models

You must provide a model name when instantiating the `outlines.models.Ollama` class. This model must be available on your system.
```python
from outlines.models import Ollama

model = Ollama("tinyllama")
```

To download a new model from the Ollama model hub, you can use the following command (it will return an `Ollama` instance):
```python
from outlines.models import Ollama

Ollama.from_pretrained("llama3.1:8b")
```

You can find the list of available models on the [Ollama library](https://ollama.com/library).

## Generate text

As with other models, you can either first create a `Generator` object and then call it
```python
from outlines.models import Ollama
from outlines.generate import Generator

model = Ollama("tinyllama")
generator = Generator(model)
answer = generator("Write a sentence about a cat.")
```
or directly call the model
```python
from outlines.models import Ollama

model = Ollama("tinyllama")
answer = model("Write a sentence about a cat.")
```

The input of the generation must be a string. Batch generation is not supported.
The only output type supported is `Json`.
```python
from outlines.models import Ollama
from outlines.types import Json
from pydantic import BaseModel

class Character(BaseModel):
name: str

model = Ollama("tinyllama")
answer = model("Create a character.", output_type=Json(Character))
```

You can also stream the tokens:
```python
from outlines.models import Ollama

model = Ollama("tinyllama")
tokens = model.stream("Write a sentence about a cat.")
```

## Optional parameters

You can provide the same optional parameters you would pass to the `ollama` sdk's client both during the initialization of the `Ollama` class and when generating text.
Consult the [ollama python sdk Github repository](https://github.com/ollama/ollama-python) for the full list of parameters.
7 changes: 5 additions & 2 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
from .gemini import Gemini
from .llamacpp import LlamaCpp
from .mlxlm import MLXLM, mlxlm
from .ollama import Ollama
from .openai import AzureOpenAI, OpenAI
from .transformers import Transformers, TransformerTokenizer, mamba, transformers
from .transformers_vision import TransformersVision, transformers_vision
from .vllm import VLLM, vllm

LogitsGenerator = Union[Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM]
LogitsGenerator = Union[
Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM, Ollama
]

LocalModel = LlamaCpp
APIModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini]
APIModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini, Ollama]
95 changes: 95 additions & 0 deletions outlines/models/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from functools import singledispatchmethod
from types import NoneType
from typing import Iterator

from outlines.models.base import Model, ModelTypeAdapter
from outlines.types import Json


class OllamaTypeAdapter(ModelTypeAdapter):
"""Type adapter for the Ollama model."""

@singledispatchmethod
def format_input(self, model_input):
"""Generate the prompt argument to pass to the model.
Argument
--------
model_input
The input passed by the user.
"""
raise NotImplementedError(
f"The input type {input} is not available. "
"Ollama does not support batch inference."
)

@format_input.register(str)
def format_str_input(self, model_input: str):
return model_input

@singledispatchmethod
def format_output_type(self, output_type):
"""Generate the `format` argument to pass to the model.
Argument
--------
output_type
The output type passed by the user.
"""
raise NotImplementedError(
f"The output type {input} is not available. "
"Ollama only supports structured output with `Json`."
)

@format_output_type.register(NoneType)
def format_none_output_type(self, output_type: None):
return ""

@format_output_type.register(Json)
def format_json_output_type(self, output_type: Json):
return output_type.to_json_schema()


class Ollama(Model):
"""Thin wrapper around the `ollama` client.
This wrapper is used to convert the input and output types specified by the
users at a higher level to arguments to the `ollama` client.
"""

def __init__(self, model_name: str, *args, **kwargs):
from ollama import Client

self.client = Client(*args, **kwargs)
self.model_name = model_name
self.type_adapter = OllamaTypeAdapter()

@classmethod
def from_pretrained(cls, model_name: str, *args, **kwargs):
"""Download the model weights from Ollama and create a `Ollama` instance."""
from ollama import pull

pull(model_name)
return cls(model_name, *args, **kwargs)

def generate(self, model_input, output_type=None, **kwargs) -> str:
response = self.client.generate(
model=self.model_name,
prompt=self.type_adapter.format_input(model_input),
format=self.type_adapter.format_output_type(output_type),
**kwargs,
)
return response.response

def stream(self, model_input, output_type=None, **kwargs) -> Iterator[str]:
response = self.client.generate(
model=self.model_name,
prompt=self.type_adapter.format_input(model_input),
format=self.type_adapter.format_output_type(output_type),
stream=True,
**kwargs,
)
for chunk in response:
yield chunk.response
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ anthropic = ["anthropic"]
gemini = ["google-generativeai"]
llamacpp = ["llama-cpp-python", "transformers", "datasets", "numpy<2"]
exllamav2 = ["exllamav2"]
ollama = ["ollama"]
test = [
"pre-commit",
"pytest",
Expand All @@ -77,6 +78,7 @@ test = [
"pillow",
"exllamav2",
"jax",
"ollama"
]
test-gpu=["outlines[test]", "vllm; sys_platform == 'linux'"]
serve = [
Expand Down Expand Up @@ -156,7 +158,11 @@ module = [
"iso3166.*",
"airportsdata.*",
"outlines_core.*",
<<<<<<< HEAD
"genson",
=======
"ollama.*",
>>>>>>> 5af6b93 (Create the Ollama model)
]
ignore_missing_imports = true

Expand Down
76 changes: 76 additions & 0 deletions tests/models/test_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import json
from enum import Enum

import pytest
from pydantic import BaseModel

from outlines.models import Ollama
from outlines.types import Choice, Json

MODEL_NAME = "tinyllama"


def test_pull_model():
model = Ollama.from_pretrained(MODEL_NAME)
assert isinstance(model, Ollama)


def test_ollama_wrong_init_parameters():
with pytest.raises(TypeError, match="got an unexpected"):
Ollama(MODEL_NAME, foo=10)


def test_wrong_inference_parameters():
with pytest.raises(TypeError, match="got an unexpected"):
Ollama(MODEL_NAME).generate("Respond with one word. Not more.", None, foo=10)


def test_ollama_simple():
result = Ollama(MODEL_NAME).generate("Respond with one word. Not more.", None)
assert isinstance(result, str)


def test_ollama_direct():
result = Ollama(MODEL_NAME)("Respond with one word. Not more.", None)
assert isinstance(result, str)


def test_ollama_json():
class Foo(BaseModel):
foo: str

result = Ollama(MODEL_NAME)("Respond with one word. Not more.", Json(Foo))
assert isinstance(result, str)
assert "foo" in json.loads(result)


def test_ollama_wrong_output_type():
class Foo(Enum):
bar = "Bar"
foor = "Foo"

with pytest.raises(NotImplementedError, match="is not available"):
Ollama(MODEL_NAME).generate("foo?", Choice(Foo))


def test_ollama_wrong_input_type():
with pytest.raises(NotImplementedError, match="is not available"):
Ollama(MODEL_NAME).generate(["foo?", "bar?"], None)


def test_ollama_stream():
model = Ollama(MODEL_NAME)
generator = model.stream("Write a sentence about a cat.")
assert isinstance(next(generator), str)


def test_ollama_stream_json():
class Foo(BaseModel):
foo: str

model = Ollama(MODEL_NAME)
generator = model.stream("Create a character.", Json(Foo))
generated_text = []
for text in generator:
generated_text.append(text)
assert "foo" in json.loads("".join(generated_text))

0 comments on commit b56acd1

Please sign in to comment.