Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 22, 2025
1 parent 361c4d8 commit 1a638f3
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 32 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ jobs:
run: |
curl -fsSL https://ollama.com/install.sh | sh
ollama --version
ollama pull tinyllama
ollama serve
- name: Set up test environment
run: |
python -m pip install --upgrade pip
Expand Down
4 changes: 2 additions & 2 deletions outlines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)


models = [
model_list = [
"from_anthropic",
"from_gemini",
"from_llamacpp",
Expand All @@ -47,4 +47,4 @@
"Prompt",
"vectorize",
"grammars",
] + models
] + model_list
13 changes: 10 additions & 3 deletions outlines/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import requests

from outlines import generate, models
import outlines
from outlines import Generator, JsonType

if TYPE_CHECKING:
from outlines.generate.api import SequenceGenerator
Expand Down Expand Up @@ -37,8 +38,14 @@ def from_github(cls, program_path: str, function_name: str = "fn"):

def init_generator(self):
"""Load the model and initialize the generator."""
model = models.transformers(self.model_name)
self.generator = generate.json(model, self.schema)
from transformers import AutoModelForCausalLM, AutoTokenizer

model = outlines.from_transformers(
AutoModelForCausalLM.from_pretrained(self.model_name),
AutoTokenizer.from_pretrained(self.model_name)
)

self.generator = Generator(model, JsonType(self.schema))

def __call__(self, *args, **kwargs):
"""Call the function.
Expand Down
60 changes: 42 additions & 18 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import pytest

import outlines
import outlines.generate as generate
import outlines.models as models
import outlines.samplers as samplers

##########################################
Expand All @@ -22,11 +22,12 @@ def model_llamacpp(tmp_path_factory):
filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf",
verbose=False,
)
return models.LlamaCpp(llm)
return outlines.from_llamacpp(llm)


@pytest.fixture(scope="session")
def model_exllamav2(tmp_path_factory):
from outlines.models.exllamav2 import exl2
from huggingface_hub import snapshot_download

tmp_dir = tmp_path_factory.mktemp("model_download")
Expand All @@ -35,7 +36,7 @@ def model_exllamav2(tmp_path_factory):
cache_dir=tmp_dir,
)

return models.exl2(
return exl2(
model_path=model_path,
cache_q4=True,
paged=False,
Expand All @@ -44,56 +45,79 @@ def model_exllamav2(tmp_path_factory):

@pytest.fixture(scope="session")
def model_mlxlm(tmp_path_factory):
return models.mlxlm("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")
from mlx_lm import load

return outlines.from_mlxlm(*load("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit"))


@pytest.fixture(scope="session")
def model_mlxlm_phi3(tmp_path_factory):
return models.mlxlm("mlx-community/Phi-3-mini-4k-instruct-4bit")
from mlx_lm import load

return outlines.from_mlxlm(*load("mlx-community/Phi-3-mini-4k-instruct-4bit"))


@pytest.fixture(scope="session")
def model_transformers_random(tmp_path_factory):
return models.Transformers("hf-internal-testing/tiny-random-gpt2")
from transformers import AutoModelForCausalLM, AutoTokenizer

return outlines.from_transformers(
AutoModelForCausalLM.fromt_pretrained("hf-internal-testing/tiny-random-gpt2"),
AutoTokenizer.fromt_pretrained("hf-internal-testing/tiny-random-gpt2"),
)


@pytest.fixture(scope="session")
def model_transformers_opt125m(tmp_path_factory):
return models.Transformers("facebook/opt-125m")
from transformers import AutoModelForCausalLM, AutoTokenizer

return outlines.from_transformers(
AutoModelForCausalLM.fromt_pretrained("facebook/opt-125m"),
AutoTokenizer.fromt_pretrained("facebook/opt-125m"),
)


@pytest.fixture(scope="session")
def model_mamba(tmp_path_factory):
return models.Mamba(model_name="state-spaces/mamba-130m-hf")
from transformers import MambaModel, AutoTokenizer

return outlines.from_transformers(
MambaModel.from_pretrained(model_name="state-spaces/mamba-130m-hf"),
AutoTokenizer.from_pretrained(model_name="state-spaces/mamba-130m-hf"),
)


@pytest.fixture(scope="session")
def model_bart(tmp_path_factory):
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

return models.Transformers("facebook/bart-base", model_class=AutoModelForSeq2SeqLM)
return outlines.from_transformers(
AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base"),
AutoTokenizer.from_pretrained("facebook/bart-base"),
)


@pytest.fixture(scope="session")
def model_transformers_vision(tmp_path_factory):
import torch
from transformers import LlavaNextForConditionalGeneration
from transformers import LlavaNextForConditionalGeneration, AutoTokenizer

return models.transformers_vision(
"llava-hf/llava-v1.6-mistral-7b-hf",
model_class=LlavaNextForConditionalGeneration,
device="cuda",
model_kwargs=dict(
return outlines.from_transformers(
LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
torch_dtype=torch.bfloat16,
load_in_4bit=True,
low_cpu_mem_usage=True,
low_mem_cpu_usage=True,
),
AutoTokenizer.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf"),
)


@pytest.fixture(scope="session")
def model_vllm(tmp_path_factory):
return models.vllm("facebook/opt-125m", gpu_memory_utilization=0.1)
from vllm import LLM

return outlines.from_vllm(LLM("facebook/opt-125m", gpu_memory_utilization=0.1))


# TODO: exllamav2 failing in main, address in https://github.com/dottxt-ai/outlines/issues/808
Expand Down
12 changes: 6 additions & 6 deletions tests/generate/test_integration_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
import pytest
import torch
from pydantic import BaseModel, constr
from vllm import LLM

try:
from vllm.sampling_params import SamplingParams
except ImportError:
pass

import outlines
import outlines.generate as generate
import outlines.grammars as grammars
import outlines.models as models
import outlines.samplers as samplers

try:
from vllm import LLM
from vllm.sampling_params import SamplingParams
except ImportError:
pass

pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(), reason="vLLM models can only be run on GPU."
)
Expand Down
10 changes: 7 additions & 3 deletions tests/test_function.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Annotated

import json
import pytest
import responses
from pydantic import BaseModel
from pydantic import BaseModel, Field
from requests.exceptions import HTTPError

import outlines
Expand All @@ -13,14 +16,15 @@ def test_template(text: str):
"""{{ text }}"""

class Foo(BaseModel):
id: int
id: Annotated[str, Field(min_length=10, max_length=10)]

fn = Function(test_template, Foo, "hf-internal-testing/tiny-random-GPTJForCausalLM")

assert fn.generator is None

result = fn("test")
assert isinstance(result, BaseModel)
assert isinstance(json.loads(result), dict)
assert "id" in json.loads(result)


def test_download_from_github_invalid():
Expand Down

0 comments on commit 1a638f3

Please sign in to comment.