diff --git a/outlines/models/transformers_vision.py b/outlines/models/transformers_vision.py index 772645b80..22b673815 100644 --- a/outlines/models/transformers_vision.py +++ b/outlines/models/transformers_vision.py @@ -1,138 +1,156 @@ -from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union +from functools import singledispatchmethod -from outlines.generate.api import GenerationParameters, SamplingParameters from outlines.models import Transformers +from outlines.models.base import ModelTypeAdapter -if TYPE_CHECKING: - from outlines.processors import OutlinesLogitsProcessor + +class TransformersVisionTypeAdapter(ModelTypeAdapter): + """Type adapter for TransformersVision models.""" + + @singledispatchmethod + def format_input(self, model_input): + """Generate the prompt argument to pass to the model. + + Argument + -------- + model_input + The input passed by the user. + + """ + raise NotImplementedError( + f"The input type {input} is not available." + "Please provide a 2 element tuple containing " + "a prompt or list of prompts and a PIL image " + "or a list of PIL images. The number of prompts " + "must match the number of images." + ) + + @format_input.register(tuple) + def format_list_input(self, model_input): + try: + prompts, images = model_input + except ValueError: + raise TypeError("The input must be a tuple of exactly 2 elements.") + if (isinstance(prompts, str) and isinstance(images, list)) or ( + isinstance(prompts, list) and not isinstance(images, list) + ): + raise TypeError( + "You must either provide a single prompt and a single image " + "or a list of prompts and a list of images." + ) + if ( + isinstance(prompts, list) + and isinstance(images, list) + and len(prompts) != len(images) + ): + raise TypeError("The number of prompts must match the number of images.") + return prompts, images + + def format_output_type(self, output_type): + """Generate the logits processor argument to pass to the model. + + Argument + -------- + output_type + The logits processor provided. + + """ + from transformers import LogitsProcessorList + + if output_type is not None: + return LogitsProcessorList([output_type]) + return None class TransformersVision(Transformers): - def __init__(self, model, tokenizer, processor): - super().__init__(model, tokenizer) - self.processor = processor + """Represents a `transformers` model with a vision processor.""" - def generate( # type: ignore + def __init__( self, - prompts: Union[str, List[str]], - media: Union[List[Any], List[List[Any]]], - generation_parameters: GenerationParameters, - logits_processor: Optional["OutlinesLogitsProcessor"], - sampling_parameters: SamplingParameters, - ) -> Union[str, List[str], List[List[str]]]: - """Generate text using `transformers`. - - Arguments - --------- - prompts - A prompt or list of prompts. - media - A List[PIL.Image] or List[List[PIL.Image]] - generation_parameters - An instance of `GenerationParameters` that contains the prompt, - the maximum number of tokens, stop sequences and seed. All the - arguments to `SequenceGeneratorAdapter`'s `__cal__` method. - logits_processor - The logits processor to use when generating text. - sampling_parameters - An instance of `SamplingParameters`, a dataclass that contains - the name of the sampler to use and related parameters as available - in Outlines. - - Returns - ------- - The generated text + model_name: str, + model_class, + model_kwargs: dict = {}, + tokenizer_class=None, + tokenizer_kwargs: dict = {}, + processor_class=None, + processor_kwargs: dict = {}, + ): + """Create a TransformersVision model instance + + We rely on the `__init__` method of the `Transformers` class to handle + most of the initialization and then add elements specific to vision + models. + + Parameters + ---------- + model_name + The name of the transformers model to use; + model_class + The Transformers model class from which to create the model. + model_kwargs + A dictionary of keyword arguments to pass to the `from_pretrained` + method of the model class. + tokenizer_class + The Transformers tokenizer class from which to create the tokenizer. + If not provided,`AutoTokenizer` will be used. + If you gave the name of a model that is not compatible with `AutoTokenizer`, + you must provide a value for this parameter. + tokenizer_kwargs + A dictionary of keyword arguments to pass to the `from_pretrained` + method of the tokenizer class. + processor_class + The Transformers processor class from which to create the processor. + If not provided,`AutoProcessor` will be used. + If you gave the name of a model that is not compatible with `AutoProcessor`, + you must provide a value for this parameter. + processor_kwargs + A dictionary of keyword arguments to pass to the `from_pretrained` + method of the processor class. + """ + if processor_class is None: + try: + from transformers import AutoProcessor + + processor_class = AutoProcessor + except ImportError: + raise ImportError( + "The `transformers` library needs to be installed in order to use `transformers` models." + ) + + processor_kwargs.setdefault("padding_side", "left") + processor_kwargs.setdefault("pad_token", "[PAD]") + self.processor = processor_class.from_pretrained(model_name, **processor_kwargs) + + if tokenizer_class is None and getattr(self.processor, "tokenizer", None): + tokenizer_class = type(self.processor.tokenizer) + + super().__init__( + model_name, + model_class, + model_kwargs, + tokenizer_class, + tokenizer_kwargs, + ) + + self.type_adapter = TransformersVisionTypeAdapter() + + def generate(self, model_input, output_type, **inference_kwargs): + prompts, images = self.type_adapter.format_input(model_input) + inputs = self.processor( - text=prompts, images=media, padding=True, return_tensors="pt" + text=prompts, images=images, padding=True, return_tensors="pt" ).to(self.model.device) - generation_kwargs = self._get_generation_kwargs( - prompts, - generation_parameters, - logits_processor, - sampling_parameters, - ) - generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs) + generated_ids = self._generate_output_seq(prompts, inputs, **inference_kwargs) - # if single str input and single sample per input, convert to a 1D output + # if single str input, convert to a 1D outputt if isinstance(prompts, str): - # Should always be true until NotImplementedError above is fixed generated_ids = generated_ids.squeeze(0) return self._decode_generation(generated_ids) - def stream( # type: ignore - self, - prompts: Union[str, List[str]], - media: Union[Any, List[Any]], # TODO: docstring - generation_parameters: GenerationParameters, - logits_processor: Optional["OutlinesLogitsProcessor"], - sampling_parameters: SamplingParameters, - ) -> Iterator[Union[str, List[str]]]: - raise NotImplementedError - - -def transformers_vision( - model_name: str, - model_class, - device: Optional[str] = None, - model_kwargs: dict = {}, - processor_kwargs: dict = {}, - tokenizer_class=None, - processor_class=None, -): - """Instantiate a model from the `transformers` library and its tokenizer. - - Parameters - ---------- - model_name - The name of the model as listed on Hugging Face's model page. - model_class - The `PreTrainedModel` class from transformers to use in initializing the vision model from `model_name`. - https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel - device - The device(s) on which the model should be loaded. This overrides - the `device_map` entry in `model_kwargs` when provided. - model_kwargs - A dictionary that contains the keyword arguments to pass to the - `from_pretrained` method when loading the model. - processor_kwargs - A dictionary that contains the keyword arguments to pass to the - `from_pretrained` method when loading the processor. - - Returns - ------- - A `TransformersModel` model instance. - - """ - if processor_class is None or tokenizer_class is None: - try: - from transformers import AutoProcessor, AutoTokenizer - except ImportError: - raise ImportError( - "The `transformers` library needs to be installed in order to use `transformers` models." - ) - if processor_class is None: - processor_class = AutoProcessor - if tokenizer_class is None: - tokenizer_class = AutoTokenizer - - if device is not None: - model_kwargs["device_map"] = device - - model = model_class.from_pretrained(model_name, **model_kwargs) - - processor_kwargs.setdefault("padding_side", "left") - processor_kwargs.setdefault("pad_token", "[PAD]") - processor = processor_class.from_pretrained(model_name, **processor_kwargs) - - if tokenizer_class is None: - if getattr(processor, "tokenizer", None): - tokenizer = processor.tokenizer - else: - tokenizer = AutoTokenizer.from_pretrained(model_name, **processor_kwargs) - else: - tokenizer = tokenizer_class.from_pretrained(model_name, **processor_kwargs) - - return TransformersVision(model, tokenizer, processor) + def stream(self, model_input, output_type, **inference_kwargs): + raise NotImplementedError( + "Streaming is not implemented for TransformersVision models." + ) diff --git a/tests/models/test_transformers_vision.py b/tests/models/test_transformers_vision.py new file mode 100644 index 000000000..b96e418d7 --- /dev/null +++ b/tests/models/test_transformers_vision.py @@ -0,0 +1,116 @@ +from enum import Enum + +import numpy as np +import pytest +from PIL import Image +from transformers import Blip2ForConditionalGeneration, CLIPModel, CLIPProcessor + +from outlines.models.transformers_vision import TransformersVision +from outlines.processors import RegexLogitsProcessor +from outlines.types import Choice + +TEST_MODEL = "hf-internal-testing/tiny-random-Blip2Model" +TEST_CLIP_MODEL = "openai/clip-vit-base-patch32" + + +@pytest.fixture +def image(): + return Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)) + + +@pytest.fixture +def model(): + return TransformersVision( + model_name=TEST_MODEL, model_class=Blip2ForConditionalGeneration + ) + + +def test_transformers_vision_instantiate_simple(): + model = TransformersVision( + model_name=TEST_MODEL, + model_class=Blip2ForConditionalGeneration, + ) + assert isinstance(model, TransformersVision) + + +def test_transformers_vision_instantiate_other_processor_class(): + model = TransformersVision( + model_name=TEST_CLIP_MODEL, model_class=CLIPModel, processor_class=CLIPProcessor + ) + assert isinstance(model, TransformersVision) + + +def test_transformers_vision_simple(model, image): + result = model.generate(("Describe this image in one sentence:", image), None) + assert isinstance(result, str) + + +def test_transformers_vision_call(model, image): + result = model(("Describe this image in one sentence:", image)) + assert isinstance(result, str) + + +def test_transformers_vision_wrong_input_type(model): + with pytest.raises(NotImplementedError): + model.generate("invalid input", None) + + +def test_transformers_inference_kwargs(model, image): + result = model(("Describe this image in one sentence:", image), max_new_tokens=100) + assert isinstance(result, str) + + +def test_transformers_invalid_inference_kwargs(model): + with pytest.raises(ValueError): + model(("Describe this image in one sentence:", image), foo="bar") + + +def test_transformers_vision_choice(model, image): + class Foo(Enum): + white = "white" + black = "black" + + regex_str = Choice(Foo).to_regex() + logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) + result = model.generate(("Is this image white or black?", image), logits_processor) + + assert isinstance(result, str) + + +def test_transformers_vision_batch_input_samples(model, image): + result = model(("Describe this image in one sentence.", image)) + assert isinstance(result, str) + result = model( + ("Describe this image in one sentence.", image), + num_return_sequences=2, + num_beams=2, + ) + assert isinstance(result, list) + assert len(result) == 2 + result = model( + ( + [ + "Describe this image in one sentence.", + "Describe this image in one sentence.", + ], + [image, image], + ) + ) + assert isinstance(result, list) + assert len(result) == 2 + result = model( + ( + [ + "Describe this image in one sentence.", + "Describe this image in one sentence.", + ], + [image, image], + ), + num_return_sequences=2, + num_beams=2, + ) + assert isinstance(result, list) + assert len(result) == 2 + for item in result: + assert isinstance(item, list) + assert len(item) == 2