From ed8722045c94a57d5b236cec2678427f7efedf7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Sat, 22 Feb 2025 16:33:48 +0100 Subject: [PATCH] Fix tests --- .github/workflows/tests.yml | 1 + outlines/__init__.py | 5 +-- outlines/function.py | 12 +++-- tests/fsm/test_cfg_guide.py | 11 ++--- tests/generate/test_generate.py | 60 +++++++++++++++++-------- tests/generate/test_integration_vllm.py | 12 ++--- tests/models/test_openai.py | 10 ++--- tests/test_function.py | 4 ++ 8 files changed, 74 insertions(+), 41 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0565d0f3c..e8e70910c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,6 +33,7 @@ jobs: run: | curl -fsSL https://ollama.com/install.sh | sh ollama --version + ollama pull tinyllama - name: Set up test environment run: | python -m pip install --upgrade pip diff --git a/outlines/__init__.py b/outlines/__init__.py index c44abd725..581f76b79 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -5,7 +5,6 @@ import outlines.models import outlines.processors import outlines.types -from outlines.generate import Generator from outlines.types import Choice, Regex, JsonType from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache @@ -25,7 +24,7 @@ ) -models = [ +model_list = [ "from_anthropic", "from_gemini", "from_llamacpp", @@ -48,4 +47,4 @@ "Template", "vectorize", "grammars", -] + models +] + model_list diff --git a/outlines/function.py b/outlines/function.py index 1f5e80595..d7f6dec68 100644 --- a/outlines/function.py +++ b/outlines/function.py @@ -5,8 +5,6 @@ import requests import outlines -from outlines import models -from outlines.types import JsonType if TYPE_CHECKING: from outlines.generate.api import SequenceGenerator @@ -39,8 +37,14 @@ def from_github(cls, program_path: str, function_name: str = "fn"): def init_generator(self): """Load the model and initialize the generator.""" - model = models.Transformers(self.model_name) - self.generator = outlines.generate.Generator(model, JsonType(self.schema)) + from transformers import AutoModelForCausalLM, AutoTokenizer + + model = outlines.from_transformers( + AutoModelForCausalLM.from_pretrained(self.model_name), + AutoTokenizer.from_pretrained(self.model_name), + ) + + self.generator = outlines.Generator(model, outlines.JsonType(self.schema)) def __call__(self, *args, **kwargs): """Call the function. diff --git a/tests/fsm/test_cfg_guide.py b/tests/fsm/test_cfg_guide.py index d92afa625..b925fefea 100644 --- a/tests/fsm/test_cfg_guide.py +++ b/tests/fsm/test_cfg_guide.py @@ -4,7 +4,8 @@ import pytest from transformers import AutoTokenizer -from outlines import grammars, models +from outlines import grammars +from outlines.models.transformers import TransformerTokenizer from outlines.fsm.guide import CFGGuide @@ -341,12 +342,12 @@ def decode(self, token_ids): @pytest.fixture(scope="session") def tokenizer_sentencepiece_gpt2(): - return models.TransformerTokenizer(AutoTokenizer.from_pretrained("gpt2")) + return TransformerTokenizer(AutoTokenizer.from_pretrained("gpt2")) @pytest.fixture(scope="session") def tokenizer_sentencepiece_llama1(): - return models.TransformerTokenizer( + return TransformerTokenizer( AutoTokenizer.from_pretrained( "trl-internal-testing/tiny-random-LlamaForCausalLM" ) @@ -355,14 +356,14 @@ def tokenizer_sentencepiece_llama1(): @pytest.fixture(scope="session") def tokenizer_tiktoken_llama3(): - return models.TransformerTokenizer( + return TransformerTokenizer( AutoTokenizer.from_pretrained("yujiepan/llama-3-tiny-random") ) @pytest.fixture(scope="session") def tokenizer_character_level_byt5(): - return models.TransformerTokenizer( + return TransformerTokenizer( AutoTokenizer.from_pretrained("google/byt5-small") ) diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 64dee6e4b..1bf3f2a41 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -4,8 +4,8 @@ import pytest +import outlines import outlines.generate as generate -import outlines.models as models import outlines.samplers as samplers ########################################## @@ -22,11 +22,12 @@ def model_llamacpp(tmp_path_factory): filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", verbose=False, ) - return models.LlamaCpp(llm) + return outlines.from_llamacpp(llm) @pytest.fixture(scope="session") def model_exllamav2(tmp_path_factory): + from outlines.models.exllamav2 import exl2 from huggingface_hub import snapshot_download tmp_dir = tmp_path_factory.mktemp("model_download") @@ -35,7 +36,7 @@ def model_exllamav2(tmp_path_factory): cache_dir=tmp_dir, ) - return models.exl2( + return exl2( model_path=model_path, cache_q4=True, paged=False, @@ -44,56 +45,79 @@ def model_exllamav2(tmp_path_factory): @pytest.fixture(scope="session") def model_mlxlm(tmp_path_factory): - return models.mlxlm("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit") + from mlx_lm import load + + return outlines.from_mlxlm(*load("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")) @pytest.fixture(scope="session") def model_mlxlm_phi3(tmp_path_factory): - return models.mlxlm("mlx-community/Phi-3-mini-4k-instruct-4bit") + from mlx_lm import load + + return outlines.from_mlxlm(*load("mlx-community/Phi-3-mini-4k-instruct-4bit")) @pytest.fixture(scope="session") def model_transformers_random(tmp_path_factory): - return models.Transformers("hf-internal-testing/tiny-random-gpt2") + from transformers import AutoModelForCausalLM, AutoTokenizer + + return outlines.from_transformers( + AutoModelForCausalLM.fromt_pretrained("hf-internal-testing/tiny-random-gpt2"), + AutoTokenizer.fromt_pretrained("hf-internal-testing/tiny-random-gpt2"), + ) @pytest.fixture(scope="session") def model_transformers_opt125m(tmp_path_factory): - return models.Transformers("facebook/opt-125m") + from transformers import AutoModelForCausalLM, AutoTokenizer + + return outlines.from_transformers( + AutoModelForCausalLM.fromt_pretrained("facebook/opt-125m"), + AutoTokenizer.fromt_pretrained("facebook/opt-125m"), + ) @pytest.fixture(scope="session") def model_mamba(tmp_path_factory): - return models.Mamba(model_name="state-spaces/mamba-130m-hf") + from transformers import MambaModel, AutoTokenizer + + return outlines.from_transformers( + MambaModel.from_pretrained(model_name="state-spaces/mamba-130m-hf"), + AutoTokenizer.from_pretrained(model_name="state-spaces/mamba-130m-hf"), + ) @pytest.fixture(scope="session") def model_bart(tmp_path_factory): - from transformers import AutoModelForSeq2SeqLM + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer - return models.Transformers("facebook/bart-base", model_class=AutoModelForSeq2SeqLM) + return outlines.from_transformers( + AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base"), + AutoTokenizer.from_pretrained("facebook/bart-base"), + ) @pytest.fixture(scope="session") def model_transformers_vision(tmp_path_factory): import torch - from transformers import LlavaNextForConditionalGeneration + from transformers import LlavaNextForConditionalGeneration, AutoTokenizer - return models.transformers_vision( - "llava-hf/llava-v1.6-mistral-7b-hf", - model_class=LlavaNextForConditionalGeneration, - device="cuda", - model_kwargs=dict( + return outlines.from_transformers( + LlavaNextForConditionalGeneration.from_pretrained( + "llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.bfloat16, load_in_4bit=True, - low_cpu_mem_usage=True, + low_mem_cpu_usage=True, ), + AutoTokenizer.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf"), ) @pytest.fixture(scope="session") def model_vllm(tmp_path_factory): - return models.vllm("facebook/opt-125m", gpu_memory_utilization=0.1) + from vllm import LLM + + return outlines.from_vllm(LLM("facebook/opt-125m", gpu_memory_utilization=0.1)) # TODO: exllamav2 failing in main, address in https://github.com/dottxt-ai/outlines/issues/808 diff --git a/tests/generate/test_integration_vllm.py b/tests/generate/test_integration_vllm.py index faa8a404a..7b9acc240 100644 --- a/tests/generate/test_integration_vllm.py +++ b/tests/generate/test_integration_vllm.py @@ -4,12 +4,6 @@ import pytest import torch from pydantic import BaseModel, constr -from vllm import LLM - -try: - from vllm.sampling_params import SamplingParams -except ImportError: - pass import outlines import outlines.generate as generate @@ -17,6 +11,12 @@ import outlines.models as models import outlines.samplers as samplers +try: + from vllm import LLM + from vllm.sampling_params import SamplingParams +except ImportError: + pass + pytestmark = pytest.mark.skipif( not torch.cuda.is_available(), reason="vLLM models can only be run on GPU." ) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index b8c857213..57ed323de 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -30,8 +30,8 @@ def api_key(): return api_key -def test_init_from_client(): - client = OpenAIClient() +def test_init_from_client(api_key): + client = OpenAIClient(api_key=api_key) model = outlines.from_openai(client, "gpt-4o") assert isinstance(model, OpenAI) assert model.client == client @@ -39,7 +39,7 @@ def test_init_from_client(): def test_openai_wrong_inference_parameters(api_key): with pytest.raises(TypeError, match="got an unexpected"): - model = OpenAI(OpenAIClient(), MODEL_NAME) + model = OpenAI(OpenAIClient(api_key=api_key), MODEL_NAME) model.generate("prompt", foo=10) @@ -49,7 +49,7 @@ def __init__(self, foo): self.foo = foo with pytest.raises(NotImplementedError, match="is not available"): - model = OpenAI(OpenAIClient(), MODEL_NAME) + model = OpenAI(OpenAIClient(api_key=api_key), MODEL_NAME) model.generate(Foo("prompt")) @@ -59,7 +59,7 @@ def __init__(self, foo): self.foo = foo with pytest.raises(NotImplementedError, match="is not available"): - model = OpenAI(OpenAIClient, MODEL_NAME) + model = OpenAI(OpenAIClient(api_key=api_key), MODEL_NAME) model.generate("prompt", Foo(1)) diff --git a/tests/test_function.py b/tests/test_function.py index cc8d7493e..42e121c60 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -1,4 +1,6 @@ from typing import Annotated + +import json import pytest import responses from pydantic import BaseModel, Field @@ -24,6 +26,8 @@ class Foo(BaseModel): result = fn("test") Foo.parse_raw(result) + assert isinstance(json.loads(result), dict) + assert "id" in json.loads(result) def test_download_from_github_invalid():