forked from dottxt-ai/outlines
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor the TransformersVision model
- Loading branch information
Robin Picard
committed
Feb 19, 2025
1 parent
baa8fcd
commit 4866f7f
Showing
2 changed files
with
254 additions
and
120 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,138 +1,156 @@ | ||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union | ||
from functools import singledispatchmethod | ||
|
||
from outlines.generate.api import GenerationParameters, SamplingParameters | ||
from outlines.models import Transformers | ||
from outlines.models.base import ModelTypeAdapter | ||
|
||
if TYPE_CHECKING: | ||
from outlines.processors import OutlinesLogitsProcessor | ||
|
||
class TransformersVisionTypeAdapter(ModelTypeAdapter): | ||
"""Type adapter for TransformersVision models.""" | ||
|
||
@singledispatchmethod | ||
def format_input(self, model_input): | ||
"""Generate the prompt argument to pass to the model. | ||
Argument | ||
-------- | ||
model_input | ||
The input passed by the user. | ||
""" | ||
raise NotImplementedError( | ||
f"The input type {input} is not available." | ||
"Please provide a 2 element tuple containing " | ||
"a prompt or list of prompts and a PIL image " | ||
"or a list of PIL images. The number of prompts " | ||
"must match the number of images." | ||
) | ||
|
||
@format_input.register(tuple) | ||
def format_list_input(self, model_input): | ||
try: | ||
prompts, images = model_input | ||
except ValueError: | ||
raise TypeError("The input must be a tuple of exactly 2 elements.") | ||
if (isinstance(prompts, str) and isinstance(images, list)) or ( | ||
isinstance(prompts, list) and not isinstance(images, list) | ||
): | ||
raise TypeError( | ||
"You must either provide a single prompt and a single image " | ||
"or a list of prompts and a list of images." | ||
) | ||
if ( | ||
isinstance(prompts, list) | ||
and isinstance(images, list) | ||
and len(prompts) != len(images) | ||
): | ||
raise TypeError("The number of prompts must match the number of images.") | ||
return prompts, images | ||
|
||
def format_output_type(self, output_type): | ||
"""Generate the logits processor argument to pass to the model. | ||
Argument | ||
-------- | ||
output_type | ||
The logits processor provided. | ||
""" | ||
from transformers import LogitsProcessorList | ||
|
||
if output_type is not None: | ||
return LogitsProcessorList([output_type]) | ||
return None | ||
|
||
|
||
class TransformersVision(Transformers): | ||
def __init__(self, model, tokenizer, processor): | ||
super().__init__(model, tokenizer) | ||
self.processor = processor | ||
"""Represents a `transformers` model with a vision processor.""" | ||
|
||
def generate( # type: ignore | ||
def __init__( | ||
self, | ||
prompts: Union[str, List[str]], | ||
media: Union[List[Any], List[List[Any]]], | ||
generation_parameters: GenerationParameters, | ||
logits_processor: Optional["OutlinesLogitsProcessor"], | ||
sampling_parameters: SamplingParameters, | ||
) -> Union[str, List[str], List[List[str]]]: | ||
"""Generate text using `transformers`. | ||
Arguments | ||
--------- | ||
prompts | ||
A prompt or list of prompts. | ||
media | ||
A List[PIL.Image] or List[List[PIL.Image]] | ||
generation_parameters | ||
An instance of `GenerationParameters` that contains the prompt, | ||
the maximum number of tokens, stop sequences and seed. All the | ||
arguments to `SequenceGeneratorAdapter`'s `__cal__` method. | ||
logits_processor | ||
The logits processor to use when generating text. | ||
sampling_parameters | ||
An instance of `SamplingParameters`, a dataclass that contains | ||
the name of the sampler to use and related parameters as available | ||
in Outlines. | ||
Returns | ||
------- | ||
The generated text | ||
model_name: str, | ||
model_class, | ||
model_kwargs: dict = {}, | ||
tokenizer_class=None, | ||
tokenizer_kwargs: dict = {}, | ||
processor_class=None, | ||
processor_kwargs: dict = {}, | ||
): | ||
"""Create a TransformersVision model instance | ||
We rely on the `__init__` method of the `Transformers` class to handle | ||
most of the initialization and then add elements specific to vision | ||
models. | ||
Parameters | ||
---------- | ||
model_name | ||
The name of the transformers model to use; | ||
model_class | ||
The Transformers model class from which to create the model. | ||
model_kwargs | ||
A dictionary of keyword arguments to pass to the `from_pretrained` | ||
method of the model class. | ||
tokenizer_class | ||
The Transformers tokenizer class from which to create the tokenizer. | ||
If not provided,`AutoTokenizer` will be used. | ||
If you gave the name of a model that is not compatible with `AutoTokenizer`, | ||
you must provide a value for this parameter. | ||
tokenizer_kwargs | ||
A dictionary of keyword arguments to pass to the `from_pretrained` | ||
method of the tokenizer class. | ||
processor_class | ||
The Transformers processor class from which to create the processor. | ||
If not provided,`AutoProcessor` will be used. | ||
If you gave the name of a model that is not compatible with `AutoProcessor`, | ||
you must provide a value for this parameter. | ||
processor_kwargs | ||
A dictionary of keyword arguments to pass to the `from_pretrained` | ||
method of the processor class. | ||
""" | ||
if processor_class is None: | ||
try: | ||
from transformers import AutoProcessor | ||
|
||
processor_class = AutoProcessor | ||
except ImportError: | ||
raise ImportError( | ||
"The `transformers` library needs to be installed in order to use `transformers` models." | ||
) | ||
|
||
processor_kwargs.setdefault("padding_side", "left") | ||
processor_kwargs.setdefault("pad_token", "[PAD]") | ||
self.processor = processor_class.from_pretrained(model_name, **processor_kwargs) | ||
|
||
if tokenizer_class is None and getattr(self.processor, "tokenizer", None): | ||
tokenizer_class = type(self.processor.tokenizer) | ||
|
||
super().__init__( | ||
model_name, | ||
model_class, | ||
model_kwargs, | ||
tokenizer_class, | ||
tokenizer_kwargs, | ||
) | ||
|
||
self.type_adapter = TransformersVisionTypeAdapter() | ||
|
||
def generate(self, model_input, output_type, **inference_kwargs): | ||
prompts, images = self.type_adapter.format_input(model_input) | ||
|
||
inputs = self.processor( | ||
text=prompts, images=media, padding=True, return_tensors="pt" | ||
text=prompts, images=images, padding=True, return_tensors="pt" | ||
).to(self.model.device) | ||
|
||
generation_kwargs = self._get_generation_kwargs( | ||
prompts, | ||
generation_parameters, | ||
logits_processor, | ||
sampling_parameters, | ||
) | ||
generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs) | ||
generated_ids = self._generate_output_seq(prompts, inputs, **inference_kwargs) | ||
|
||
# if single str input and single sample per input, convert to a 1D output | ||
# if single str input, convert to a 1D outputt | ||
if isinstance(prompts, str): | ||
# Should always be true until NotImplementedError above is fixed | ||
generated_ids = generated_ids.squeeze(0) | ||
|
||
return self._decode_generation(generated_ids) | ||
|
||
def stream( # type: ignore | ||
self, | ||
prompts: Union[str, List[str]], | ||
media: Union[Any, List[Any]], # TODO: docstring | ||
generation_parameters: GenerationParameters, | ||
logits_processor: Optional["OutlinesLogitsProcessor"], | ||
sampling_parameters: SamplingParameters, | ||
) -> Iterator[Union[str, List[str]]]: | ||
raise NotImplementedError | ||
|
||
|
||
def transformers_vision( | ||
model_name: str, | ||
model_class, | ||
device: Optional[str] = None, | ||
model_kwargs: dict = {}, | ||
processor_kwargs: dict = {}, | ||
tokenizer_class=None, | ||
processor_class=None, | ||
): | ||
"""Instantiate a model from the `transformers` library and its tokenizer. | ||
Parameters | ||
---------- | ||
model_name | ||
The name of the model as listed on Hugging Face's model page. | ||
model_class | ||
The `PreTrainedModel` class from transformers to use in initializing the vision model from `model_name`. | ||
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel | ||
device | ||
The device(s) on which the model should be loaded. This overrides | ||
the `device_map` entry in `model_kwargs` when provided. | ||
model_kwargs | ||
A dictionary that contains the keyword arguments to pass to the | ||
`from_pretrained` method when loading the model. | ||
processor_kwargs | ||
A dictionary that contains the keyword arguments to pass to the | ||
`from_pretrained` method when loading the processor. | ||
Returns | ||
------- | ||
A `TransformersModel` model instance. | ||
""" | ||
if processor_class is None or tokenizer_class is None: | ||
try: | ||
from transformers import AutoProcessor, AutoTokenizer | ||
except ImportError: | ||
raise ImportError( | ||
"The `transformers` library needs to be installed in order to use `transformers` models." | ||
) | ||
if processor_class is None: | ||
processor_class = AutoProcessor | ||
if tokenizer_class is None: | ||
tokenizer_class = AutoTokenizer | ||
|
||
if device is not None: | ||
model_kwargs["device_map"] = device | ||
|
||
model = model_class.from_pretrained(model_name, **model_kwargs) | ||
|
||
processor_kwargs.setdefault("padding_side", "left") | ||
processor_kwargs.setdefault("pad_token", "[PAD]") | ||
processor = processor_class.from_pretrained(model_name, **processor_kwargs) | ||
|
||
if tokenizer_class is None: | ||
if getattr(processor, "tokenizer", None): | ||
tokenizer = processor.tokenizer | ||
else: | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, **processor_kwargs) | ||
else: | ||
tokenizer = tokenizer_class.from_pretrained(model_name, **processor_kwargs) | ||
|
||
return TransformersVision(model, tokenizer, processor) | ||
def stream(self, model_input, output_type, **inference_kwargs): | ||
raise NotImplementedError( | ||
"Streaming is not implemented for TransformersVision models." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from enum import Enum | ||
|
||
import numpy as np | ||
import pytest | ||
from PIL import Image | ||
from transformers import Blip2ForConditionalGeneration, CLIPModel, CLIPProcessor | ||
|
||
from outlines.models.transformers_vision import TransformersVision | ||
from outlines.processors import RegexLogitsProcessor | ||
from outlines.types import Choice | ||
|
||
TEST_MODEL = "hf-internal-testing/tiny-random-Blip2Model" | ||
TEST_CLIP_MODEL = "openai/clip-vit-base-patch32" | ||
|
||
|
||
@pytest.fixture | ||
def image(): | ||
return Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)) | ||
|
||
|
||
@pytest.fixture | ||
def model(): | ||
return TransformersVision( | ||
model_name=TEST_MODEL, model_class=Blip2ForConditionalGeneration | ||
) | ||
|
||
|
||
def test_transformers_vision_instantiate_simple(): | ||
model = TransformersVision( | ||
model_name=TEST_MODEL, | ||
model_class=Blip2ForConditionalGeneration, | ||
) | ||
assert isinstance(model, TransformersVision) | ||
|
||
|
||
def test_transformers_vision_instantiate_other_processor_class(): | ||
model = TransformersVision( | ||
model_name=TEST_CLIP_MODEL, model_class=CLIPModel, processor_class=CLIPProcessor | ||
) | ||
assert isinstance(model, TransformersVision) | ||
|
||
|
||
def test_transformers_vision_simple(model, image): | ||
result = model.generate(("Describe this image in one sentence:", image), None) | ||
assert isinstance(result, str) | ||
|
||
|
||
def test_transformers_vision_call(model, image): | ||
result = model(("Describe this image in one sentence:", image)) | ||
assert isinstance(result, str) | ||
|
||
|
||
def test_transformers_vision_wrong_input_type(model): | ||
with pytest.raises(NotImplementedError): | ||
model.generate("invalid input", None) | ||
|
||
|
||
def test_transformers_inference_kwargs(model, image): | ||
result = model(("Describe this image in one sentence:", image), max_new_tokens=100) | ||
assert isinstance(result, str) | ||
|
||
|
||
def test_transformers_invalid_inference_kwargs(model): | ||
with pytest.raises(ValueError): | ||
model(("Describe this image in one sentence:", image), foo="bar") | ||
|
||
|
||
def test_transformers_vision_choice(model, image): | ||
class Foo(Enum): | ||
white = "white" | ||
black = "black" | ||
|
||
regex_str = Choice(Foo).to_regex() | ||
logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) | ||
result = model.generate(("Is this image white or black?", image), logits_processor) | ||
|
||
assert isinstance(result, str) | ||
|
||
|
||
def test_transformers_vision_batch_input_samples(model, image): | ||
result = model(("Describe this image in one sentence.", image)) | ||
assert isinstance(result, str) | ||
result = model( | ||
("Describe this image in one sentence.", image), | ||
num_return_sequences=2, | ||
num_beams=2, | ||
) | ||
assert isinstance(result, list) | ||
assert len(result) == 2 | ||
result = model( | ||
( | ||
[ | ||
"Describe this image in one sentence.", | ||
"Describe this image in one sentence.", | ||
], | ||
[image, image], | ||
) | ||
) | ||
assert isinstance(result, list) | ||
assert len(result) == 2 | ||
result = model( | ||
( | ||
[ | ||
"Describe this image in one sentence.", | ||
"Describe this image in one sentence.", | ||
], | ||
[image, image], | ||
), | ||
num_return_sequences=2, | ||
num_beams=2, | ||
) | ||
assert isinstance(result, list) | ||
assert len(result) == 2 | ||
for item in result: | ||
assert isinstance(item, list) | ||
assert len(item) == 2 |