Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the VLLM model #1453

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions docs/reference/models/vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

Outlines supports models available via vLLM's offline batched inference interface. You can load a model using:


```python
import outlines
from vllm import LLM

model = outlines.from_vllm(LLM("microsoft/Phi-3-mini-4k-instruct"))
```

Expand All @@ -28,18 +28,37 @@ Models are loaded from the [HuggingFace hub](https://huggingface.co/).

## Generate text

In addition to the parameters described in the [text generation section](../text.md) you can pass an instance of `SamplingParams` directly to any generator via the `sampling_params` keyword argument:
To generate text, you can just call the model with a prompt as argument:

```python
from vllm.sampling_params import SamplingParams
from outlines import models, generate
import outlines
from vllm import LLM

model = outlines.from_vllm(LLM("microsoft/Phi-3-mini-4k-instruct"))
answer = model("Write a short story about a cat.")
```

You can also use structured generation with the `VLLM` model by providing an output type after the prompt:

```python
import outlines
from vllm import LLM
from outlines.types import JsonType
from pydantic import BaseModel

class Character(BaseModel):
name: str

model = outlines.from_vllm(LLM("microsoft/Phi-3-mini-4k-instruct"))
params = SamplingParams(n=2, frequency_penalty=1., min_tokens=2)
answer = model("A prompt", sampling_params=params)
answer = model("Create a character.", output_type=JsonType(Character))
```

The VLLM model supports batch generation. To use it, you can pass a list of strings as prompt instead of a single string.

## Optional parameters

When calling the model, you can provide optional parameters on top of the prompt and the output type. Those will be passed on to the `LLM.generate` method of the `vllm` library. An optional parameter of particular interest is `sampling_params`, which is an instance of `SamplingParams`. You can find more information about it in the [vLLM documentation][https://docs.vllm.ai/en/latest/api/inference_params.html].

!!! Warning

Streaming is not available for the offline vLLM integration.
Expand Down
159 changes: 64 additions & 95 deletions outlines/models/vllm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses
from functools import singledispatchmethod
from typing import TYPE_CHECKING, List, Optional, Union

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.models.base import Model, ModelTypeAdapter

if TYPE_CHECKING:
from transformers import PreTrainedTokenizerBase
Expand All @@ -12,21 +14,59 @@
__all__ = ["VLLM", "from_vllm"]


class VLLM:
"""Represents a vLLM model.
class VLLMTypeAdapter(ModelTypeAdapter):
@singledispatchmethod
def format_input(self, model_input):
"""Generate the prompt argument to pass to the model.

We wrap models from model providing libraries in order to give all of
them the same interface in Outlines and allow users to easily switch
between providers. This class wraps the `vllm.LLM` class from the
`vllm` library.
Argument
--------
model_input
The input passed by the user.

"""
"""
raise NotImplementedError(
f"The input type {input} is not available."
"Please use a string or a list of strings."
)

@format_input.register(str)
def format_str_input(self, model_input):
return model_input

@format_input.register(list)
def format_list_input(self, model_input):
return model_input

def format_output_type(self, output_type):
"""Generate the logits processor argument to pass to the model.

Argument
--------
output_type
The logits processor provided.

"""
if output_type is not None:
return [output_type]
return []


class VLLM(Model):
"""Represents a `vllm` model."""

def __init__(self, model: "LLM"):
self.model = model
self.lora_request = None
"""Create a VLLM model instance.

Parameters
----------
model
A `vllm.LLM` model instance.

"""
self.model = model
self.tokenizer = self._get_tokenizer()
self.type_adapter = VLLMTypeAdapter()

def _get_tokenizer(self):
if hasattr(self.model, "get_tokenizer"):
Expand All @@ -43,98 +83,35 @@ def _get_tokenizer(self):
)
return adapt_tokenizer(tokenizer=tokenizer)

def generate(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
logits_processor,
sampling_parameters: SamplingParameters,
*,
sampling_params: Optional["SamplingParams"] = None,
use_tqdm: bool = True,
):
"""Generate text using vLLM.
def generate(self, model_input, output_type, **inference_kwargs):
"""Generate text using `vllm`.

Parameters
----------
prompts
A prompt or list of prompts.
generation_parameters
An instance of `GenerationParameters` that contains the prompt,
the maximum number of tokens, stop sequences and seed. All the
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
Arguments
---------
prompt
The text prompt provided by the user.
logits_processor
The logits processor to use when generating text.
sampling_parameters
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
sampling_params
An instance of `vllm.sampling_params.SamplingParams`. The values
passed via this dataclass supersede the values of the parameters
in `generation_parameters` and `sampling_parameters`. See the
vLLM documentation for more details: https://docs.vllm.ai/en/latest/dev/sampling_params.html.
use_tqdm
A boolean in order to display progress bar while inferencing
inference_kwargs
The inference kwargs that can be passed to the `LLM.generate` method
in the `vllm` library.

Returns
-------
The generated text, of shape `(n_batch, n_samples)`. If there are only
one batch and several samples, the list is of shape `(n_samples)`. If
this is a batch with several sequences but only one sample the list is
of shape `(n_batch)`. If there is only one sequence and one sample, a
string is returned.
The generated text.

"""
from vllm.sampling_params import SamplingParams

sampling_params = inference_kwargs.pop("sampling_params", None)
if sampling_params is None:
sampling_params = SamplingParams()

max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)

# We only update the values in `sampling_params` if they
# are specified by the user when calling the generator.
if max_tokens is not None:
sampling_params.max_tokens = max_tokens
if stop_at is not None:
if isinstance(stop_at, str):
stop_at = [stop_at]
sampling_params.stop = stop_at
if seed is not None:
sampling_params.seed = seed

sampling_params.logits_processors = (
[logits_processor] if logits_processor is not None else []
)

sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
sampling_parameters
)

# We only update the values in `sampling_params` that
# were not specified by the user.
if sampling_params.n == 1:
sampling_params.n = num_samples
sampling_params.best_of = num_samples
if top_p is not None and sampling_params.top_p == 1.0:
sampling_params.top_p = top_p
if top_k is not None and sampling_params.top_k == -1:
sampling_params.top_k = top_k
# TODO: remove this if statement once fixed
# https://github.com/vllm-project/vllm/issues/5404#issuecomment-2175972897
if top_k == 1:
sampling_params.repetition_penalty = 0
if temperature is not None and sampling_params.temperature == 1.0:
sampling_params.temperature = temperature
if sampler == "beam_search":
sampling_params.use_beam_search = True
sampling_params.logits_processors = self.type_adapter.format_output_type(output_type)

results = self.model.generate(
prompts,
self.type_adapter.format_input(model_input),
sampling_params=sampling_params,
lora_request=self.lora_request,
use_tqdm=use_tqdm,
**inference_kwargs,
)
results = [[sample.text for sample in batch.outputs] for batch in results]

Expand All @@ -150,7 +127,7 @@ def generate(

return results

def stream(self, *args, **kwargs):
def stream(self, model_input, output_type, **inference_kwargs):
"""Return a text generator.

Streaming is not yet available for `vllm.LLM`.
Expand All @@ -162,14 +139,6 @@ def stream(self, *args, **kwargs):
"Streaming is not available for the vLLM integration."
)

def load_lora(self, adapter_path: Optional[str]):
from vllm.lora.request import LoRARequest

if adapter_path is None:
self.lora_request = None
else:
self.lora_request = LoRARequest(adapter_path, 1, adapter_path)


def from_vllm(model: "LLM") -> VLLM:
return VLLM(model)
Expand Down
7 changes: 3 additions & 4 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,16 @@ def model_t5(tmp_path_factory):
ALL_MODEL_FIXTURES = (
# "model_llamacpp", # temporary disabled due to the v1 model refactoring
"model_exllamav2",
"model_mlxlm",
"model_mlxlm_phi3",
# "model_mlxlm",
# "model_mlxlm_phi3",
# "model_transformers_random",
# "model_transformers_opt125m",
# "model_mamba",
# "model_bart",
# "model_transformers_vision",
"model_vllm",
# "model_vllm",
)


class MyEnum(Enum):
foo = "foo"
bar = "bar"
Expand Down
Loading
Loading