forked from dottxt-ai/outlines
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'v1.0' into add_dottxt_sdk
- Loading branch information
Showing
6 changed files
with
256 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Ollama | ||
|
||
!!! Installation | ||
|
||
To be able to use Ollama in Outlines, you must install both Ollama and the `ollama` python sdk. | ||
|
||
- To download Ollama: https://ollama.com/download | ||
- To install the ollama python sdk: `pip install ollama` | ||
|
||
## Ollama models | ||
|
||
You must provide a model name when instantiating the `outlines.models.Ollama` class. This model must be available on your system. | ||
```python | ||
from outlines.models import Ollama | ||
|
||
model = Ollama("tinyllama") | ||
``` | ||
|
||
To download a new model from the Ollama model hub, you can use the following command (it will return an `Ollama` instance): | ||
```python | ||
from outlines.models import Ollama | ||
|
||
Ollama.from_pretrained("llama3.1:8b") | ||
``` | ||
|
||
You can find the list of available models on the [Ollama library](https://ollama.com/library). | ||
|
||
## Generate text | ||
|
||
As with other models, you can either first create a `Generator` object and then call it | ||
```python | ||
from outlines.models import Ollama | ||
from outlines.generate import Generator | ||
|
||
model = Ollama("tinyllama") | ||
generator = Generator(model) | ||
answer = generator("Write a sentence about a cat.") | ||
``` | ||
or directly call the model | ||
```python | ||
from outlines.models import Ollama | ||
|
||
model = Ollama("tinyllama") | ||
answer = model("Write a sentence about a cat.") | ||
``` | ||
|
||
The input of the generation must be a string. Batch generation is not supported. | ||
The only output type supported is `Json`. | ||
```python | ||
from outlines.models import Ollama | ||
from outlines.types import Json | ||
from pydantic import BaseModel | ||
|
||
class Character(BaseModel): | ||
name: str | ||
|
||
model = Ollama("tinyllama") | ||
answer = model("Create a character.", output_type=Json(Character)) | ||
``` | ||
|
||
You can also stream the tokens: | ||
```python | ||
from outlines.models import Ollama | ||
|
||
model = Ollama("tinyllama") | ||
tokens = model.stream("Write a sentence about a cat.") | ||
``` | ||
|
||
## Optional parameters | ||
|
||
You can provide the same optional parameters you would pass to the `ollama` sdk's client both during the initialization of the `Ollama` class and when generating text. | ||
Consult the [ollama python sdk Github repository](https://github.com/ollama/ollama-python) for the full list of parameters. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from functools import singledispatchmethod | ||
from types import NoneType | ||
from typing import Iterator | ||
|
||
from outlines.models.base import Model, ModelTypeAdapter | ||
from outlines.types import Json | ||
|
||
|
||
class OllamaTypeAdapter(ModelTypeAdapter): | ||
"""Type adapter for the Ollama model.""" | ||
|
||
@singledispatchmethod | ||
def format_input(self, model_input): | ||
"""Generate the prompt argument to pass to the model. | ||
Argument | ||
-------- | ||
model_input | ||
The input passed by the user. | ||
""" | ||
raise NotImplementedError( | ||
f"The input type {input} is not available. " | ||
"Ollama does not support batch inference." | ||
) | ||
|
||
@format_input.register(str) | ||
def format_str_input(self, model_input: str): | ||
return model_input | ||
|
||
@singledispatchmethod | ||
def format_output_type(self, output_type): | ||
"""Generate the `format` argument to pass to the model. | ||
Argument | ||
-------- | ||
output_type | ||
The output type passed by the user. | ||
""" | ||
raise NotImplementedError( | ||
f"The output type {input} is not available. " | ||
"Ollama only supports structured output with `Json`." | ||
) | ||
|
||
@format_output_type.register(NoneType) | ||
def format_none_output_type(self, output_type: None): | ||
return "" | ||
|
||
@format_output_type.register(Json) | ||
def format_json_output_type(self, output_type: Json): | ||
return output_type.to_json_schema() | ||
|
||
|
||
class Ollama(Model): | ||
"""Thin wrapper around the `ollama` client. | ||
This wrapper is used to convert the input and output types specified by the | ||
users at a higher level to arguments to the `ollama` client. | ||
""" | ||
|
||
def __init__(self, model_name: str, *args, **kwargs): | ||
from ollama import Client | ||
|
||
self.client = Client(*args, **kwargs) | ||
self.model_name = model_name | ||
self.type_adapter = OllamaTypeAdapter() | ||
|
||
@classmethod | ||
def from_pretrained(cls, model_name: str, *args, **kwargs): | ||
"""Download the model weights from Ollama and create a `Ollama` instance.""" | ||
from ollama import pull | ||
|
||
pull(model_name) | ||
return cls(model_name, *args, **kwargs) | ||
|
||
def generate(self, model_input, output_type=None, **kwargs) -> str: | ||
response = self.client.generate( | ||
model=self.model_name, | ||
prompt=self.type_adapter.format_input(model_input), | ||
format=self.type_adapter.format_output_type(output_type), | ||
**kwargs, | ||
) | ||
return response.response | ||
|
||
def stream(self, model_input, output_type=None, **kwargs) -> Iterator[str]: | ||
response = self.client.generate( | ||
model=self.model_name, | ||
prompt=self.type_adapter.format_input(model_input), | ||
format=self.type_adapter.format_output_type(output_type), | ||
stream=True, | ||
**kwargs, | ||
) | ||
for chunk in response: | ||
yield chunk.response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import json | ||
from enum import Enum | ||
|
||
import pytest | ||
from pydantic import BaseModel | ||
|
||
from outlines.models import Ollama | ||
from outlines.types import Choice, Json | ||
|
||
MODEL_NAME = "tinyllama" | ||
|
||
|
||
def test_pull_model(): | ||
model = Ollama.from_pretrained(MODEL_NAME) | ||
assert isinstance(model, Ollama) | ||
|
||
|
||
def test_ollama_wrong_init_parameters(): | ||
with pytest.raises(TypeError, match="got an unexpected"): | ||
Ollama(MODEL_NAME, foo=10) | ||
|
||
|
||
def test_wrong_inference_parameters(): | ||
with pytest.raises(TypeError, match="got an unexpected"): | ||
Ollama(MODEL_NAME).generate("Respond with one word. Not more.", None, foo=10) | ||
|
||
|
||
def test_ollama_simple(): | ||
result = Ollama(MODEL_NAME).generate("Respond with one word. Not more.", None) | ||
assert isinstance(result, str) | ||
|
||
|
||
def test_ollama_direct(): | ||
result = Ollama(MODEL_NAME)("Respond with one word. Not more.", None) | ||
assert isinstance(result, str) | ||
|
||
|
||
def test_ollama_json(): | ||
class Foo(BaseModel): | ||
foo: str | ||
|
||
result = Ollama(MODEL_NAME)("Respond with one word. Not more.", Json(Foo)) | ||
assert isinstance(result, str) | ||
assert "foo" in json.loads(result) | ||
|
||
|
||
def test_ollama_wrong_output_type(): | ||
class Foo(Enum): | ||
bar = "Bar" | ||
foor = "Foo" | ||
|
||
with pytest.raises(NotImplementedError, match="is not available"): | ||
Ollama(MODEL_NAME).generate("foo?", Choice(Foo)) | ||
|
||
|
||
def test_ollama_wrong_input_type(): | ||
with pytest.raises(NotImplementedError, match="is not available"): | ||
Ollama(MODEL_NAME).generate(["foo?", "bar?"], None) | ||
|
||
|
||
def test_ollama_stream(): | ||
model = Ollama(MODEL_NAME) | ||
generator = model.stream("Write a sentence about a cat.") | ||
assert isinstance(next(generator), str) | ||
|
||
|
||
def test_ollama_stream_json(): | ||
class Foo(BaseModel): | ||
foo: str | ||
|
||
model = Ollama(MODEL_NAME) | ||
generator = model.stream("Create a character.", Json(Foo)) | ||
generated_text = [] | ||
for text in generator: | ||
generated_text.append(text) | ||
assert "foo" in json.loads("".join(generated_text)) |