From 49c8669ede8330b5c85dfc1ba265afc8624d3c08 Mon Sep 17 00:00:00 2001 From: Robin Picard Date: Wed, 26 Feb 2025 17:56:52 +0100 Subject: [PATCH] Add support for Transformers audio models --- docs/reference/models/transformers.md | 2 +- .../models/transformers_multimodal.md | 196 ++++++++++++++++++ docs/reference/models/transformers_vision.md | 117 ----------- outlines/generate/cfg.py | 4 +- outlines/generate/fsm.py | 4 +- outlines/generate/regex.py | 4 +- outlines/generate/text.py | 4 +- outlines/models/__init__.py | 2 +- outlines/models/transformers.py | 76 +++---- outlines/models/transformers_vision.py | 0 ...ion.py => test_transformers_multimodal.py} | 36 ++-- 11 files changed, 256 insertions(+), 189 deletions(-) create mode 100644 docs/reference/models/transformers_multimodal.md delete mode 100644 docs/reference/models/transformers_vision.md delete mode 100644 outlines/models/transformers_vision.py rename tests/models/{test_transformers_vision.py => test_transformers_multimodal.py} (81%) diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md index 724a724c3..c3b397297 100644 --- a/docs/reference/models/transformers.md +++ b/docs/reference/models/transformers.md @@ -21,7 +21,7 @@ model = models.from_transformers( ) ``` -The model name must be a valid `transformers` model name. You can find a list of all in the HuggingFace library [here](https://huggingface.co/models). We currently support `CausalLM`, `Seq2Seq`, `Mamba` and vision models. +The model name must be a valid `transformers` model name. You can find a list of all in the HuggingFace library [here](https://huggingface.co/models). We currently support `CausalLM`, `Seq2Seq`, `Mamba`, vision and audio models. Be cautious with model selection though, some models such as `t5-base` don't include certain characters (`{`) and you may get an error when trying to perform structured generation. diff --git a/docs/reference/models/transformers_multimodal.md b/docs/reference/models/transformers_multimodal.md new file mode 100644 index 000000000..cea644204 --- /dev/null +++ b/docs/reference/models/transformers_multimodal.md @@ -0,0 +1,196 @@ +# Transformers Multimodal + +Outlines allows seamless use of [multimodal models](https://huggingface.co/learn/multimodal-models-course/en/unit1/introduction). + +`outlines.models.TransformersMultimodal` shares interfaces with, and is based on [outlines.models.Transformers](./transformers.md). + +## Create a `TransformersMultimodal` model + +`TransformersMultimodal` models inherit from `Transformers` and accept similar initialization parameters, with the exception that you should provide a processor instead of a tokenizer. + +You can use `outlines.from_transformers` to load a `transformers` model and processor: + +```python +from transformers import CLIPModel, CLIPProcessor +from outlines import models + +model_name = "openai/clip-vit-base-patch32" +model = models.from_transformers( + CLIPModel.from_pretrained(model_name), + CLIPProcessor.from_pretrained(model_name) +) +``` + +You must provider a processor adapted to your model. We currently only support vision and audio models. + +## Use the model to generate text from prompts and images/audios + +When calling the model, the prompt argument you provide must be a dictionary using the following format: +``` +{ + "text": Union[str, List[str]], + "images": Optional[Union[Any, List[Any]]], + "audios": Optional[Union[Any, List[Any]]], +} +``` + +The `text` key is required. Either the `images` or `audios` key must be provided. +The format and values of those keys must correspond to what you would provide to the `__call__` method of the processor if you were to call it directly. Thus, the text prompt must include the appropriate tags to indicate where the images or audios should be inserted (e.g. `` or `<|AUDIO|>`). + +### Vision models + +For easier use, we recommend you to create a convenience function to load a `PIL.Image` from URL. +```python +from PIL import Image +from io import BytesIO +from urllib.request import urlopen + +def img_from_url(url): + img_byte_stream = BytesIO(urlopen(url).read()) + return Image.open(img_byte_stream).convert("RGB") +``` + +You can then call the model with your prompts and images to generate text. +```python +from transformers import LlavaForConditionalGeneration, LlavaProcessor +from outlines import models + +MODEL_NAME = "trl-internal-testing/tiny-LlavaForConditionalGeneration" + +model = models.from_transformers( + LlavaForConditionalGeneration.from_pretrained(MODEL_NAME), + LlavaProcessor.from_pretrained(MODEL_NAME) +) +prompt = { + "text": " detailed description:", + "images": img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg") +} +model(prompt) +``` + +You can include several images per prompt by adding more `` tags to the prompt. Batching is also supported. +```python +from transformers import LlavaForConditionalGeneration +from outlines import models + +MODEL_NAME = "trl-internal-testing/tiny-LlavaForConditionalGeneration" + +model = models.from_transformers( + LlavaForConditionalGeneration.from_pretrained(MODEL_NAME), + LlavaProcessor.from_pretrained(MODEL_NAME) +) +prompt = { + "text": ["detailed description:", ". What animals are present?"], + "images": [ + img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"), + img_from_url("https://upload.wikimedia.org/wikipedia/commons/7/71/2010-kodiak-bear-1.jpg"), + img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"), + img_from_url("https://upload.wikimedia.org/wikipedia/commons/7/71/2010-kodiak-bear-1.jpg"), + ] +} +model(prompt) +``` + +Here we have two prompts, each expecting two images. We correspondingly provide four images. This will generate two descriptions, one for each prompt. + +### Audio models + +For easier use, we recommend you to create a convenience function to turn a `.wav` file from a URL into an audio array. +```python +import soundfile +from io import BytesIO +from urllib.request import urlopen + +def audio_from_url(url): + audio_data = BytesIO(urlopen(url).read()) + audio_array, _ = soundfile.read(audio_data) + return audio_array +``` + +You can then call the model with your prompt and audio files to generate text. +```python +from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor +from outlines import models + +MODEL_NAME = "Qwen/Qwen2-Audio-7B-Instruct" +URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav" + +tf_model = Qwen2AudioForConditionalGeneration.from_pretrained(MODEL_NAME) +tf_processor = AutoProcessor.from_pretrained(MODEL_NAME) + +model = models.TransformersMultiModal(tf_model, tf_processor) + +text = """<|im_start|>user +Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|> +<|im_end|> +<|im_start|>user +Write a sentence about what you hear in the audio.<|im_end|> +<|im_start|>assistant +""" +prompt = { + "text": text, + "audios": audio_from_url(URL) +} +model(prompt) +``` + +You can include multiple audio files by providing a list for the `audios` argument and adding more of the appropriate tags to the prompt. Batching is also supported. + +## Use the model for structured generation + +You can use the model to generate structured data by providing a value for the parameter `output_type` (the second positional argument of the `generate` method, right after the prompt). + +Supported types include `Json`, `Choice`, `Regex` and `CFG`. + +For instance to do classification, you can use the `Regex` type: +```python +from outlines import models +from outlines.types import Regex +from transformers import LlavaForConditionalGeneration, LlavaProcessor + +MODEL_NAME = "trl-internal-testing/tiny-LlavaForConditionalGeneration" + +model = models.from_transformers( + LlavaForConditionalGeneration.from_pretrained(MODEL_NAME), + LlavaProcessor.from_pretrained(MODEL_NAME) +) +pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto" +prompt = { + "text": "detailed description:", + "images": img_from_url("https://upload.wikimedia.org/wikipedia/commons/e/e3/Saturn_from_Cassini_Orbiter_%282004-10-06%29.jpg"), +} +model(prompt, Regex(pattern)) +``` + +Another example could be to generated a structured description of an image using the `Json` type: +```python +from outlines import models +from pydantic import BaseModel +from transformers import LlavaForConditionalGeneration, LlavaProcessor +from typing import List, Optional + +MODEL_NAME = "trl-internal-testing/tiny-LlavaForConditionalGeneration" + +class ImageData(BaseModel): + caption: str + tags_list: List[str] + object_list: List[str] + is_photo: bool + +model = models.from_transformers( + LlavaForConditionalGeneration.from_pretrained(MODEL_NAME), + LlavaProcessor.from_pretrained(MODEL_NAME) +) +pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto" +prompt = { + "text": "detailed description:", + "images": img_from_url("https://upload.wikimedia.org/wikipedia/commons/e/e3/Saturn_from_Cassini_Orbiter_%282004-10-06%29.jpg"), +} +model(prompt, Json(ImageData)) +``` + +## Resources + +### Choosing a model +- https://mmbench.opencompass.org.cn/leaderboard +- https://huggingface.co/spaces/WildVision/vision-arena diff --git a/docs/reference/models/transformers_vision.md b/docs/reference/models/transformers_vision.md deleted file mode 100644 index 03ba6a5c9..000000000 --- a/docs/reference/models/transformers_vision.md +++ /dev/null @@ -1,117 +0,0 @@ -# Transformers Vision - -Outlines allows seamless use of [vision models](https://huggingface.co/learn/computer-vision-course/en/unit4/multimodal-models/tasks-models-part1). - -`outlines.models.Transformers_vision` shares interfaces with, and is based on [outlines.models.Transformers](./transformers.md). - -## Create a `TransformersVision` model - -`TransformersVision` models inherit from `Transformers` and accept the same initialization parameters. - -In addition, they also accept the optional parameters `processor_class` and `processor_kwargs`. Those are used to create a `Processor` instance that is then used to preprocess the images. By default, `AutoProcessor` is used to create the processor as such: `AutoProcessor.from_pretrained(model_name, **processor_kwargs)`. - -If your model is not compatible with `AutoProcessor`, you must provide a value for the `processor_class` parameter. -For instance: -```python -from outlines import models -from transformers import CLIPModel, CLIPProcessor - -model = models.TransformersVision("openai/clip-vit-base-patch32", model_class=CLIPModel, processor_class=CLIPProcessor) -``` - -## Use the model to generate text from prompts and images - -When calling the model, the prompt argument you provide must be a dictionary with a key `"prompts"` and a key `"images"`. The associated values must be a string or a list of strings for the prompts, and a PIL image or a list of PIL images for the images. Your prompts must include `` tags tokens to indicate where the image should be inserted. There must be as many `` tags as there are images. - -For easier use, we recommend you to create a convenience function to load a `PIL.Image` from URL. -```python -from PIL import Image -from io import BytesIO -from urllib.request import urlopen - -def img_from_url(url): - img_byte_stream = BytesIO(urlopen(url).read()) - return Image.open(img_byte_stream).convert("RGB") -``` - -You can then call the model with your prompts and images to generate text. -```python -from transformers import LlavaForConditionalGeneration -from outlines import models - -model = models.TransformersVision("trl-internal-testing/tiny-LlavaForConditionalGeneration", model_class=LlavaForConditionalGeneration) -prompt = { - "prompts": " detailed description:", - "images": img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg") -} -model(prompt) -``` - -You can include several images per prompt by adding more `` tags to the prompt. Batching is also supported. -```python -from transformers import LlavaForConditionalGeneration -from outlines import models - -model = models.TransformersVision("trl-internal-testing/tiny-LlavaForConditionalGeneration", model_class=LlavaForConditionalGeneration) -prompt = { - "prompts": ["detailed description:", ". What animals are present?"], - "images": [ - img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"), - img_from_url("https://upload.wikimedia.org/wikipedia/commons/7/71/2010-kodiak-bear-1.jpg"), - img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"), - img_from_url("https://upload.wikimedia.org/wikipedia/commons/7/71/2010-kodiak-bear-1.jpg"), - ] -} -model(prompt) -``` - -Here we have two prompts, each expecting two images. We correspondingly provide four images. This will generate two descriptions, one for each prompt. - -### Use the model for structured generation - -You can use the model to generate structured data by providing a value for the parameter `output_type` (the second positional argument of the `generate` method, right after the prompt). - -Supported types include `Json`, `Choice`, `Regex` and `CFG`. - -For instance to do classification, you can use the `Regex` type: -```python -from outlines import models -from outlines.types import Regex -from transformers import LlavaForConditionalGeneration - -model = models.TransformersVision("trl-internal-testing/tiny-LlavaForConditionalGeneration", model_class=LlavaForConditionalGeneration) -pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto" -prompt = { - "prompts": "detailed description:", - "images": img_from_url("https://upload.wikimedia.org/wikipedia/commons/e/e3/Saturn_from_Cassini_Orbiter_%282004-10-06%29.jpg"), -} -model(prompt, Regex(pattern)) -``` - -Another example could be to generated a structured description of an image using the `Json` type: -```python -from outlines import models -from pydantic import BaseModel -from transformers import LlavaForConditionalGeneration -from typing import List, Optional - -class ImageData(BaseModel): - caption: str - tags_list: List[str] - object_list: List[str] - is_photo: bool - -model = models.TransformersVision("trl-internal-testing/tiny-LlavaForConditionalGeneration", model_class=LlavaForConditionalGeneration) -pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto" -prompt = { - "prompts": "detailed description:", - "images": img_from_url("https://upload.wikimedia.org/wikipedia/commons/e/e3/Saturn_from_Cassini_Orbiter_%282004-10-06%29.jpg"), -} -model(prompt, Json(ImageData)) -``` - -## Resources - -### Choosing a model -- https://mmbench.opencompass.org.cn/leaderboard -- https://huggingface.co/spaces/WildVision/vision-arena diff --git a/outlines/generate/cfg.py b/outlines/generate/cfg.py index b677040d5..0ffbda6a3 100644 --- a/outlines/generate/cfg.py +++ b/outlines/generate/cfg.py @@ -4,7 +4,7 @@ SequenceGeneratorAdapter, VisionSequenceGeneratorAdapter, ) -from outlines.models import LlamaCpp, OpenAI, TransformersVision +from outlines.models import LlamaCpp, OpenAI, TransformersMultiModal from outlines.samplers import Sampler, multinomial @@ -33,7 +33,7 @@ def cfg( return SequenceGeneratorAdapter(model, logits_processor, sampler) -@cfg.register(TransformersVision) +@cfg.register(TransformersMultiModal) def cfg_vision(model, cfg_str: str, sampler: Sampler = multinomial()): from outlines.processors import CFGLogitsProcessor diff --git a/outlines/generate/fsm.py b/outlines/generate/fsm.py index 1950812d2..81f974bd2 100644 --- a/outlines/generate/fsm.py +++ b/outlines/generate/fsm.py @@ -7,7 +7,7 @@ SequenceGeneratorAdapter, VisionSequenceGeneratorAdapter, ) -from outlines.models import TransformersVision +from outlines.models import TransformersMultiModal from outlines.samplers import Sampler, multinomial @@ -22,7 +22,7 @@ def fsm( return SequenceGeneratorAdapter(model, logits_processor, sampler) -@fsm.register(TransformersVision) +@fsm.register(TransformersMultiModal) def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()): from outlines.processors import GuideLogitsProcessor diff --git a/outlines/generate/regex.py b/outlines/generate/regex.py index 326701f4b..477a7bf5e 100644 --- a/outlines/generate/regex.py +++ b/outlines/generate/regex.py @@ -4,7 +4,7 @@ SequenceGeneratorAdapter, VisionSequenceGeneratorAdapter, ) -from outlines.models import OpenAI, TransformersVision +from outlines.models import OpenAI, TransformersMultiModal from outlines.samplers import Sampler, multinomial from outlines.types import Regex @@ -39,7 +39,7 @@ def regex(model, regex_str: str | Regex, sampler: Sampler = multinomial()): return SequenceGeneratorAdapter(model, logits_processor, sampler) -@regex.register(TransformersVision) +@regex.register(TransformersMultiModal) def regex_vision( model, regex_str: str | Regex, diff --git a/outlines/generate/text.py b/outlines/generate/text.py index 32530d0c4..e2e4cbe5d 100644 --- a/outlines/generate/text.py +++ b/outlines/generate/text.py @@ -4,7 +4,7 @@ SequenceGeneratorAdapter, VisionSequenceGeneratorAdapter, ) -from outlines.models import OpenAI, TransformersVision +from outlines.models import OpenAI, TransformersMultiModal from outlines.samplers import Sampler, multinomial @@ -34,7 +34,7 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGeneratorAdapter: return SequenceGeneratorAdapter(model, None, sampler) -@text.register(TransformersVision) +@text.register(TransformersMultiModal) def text_vision(model, sampler: Sampler = multinomial()): return VisionSequenceGeneratorAdapter(model, None, sampler) diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index 6a9ba753f..fb982d3b2 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -20,7 +20,7 @@ from .transformers import ( Transformers, TransformerTokenizer, - TransformersVision, + TransformersMultiModal, from_transformers, ) from .vllm import VLLM, from_vllm diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index b65d0cefb..ceaf39b7d 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: import torch - from transformers import PreTrainedTokenizer, PreTrainedModel + from transformers import PreTrainedTokenizer, PreTrainedModel, ProcessorMixin __all__ = ["Transformers", "from_transformers"] @@ -256,12 +256,12 @@ def _decode_generation(self, generated_ids: "torch.Tensor"): ) -class TransformersVisionTypeAdapter(ModelTypeAdapter): - """Type adapter for TransformersVision models.""" +class TransformersMultiModalTypeAdapter(ModelTypeAdapter): + """Type adapter for TransformersMultiModal models.""" @singledispatchmethod def format_input(self, model_input): - """Generate the prompt argument to pass to the model. + """Generate the prompt arguments to pass to the model. Argument -------- @@ -272,18 +272,27 @@ def format_input(self, model_input): raise NotImplementedError( f"The input type {input} is not available. Please provide a " "dictionary with the following format: " - "{'prompts': Union[str, List[str]], 'images': Union[Any, List[Any]]}" - "Make sure the number of image tags in the prompts is equal to the " - "number of images provided." + "{'text': Union[str, List[str]], 'images': Optional[Union[Any, " + "List[Any]]], 'audios': Optional[Union[Any, List[Any]]]}" + "Either 'images' or 'audios' must be provided." + "Make sure the text is formatted correctly for the model " + "(e.g. include or <|AUDIO|> tags)." ) @format_input.register(dict) def format_list_input(self, model_input): - if "prompts" not in model_input or "images" not in model_input: + if ( + "text" not in model_input + or ( + "images" not in model_input + and "audios" not in model_input + ) + ): raise ValueError( - "The input must contain the following keys: 'prompts' and 'images'." + "The input must contain the 'text' key along with either " + "'images' or 'audios'." ) - return model_input["prompts"], model_input["images"] + return model_input def format_output_type(self, output_type): """Generate the logits processor argument to pass to the model. @@ -301,11 +310,11 @@ def format_output_type(self, output_type): return None -class TransformersVision(Transformers): +class TransformersMultiModal(Transformers): """Represents a `transformers` model with a vision processor.""" def __init__(self, model: "PreTrainedModel", processor): - """Create a TransformersVision model instance + """Create a TransformersMultiModal model instance We rely on the `__init__` method of the `Transformers` class to handle most of the initialization and then add elements specific to vision @@ -320,47 +329,25 @@ def __init__(self, model: "PreTrainedModel", processor): super().__init__(model, tokenizer) - self.type_adapter = TransformersVisionTypeAdapter() + self.type_adapter = TransformersMultiModalTypeAdapter() def _prepare_model_inputs(self, model_input, output_type): - prompts, images = self.type_adapter.format_input(model_input) + model_input = self.type_adapter.format_input(model_input) inputs = self.processor( - text=prompts, images=images, padding=True, return_tensors="pt" + **model_input, padding=True, return_tensors="pt" ).to(self.model.device) - return prompts, inputs + return model_input["text"], inputs def from_transformers( model: "PreTrainedModel", - tokenizer_or_processor: "PreTrainedTokenizer", + tokenizer_or_processor: Union["PreTrainedTokenizer", "ProcessorMixin"], ): from transformers import ( PreTrainedTokenizer, PreTrainedTokenizerFast, - Blip2Processor, - LlavaProcessor, - IdeficsProcessor, - CLIPProcessor, - Qwen2_5_VLProcessor, - Qwen2VLProcessor, - NougatProcessor, - LlavaNextProcessor, - PixtralProcessor, - PaliGemmaProcessor, - ) - - vision_processors = ( - Blip2Processor, - LlavaProcessor, - IdeficsProcessor, - CLIPProcessor, - Qwen2_5_VLProcessor, - Qwen2VLProcessor, - NougatProcessor, - LlavaNextProcessor, - PixtralProcessor, - PaliGemmaProcessor, + ProcessorMixin ) if isinstance( @@ -368,13 +355,12 @@ def from_transformers( ): tokenizer = tokenizer_or_processor return Transformers(model, tokenizer) - elif isinstance(tokenizer_or_processor, vision_processors): + elif isinstance(tokenizer_or_processor, ProcessorMixin): processor = tokenizer_or_processor - return TransformersVision(model, processor) + return TransformersMultiModal(model, processor) else: raise ValueError( "We could determine whether the model passed to `from_transformers`" - + " is a text-2-text of vision language model. If you passed a" - + " vision language model please open an issue pasting the error" - + " with the name of the vision language model." + + " is a text-2-text or a multi-modal model. Please provide a " + + "a transformers tokenizer or processor." ) diff --git a/outlines/models/transformers_vision.py b/outlines/models/transformers_vision.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/models/test_transformers_vision.py b/tests/models/test_transformers_multimodal.py similarity index 81% rename from tests/models/test_transformers_vision.py rename to tests/models/test_transformers_multimodal.py index 57dedcb69..9292bdf6a 100644 --- a/tests/models/test_transformers_vision.py +++ b/tests/models/test_transformers_multimodal.py @@ -1,3 +1,5 @@ +# we only test vision models here as audio models are too heavy to run on CI + import re from enum import Enum from io import BytesIO @@ -15,7 +17,7 @@ ) import outlines -from outlines.models.transformers import TransformersVision +from outlines.models.transformers import TransformersMultiModal from outlines.types import Choice, JsonType, Regex TEST_MODEL = "trl-internal-testing/tiny-LlavaForConditionalGeneration" @@ -48,12 +50,12 @@ def test_transformers_vision_instantiate_simple(): Blip2ForConditionalGeneration.from_pretrained(TEST_MODEL), AutoProcessor.from_pretrained(TEST_MODEL), ) - assert isinstance(model, TransformersVision) + assert isinstance(model, TransformersMultiModal) def test_transformers_vision_simple(model, images): result = model.generate( - {"prompts": "Describe this image in one sentence:", "images": images[0]}, + {"text": "Describe this image in one sentence:", "images": images[0]}, None, ) assert isinstance(result, str) @@ -61,7 +63,7 @@ def test_transformers_vision_simple(model, images): def test_transformers_vision_call(model, images): result = model( - {"prompts": "Describe this image in one sentence:", "images": images[0]} + {"text": "Describe this image in one sentence:", "images": images[0]} ) assert isinstance(result, str) @@ -70,7 +72,7 @@ def test_transformers_vision_wrong_number_images(model, images): with pytest.raises(ValueError): a = model( { - "prompts": "Describe this image in one sentence:", + "text": "Describe this image in one sentence:", "images": [images[0], images[1]], } ) @@ -84,7 +86,7 @@ def test_transformers_vision_wrong_input_type(model): def test_transformers_inference_kwargs(model, images): result = model( - {"prompts": "Describe this image in one sentence:", "images": images[0]}, + {"text": "Describe this image in one sentence:", "images": images[0]}, max_new_tokens=100, ) assert isinstance(result, str) @@ -94,7 +96,7 @@ def test_transformers_invalid_inference_kwargs(model, images): with pytest.raises(ValueError): model( { - "prompts": "Describe this image in one sentence:", + "text": "Describe this image in one sentence:", "images": images[0], }, foo="bar", @@ -104,7 +106,7 @@ def test_transformers_invalid_inference_kwargs(model, images): def test_transformers_several_images(model, images): result = model( { - "prompts": "Describe this image in one sentence:", + "text": "Describe this image in one sentence:", "images": [images[0], images[1]], } ) @@ -116,7 +118,7 @@ class Foo(BaseModel): name: str result = model( - {"prompts": "Give a name to this animal.", "images": images[0]}, + {"text": "Give a name to this animal.", "images": images[0]}, JsonType(Foo), ) assert "name" in result @@ -124,7 +126,7 @@ class Foo(BaseModel): def test_transformers_vision_regex(model, images): result = model( - {"prompts": "How old is it?", "images": images[0]}, Regex(r"[0-9]") + {"text": "How old is it?", "images": images[0]}, Regex(r"[0-9]") ) assert isinstance(result, str) @@ -137,7 +139,7 @@ class Foo(Enum): dog = "dog" result = model( - {"prompts": "Is it a cat or a dog?", "images": images[0]}, Choice(Foo) + {"text": "Is it a cat or a dog?", "images": images[0]}, Choice(Foo) ) assert isinstance(result, str) @@ -146,7 +148,7 @@ class Foo(Enum): def test_transformers_vision_batch_samples(model, images): result = model( - {"prompts": "Describe this image in one sentence.", "images": images[0]}, + {"text": "Describe this image in one sentence.", "images": images[0]}, num_return_sequences=2, num_beams=2, ) @@ -154,7 +156,7 @@ def test_transformers_vision_batch_samples(model, images): assert len(result) == 2 result = model( { - "prompts": [ + "text": [ "Describe this image in one sentence.", "Describe this image in one sentence.", ], @@ -165,7 +167,7 @@ def test_transformers_vision_batch_samples(model, images): assert len(result) == 2 result = model( { - "prompts": [ + "text": [ "Describe this image in one sentence.", "Describe this image in one sentence.", ], @@ -187,7 +189,7 @@ class Foo(Enum): dog = "dog" result = model( - {"prompts": "Describe this image in one sentence.", "images": images[0]}, + {"text": "Describe this image in one sentence.", "images": images[0]}, Choice(Foo), num_return_sequences=2, num_beams=2, @@ -198,7 +200,7 @@ class Foo(Enum): assert item in ["cat", "dog"] result = model( { - "prompts": [ + "text": [ "Describe this image in one sentence.", "Describe this image in one sentence.", ], @@ -212,7 +214,7 @@ class Foo(Enum): assert item in ["cat", "dog"] result = model( { - "prompts": [ + "text": [ "Describe this image in one sentence.", "Describe this image in one sentence.", ],