From eb71b4949272102e97799afed698fae1e689ab25 Mon Sep 17 00:00:00 2001 From: Robin Picard Date: Thu, 27 Feb 2025 09:13:27 +0100 Subject: [PATCH] Refactor the VLLM model --- docs/reference/models/vllm.md | 31 ++- outlines/models/vllm.py | 159 +++++++-------- tests/generate/test_generate.py | 7 +- tests/generate/test_integration_vllm.py | 245 ------------------------ tests/models/test_vllm.py | 125 ++++++++++++ 5 files changed, 217 insertions(+), 350 deletions(-) delete mode 100644 tests/generate/test_integration_vllm.py create mode 100644 tests/models/test_vllm.py diff --git a/docs/reference/models/vllm.md b/docs/reference/models/vllm.md index 1636410a4..7698cd31e 100644 --- a/docs/reference/models/vllm.md +++ b/docs/reference/models/vllm.md @@ -11,10 +11,10 @@ Outlines supports models available via vLLM's offline batched inference interface. You can load a model using: - ```python import outlines from vllm import LLM + model = outlines.from_vllm(LLM("microsoft/Phi-3-mini-4k-instruct")) ``` @@ -28,18 +28,37 @@ Models are loaded from the [HuggingFace hub](https://huggingface.co/). ## Generate text -In addition to the parameters described in the [text generation section](../text.md) you can pass an instance of `SamplingParams` directly to any generator via the `sampling_params` keyword argument: +To generate text, you can just call the model with a prompt as argument: ```python -from vllm.sampling_params import SamplingParams -from outlines import models, generate +import outlines +from vllm import LLM +model = outlines.from_vllm(LLM("microsoft/Phi-3-mini-4k-instruct")) +answer = model("Write a short story about a cat.") +``` + +You can also use structured generation with the `VLLM` model by providing an output type after the prompt: + +```python +import outlines +from vllm import LLM +from outlines.types import JsonType +from pydantic import BaseModel + +class Character(BaseModel): + name: str model = outlines.from_vllm(LLM("microsoft/Phi-3-mini-4k-instruct")) -params = SamplingParams(n=2, frequency_penalty=1., min_tokens=2) -answer = model("A prompt", sampling_params=params) +answer = model("Create a character.", output_type=JsonType(Character)) ``` +The VLLM model supports batch generation. To use it, you can pass a list of strings as prompt instead of a single string. + +## Optional parameters + +When calling the model, you can provide optional parameters on top of the prompt and the output type. Those will be passed on to the `LLM.generate` method of the `vllm` library. An optional parameter of particular interest is `sampling_params`, which is an instance of `SamplingParams`. You can find more information about it in the [vLLM documentation][https://docs.vllm.ai/en/latest/api/inference_params.html]. + !!! Warning Streaming is not available for the offline vLLM integration. diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py index 8c709a5e2..3f10f6433 100644 --- a/outlines/models/vllm.py +++ b/outlines/models/vllm.py @@ -1,7 +1,9 @@ import dataclasses +from functools import singledispatchmethod from typing import TYPE_CHECKING, List, Optional, Union from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.models.base import Model, ModelTypeAdapter if TYPE_CHECKING: from transformers import PreTrainedTokenizerBase @@ -12,21 +14,59 @@ __all__ = ["VLLM", "from_vllm"] -class VLLM: - """Represents a vLLM model. +class VLLMTypeAdapter(ModelTypeAdapter): + @singledispatchmethod + def format_input(self, model_input): + """Generate the prompt argument to pass to the model. - We wrap models from model providing libraries in order to give all of - them the same interface in Outlines and allow users to easily switch - between providers. This class wraps the `vllm.LLM` class from the - `vllm` library. + Argument + -------- + model_input + The input passed by the user. - """ + """ + raise NotImplementedError( + f"The input type {input} is not available." + "Please use a string or a list of strings." + ) + + @format_input.register(str) + def format_str_input(self, model_input): + return model_input + + @format_input.register(list) + def format_list_input(self, model_input): + return model_input + + def format_output_type(self, output_type): + """Generate the logits processor argument to pass to the model. + + Argument + -------- + output_type + The logits processor provided. + + """ + if output_type is not None: + return [output_type] + return [] + + +class VLLM(Model): + """Represents a `vllm` model.""" def __init__(self, model: "LLM"): - self.model = model - self.lora_request = None + """Create a VLLM model instance. + + Parameters + ---------- + model + A `vllm.LLM` model instance. + """ + self.model = model self.tokenizer = self._get_tokenizer() + self.type_adapter = VLLMTypeAdapter() def _get_tokenizer(self): if hasattr(self.model, "get_tokenizer"): @@ -43,98 +83,35 @@ def _get_tokenizer(self): ) return adapt_tokenizer(tokenizer=tokenizer) - def generate( - self, - prompts: Union[str, List[str]], - generation_parameters: GenerationParameters, - logits_processor, - sampling_parameters: SamplingParameters, - *, - sampling_params: Optional["SamplingParams"] = None, - use_tqdm: bool = True, - ): - """Generate text using vLLM. + def generate(self, model_input, output_type, **inference_kwargs): + """Generate text using `vllm`. - Parameters - ---------- - prompts - A prompt or list of prompts. - generation_parameters - An instance of `GenerationParameters` that contains the prompt, - the maximum number of tokens, stop sequences and seed. All the - arguments to `SequenceGeneratorAdapter`'s `__cal__` method. + Arguments + --------- + prompt + The text prompt provided by the user. logits_processor The logits processor to use when generating text. - sampling_parameters - An instance of `SamplingParameters`, a dataclass that contains - the name of the sampler to use and related parameters as available - in Outlines. - sampling_params - An instance of `vllm.sampling_params.SamplingParams`. The values - passed via this dataclass supersede the values of the parameters - in `generation_parameters` and `sampling_parameters`. See the - vLLM documentation for more details: https://docs.vllm.ai/en/latest/dev/sampling_params.html. - use_tqdm - A boolean in order to display progress bar while inferencing + inference_kwargs + The inference kwargs that can be passed to the `LLM.generate` method + in the `vllm` library. Returns ------- - The generated text, of shape `(n_batch, n_samples)`. If there are only - one batch and several samples, the list is of shape `(n_samples)`. If - this is a batch with several sequences but only one sample the list is - of shape `(n_batch)`. If there is only one sequence and one sample, a - string is returned. + The generated text. """ from vllm.sampling_params import SamplingParams + sampling_params = inference_kwargs.pop("sampling_params", None) if sampling_params is None: sampling_params = SamplingParams() - - max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) - - # We only update the values in `sampling_params` if they - # are specified by the user when calling the generator. - if max_tokens is not None: - sampling_params.max_tokens = max_tokens - if stop_at is not None: - if isinstance(stop_at, str): - stop_at = [stop_at] - sampling_params.stop = stop_at - if seed is not None: - sampling_params.seed = seed - - sampling_params.logits_processors = ( - [logits_processor] if logits_processor is not None else [] - ) - - sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( - sampling_parameters - ) - - # We only update the values in `sampling_params` that - # were not specified by the user. - if sampling_params.n == 1: - sampling_params.n = num_samples - sampling_params.best_of = num_samples - if top_p is not None and sampling_params.top_p == 1.0: - sampling_params.top_p = top_p - if top_k is not None and sampling_params.top_k == -1: - sampling_params.top_k = top_k - # TODO: remove this if statement once fixed - # https://github.com/vllm-project/vllm/issues/5404#issuecomment-2175972897 - if top_k == 1: - sampling_params.repetition_penalty = 0 - if temperature is not None and sampling_params.temperature == 1.0: - sampling_params.temperature = temperature - if sampler == "beam_search": - sampling_params.use_beam_search = True + sampling_params.logits_processors = self.type_adapter.format_output_type(output_type) results = self.model.generate( - prompts, + self.type_adapter.format_input(model_input), sampling_params=sampling_params, - lora_request=self.lora_request, - use_tqdm=use_tqdm, + **inference_kwargs, ) results = [[sample.text for sample in batch.outputs] for batch in results] @@ -150,7 +127,7 @@ def generate( return results - def stream(self, *args, **kwargs): + def stream(self, model_input, output_type, **inference_kwargs): """Return a text generator. Streaming is not yet available for `vllm.LLM`. @@ -162,14 +139,6 @@ def stream(self, *args, **kwargs): "Streaming is not available for the vLLM integration." ) - def load_lora(self, adapter_path: Optional[str]): - from vllm.lora.request import LoRARequest - - if adapter_path is None: - self.lora_request = None - else: - self.lora_request = LoRARequest(adapter_path, 1, adapter_path) - def from_vllm(model: "LLM") -> VLLM: return VLLM(model) diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 1bf3f2a41..84c8b761c 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -143,17 +143,16 @@ def model_t5(tmp_path_factory): ALL_MODEL_FIXTURES = ( # "model_llamacpp", # temporary disabled due to the v1 model refactoring "model_exllamav2", - "model_mlxlm", - "model_mlxlm_phi3", + # "model_mlxlm", + # "model_mlxlm_phi3", # "model_transformers_random", # "model_transformers_opt125m", # "model_mamba", # "model_bart", # "model_transformers_vision", - "model_vllm", + # "model_vllm", ) - class MyEnum(Enum): foo = "foo" bar = "bar" diff --git a/tests/generate/test_integration_vllm.py b/tests/generate/test_integration_vllm.py deleted file mode 100644 index 7b9acc240..000000000 --- a/tests/generate/test_integration_vllm.py +++ /dev/null @@ -1,245 +0,0 @@ -import datetime -import re - -import pytest -import torch -from pydantic import BaseModel, constr - -import outlines -import outlines.generate as generate -import outlines.grammars as grammars -import outlines.models as models -import outlines.samplers as samplers - -try: - from vllm import LLM - from vllm.sampling_params import SamplingParams -except ImportError: - pass - -pytestmark = pytest.mark.skipif( - not torch.cuda.is_available(), reason="vLLM models can only be run on GPU." -) - - -@pytest.fixture(scope="module") -def model(): - return outlines.from_vllm(LLM("gpt2", gpu_memory_utilization=0.5)) - - -@pytest.mark.parametrize( - "generator_type,params", ((generate.text, []), (generate.regex, ("[0-9]",))) -) -def test_vllm_generation_api(model, generator_type, params): - generator = generator_type(model, *params) - - res = generator("test") - assert isinstance(res, str) - - res = generator("test", max_tokens=10) - assert isinstance(res, str) - - res = generator("test", stop_at=".") - assert isinstance(res, str) - - res = generator("test", stop_at=[".", "ab"]) - assert isinstance(res, str) - - res = generator("test", stop_at=[".", "ab"]) - assert isinstance(res, str) - - res1 = generator("test", seed=1) - res2 = generator("test", seed=1) - assert isinstance(res1, str) - assert isinstance(res2, str) - assert res1 == res2 - - res = generator(["test", "test1"]) - assert len(res) == 2 - - -def test_vllm_sampling_params(model): - generator = generate.text(model) - - sampling_params = SamplingParams(n=2) - res = generator("test", sampling_params=sampling_params) - assert len(res) == 2 - assert isinstance(res[0], str) - assert isinstance(res[1], str) - - sampling_params = SamplingParams(seed=2) - res1 = generator("test", sampling_params=sampling_params) - res2 = generator("test", sampling_params=sampling_params) - assert res1 == res2 - - -def test_vllm_greedy_sampling(model): - sampler = samplers.greedy() - generator = generate.text(model, sampler) - res = generator("test") - assert isinstance(res, str) - - -def test_vllm_multinomial_sampling(model): - sampler = samplers.multinomial() - generator = generate.text(model, sampler) - res = generator("test") - assert isinstance(res, str) - - sampler = samplers.multinomial(3) - generator = generate.text(model, sampler) - res = generator("test") - assert len(res) == 3 - assert isinstance(res[0], str) - assert isinstance(res[1], str) - - sampler = samplers.multinomial(2, top_k=1) - generator = generate.text(model, sampler) - res = generator("test") - assert res[0] == res[1] - - sampler = samplers.multinomial(1, top_p=0.5) - generator = generate.text(model, sampler) - res = generator("test") - assert isinstance(res, str) - - sampler = samplers.multinomial(2, temperature=0.00001) - generator = generate.text(model, sampler) - res = generator("test") - assert res[0] == res[1] - - -def test_vllm_beam_search(model): - sampler = samplers.beam_search(1) - generator = generate.text(model, sampler) - res1 = generator("test") - sampler = samplers.greedy() - generator = generate.text(model, sampler) - res2 = generator("test") - assert res1 == res2 - - sampler = samplers.beam_search(2) - generator = generate.text(model, sampler) - res = generator("test") - assert len(res) == 2 - assert res[0] != res[1] - - -def test_vllm_text_stop(model): - prompt = "Write a short sentence containing 'You': " - sequence = generate.text(model)(prompt, max_tokens=100, seed=10) - assert sequence.find("news") != -1 - - sequence = generate.text(model)(prompt, stop_at="news", max_tokens=100, seed=10) - assert isinstance(sequence, str) - assert sequence.find("news") == -1 - - -def test_vllm_regex(model): - prompt = "Write an email address: " - regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" - generator = generate.regex(model, regex_str) - - # One prompt - sequence = generator(prompts=prompt) - assert isinstance(sequence, str) - assert re.fullmatch(pattern=regex_str, string=sequence) is not None - - -def test_vllm_integer(model): - prompt = "Give me an integer: " - sequence = generate.format(model, int)(prompt, max_tokens=10) - assert isinstance(sequence, int) - assert sequence != "" - int(sequence) - - -def test_vllm_float(model): - prompt = "Give me a floating-point number: " - sequence = generate.format(model, float)(prompt, max_tokens=10) - assert isinstance(sequence, float) - - assert sequence != "" - float(sequence) - - -def test_vllm_bool(model): - prompt = "Is this True or False? " - sequence = generate.format(model, bool)(prompt, max_tokens=10) - assert isinstance(sequence, bool) - - assert sequence != "" - bool(sequence) - - -def test_vllm_date(model): - prompt = "What day is it today? " - sequence = generate.format(model, datetime.date)(prompt, max_tokens=10) - assert isinstance(sequence, datetime.date) - - -def test_vllm_time(model): - prompt = "What time is it? " - sequence = generate.format(model, datetime.time)(prompt, max_tokens=10) - assert isinstance(sequence, datetime.time) - - -def test_vllm_datetime(model): - prompt = "What time is it? " - sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20) - assert isinstance(sequence, datetime.datetime) - - -def test_vllm_choice(model): - prompt = "Which one between 'test' and 'choice'? " - sequence = generate.choice(model, ["test", "choice"])(prompt) - assert sequence == "test" or sequence == "choice" - - -def test_vllm_json_basic(model): - prompt = "Output some JSON. " - - class Spam(BaseModel): - spam: constr(max_length=10) - fuzz: bool - - sampling_params = SamplingParams(temperature=0) - result = generate.json(model, Spam, whitespace_pattern="")( - prompt, max_tokens=100, seed=1, sampling_params=sampling_params - ) - assert isinstance(result, BaseModel) - assert isinstance(result.spam, str) - assert isinstance(result.fuzz, bool) - assert len(result.spam) <= 10 - - -def test_vllm_json_schema(model): - prompt = "Output some JSON. " - - schema = """{ - "title": "spam", - "type": "object", - "properties": { - "foo" : {"type": "boolean"}, - "bar": {"type": "string", "maxLength": 4} - }, - "required": ["foo", "bar"] - } - """ - - sampling_params = SamplingParams(temperature=0) - result = generate.json(model, schema, whitespace_pattern="")( - prompt, max_tokens=100, seed=10, sampling_params=sampling_params - ) - assert isinstance(result, dict) - assert isinstance(result["foo"], bool) - assert isinstance(result["bar"], str) - - -@pytest.mark.xfail( - reason="The CFG logits processor for vLLM has not been implemented yet." -) -def test_vllm_cfg(model): - prompt = "<|im_start|>user\nOutput a short and valid JSON object with two keys.<|im_end|>\n><|im_start|>assistant\n" - result = generate.cfg(model, grammars.arithmetic)(prompt, seed=11) - assert isinstance(result, str) diff --git a/tests/models/test_vllm.py b/tests/models/test_vllm.py new file mode 100644 index 000000000..bf598ae0c --- /dev/null +++ b/tests/models/test_vllm.py @@ -0,0 +1,125 @@ +import re +from enum import Enum + +import pytest +from pydantic import BaseModel +from vllm import LLM, SamplingParams + +import outlines +from outlines.models.vllm import VLLM +from outlines.types import Choice, JsonType, Regex + +TEST_MODEL = "erwanf/gpt2-mini" + + +@pytest.fixture(scope="session") +def model(tmp_path_factory): + model = outlines.from_vllm(LLM(TEST_MODEL)) + return model + + +def test_vllm_simple(model): + result = model.generate("Respond with one word. Not more.", None) + assert isinstance(result, str) + + +def test_vllm_call(model): + result = model("Respond with one word. Not more.") + assert isinstance(result, str) + + +def test_vllm_inference_kwargs(model): + result = model("Write a short story about a cat.", sampling_params=SamplingParams(max_tokens=2), use_tqdm=True) + assert isinstance(result, str) + assert len(result) <= 20 + + +def test_vllm_invalid_inference_kwargs(model): + with pytest.raises(TypeError): + model("Respond with one word. Not more.", foo="bar") + + +def test_vllm_regex(model): + result = model("Give a number between 0 and 9.", Regex(r"[0-9]")) + assert isinstance(result, str) + assert re.match(r"[0-9]", result) + + +def test_vllm_json(model): + class Character(BaseModel): + name: str + + result = model("Create a character with a name.", JsonType(Character)) + assert "name" in result + + +def test_vllm_choice(model): + class Foo(Enum): + cat = "cat" + dog = "dog" + + result = model("Cat or dog?", Choice(Foo)) + assert result in ["cat", "dog"] + + +def test_vllm_batch_samples(model): + result = model( + "Respond with one word. Not more.", + sampling_params=SamplingParams(n=2) + ) + assert isinstance(result, list) + assert len(result) == 2 + + result = model( + ["Respond with one word. Not more.", "Respond with one word. Not more."] + ) + assert isinstance(result, list) + assert len(result) == 2 + + result = model( + ["Respond with one word. Not more.", "Respond with one word. Not more."], + sampling_params=SamplingParams(n=2) + ) + assert isinstance(result, list) + assert len(result) == 2 + for item in result: + assert isinstance(item, list) + assert len(item) == 2 + + +def test_vllm_batch_samples_constrained(model): + class Foo(Enum): + cat = "cat" + dog = "dog" + + result = model( + "Cat or dog?", + Choice(Foo), + sampling_params=SamplingParams(n=2) + ) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0] in ["cat", "dog"] + assert result[1] in ["cat", "dog"] + + result = model( + ["Cat or dog?", "Cat or dog?"], + Choice(Foo), + ) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0] in ["cat", "dog"] + assert result[1] in ["cat", "dog"] + + result = model( + ["Cat or dog?", "Cat or dog?"], + Choice(Foo), + sampling_params=SamplingParams(n=2) + ) + assert isinstance(result, list) + assert len(result) == 2 + for item in result: + assert isinstance(item, list) + assert len(item) == 2 + assert item[0] in ["cat", "dog"] + assert item[1] in ["cat", "dog"]