Skip to content

Commit

Permalink
Refactor the VLLM model
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard committed Feb 27, 2025
1 parent 9cd81ed commit 0b4cb61
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 350 deletions.
31 changes: 25 additions & 6 deletions docs/reference/models/vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

Outlines supports models available via vLLM's offline batched inference interface. You can load a model using:


```python
import outlines
from vllm import LLM

model = outlines.from_vllm(LLM("microsoft/Phi-3-mini-4k-instruct"))
```

Expand All @@ -28,18 +28,37 @@ Models are loaded from the [HuggingFace hub](https://huggingface.co/).

## Generate text

In addition to the parameters described in the [text generation section](../text.md) you can pass an instance of `SamplingParams` directly to any generator via the `sampling_params` keyword argument:
To generate text, you can just call the model with a prompt as argument:

```python
from vllm.sampling_params import SamplingParams
from outlines import models, generate
import outlines
from vllm import LLM

model = outlines.from_vllm(LLM("microsoft/Phi-3-mini-4k-instruct"))
answer = model("Write a short story about a cat.")
```

You can also use structured generation with the `VLLM` model by providing an output type after the prompt:

```python
import outlines
from vllm import LLM
from outlines.types import JsonType
from pydantic import BaseModel

class Character(BaseModel):
name: str

model = outlines.from_vllm(LLM("microsoft/Phi-3-mini-4k-instruct"))
params = SamplingParams(n=2, frequency_penalty=1., min_tokens=2)
answer = model("A prompt", sampling_params=params)
answer = model("Create a character.", output_type=JsonType(Character))
```

The VLLM model supports batch generation. To use it, you can pass a list of strings as prompt instead of a single string.

## Optional parameters

When calling the model, you can provide optional parameters on top of the prompt and the output type. Those will be passed on to the `LLM.generate` method of the `vllm` library. An optional parameter of particular interest is `sampling_params`, which is an instance of `SamplingParams`. You can find more information about it in the [vLLM documentation][https://docs.vllm.ai/en/latest/api/inference_params.html].

!!! Warning

Streaming is not available for the offline vLLM integration.
Expand Down
159 changes: 64 additions & 95 deletions outlines/models/vllm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses
from functools import singledispatchmethod
from typing import TYPE_CHECKING, List, Optional, Union

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.models.base import Model, ModelTypeAdapter

if TYPE_CHECKING:
from transformers import PreTrainedTokenizerBase
Expand All @@ -12,21 +14,59 @@
__all__ = ["VLLM", "from_vllm"]


class VLLM:
"""Represents a vLLM model.
class VLLMTypeAdapter(ModelTypeAdapter):
@singledispatchmethod
def format_input(self, model_input):
"""Generate the prompt argument to pass to the model.
We wrap models from model providing libraries in order to give all of
them the same interface in Outlines and allow users to easily switch
between providers. This class wraps the `vllm.LLM` class from the
`vllm` library.
Argument
--------
model_input
The input passed by the user.
"""
"""
raise NotImplementedError(
f"The input type {input} is not available."
"Please use a string or a list of strings."
)

@format_input.register(str)
def format_str_input(self, model_input):
return model_input

@format_input.register(list)
def format_list_input(self, model_input):
return model_input

def format_output_type(self, output_type):
"""Generate the logits processor argument to pass to the model.
Argument
--------
output_type
The logits processor provided.
"""
if output_type is not None:
return [output_type]
return []


class VLLM(Model):
"""Represents a `vllm` model."""

def __init__(self, model: "LLM"):
self.model = model
self.lora_request = None
"""Create a VLLM model instance.
Parameters
----------
model
A `vllm.LLM` model instance.
"""
self.model = model
self.tokenizer = self._get_tokenizer()
self.type_adapter = VLLMTypeAdapter()

def _get_tokenizer(self):
if hasattr(self.model, "get_tokenizer"):
Expand All @@ -43,98 +83,35 @@ def _get_tokenizer(self):
)
return adapt_tokenizer(tokenizer=tokenizer)

def generate(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
logits_processor,
sampling_parameters: SamplingParameters,
*,
sampling_params: Optional["SamplingParams"] = None,
use_tqdm: bool = True,
):
"""Generate text using vLLM.
def generate(self, model_input, output_type, **inference_kwargs):
"""Generate text using `vllm`.
Parameters
----------
prompts
A prompt or list of prompts.
generation_parameters
An instance of `GenerationParameters` that contains the prompt,
the maximum number of tokens, stop sequences and seed. All the
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
Arguments
---------
prompt
The text prompt provided by the user.
logits_processor
The logits processor to use when generating text.
sampling_parameters
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
sampling_params
An instance of `vllm.sampling_params.SamplingParams`. The values
passed via this dataclass supersede the values of the parameters
in `generation_parameters` and `sampling_parameters`. See the
vLLM documentation for more details: https://docs.vllm.ai/en/latest/dev/sampling_params.html.
use_tqdm
A boolean in order to display progress bar while inferencing
inference_kwargs
The inference kwargs that can be passed to the `LLM.generate` method
in the `vllm` library.
Returns
-------
The generated text, of shape `(n_batch, n_samples)`. If there are only
one batch and several samples, the list is of shape `(n_samples)`. If
this is a batch with several sequences but only one sample the list is
of shape `(n_batch)`. If there is only one sequence and one sample, a
string is returned.
The generated text.
"""
from vllm.sampling_params import SamplingParams

sampling_params = inference_kwargs.pop("sampling_params", None)
if sampling_params is None:
sampling_params = SamplingParams()

max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)

# We only update the values in `sampling_params` if they
# are specified by the user when calling the generator.
if max_tokens is not None:
sampling_params.max_tokens = max_tokens
if stop_at is not None:
if isinstance(stop_at, str):
stop_at = [stop_at]
sampling_params.stop = stop_at
if seed is not None:
sampling_params.seed = seed

sampling_params.logits_processors = (
[logits_processor] if logits_processor is not None else []
)

sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
sampling_parameters
)

# We only update the values in `sampling_params` that
# were not specified by the user.
if sampling_params.n == 1:
sampling_params.n = num_samples
sampling_params.best_of = num_samples
if top_p is not None and sampling_params.top_p == 1.0:
sampling_params.top_p = top_p
if top_k is not None and sampling_params.top_k == -1:
sampling_params.top_k = top_k
# TODO: remove this if statement once fixed
# https://github.com/vllm-project/vllm/issues/5404#issuecomment-2175972897
if top_k == 1:
sampling_params.repetition_penalty = 0
if temperature is not None and sampling_params.temperature == 1.0:
sampling_params.temperature = temperature
if sampler == "beam_search":
sampling_params.use_beam_search = True
sampling_params.logits_processors = self.type_adapter.format_output_type(output_type)

results = self.model.generate(
prompts,
self.type_adapter.format_input(model_input),
sampling_params=sampling_params,
lora_request=self.lora_request,
use_tqdm=use_tqdm,
**inference_kwargs,
)
results = [[sample.text for sample in batch.outputs] for batch in results]

Expand All @@ -150,7 +127,7 @@ def generate(

return results

def stream(self, *args, **kwargs):
def stream(self, model_input, output_type, **inference_kwargs):
"""Return a text generator.
Streaming is not yet available for `vllm.LLM`.
Expand All @@ -162,14 +139,6 @@ def stream(self, *args, **kwargs):
"Streaming is not available for the vLLM integration."
)

def load_lora(self, adapter_path: Optional[str]):
from vllm.lora.request import LoRARequest

if adapter_path is None:
self.lora_request = None
else:
self.lora_request = LoRARequest(adapter_path, 1, adapter_path)


def from_vllm(model: "LLM") -> VLLM:
return VLLM(model)
Expand Down
7 changes: 3 additions & 4 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,16 @@ def model_t5(tmp_path_factory):
ALL_MODEL_FIXTURES = (
# "model_llamacpp", # temporary disabled due to the v1 model refactoring
"model_exllamav2",
"model_mlxlm",
"model_mlxlm_phi3",
# "model_mlxlm",
# "model_mlxlm_phi3",
# "model_transformers_random",
# "model_transformers_opt125m",
# "model_mamba",
# "model_bart",
# "model_transformers_vision",
"model_vllm",
# "model_vllm",
)


class MyEnum(Enum):
foo = "foo"
bar = "bar"
Expand Down
Loading

0 comments on commit 0b4cb61

Please sign in to comment.