Skip to content

Commit

Permalink
Refactor the TransformersVision model
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin Picard committed Feb 19, 2025
1 parent baa8fcd commit 4866f7f
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 120 deletions.
258 changes: 138 additions & 120 deletions outlines/models/transformers_vision.py
Original file line number Diff line number Diff line change
@@ -1,138 +1,156 @@
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
from functools import singledispatchmethod

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.models import Transformers
from outlines.models.base import ModelTypeAdapter

if TYPE_CHECKING:
from outlines.processors import OutlinesLogitsProcessor

class TransformersVisionTypeAdapter(ModelTypeAdapter):
"""Type adapter for TransformersVision models."""

@singledispatchmethod
def format_input(self, model_input):
"""Generate the prompt argument to pass to the model.
Argument
--------
model_input
The input passed by the user.
"""
raise NotImplementedError(
f"The input type {input} is not available."
"Please provide a 2 element tuple containing "
"a prompt or list of prompts and a PIL image "
"or a list of PIL images. The number of prompts "
"must match the number of images."
)

@format_input.register(tuple)
def format_list_input(self, model_input):
try:
prompts, images = model_input
except ValueError:
raise TypeError("The input must be a tuple of exactly 2 elements.")
if (isinstance(prompts, str) and isinstance(images, list)) or (
isinstance(prompts, list) and not isinstance(images, list)
):
raise TypeError(
"You must either provide a single prompt and a single image "
"or a list of prompts and a list of images."
)
if (
isinstance(prompts, list)
and isinstance(images, list)
and len(prompts) != len(images)
):
raise TypeError("The number of prompts must match the number of images.")
return prompts, images

def format_output_type(self, output_type):
"""Generate the logits processor argument to pass to the model.
Argument
--------
output_type
The logits processor provided.
"""
from transformers import LogitsProcessorList

if output_type is not None:
return LogitsProcessorList([output_type])
return None


class TransformersVision(Transformers):
def __init__(self, model, tokenizer, processor):
super().__init__(model, tokenizer)
self.processor = processor
"""Represents a `transformers` model with a vision processor."""

def generate( # type: ignore
def __init__(
self,
prompts: Union[str, List[str]],
media: Union[List[Any], List[List[Any]]],
generation_parameters: GenerationParameters,
logits_processor: Optional["OutlinesLogitsProcessor"],
sampling_parameters: SamplingParameters,
) -> Union[str, List[str], List[List[str]]]:
"""Generate text using `transformers`.
Arguments
---------
prompts
A prompt or list of prompts.
media
A List[PIL.Image] or List[List[PIL.Image]]
generation_parameters
An instance of `GenerationParameters` that contains the prompt,
the maximum number of tokens, stop sequences and seed. All the
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
logits_processor
The logits processor to use when generating text.
sampling_parameters
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
Returns
-------
The generated text
model_name: str,
model_class,
model_kwargs: dict = {},
tokenizer_class=None,
tokenizer_kwargs: dict = {},
processor_class=None,
processor_kwargs: dict = {},
):
"""Create a TransformersVision model instance
We rely on the `__init__` method of the `Transformers` class to handle
most of the initialization and then add elements specific to vision
models.
Parameters
----------
model_name
The name of the transformers model to use;
model_class
The Transformers model class from which to create the model.
model_kwargs
A dictionary of keyword arguments to pass to the `from_pretrained`
method of the model class.
tokenizer_class
The Transformers tokenizer class from which to create the tokenizer.
If not provided,`AutoTokenizer` will be used.
If you gave the name of a model that is not compatible with `AutoTokenizer`,
you must provide a value for this parameter.
tokenizer_kwargs
A dictionary of keyword arguments to pass to the `from_pretrained`
method of the tokenizer class.
processor_class
The Transformers processor class from which to create the processor.
If not provided,`AutoProcessor` will be used.
If you gave the name of a model that is not compatible with `AutoProcessor`,
you must provide a value for this parameter.
processor_kwargs
A dictionary of keyword arguments to pass to the `from_pretrained`
method of the processor class.
"""
if processor_class is None:
try:
from transformers import AutoProcessor

processor_class = AutoProcessor
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)

processor_kwargs.setdefault("padding_side", "left")
processor_kwargs.setdefault("pad_token", "[PAD]")
self.processor = processor_class.from_pretrained(model_name, **processor_kwargs)

if tokenizer_class is None and getattr(self.processor, "tokenizer", None):
tokenizer_class = type(self.processor.tokenizer)

super().__init__(
model_name,
model_class,
model_kwargs,
tokenizer_class,
tokenizer_kwargs,
)

self.type_adapter = TransformersVisionTypeAdapter()

def generate(self, model_input, output_type, **inference_kwargs):
prompts, images = self.type_adapter.format_input(model_input)

inputs = self.processor(
text=prompts, images=media, padding=True, return_tensors="pt"
text=prompts, images=images, padding=True, return_tensors="pt"
).to(self.model.device)

generation_kwargs = self._get_generation_kwargs(
prompts,
generation_parameters,
logits_processor,
sampling_parameters,
)
generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs)
generated_ids = self._generate_output_seq(prompts, inputs, **inference_kwargs)

# if single str input and single sample per input, convert to a 1D output
# if single str input, convert to a 1D outputt
if isinstance(prompts, str):
# Should always be true until NotImplementedError above is fixed
generated_ids = generated_ids.squeeze(0)

return self._decode_generation(generated_ids)

def stream( # type: ignore
self,
prompts: Union[str, List[str]],
media: Union[Any, List[Any]], # TODO: docstring
generation_parameters: GenerationParameters,
logits_processor: Optional["OutlinesLogitsProcessor"],
sampling_parameters: SamplingParameters,
) -> Iterator[Union[str, List[str]]]:
raise NotImplementedError


def transformers_vision(
model_name: str,
model_class,
device: Optional[str] = None,
model_kwargs: dict = {},
processor_kwargs: dict = {},
tokenizer_class=None,
processor_class=None,
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Parameters
----------
model_name
The name of the model as listed on Hugging Face's model page.
model_class
The `PreTrainedModel` class from transformers to use in initializing the vision model from `model_name`.
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel
device
The device(s) on which the model should be loaded. This overrides
the `device_map` entry in `model_kwargs` when provided.
model_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the model.
processor_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the processor.
Returns
-------
A `TransformersModel` model instance.
"""
if processor_class is None or tokenizer_class is None:
try:
from transformers import AutoProcessor, AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if processor_class is None:
processor_class = AutoProcessor
if tokenizer_class is None:
tokenizer_class = AutoTokenizer

if device is not None:
model_kwargs["device_map"] = device

model = model_class.from_pretrained(model_name, **model_kwargs)

processor_kwargs.setdefault("padding_side", "left")
processor_kwargs.setdefault("pad_token", "[PAD]")
processor = processor_class.from_pretrained(model_name, **processor_kwargs)

if tokenizer_class is None:
if getattr(processor, "tokenizer", None):
tokenizer = processor.tokenizer
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, **processor_kwargs)
else:
tokenizer = tokenizer_class.from_pretrained(model_name, **processor_kwargs)

return TransformersVision(model, tokenizer, processor)
def stream(self, model_input, output_type, **inference_kwargs):
raise NotImplementedError(
"Streaming is not implemented for TransformersVision models."
)
116 changes: 116 additions & 0 deletions tests/models/test_transformers_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from enum import Enum

import numpy as np
import pytest
from PIL import Image
from transformers import Blip2ForConditionalGeneration, CLIPModel, CLIPProcessor

from outlines.models.transformers_vision import TransformersVision
from outlines.processors import RegexLogitsProcessor
from outlines.types import Choice

TEST_MODEL = "hf-internal-testing/tiny-random-Blip2Model"
TEST_CLIP_MODEL = "openai/clip-vit-base-patch32"


@pytest.fixture
def image():
return Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8))


@pytest.fixture
def model():
return TransformersVision(
model_name=TEST_MODEL, model_class=Blip2ForConditionalGeneration
)


def test_transformers_vision_instantiate_simple():
model = TransformersVision(
model_name=TEST_MODEL,
model_class=Blip2ForConditionalGeneration,
)
assert isinstance(model, TransformersVision)


def test_transformers_vision_instantiate_other_processor_class():
model = TransformersVision(
model_name=TEST_CLIP_MODEL, model_class=CLIPModel, processor_class=CLIPProcessor
)
assert isinstance(model, TransformersVision)


def test_transformers_vision_simple(model, image):
result = model.generate(("Describe this image in one sentence:", image), None)
assert isinstance(result, str)


def test_transformers_vision_call(model, image):
result = model(("Describe this image in one sentence:", image))
assert isinstance(result, str)


def test_transformers_vision_wrong_input_type(model):
with pytest.raises(NotImplementedError):
model.generate("invalid input", None)


def test_transformers_inference_kwargs(model, image):
result = model(("Describe this image in one sentence:", image), max_new_tokens=100)
assert isinstance(result, str)


def test_transformers_invalid_inference_kwargs(model):
with pytest.raises(ValueError):
model(("Describe this image in one sentence:", image), foo="bar")


def test_transformers_vision_choice(model, image):
class Foo(Enum):
white = "white"
black = "black"

regex_str = Choice(Foo).to_regex()
logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer)
result = model.generate(("Is this image white or black?", image), logits_processor)

assert isinstance(result, str)


def test_transformers_vision_batch_input_samples(model, image):
result = model(("Describe this image in one sentence.", image))
assert isinstance(result, str)
result = model(
("Describe this image in one sentence.", image),
num_return_sequences=2,
num_beams=2,
)
assert isinstance(result, list)
assert len(result) == 2
result = model(
(
[
"Describe this image in one sentence.",
"Describe this image in one sentence.",
],
[image, image],
)
)
assert isinstance(result, list)
assert len(result) == 2
result = model(
(
[
"Describe this image in one sentence.",
"Describe this image in one sentence.",
],
[image, image],
),
num_return_sequences=2,
num_beams=2,
)
assert isinstance(result, list)
assert len(result) == 2
for item in result:
assert isinstance(item, list)
assert len(item) == 2

0 comments on commit 4866f7f

Please sign in to comment.