diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md index 3649b9b13..f24591436 100644 --- a/docs/reference/models/transformers.md +++ b/docs/reference/models/transformers.md @@ -1,96 +1,63 @@ -# transformers +# Transformers !!! Installation - You need to install the `transformer`, `datasets` and `torch` libraries to be able to use these models in Outlines, or alternatively: + You need to install the `transformer` library to be able to use these models in Outlines, or alternatively: ```bash pip install "outlines[transformers]" ``` +## Create a `Transformers` model -Outlines provides an integration with the `torch` implementation of causal models in the [transformers][transformers] library. You can initialize the model by passing its name: - +The only mandatory argument to instantiate a `Transformers` model is the name of the model to use. ```python from outlines import models -model = models.transformers("microsoft/Phi-3-mini-4k-instruct", device="cuda") +model = models.Transformers("microsoft/Phi-3-mini-4k-instruct") ``` -If you need more fine-grained control you can also initialize the model and tokenizer separately: +The model name must be a valid `transformers` model name. You can find a list of all in the HuggingFace library [here](https://huggingface.co/models). + +When instantiating a `Transformers` model as such, the class creates a model from the transformers libray using the class `AutoModelForCausalLM` by default (`transformers.AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)`). +You can also provide keyword arguments in an optional `model_kwargs` parameter. Those will be passed to the `from_pretrained` method of the model class. One such argument is `device_map`, which allows you to specify the device on which the model will be loaded. +For instance: ```python -from transformers import AutoModelForCausalLM, AutoTokenizer from outlines import models -llm = AutoModelForCausalLM.from_pretrained("gpt2", output_attentions=True) -tokenizer = AutoTokenizer.from_pretrained("gpt2") -model = models.Transformers(llm, tokenizer) +model = models.Transformers("microsoft/Phi-3-mini-4k-instruct", model_kwargs={"device_map": "cuda"}) ``` -## Using Logits Processors - -There are two ways to use Outlines Structured Generation with HuggingFace Transformers: +## Alternative model classes -1. Use Outlines generation wrapper, `outlines.models.transformers` -2. Use `OutlinesLogitsProcessor` with `transformers.AutoModelForCausalLM` - -Outlines supports a myriad of logits processors for structured generation. In these example, we will use the `RegexLogitsProcessor` which guarantees generated text matches the specified pattern. - -### Using `outlines.models.transformers` +If the model you want to use is not compatible with `AutoModelForCausalLM`, you must provide a value for the `model_class` parameter. This value must be a valid `transformers` model class. +For instance: ```python -import outlines - -time_regex_pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?" - -model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct", device="cuda") -generator = outlines.generate.regex(model, time_regex_pattern) +from outlines import models +from transformers import AutoModelForSeq2SeqLM -output = generator("The the best time to visit a dentist is at ") -print(output) -# 2:30 pm +model = models.Transformers("facebook/bart-large", model_class=AutoModelForSeq2SeqLM) ``` -### Using models initialized via the `transformers` library - -```python -import outlines -import transformers - +When you instantiate a `Transformers` model, the class also creates a `Tokenizer` instance from the `AutoTokenizer` class. You can provide keyword arguments in an optional `tokenizer_kwargs` parameter. Those will be passed to the `from_pretrained` method of the tokenizer class as such: `tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs)`. -model_uri = "microsoft/Phi-3-mini-4k-instruct" +Similarly, if your model is not compatible with `AutoTokenizer`, you must provide a value for the `tokenizer_class` parameter. -outlines_tokenizer = outlines.models.TransformerTokenizer( - transformers.AutoTokenizer.from_pretrained(model_uri) -) -phone_number_logits_processor = outlines.processors.RegexLogitsProcessor( - "\\+?[1-9][0-9]{7,14}", # phone number pattern - outlines_tokenizer, -) - -generator = transformers.pipeline('text-generation', model=model_uri) +```python +from outlines import models +from transformers import T5ForConditionalGeneration, T5Tokenizer -output = generator( - "Jenny gave me her number it's ", - logits_processor=transformers.LogitsProcessorList([phone_number_logits_processor]) +model_pile_t5 = models.Transformers( + model_name="EleutherAI/pile-t5-large", + model_class=T5ForConditionalGeneration, + tokenizer_class=T5Tokenizer ) -print(output) -# [{'generated_text': "Jenny gave me her number it's 2125550182"}] -# not quite 8675309 what we expected, but it is a valid phone number ``` -[transformers]: https://github.com/huggingface/transformers - - -## Alternative Model Classes - -`outlines.models.transformers` defaults to `transformers.AutoModelForCausalLM`, which is the appropriate class for most standard large language models, including Llama 3, Mistral, Phi-3, etc. - -However other variants with unique behavior can be used as well by passing the appropriate class. - ### Mamba [Mamba](https://github.com/state-spaces/mamba) is a transformers alternative which employs memory efficient, linear-time decoding. @@ -100,25 +67,14 @@ To use Mamba with outlines you must first install the necessary requirements: pip install causal-conv1d>=1.2.0 mamba-ssm torch transformers ``` -Then you can either create an Mamba-2 Outlines model via -```python -import outlines - -model = outlines.models.mamba("state-spaces/mamba-2.8b-hf") -``` - -or explicitly with +Then you can create an `Mamba` Outlines model via: ```python -import outlines -from transformers import MambaForCausalLM +from outlines import models -model = outlines.models.transformers( - "state-spaces/mamba-2.8b-hf", - model_class=MambaForCausalLM -) +model = models.Mamba("state-spaces/mamba-2.8b-hf", model_kwargs={"device_map": "cuda"}, tokenizer_kwargs={"padding_side": "left"}) ``` - +Alternatively, you can use the `Transformers` class to create an `Mamba` model by providing the appropriate `model_class` and `tokenizer_class` arguments. Read [`transformers`'s documentation](https://huggingface.co/docs/transformers/en/model_doc/mamba) for more information. @@ -128,21 +84,43 @@ You can use encoder-decoder (seq2seq) models like T5 and BART with Outlines. Be cautious with model selection though, some models such as `t5-base` don't include certain characters (`{`) and you may get an error when trying to perform structured generation. -T5 Example: +## Use the model to generate text + +Once you have created a `Transformers` model, you can use it to generate text by calling the instance of the model. ```python -import outlines -from transformers import AutoModelForSeq2SeqLM +model("Hello, how are you?") +``` -model_pile_t5 = outlines.models.transformers( - model_name="EleutherAI/pile-t5-large", - model_class=AutoModelForSeq2SeqLM, -) +You can also first create a `Generator` and then call it. +```python +from outlines import Generator + +generator = Generator(model) +generator("Hello, how are you?") ``` -Bart Example: +`Transformers` models typically support batching and the generation of several samples at once. + +For instance: ```python -model_bart = outlines.models.transformers( - model_name="facebook/bart-large", - model_class=AutoModelForSeq2SeqLM, -) +model(["Hello, how are you?", "Respond with one word. Not more."], num_return_sequences=2, num_beams=2) +``` + +This would generate two sequences for each prompt, for a total of four sequences (two lists of 2 elements each in a list). + +## Use the model to generate structured data + +`Transformers` models can generate structured data by providing a value for the parameter `output_type` (the second positional argument of the `generate` method, right after the prompt). + +Supported types include `Json`, `Choice`, `Regex` and `CFG`. + +For instance: +```python +from outlines.types import Json +from pydantic import BaseModel + +class Character(BaseModel): + name: str + +model("Create a character with a name.", Json(Character)) ``` diff --git a/docs/reference/models/transformers_vision.md b/docs/reference/models/transformers_vision.md index 451e54a8f..03ba6a5c9 100644 --- a/docs/reference/models/transformers_vision.md +++ b/docs/reference/models/transformers_vision.md @@ -2,33 +2,28 @@ Outlines allows seamless use of [vision models](https://huggingface.co/learn/computer-vision-course/en/unit4/multimodal-models/tasks-models-part1). -`outlines.models.transformers_vision` shares interfaces with, and is based on [outlines.models.transformers](./transformers.md). +`outlines.models.Transformers_vision` shares interfaces with, and is based on [outlines.models.Transformers](./transformers.md). -Tasks supported include +## Create a `TransformersVision` model -- image + text -> text -- video + text -> text +`TransformersVision` models inherit from `Transformers` and accept the same initialization parameters. +In addition, they also accept the optional parameters `processor_class` and `processor_kwargs`. Those are used to create a `Processor` instance that is then used to preprocess the images. By default, `AutoProcessor` is used to create the processor as such: `AutoProcessor.from_pretrained(model_name, **processor_kwargs)`. +If your model is not compatible with `AutoProcessor`, you must provide a value for the `processor_class` parameter. +For instance: +```python +from outlines import models +from transformers import CLIPModel, CLIPProcessor -## Example: Using [Llava-Next](https://huggingface.co/docs/transformers/en/model_doc/llava_next) Vision Models +model = models.TransformersVision("openai/clip-vit-base-patch32", model_class=CLIPModel, processor_class=CLIPProcessor) +``` -Install dependencies -`pip install torchvision pillow flash-attn` +## Use the model to generate text from prompts and images -Create the model -```python -import outlines -from transformers import LlavaNextForConditionalGeneration - -model = outlines.models.transformers_vision( - "llava-hf/llava-v1.6-mistral-7b-hf", - model_class=LlavaNextForConditionalGeneration, - device="cuda", -) -``` +When calling the model, the prompt argument you provide must be a dictionary with a key `"prompts"` and a key `"images"`. The associated values must be a string or a list of strings for the prompts, and a PIL image or a list of PIL images for the images. Your prompts must include `` tags tokens to indicate where the image should be inserted. There must be as many `` tags as there are images. -Create convenience function to load a `PIL.Image` from URL +For easier use, we recommend you to create a convenience function to load a `PIL.Image` from URL. ```python from PIL import Image from io import BytesIO @@ -39,56 +34,65 @@ def img_from_url(url): return Image.open(img_byte_stream).convert("RGB") ``` -### Describing an image - +You can then call the model with your prompts and images to generate text. ```python -description_generator = outlines.generate.text(model) -description_generator( - " detailed description:", - [img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg")] -) +from transformers import LlavaForConditionalGeneration +from outlines import models + +model = models.TransformersVision("trl-internal-testing/tiny-LlavaForConditionalGeneration", model_class=LlavaForConditionalGeneration) +prompt = { + "prompts": " detailed description:", + "images": img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg") +} +model(prompt) ``` -> This is a color photograph featuring a Siamese cat with striking blue eyes. The cat has a creamy coat and a light eye color, which is typical for the Siamese breed. Its features include elongated ears, a long, thin tail, and a striking coat pattern. The cat is sitting in an indoor setting, possibly on a cat tower or a similar raised platform, which is covered with a beige fabric, providing a comfortable and soft surface for the cat to rest or perch. The surface of the wall behind the cat appears to be a light-colored stucco or plaster. - -#### Multiple Images - -To include multiple images in your prompt you simply add more `` tokens to the prompt - +You can include several images per prompt by adding more `` tags to the prompt. Batching is also supported. ```python -image_urls = [ - "https://cdn1.byjus.com/wp-content/uploads/2020/08/ShapeArtboard-1-copy-3.png", # triangle - "https://cdn1.byjus.com/wp-content/uploads/2020/08/ShapeArtboard-1-copy-11.png", # hexagon -] -description_generator = outlines.generate.text(model) -description_generator( - "What shapes are present?", - list(map(img_from_url, image_urls)), -) +from transformers import LlavaForConditionalGeneration +from outlines import models + +model = models.TransformersVision("trl-internal-testing/tiny-LlavaForConditionalGeneration", model_class=LlavaForConditionalGeneration) +prompt = { + "prompts": ["detailed description:", ". What animals are present?"], + "images": [ + img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"), + img_from_url("https://upload.wikimedia.org/wikipedia/commons/7/71/2010-kodiak-bear-1.jpg"), + img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"), + img_from_url("https://upload.wikimedia.org/wikipedia/commons/7/71/2010-kodiak-bear-1.jpg"), + ] +} +model(prompt) ``` -> There are two shapes present. One shape is a hexagon and the other shape is an triangle. ' +Here we have two prompts, each expecting two images. We correspondingly provide four images. This will generate two descriptions, one for each prompt. + +### Use the model for structured generation +You can use the model to generate structured data by providing a value for the parameter `output_type` (the second positional argument of the `generate` method, right after the prompt). -### Classifying an Image +Supported types include `Json`, `Choice`, `Regex` and `CFG`. +For instance to do classification, you can use the `Regex` type: ```python -pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto" -planet_generator = outlines.generate.regex(model, pattern) +from outlines import models +from outlines.types import Regex +from transformers import LlavaForConditionalGeneration -planet_generator( - "What planet is this: ", - [img_from_url("https://upload.wikimedia.org/wikipedia/commons/e/e3/Saturn_from_Cassini_Orbiter_%282004-10-06%29.jpg")] -) +model = models.TransformersVision("trl-internal-testing/tiny-LlavaForConditionalGeneration", model_class=LlavaForConditionalGeneration) +pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto" +prompt = { + "prompts": "detailed description:", + "images": img_from_url("https://upload.wikimedia.org/wikipedia/commons/e/e3/Saturn_from_Cassini_Orbiter_%282004-10-06%29.jpg"), +} +model(prompt, Regex(pattern)) ``` -> Saturn - - -### Extracting Structured Image data - +Another example could be to generated a structured description of an image using the `Json` type: ```python +from outlines import models from pydantic import BaseModel +from transformers import LlavaForConditionalGeneration from typing import List, Optional class ImageData(BaseModel): @@ -97,17 +101,15 @@ class ImageData(BaseModel): object_list: List[str] is_photo: bool -image_data_generator = outlines.generate.json(model, ImageData) - -image_data_generator( - " detailed JSON metadata:", - [img_from_url("https://upload.wikimedia.org/wikipedia/commons/9/98/Aldrin_Apollo_11_original.jpg")] -) +model = models.TransformersVision("trl-internal-testing/tiny-LlavaForConditionalGeneration", model_class=LlavaForConditionalGeneration) +pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto" +prompt = { + "prompts": "detailed description:", + "images": img_from_url("https://upload.wikimedia.org/wikipedia/commons/e/e3/Saturn_from_Cassini_Orbiter_%282004-10-06%29.jpg"), +} +model(prompt, Json(ImageData)) ``` -> `ImageData(caption='An astronaut on the moon', tags_list=['moon', 'space', 'nasa', 'americanflag'], object_list=['moon', 'moon_surface', 'space_suit', 'americanflag'], is_photo=True)` - - ## Resources ### Choosing a model diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index ea228c5c6..4254e7cf6 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -16,13 +16,13 @@ from .mlxlm import MLXLM, mlxlm from .ollama import Ollama from .openai import AzureOpenAI, OpenAI -from .transformers import Transformers, TransformerTokenizer, mamba, transformers -from .transformers_vision import TransformersVision, transformers_vision +from .transformers import Mamba, Transformers, TransformerTokenizer +from .transformers_vision import TransformersVision from .vllm import VLLM, vllm LogitsGenerator = Union[ Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM, Ollama ] -LocalModel = LlamaCpp +LocalModel = Union[LlamaCpp, Transformers] APIModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini, Ollama] diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 444492500..3d87443f6 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -1,17 +1,14 @@ -import dataclasses -import inspect -from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union +from functools import singledispatchmethod +from typing import TYPE_CHECKING, List, Tuple, Union -from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.models.base import Model, ModelTypeAdapter from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: import torch - from transformers import PreTrainedModel, PreTrainedTokenizer + from transformers import PreTrainedTokenizer - from outlines.processors import OutlinesLogitsProcessor - -__all__ = ["transformers"] +__all__ = ["Transformers", "Mamba"] KVCacheType = Tuple[Tuple["torch.DoubleTensor", "torch.DoubleTensor"], ...] @@ -126,230 +123,130 @@ def __setstate__(self, state): self.__init__(state["tokenizer"]) -class Transformers: - """Represents a `transformers` model.""" +class TransformersTypeAdapter(ModelTypeAdapter): + @singledispatchmethod + def format_input(self, model_input): + """Generate the prompt argument to pass to the model. - def __init__( - self, - model: "PreTrainedModel", - tokenizer: "PreTrainedTokenizer", - ): - self.model = model - self.tokenizer = TransformerTokenizer(tokenizer) - - def forward( - self, - input_ids: "torch.LongTensor", - attention_mask: "torch.LongTensor", - past_key_values: Optional[Tuple] = None, - ) -> Tuple["torch.FloatTensor", Optional[KVCacheType]]: - """Compute a forward pass through the transformer model. - - Parameters - ---------- - input_ids - The input token ids. Must be one or two dimensional. - attention_mask - The attention mask. Must be one or two dimensional. - past_key_values - A tuple of tuples containing the cached key and value tensors for each - attention head. - - Returns - ------- - The computed logits and the new cached key and value tensors. + Argument + -------- + model_input + The input passed by the user. """ - try: - import torch - except ImportError: - ImportError( - "The `torch` library needs to be installed to use `transformers` models." - ) - assert 0 < input_ids.ndim < 3 - - if past_key_values: - input_ids = input_ids[..., -1].unsqueeze(-1) - - with torch.inference_mode(): - output = self.model( - input_ids, - attention_mask=attention_mask, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - past_key_values=past_key_values, - ) + raise NotImplementedError( + f"The input type {input} is not available." + "Please use a string or a list of strings." + ) - return output.logits, output.past_key_values + @format_input.register(str) + def format_str_input(self, model_input): + return model_input - def __call__( - self, - input_ids: "torch.LongTensor", - attention_mask: "torch.LongTensor", - past_key_values: Optional[Tuple] = None, - ) -> "torch.FloatTensor": - logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values) - next_token_logits = logits[..., -1, :] + @format_input.register(list) + def format_list_input(self, model_input): + return model_input - return next_token_logits, kv_cache + def format_output_type(self, output_type): + """Generate the logits processor argument to pass to the model. + + Argument + -------- + output_type + The logits processor provided. - def generate( - self, - prompts: Union[str, List[str]], - generation_parameters: GenerationParameters, - logits_processor: Optional["OutlinesLogitsProcessor"], - sampling_parameters: SamplingParameters, - ) -> Union[str, List[str], List[List[str]]]: - """Generate text using `transformers`. - - Arguments - --------- - prompts - A prompt or list of prompts. - generation_parameters - An instance of `GenerationParameters` that contains the prompt, - the maximum number of tokens, stop sequences and seed. All the - arguments to `SequenceGeneratorAdapter`'s `__cal__` method. - logits_processor - The logits processor to use when generating text. - sampling_parameters - An instance of `SamplingParameters`, a dataclass that contains - the name of the sampler to use and related parameters as available - in Outlines. - - Returns - ------- - The generated text """ - if isinstance(prompts, str): - # convert to 2d - input_ids, attention_mask = self.tokenizer.encode([prompts]) - else: - input_ids, attention_mask = self.tokenizer.encode(prompts) + from transformers import LogitsProcessorList - inputs = { - "input_ids": input_ids.to(self.model.device), - "attention_mask": attention_mask.to(self.model.device), - } - if ( - "attention_mask" - not in inspect.signature(self.model.forward).parameters.keys() - ): - del inputs["attention_mask"] - - generation_kwargs = self._get_generation_kwargs( - prompts, - generation_parameters, - logits_processor, - sampling_parameters, - ) - generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs) + if output_type is not None: + return LogitsProcessorList([output_type]) + return None - # if single str input and single sample per input, convert to a 1D output - if isinstance(prompts, str): - generated_ids = generated_ids.squeeze(0) - return self._decode_generation(generated_ids) +class Transformers(Model): + """Represents a `transformers` model.""" - def stream( + def __init__( self, - prompts: Union[str, List[str]], - generation_parameters: GenerationParameters, - logits_processor: Optional["OutlinesLogitsProcessor"], - sampling_parameters: SamplingParameters, - ) -> Iterator[Union[str, List[str]]]: - """ - Temporary stream stand-in which implements stream() signature - and equivalent behaviour but isn't yielded until generation completes. + model_name: str, + model_class=None, + model_kwargs: dict = {}, + tokenizer_class=None, + tokenizer_kwargs: dict = {}, + ): + """Create a Transformers model instance + + Parameters: + ---------- + model_name + The name of the transformers model to use; + model_class + The Transformers model class from which to create the model. + If not provided,`AutoModelForCausalLM` will be used. + If you gave the name of a non-causal language model, + you must provide a value for this parameter. + model_kwargs + A dictionary of keyword arguments to pass to the `from_pretrained` + method of the model class. + tokenizer_class + The Transformers tokenizer class from which to create the tokenizer. + If not provided,`AutoTokenizer` will be used. + If you gave the name of a model that is not compatible with `AutoTokenizer`, + you must provide a value for this parameter. + tokenizer_kwargs + A dictionary of keyword arguments to pass to the `from_pretrained` + method of the tokenizer class. - TODO: implement following completion of https://github.com/huggingface/transformers/issues/30810 """ - if isinstance(prompts, str): - # convert to 2d - input_ids, attention_mask = self.tokenizer.encode([prompts]) - else: - input_ids, attention_mask = self.tokenizer.encode(prompts) + if model_class is None or tokenizer_class is None: + try: + from transformers import AutoModelForCausalLM, AutoTokenizer + except ImportError: + raise ImportError( + "The `transformers` library needs to be installed in order to use `transformers` models." + ) + if model_class is None: + model_class = AutoModelForCausalLM + if tokenizer_class is None: + tokenizer_class = AutoTokenizer + self.model = model_class.from_pretrained(model_name, **model_kwargs) + tokenizer_kwargs.setdefault("padding_side", "left") + self.tokenizer = TransformerTokenizer( + tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs) + ) + self.type_adapter = TransformersTypeAdapter() + + def generate(self, model_input, output_type, **inference_kwargs): + prompts = self.type_adapter.format_input(model_input) + input_ids, attention_mask = self.tokenizer.encode(prompts) inputs = { "input_ids": input_ids.to(self.model.device), "attention_mask": attention_mask.to(self.model.device), } - if ( - "attention_mask" - not in inspect.signature(self.model.forward).parameters.keys() - ): - del inputs["attention_mask"] - - generation_kwargs = self._get_generation_kwargs( - prompts, - generation_parameters, - logits_processor, - sampling_parameters, + + logits_processor = self.type_adapter.format_output_type(output_type) + + generated_ids = self._generate_output_seq( + prompts, inputs, logits_processor=logits_processor, **inference_kwargs ) - generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs) - # if single str input and single sample per input, convert to a 1D output - if isinstance(prompts, str): + # if single str input, convert to a 1D outputt + if isinstance(model_input, str): generated_ids = generated_ids.squeeze(0) - for i in range(generated_ids.size(-1)): - output_group_ids = generated_ids.select(-1, i).unsqueeze(-1) - yield self._decode_generation(output_group_ids) + return self._decode_generation(generated_ids) - def _get_generation_kwargs( - self, - prompts: Union[str, List[str]], - generation_parameters: GenerationParameters, - logits_processor: Optional["OutlinesLogitsProcessor"], - sampling_parameters: SamplingParameters, - ) -> dict: + def stream(self, model_input, output_type, **inference_kwargs): """ - Conert outlines generation parameters into model.generate kwargs + TODO: implement following completion of https://github.com/huggingface/transformers/issues/30810 """ - from transformers import GenerationConfig, LogitsProcessorList, set_seed - - max_new_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) - sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( - sampling_parameters - ) - if max_new_tokens is None: - max_new_tokens = int(2**30) - - # global seed, not desirable - if seed is not None: - set_seed(seed) - - if logits_processor is not None: - logits_processor_list = LogitsProcessorList([logits_processor]) - else: - logits_processor_list = None - - generation_config = GenerationConfig( - max_new_tokens=max_new_tokens, - stop_strings=stop_at, - num_return_sequences=(num_samples or 1), - top_p=top_p, - top_k=top_k, - temperature=temperature, - do_sample=(sampler == "multinomial"), - num_beams=(num_samples if sampler == "beam_search" else 1), - eos_token_id=self.tokenizer.eos_token_id, - pad_token_id=self.tokenizer.pad_token_id, + raise NotImplementedError( + "Streaming is not implemented for Transformers models." ) - return dict( - logits_processor=logits_processor_list, - generation_config=generation_config, - tokenizer=self.tokenizer.tokenizer, - ) - - def _generate_output_seq( - self, prompts, inputs, generation_config, **generation_kwargs - ): + def _generate_output_seq(self, prompts, inputs, **inference_kwargs): input_ids = inputs["input_ids"] - output_ids = self.model.generate( - **inputs, generation_config=generation_config, **generation_kwargs - ) + output_ids = self.model.generate(**inputs, **inference_kwargs) # encoder-decoder returns output_ids only, decoder-only returns full seq ids if self.model.config.is_encoder_decoder: @@ -358,12 +255,10 @@ def _generate_output_seq( generated_ids = output_ids[:, input_ids.shape[1] :] # if batch list inputs AND multiple samples per input, convert generated_id to 3D view - num_samples = generation_config.num_return_sequences or 1 - + num_samples = inference_kwargs.get("num_return_sequences", 1) if num_samples > 1 and isinstance(prompts, list): batch_size = input_ids.size(0) - num_return_sequences = generation_config.num_return_sequences or 1 - generated_ids = generated_ids.view(batch_size, num_return_sequences, -1) + generated_ids = generated_ids.view(batch_size, num_samples, -1) return generated_ids @@ -383,76 +278,41 @@ def _decode_generation(self, generated_ids: "torch.Tensor"): ) -def transformers( - model_name: str, - device: Optional[str] = None, - model_kwargs: dict = {}, - tokenizer_kwargs: dict = {}, - model_class=None, - tokenizer_class=None, -): - """Instantiate a model from the `transformers` library and its tokenizer. - - Parameters - ---------- - model_name - The name of the model as listed on Hugging Face's model page. - device - The device(s) on which the model should be loaded. This overrides - the `device_map` entry in `model_kwargs` when provided. - model_kwargs - A dictionary that contains the keyword arguments to pass to the - `from_pretrained` method when loading the model. - tokenizer_kwargs - A dictionary that contains the keyword arguments to pass to the - `from_pretrained` method when loading the tokenizer. - - Returns - ------- - A `TransformersModel` model instance. +class Mamba(Transformers): + """Represents a Mamba model.""" - """ - if model_class is None or tokenizer_class is None: + def __init__( + self, + model_name: str, + model_kwargs: dict = {}, + tokenizer_kwargs: dict = {}, + ): + """ + Create a Mamba model instance + + Parameters: + ---------- + model_name + The name of the transformers model to use. It will be passed to + the `from_pretrained` method of the `MambaForCausalLM` class. + model_kwargs + A dictionary of keyword arguments to pass to the `from_pretrained` + method of the `MambaForCausalLM` class. + tokenizer_kwargs + A dictionary of keyword arguments to pass to the `from_pretrained` + method of the `AutoTokenizer` class. + """ try: - from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers import MambaForCausalLM + except ImportError: raise ImportError( - "The `transformers` library needs to be installed in order to use `transformers` models." + "The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba." ) - if model_class is None: - model_class = AutoModelForCausalLM - if tokenizer_class is None: - tokenizer_class = AutoTokenizer - - if device is not None: - model_kwargs["device_map"] = device - - model = model_class.from_pretrained(model_name, **model_kwargs) - tokenizer_kwargs.setdefault("padding_side", "left") - tokenizer = tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs) - - return Transformers(model, tokenizer) - - -def mamba( - model_name: str, - device: Optional[str] = None, - model_kwargs: dict = {}, - tokenizer_kwargs: dict = {}, -): - try: - from transformers import MambaForCausalLM - - except ImportError: - raise ImportError( - "The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba." + return super().__init__( + model_name=model_name, + model_class=MambaForCausalLM, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, ) - - return transformers( - model_name=model_name, - device=device, - model_kwargs=model_kwargs, - tokenizer_kwargs=tokenizer_kwargs, - model_class=MambaForCausalLM, - ) diff --git a/outlines/models/transformers_vision.py b/outlines/models/transformers_vision.py index 772645b80..b1c1c1de4 100644 --- a/outlines/models/transformers_vision.py +++ b/outlines/models/transformers_vision.py @@ -1,138 +1,146 @@ -from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union +from functools import singledispatchmethod -from outlines.generate.api import GenerationParameters, SamplingParameters from outlines.models import Transformers +from outlines.models.base import ModelTypeAdapter -if TYPE_CHECKING: - from outlines.processors import OutlinesLogitsProcessor + +class TransformersVisionTypeAdapter(ModelTypeAdapter): + """Type adapter for TransformersVision models.""" + + @singledispatchmethod + def format_input(self, model_input): + """Generate the prompt argument to pass to the model. + + Argument + -------- + model_input + The input passed by the user. + + """ + raise NotImplementedError( + f"The input type {input} is not available. Please provide a " + "dictionary with the following format: " + "{'prompts': Union[str, List[str]], 'images': Union[Any, List[Any]]}" + "Make sure the number of image tags in the prompts is equal to the " + "number of images provided." + ) + + @format_input.register(dict) + def format_list_input(self, model_input): + if "prompts" not in model_input or "images" not in model_input: + raise ValueError( + "The input must contain the following keys: 'prompts' and 'images'." + ) + return model_input["prompts"], model_input["images"] + + def format_output_type(self, output_type): + """Generate the logits processor argument to pass to the model. + + Argument + -------- + output_type + The logits processor provided. + + """ + from transformers import LogitsProcessorList + + if output_type is not None: + return LogitsProcessorList([output_type]) + return None class TransformersVision(Transformers): - def __init__(self, model, tokenizer, processor): - super().__init__(model, tokenizer) - self.processor = processor + """Represents a `transformers` model with a vision processor.""" - def generate( # type: ignore + def __init__( self, - prompts: Union[str, List[str]], - media: Union[List[Any], List[List[Any]]], - generation_parameters: GenerationParameters, - logits_processor: Optional["OutlinesLogitsProcessor"], - sampling_parameters: SamplingParameters, - ) -> Union[str, List[str], List[List[str]]]: - """Generate text using `transformers`. - - Arguments - --------- - prompts - A prompt or list of prompts. - media - A List[PIL.Image] or List[List[PIL.Image]] - generation_parameters - An instance of `GenerationParameters` that contains the prompt, - the maximum number of tokens, stop sequences and seed. All the - arguments to `SequenceGeneratorAdapter`'s `__cal__` method. - logits_processor - The logits processor to use when generating text. - sampling_parameters - An instance of `SamplingParameters`, a dataclass that contains - the name of the sampler to use and related parameters as available - in Outlines. - - Returns - ------- - The generated text + model_name: str, + model_class, + model_kwargs: dict = {}, + tokenizer_class=None, + tokenizer_kwargs: dict = {}, + processor_class=None, + processor_kwargs: dict = {}, + ): + """Create a TransformersVision model instance + + We rely on the `__init__` method of the `Transformers` class to handle + most of the initialization and then add elements specific to vision + models. + + Parameters + ---------- + model_name + The name of the transformers model to use; + model_class + The Transformers model class from which to create the model. + model_kwargs + A dictionary of keyword arguments to pass to the `from_pretrained` + method of the model class. + tokenizer_class + The Transformers tokenizer class from which to create the tokenizer. + If not provided,`AutoTokenizer` will be used. + If you gave the name of a model that is not compatible with `AutoTokenizer`, + you must provide a value for this parameter. + tokenizer_kwargs + A dictionary of keyword arguments to pass to the `from_pretrained` + method of the tokenizer class. + processor_class + The Transformers processor class from which to create the processor. + If not provided,`AutoProcessor` will be used. + If you gave the name of a model that is not compatible with `AutoProcessor`, + you must provide a value for this parameter. + processor_kwargs + A dictionary of keyword arguments to pass to the `from_pretrained` + method of the processor class. + """ + if processor_class is None: + try: + from transformers import AutoProcessor + + processor_class = AutoProcessor + except ImportError: + raise ImportError( + "The `transformers` library needs to be installed in order to use `transformers` models." + ) + + processor_kwargs.setdefault("padding_side", "left") + processor_kwargs.setdefault("pad_token", "[PAD]") + self.processor = processor_class.from_pretrained(model_name, **processor_kwargs) + + if tokenizer_class is None and getattr(self.processor, "tokenizer", None): + tokenizer_class = type(self.processor.tokenizer) + + super().__init__( + model_name, + model_class, + model_kwargs, + tokenizer_class, + tokenizer_kwargs, + ) + + self.type_adapter = TransformersVisionTypeAdapter() + + def generate(self, model_input, output_type, **inference_kwargs): + prompts, images = self.type_adapter.format_input(model_input) + inputs = self.processor( - text=prompts, images=media, padding=True, return_tensors="pt" + text=prompts, images=images, padding=True, return_tensors="pt" ).to(self.model.device) + logits_processor = self.type_adapter.format_output_type(output_type) - generation_kwargs = self._get_generation_kwargs( - prompts, - generation_parameters, - logits_processor, - sampling_parameters, + generated_ids = self._generate_output_seq( + prompts, inputs, logits_processor=logits_processor, **inference_kwargs ) - generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs) - # if single str input and single sample per input, convert to a 1D output + # if single str input, convert to a 1D outputt if isinstance(prompts, str): - # Should always be true until NotImplementedError above is fixed generated_ids = generated_ids.squeeze(0) return self._decode_generation(generated_ids) - def stream( # type: ignore - self, - prompts: Union[str, List[str]], - media: Union[Any, List[Any]], # TODO: docstring - generation_parameters: GenerationParameters, - logits_processor: Optional["OutlinesLogitsProcessor"], - sampling_parameters: SamplingParameters, - ) -> Iterator[Union[str, List[str]]]: - raise NotImplementedError - - -def transformers_vision( - model_name: str, - model_class, - device: Optional[str] = None, - model_kwargs: dict = {}, - processor_kwargs: dict = {}, - tokenizer_class=None, - processor_class=None, -): - """Instantiate a model from the `transformers` library and its tokenizer. - - Parameters - ---------- - model_name - The name of the model as listed on Hugging Face's model page. - model_class - The `PreTrainedModel` class from transformers to use in initializing the vision model from `model_name`. - https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel - device - The device(s) on which the model should be loaded. This overrides - the `device_map` entry in `model_kwargs` when provided. - model_kwargs - A dictionary that contains the keyword arguments to pass to the - `from_pretrained` method when loading the model. - processor_kwargs - A dictionary that contains the keyword arguments to pass to the - `from_pretrained` method when loading the processor. - - Returns - ------- - A `TransformersModel` model instance. - - """ - if processor_class is None or tokenizer_class is None: - try: - from transformers import AutoProcessor, AutoTokenizer - except ImportError: - raise ImportError( - "The `transformers` library needs to be installed in order to use `transformers` models." - ) - if processor_class is None: - processor_class = AutoProcessor - if tokenizer_class is None: - tokenizer_class = AutoTokenizer - - if device is not None: - model_kwargs["device_map"] = device - - model = model_class.from_pretrained(model_name, **model_kwargs) - - processor_kwargs.setdefault("padding_side", "left") - processor_kwargs.setdefault("pad_token", "[PAD]") - processor = processor_class.from_pretrained(model_name, **processor_kwargs) - - if tokenizer_class is None: - if getattr(processor, "tokenizer", None): - tokenizer = processor.tokenizer - else: - tokenizer = AutoTokenizer.from_pretrained(model_name, **processor_kwargs) - else: - tokenizer = tokenizer_class.from_pretrained(model_name, **processor_kwargs) - - return TransformersVision(model, tokenizer, processor) + def stream(self, model_input, output_type, **inference_kwargs): + raise NotImplementedError( + "Streaming is not implemented for TransformersVision models." + ) diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index b4582b0f8..fd69f2193 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -5,7 +5,7 @@ from jsonschema import Draft202012Validator as Validator from jsonschema.exceptions import SchemaError -from outlines_core.json_schema import build_regex_from_schema +from outlines_core.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel, TypeAdapter from typing_extensions import _TypedDictMeta # type: ignore diff --git a/outlines/types/dsl.py b/outlines/types/dsl.py index 86feabca0..645b0486a 100644 --- a/outlines/types/dsl.py +++ b/outlines/types/dsl.py @@ -138,6 +138,9 @@ def _display_node(self) -> str: def __repr__(self): return f"Regex(pattern='{self.pattern}')" + def to_regex(self) -> str: + return self.pattern + class JsonSchema(Term): def __init__(self, schema: Union[dict, str, type[BaseModel]]): diff --git a/pyproject.toml b/pyproject.toml index 6ad04f601..02e76eb43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,11 +158,8 @@ module = [ "iso3166.*", "airportsdata.*", "outlines_core.*", -<<<<<<< HEAD "genson", -======= "ollama.*", ->>>>>>> 5af6b93 (Create the Ollama model) ] ignore_missing_imports = true diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index c7e9bdc32..aaa4509bf 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -54,25 +54,25 @@ def model_mlxlm_phi3(tmp_path_factory): @pytest.fixture(scope="session") def model_transformers_random(tmp_path_factory): - return models.transformers("hf-internal-testing/tiny-random-gpt2", device="cpu") + return models.Transformers("hf-internal-testing/tiny-random-gpt2") @pytest.fixture(scope="session") def model_transformers_opt125m(tmp_path_factory): - return models.transformers("facebook/opt-125m", device="cpu") + return models.Transformers("facebook/opt-125m") @pytest.fixture(scope="session") def model_mamba(tmp_path_factory): - return models.mamba(model_name="state-spaces/mamba-130m-hf", device="cpu") + return models.Mamba(model_name="state-spaces/mamba-130m-hf") @pytest.fixture(scope="session") def model_bart(tmp_path_factory): from transformers import AutoModelForSeq2SeqLM - return models.transformers( - "facebook/bart-base", device="cpu", model_class=AutoModelForSeq2SeqLM + return models.Transformers( + "facebook/bart-base", model_class=AutoModelForSeq2SeqLM ) @@ -123,11 +123,11 @@ def model_t5(tmp_path_factory): "model_exllamav2", "model_mlxlm", "model_mlxlm_phi3", - "model_transformers_random", - "model_transformers_opt125m", - "model_mamba", - "model_bart", - "model_transformers_vision", + # "model_transformers_random", + # "model_transformers_opt125m", + # "model_mamba", + # "model_bart", + # "model_transformers_vision", "model_vllm", ) @@ -280,20 +280,21 @@ def test_generate_json(request, model_fixture, sample_schema): generator(**get_inputs(model_fixture), max_tokens=100) -def test_integrate_genson_generate_json(request): - from genson import SchemaBuilder - - builder = SchemaBuilder() - builder.add_schema({"type": "object", "properties": {}}) - builder.add_object({"name": "Toto", "age": 5}) - - model = request.getfixturevalue("model_transformers_opt125m") - - generator = generate.json(model, builder) - res = generator("Return a json of a young boy") - - assert "name" in res - assert "age" in res +# TODO: add support for genson in the Regex type of v1.0 +#def test_integrate_genson_generate_json(request): +# from genson import SchemaBuilder +# +# builder = SchemaBuilder() +# builder.add_schema({"type": "object", "properties": {}}) +# builder.add_object({"name": "Toto", "age": 5}) +# +# model = request.getfixturevalue("model_transformers_opt125m") +# +# generator = generate.json(model, builder) +# res = generator("Return a json of a young boy") +# +# assert "name" in res +# assert "age" in res @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py deleted file mode 100644 index 92c5d789c..000000000 --- a/tests/generate/test_integration_transformers.py +++ /dev/null @@ -1,569 +0,0 @@ -import datetime -import re -from enum import Enum -from functools import partial -from typing import List, Union - -import pytest -import torch -from outlines_core.fsm.regex import reduced_vocabulary -from pydantic import BaseModel, constr - -import outlines.generate as generate -import outlines.models as models -from outlines.models.transformers import Transformers, TransformerTokenizer -from outlines.samplers import beam_search, greedy, multinomial - - -@pytest.fixture(scope="session") -def model(tmp_path_factory): - return models.transformers( - "hf-internal-testing/tiny-random-GPTJForCausalLM", device="cpu" - ) - - -def test_transformers_integration_text(model): - sequence = generate.text(model)( - "Write a short sentence ", seed=10000, max_tokens=10 - ) - assert isinstance(sequence, str) - assert model.tokenizer.eos_token not in sequence - - sequence = generate.text(model)( - "Write a short sentence ", - max_tokens=20, - stop_at="a", - seed=10000, - ) - assert isinstance(sequence, str) - - prompts = ["Write a short sentence ", "And another one "] - sequence = generate.text(model)( - prompts, max_tokens=10, stop_at=[".", ","], seed=10000 - ) - assert isinstance(sequence, list) - assert len(sequence) == 2 - assert isinstance(sequence[0], str) - - -def test_transformers_integration_text_multiple_samples(model): - sampler = multinomial(2) - - sequence = generate.text(model, sampler=sampler)( - "Write a short sentence ", seed=10000 - ) - assert isinstance(sequence, list) - assert len(sequence) == 2 - assert model.tokenizer.eos_token not in sequence - - prompts = ["Write a short sentence ", "And another one "] - sequence = generate.text(model, sampler=sampler)( - prompts, max_tokens=10, stop_at=[".", ","], seed=10000 - ) - assert isinstance(sequence, list) - assert len(sequence) == 2 - assert isinstance(sequence[0], list) - assert len(sequence) == 2 - assert isinstance(sequence[0][0], str) - - -def test_transformers_integration_streaming(model): - sequence = generate.text(model).stream( - "Write a short sentence ", max_tokens=10, stop_at=[".", ","], seed=10000 - ) - - token = next(sequence) - assert isinstance(token, str) - - remaining = "".join([token for token in sequence]) - assert isinstance(remaining, str) - - sequence = generate.text(model).stream( - ["Prompt1", "Prompt2"], max_tokens=10, stop_at=[".", ","], seed=10000 - ) - tokens = next(sequence) - assert isinstance(tokens, list) - assert isinstance(tokens[0], str) - assert isinstance(tokens[1], str) - - -def test_transformers_integration_streaming_batch_samples(model): - sampler = multinomial(samples=2) - - sequence = generate.text(model, sampler=sampler).stream( - ["Prompt1", "Prompt2"], max_tokens=10, stop_at=[".", ","], seed=10000 - ) - tokens = next(sequence) - assert isinstance(tokens, list) - assert len(tokens) == 2 - assert isinstance(tokens[0], list) - assert len(tokens[0]) == 2 - assert isinstance(tokens[0], list) - assert len(tokens[1]) == 2 - - -def test_transformers_integration_streaming_batch_beam_search(model): - sampler = beam_search(beams=2) - - sequence = generate.regex(model, r"ab[cd]e", sampler=sampler).stream( - ["Prompt1", "Prompt2"], max_tokens=10, stop_at=["c", "d"], seed=10000 - ) - tokens = next(sequence) - assert isinstance(tokens, list) - assert len(tokens) == 2 - assert isinstance(tokens[0], list) - assert len(tokens[0]) == 2 - assert isinstance(tokens[0], list) - assert len(tokens[1]) == 2 - - -def test_transformers_integration_text_stop(model): - prompt = "Write a short sentence " - sequence = generate.text(model)(prompt, stop_at="a", seed=10000) - assert sequence[len(prompt) :].find("a") == -1 - - -def test_transformers_various_regexes(model): - prompt = "Write an email address" - regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" - generator = generate.regex(model, regex_str) - - # One prompt - sequence = generator(prompt, seed=0) - assert re.fullmatch(regex_str, sequence) is not None - - -def test_transformers_various_regexes_prompt_list(model): - prompt = "Write an email address" - regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" - generator = generate.regex(model, regex_str) - - # Two prompts - sequence = generator([prompt, prompt], seed=0) - assert re.fullmatch(regex_str, sequence[0]) is not None - assert re.fullmatch(regex_str, sequence[1]) is not None - - -def test_transformers_various_regexes_prompt_list_multiple_samples(model): - sampler = multinomial(samples=2) - prompt = "Write an email address" - regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" - generator = generate.regex(model, regex_str, sampler=sampler) - - # Two prompts - sequence = generator([prompt, prompt], seed=0, max_tokens=500) - assert isinstance(sequence, list) - assert len(sequence) == 2 - assert re.fullmatch(regex_str, sequence[0][0]) is not None - assert re.fullmatch(regex_str, sequence[0][1]) is not None - assert re.fullmatch(regex_str, sequence[1][0]) is not None - assert re.fullmatch(regex_str, sequence[1][1]) is not None - - -def test_transformers_various_regexes_prompt_list_beam_search(model): - sampler = beam_search(5) - prompt_1 = "Write an email address" - prompt_2 = "Random" - regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" - generator = generate.regex(model, regex_str, sampler=sampler) - - # Two prompts - sequence = generator([prompt_1, prompt_2], seed=0) - assert isinstance(sequence, list) - assert len(sequence) == 2 - assert len(sequence[0]) == 5 - assert re.fullmatch(regex_str, sequence[0][0]) is not None - assert re.fullmatch(regex_str, sequence[0][1]) is not None - assert re.fullmatch(regex_str, sequence[1][0]) is not None - assert re.fullmatch(regex_str, sequence[1][1]) is not None - - -def test_transformers_integration_integer(model): - prompt = "Write a short sentence" - sequence = generate.format(model, int)(prompt, max_tokens=10, seed=0) - - assert isinstance(sequence, int) - - -def test_transformers_integration_integer_array(model): - prompts = ["Give me a number", "And another one"] - sequence = generate.format(model, int)(prompts, max_tokens=10, seed=0) - assert isinstance(sequence, list) - assert len(sequence) == 2 - assert isinstance(sequence[0], int) - assert isinstance(sequence[1], int) - - -def test_transformers_integration_float(model): - prompt = "Write a short sentence" - sequence = generate.format(model, float)(prompt, max_tokens=10, seed=0) - - assert sequence != "" - assert isinstance(sequence, float) - - -def test_transformers_integration_bool(model): - prompt = "Is this True or False?" - sequence = generate.format(model, bool)(prompt, max_tokens=10, seed=0) - - assert sequence != "" - assert isinstance(sequence, bool) - - -def test_transformers_integration_date(model): - prompt = "What day is it today?" - sequence = generate.format(model, datetime.date)(prompt, max_tokens=10, seed=0) - - assert sequence != "" - assert isinstance(sequence, datetime.date) - - -def test_transformers_integration_time(model): - prompt = "What time is it?" - sequence = generate.format(model, datetime.time)(prompt, max_tokens=10, seed=0) - - assert sequence != "" - assert isinstance(sequence, datetime.time) - - -def test_transformers_integration_datetime(model): - prompt = "What time is it?" - sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20, seed=0) - - assert sequence != 0 - assert isinstance(sequence, datetime.datetime) - - -def test_transformers_integration_choice(model): - prompt = "Write a short sentence " - sequence = generate.choice(model, ["test", "choice"])(prompt, seed=0) - - assert sequence == "test" or sequence == "choice" - - -def test_transformers_integration_with_pad_token(): - model_name = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM" - model = models.transformers(model_name, device="meta") - assert model.tokenizer.pad_token_id == 1 - assert model.tokenizer.pad_token == "" - - -def test_transformers_json_basic(model): - prompt = "Output some JSON " - - class Spam(BaseModel): - foo: int - bar: float - spam: constr(max_length=10) - fuzz: bool - - result = generate.json(model, Spam)(prompt, max_tokens=500, seed=0) - assert isinstance(result, BaseModel) - assert isinstance(result.foo, int) - assert isinstance(result.bar, float) - assert isinstance(result.spam, str) - assert isinstance(result.fuzz, bool) - assert len(result.spam) <= 10 - - -def test_transformers_json_schema(model): - prompt = "Output some JSON " - - schema = """{ - "title": "spam", - "type": "object", - "properties": { - "foo" : {"type": "integer"}, - "bar": {"type": "string", "maxLength": 4} - }, - "required": ["foo", "bar"] - } - """ - - result = generate.json(model, schema)(prompt, max_tokens=500, seed=0) - assert isinstance(result, dict) - assert isinstance(result["foo"], int) - assert isinstance(result["bar"], str) - - -def test_transformers_json_batch(model): - prompts = ["Output some JSON ", "Output more JSON"] - - class Spam(BaseModel): - foo: int - bar: float - spam: constr(max_length=10) - fuzz: bool - - result = generate.json(model, Spam)(prompts, max_tokens=500, seed=0) - assert isinstance(result[0], BaseModel) - assert isinstance(result[1], BaseModel) - - -def test_transformers_json_batch_multiple_samples(model): - sampler = multinomial(samples=2) - prompts = ["Output some JSON ", "Output more JSON"] - - class Spam(BaseModel): - foo: int - bar: float - spam: constr(max_length=10) - fuzz: bool - - result = generate.json(model, Spam, sampler=sampler)( - prompts, max_tokens=500, seed=0 - ) - assert isinstance(result, list) - assert len(result) == 2 - assert isinstance(result[0][0], BaseModel) - assert isinstance(result[0][1], BaseModel) - assert isinstance(result[1][0], BaseModel) - assert isinstance(result[1][1], BaseModel) - - -def test_transformers_json_str_enum(model): - prompt = "Output some JSON " - - class Name(str, Enum): - john = "John" - marc = "Marc" - michel = "Michel" - - class User(BaseModel): - user_id: int - name: Name - - result = generate.json(model, User)(prompt, seed=0) - assert isinstance(result, BaseModel) - assert isinstance(result.user_id, int) - assert result.name in ["John", "Marc", "Michel"] - - -def test_transformers_json_int_enum(model): - prompt = "Output some JSON " - - class Id(int, Enum): - one = 1 - two = 2 - - class User(BaseModel): - user_id: Id - - result = generate.json(model, User)(prompt, seed=0) - assert isinstance(result, BaseModel) - assert isinstance(result.user_id, int) - assert result.user_id in [1, 2] - - -def add(a: int, b: int) -> int: - return a + b - - -def mul(c: float, d: float) -> float: - return c * d - - -def test_transformers_json_function_enum(model): - prompt = "Output some JSON " - - class Operation(Enum): - add = partial(add) - mul = partial(mul) - - result = generate.json(model, Operation)(prompt, seed=0) - assert isinstance(result, dict) - assert len(result) == 2 - for k, v in result.items(): - assert k in ["a", "b", "c", "d"] - assert isinstance(v, (int, float)) - - -def test_transformers_json_array(model): - prompt = "Output some JSON " - - class User(BaseModel): - user_id: int - value: List[float] - - result = generate.json(model, User)(prompt, seed=0) - assert isinstance(result, BaseModel) - assert isinstance(result.user_id, int) - assert isinstance(result.value, list) - for value in result.value: - assert isinstance(value, float) or isinstance(value, int) - - -@pytest.mark.xfail(reason="The implementation of `anyOf` is incorrect") -def test_transformers_json_union(model): - prompt = "Output some JSON " - - class Spam(BaseModel): - foo: int - bar: Union[constr(max_length=10), float] - - result = generate.json(model, Spam)(prompt, max_tokens=100, seed=4) - assert isinstance(result, BaseModel) - assert ( - isinstance(result.bar, int) - or isinstance(result.bar, float) - or isinstance(result.bar, str) - ) - - -def test_transformers_json_function(model): - prompt = "Output arguments for the function" - - def function(foo: int, bar: List[int]): - return foo + sum(bar) - - sequence = generate.json(model, function)(prompt, max_tokens=100, seed=4) - assert isinstance(sequence, dict) - assert isinstance(function(**sequence), int) - - -def test_transformers_logits_vocab_size(model): - # Artificially increase the weights/logits size relative - # to the vocabulary - model.model.resize_token_embeddings(pad_to_multiple_of=3) - - assert len(model.tokenizer.vocabulary) == 1024 - assert model.model.base_model.wte.weight.shape[0] == 1026 - - generator = generate.choice(model, ["True", "False"]) - - sequence = generator("blah", seed=101) - assert sequence == "False" - - -def test_transformers_json_custom_ws(model): - prompt = "Output some JSON with newlines" # try to force model to use newlines - - schema = """{ - "title": "spam", - "type": "object", - "properties": { - "foo" : {"type": "integer"}, - "bar": {"type": "integer"} - }, - "required": ["foo", "bar"] - } - """ - - generator = generate.json(model, schema, whitespace_pattern=r"[ ]?") - generator.format_sequence = lambda x: x # patch to return raw text - assert "\n" not in generator(prompt, max_tokens=500, seed=0) - - -def test_transformers_reduced_vocabulary_caching(): - from transformers import AutoTokenizer - - hf_tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer = TransformerTokenizer(hf_tokenizer) - tokenizer2 = TransformerTokenizer(hf_tokenizer) - - # TODO: We might actually want only one copy of a given tokenizer. - assert tokenizer is not tokenizer2 - - vocab = reduced_vocabulary(tokenizer) - vocab2 = reduced_vocabulary(tokenizer2) - - assert vocab2 is vocab - - -@pytest.mark.skip(reason="Custom Sampler Disabled in Transformers Integration") -def test_custom_sampler(model): - seen = False - target_token_ids = model.tokenizer.encode(["c"])[0] - - class biased_sampler: - def __init__(self, samples: int = 1): - self.samples = samples - - def __call__( - logits: torch.DoubleTensor, samples: int, *_ - ) -> torch.DoubleTensor: - nonlocal seen - - if not seen: - seen = True - return target_token_ids, torch.tensor([0]), None - else: - return ( - torch.tensor([[model.tokenizer.eos_token_id]]), - torch.tensor([0]), - None, - ) - - generator = generate.choice(model, ["a", "b", "c"], sampler=biased_sampler(1)) - sequence = generator( - """What is 1+1? - a. 3 - b. 4 - c. 2""" - ) - - assert sequence == "c" - - -def test_transformers_use_existing_model_and_tokenizer(): - from transformers import AutoModelForCausalLM, AutoTokenizer - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - hf_tokenizer = AutoTokenizer.from_pretrained(model_name) - hf_model = AutoModelForCausalLM.from_pretrained(model_name) - model = Transformers(hf_model, hf_tokenizer) - sequence = generate.text(model)("Write a short sentence ", seed=10000) - assert isinstance(sequence, str) - - -@pytest.mark.skip("Caching for guide was temporarily turned off") -def test_RegexGuide_caching(temp_cache_dir): - import outlines.caching - from outlines.fsm.guide import cached_create_states_mapping - - assert outlines.caching._caching_enabled - - regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" - prompt = "What is the IP address of the Google DNS servers? " - - cache = outlines.caching.get_cache() - - # Returns (hits, misses) - _ = cache.stats(enable=True) - assert cache.statistics - - assert cached_create_states_mapping.__memory__ is cache - - model = models.transformers( - "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM", device="cpu" - ) - generator = generate.regex(model, regex, sampler=greedy()) - assert cache.stats() == (0, 1) - - model_2 = models.transformers( - "hf-internal-testing/tiny-random-GPTJForCausalLM", device="cpu" - ) - generator_2 = generate.regex(model_2, regex, sampler=greedy()) - assert cache.stats() == (0, 2) - - # These two different models and tokenizers should not have the same state - # mapping results - assert ( - generator.logits_processor.guide.states_to_token_maps - != generator_2.logits_processor.guide.states_to_token_maps - ) - - generator_3 = generate.regex(model_2, regex, sampler=greedy()) - assert cache.stats() == (1, 2) - assert ( - generator_2.logits_processor.guide.states_to_token_maps - == generator_3.logits_processor.guide.states_to_token_maps - ) - - # Just for fun... - structured = generator(prompt, max_tokens=30) - structured_2 = generator_2(prompt, max_tokens=30) - - assert re.fullmatch(regex, structured) - assert re.fullmatch(regex, structured_2) - assert structured != structured_2 diff --git a/tests/generate/test_integration_transformers_vision.py b/tests/generate/test_integration_transformers_vision.py deleted file mode 100644 index ee4f84c06..000000000 --- a/tests/generate/test_integration_transformers_vision.py +++ /dev/null @@ -1,116 +0,0 @@ -from io import BytesIO -from urllib.request import urlopen - -import pytest -from PIL import Image -from transformers import AutoProcessor, LlavaForConditionalGeneration - -import outlines -from outlines.models.transformers_vision import transformers_vision - -IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg", - "https://upload.wikimedia.org/wikipedia/commons/7/71/2010-kodiak-bear-1.jpg", - "https://upload.wikimedia.org/wikipedia/commons/b/be/Tamias-rufus-001.jpg", -] - - -def img_from_url(url): - img_byte_stream = BytesIO(urlopen(url).read()) - return Image.open(img_byte_stream).convert("RGB") - - -@pytest.fixture(scope="session") -def model(tmp_path_factory): - return transformers_vision( - "trl-internal-testing/tiny-LlavaForConditionalGeneration", - model_class=LlavaForConditionalGeneration, - device="cpu", - ) - - -@pytest.fixture(scope="session") -def processor(tmp_path_factory): - return AutoProcessor.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf") - - -def test_single_image_text_gen(model, processor): - conversation = [ - { - "role": "user", - "content": [{"type": "text", "text": "What is this?"}, {"type": "image"}], - }, - ] - generator = outlines.generate.text(model) - sequence = generator( - processor.apply_chat_template(conversation), - [img_from_url(IMAGE_URLS[0])], - seed=10000, - max_tokens=10, - ) - assert isinstance(sequence, str) - - -def test_multi_image_text_gen(model, processor): - """If the length of image tags and number of images we pass are > 1 and equal, - we should yield a successful generation. - """ - conversation = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What do all these have in common?"}, - ] - + [{"type": "image"} for _ in range(len(IMAGE_URLS))], - }, - ] - generator = outlines.generate.text(model) - sequence = generator( - processor.apply_chat_template(conversation), - [img_from_url(i) for i in IMAGE_URLS], - seed=10000, - max_tokens=10, - ) - assert isinstance(sequence, str) - - -def test_mismatched_image_text_gen(model, processor): - """If the length of image tags and number of images we pass are unequal, - we should raise an error. - """ - conversation = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "I'm passing 3 images, but only 1 image tag"}, - {"type": "image"}, - ], - }, - ] - generator = outlines.generate.text(model) - with pytest.raises(ValueError): - _ = generator( - processor.apply_chat_template(conversation), - [img_from_url(i) for i in IMAGE_URLS], - seed=10000, - max_tokens=10, - ) - - -def test_single_image_choice(model, processor): - conversation = [ - { - "role": "user", - "content": [{"type": "text", "text": "What is this?"}, {"type": "image"}], - }, - ] - choices = ["cat", "dog"] - generator = outlines.generate.choice(model, choices) - sequence = generator( - processor.apply_chat_template(conversation), - [img_from_url(IMAGE_URLS[0])], - seed=10000, - max_tokens=10, - ) - assert isinstance(sequence, str) - assert sequence in choices diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 1404d287a..249893ff9 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -1,117 +1,153 @@ +import re +from enum import Enum + import pytest -import torch -from transformers import AutoTokenizer -from transformers.models.gpt2 import GPT2TokenizerFast +from pydantic import BaseModel +from transformers import AutoModelForSeq2SeqLM + +from outlines.models.transformers import Mamba, Transformers +from outlines.types import Choice, Json, Regex + +TEST_MODEL = "erwanf/gpt2-mini" +TEST_MODEL_SEQ2SEQ = "hf-internal-testing/tiny-random-t5" +TEST_MODEL_MAMBA = "hf-internal-testing/tiny-random-MambaForCausalLM" + + +def test_transformers_instantiate_simple(): + model = Transformers(TEST_MODEL) + assert isinstance(model, Transformers) + + +def test_transformers_instantiate_wrong_kwargs(): + with pytest.raises(TypeError): + Transformers(TEST_MODEL, model_kwargs={"foo": "bar"}) + + +def test_transformers_instantiate_other_model_class(): + model = Transformers( + model_name=TEST_MODEL_SEQ2SEQ, model_class=AutoModelForSeq2SeqLM + ) + assert isinstance(model, Transformers) + + +def test_transformers_instantiate_mamba(): + model = Mamba( + model_name=TEST_MODEL_MAMBA, + ) + assert isinstance(model, Mamba) + assert isinstance(model, Transformers) + + +def test_transformers_instantiate_tokenizer_kwargs(): + model = Transformers( + TEST_MODEL, + tokenizer_kwargs={"additional_special_tokens": ["", ""]} + ) + assert "" in model.tokenizer.special_tokens + assert "" in model.tokenizer.special_tokens + + +@pytest.fixture +def model(): + return Transformers(TEST_MODEL) + + +def test_transformers_simple(model): + result = model.generate("Respond with one word. Not more.", None) + assert isinstance(result, str) + + +def test_transformers_call(model): + result = model("Respond with one word. Not more.") + assert isinstance(result, str) + -from outlines.models.transformers import TransformerTokenizer, transformers +def test_transformers_inference_kwargs(model): + result = model("Respond with one word. Not more.", max_new_tokens=100) + assert isinstance(result, str) -TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM" +def test_transformers_invalid_inference_kwargs(model): + with pytest.raises(ValueError): + model("Respond with one word. Not more.", foo="bar") -def test_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL, padding_side="left") - tokenizer = TransformerTokenizer(tokenizer) - assert tokenizer.eos_token_id == 0 - assert tokenizer.pad_token_id == 0 - assert isinstance(tokenizer.tokenizer, GPT2TokenizerFast) - token_ids, attention_mask = tokenizer.encode("Test") - assert token_ids.ndim == 2 - assert token_ids.shape[0] == 1 - assert isinstance(token_ids, torch.LongTensor) - assert token_ids.shape == attention_mask.shape +def test_transformers_regex(model): + result = model("Give a number between 0 and 9.", Regex(r"[0-9]")) + assert isinstance(result, str) + assert re.match(r"[0-9]", result) - token_ids, attention_mask = tokenizer.encode(["Test", "Test"]) - assert token_ids.ndim == 2 - assert token_ids.shape[0] == 2 - assert isinstance(token_ids, torch.LongTensor) - assert token_ids.shape == attention_mask.shape - token_ids, attention_mask = tokenizer.encode(["Test", "A long sentence"]) - assert token_ids.shape == attention_mask.shape - assert attention_mask[0][0] == tokenizer.pad_token_id +def test_transformers_json(model): + class Character(BaseModel): + name: str - text = tokenizer.decode(torch.tensor([[0, 1, 2]])) - isinstance(text, str) + result = model("Create a character with a name.", Json(Character)) + assert "name" in result - text = tokenizer.decode(torch.tensor([[0, 1, 2], [3, 4, 5]])) - isinstance(text, list) - isinstance(text[0], str) - isinstance(text[1], str) - tokenizer = AutoTokenizer.from_pretrained( - TEST_MODEL, additional_special_tokens=["", ""] +def test_transformers_choice(model): + class Foo(Enum): + cat = "cat" + dog = "dog" + + result = model("Cat or dog?", Choice(Foo)) + assert result in ["cat", "dog"] + + +def test_transformers_batch_samples(model): + result = model("Respond with one word. Not more.") + assert isinstance(result, str) + result = model( + "Respond with one word. Not more.", num_return_sequences=2, num_beams=2 + ) + assert isinstance(result, list) + assert len(result) == 2 + result = model( + ["Respond with one word. Not more.", "Respond with one word. Not more."] + ) + assert isinstance(result, list) + assert len(result) == 2 + result = model( + ["Respond with one word. Not more.", "Respond with one word. Not more."], + num_return_sequences=2, + num_beams=2, + ) + assert isinstance(result, list) + assert len(result) == 2 + for item in result: + assert isinstance(item, list) + assert len(item) == 2 + + +def test_transformers_batch_samples_constrained(model): + class Foo(Enum): + cat = "cat" + dog = "dog" + + result = model("Cat or dog?", Choice(Foo), num_return_sequences=2, num_beams=2) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0] in ["cat", "dog"] + assert result[1] in ["cat", "dog"] + result = model( + ["Cat or dog?", "Cat or dog?"], + Choice(Foo), + ) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0] in ["cat", "dog"] + assert result[1] in ["cat", "dog"] + result = model( + ["Cat or dog?", "Cat or dog?"], + Choice(Foo), + num_return_sequences=2, + num_beams=2, ) - tokenizer = TransformerTokenizer(tokenizer) - assert "" in tokenizer.special_tokens - assert "" in tokenizer.special_tokens - - -def test_llama_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - tokenizer = TransformerTokenizer(tokenizer) - - # Broken - assert tokenizer.tokenizer.convert_tokens_to_string(["▁baz"]) == "baz" - assert tokenizer.tokenizer.convert_tokens_to_string(["<0x20>"]) == "" - assert tokenizer.tokenizer.convert_tokens_to_string(["▁▁▁"]) == " " - - # Not broken - assert tokenizer.convert_token_to_string("▁baz") == " baz" - assert tokenizer.convert_token_to_string("<0x20>") == " " - assert tokenizer.convert_token_to_string("▁▁▁") == " " - - -def test_model(): - model = transformers(TEST_MODEL, device="cpu") - assert isinstance(model.tokenizer, TransformerTokenizer) - assert model.model.device.type == "cpu" - - model = transformers(TEST_MODEL, model_kwargs={"device_map": "cpu"}) - assert isinstance(model.tokenizer, TransformerTokenizer) - assert model.model.device.type == "cpu" - - model = transformers(TEST_MODEL, device="cpu", model_kwargs={"device_map": "cuda"}) - assert isinstance(model.tokenizer, TransformerTokenizer) - assert model.model.device.type == "cpu" - - input_ids = torch.tensor([[0, 1, 2]]) - logits, kv_cache = model(input_ids, torch.ones_like(input_ids)) - assert logits.type() == "torch.FloatTensor" - assert logits.ndim == 2 - assert logits.shape[0] == 1 - assert len(kv_cache) == model.model.config.n_layer - assert len(kv_cache[0]) == 2 - assert kv_cache[0][0].shape[1] == model.model.config.n_head - assert kv_cache[0][0].shape[2] == 3 # number of tokens - - input_ids = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - logits, kv_cache = model(input_ids, torch.ones_like(input_ids)) - assert logits.type() == "torch.FloatTensor" - assert logits.ndim == 2 - assert logits.shape[0] == 3 - assert len(kv_cache) == model.model.config.n_layer - assert len(kv_cache[0]) == 2 - assert kv_cache[0][0].shape[1] == model.model.config.n_head - assert kv_cache[0][0].shape[2] == 3 # number of tokens - - with pytest.raises(AssertionError): - input_ids = torch.tensor([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [0, 1, 2]]]) - logits = model(input_ids, torch.ones_like(input_ids)) - - -def test_tokenizer_eq_hash(): - tokenizer_hf = AutoTokenizer.from_pretrained("gpt2") - - tokenizer = TransformerTokenizer(tokenizer_hf) - tokenizer_2 = TransformerTokenizer(tokenizer_hf) - - assert tokenizer == tokenizer_2 - assert hash(tokenizer) == hash(tokenizer_2) - - tokenizer_hf_2 = AutoTokenizer.from_pretrained("gpt2") - tokenizer_hf_2.add_tokens(["test_token"]) - - tokenizer_3 = TransformerTokenizer(tokenizer_hf_2) - assert tokenizer != tokenizer_3 - assert hash(tokenizer) != hash(tokenizer_3) + assert isinstance(result, list) + assert len(result) == 2 + for item in result: + assert isinstance(item, list) + assert len(item) == 2 + assert item[0] in ["cat", "dog"] + assert item[1] in ["cat", "dog"] diff --git a/tests/models/test_transformers_vision.py b/tests/models/test_transformers_vision.py new file mode 100644 index 000000000..44ca4e050 --- /dev/null +++ b/tests/models/test_transformers_vision.py @@ -0,0 +1,235 @@ +import re +from enum import Enum +from io import BytesIO +from urllib.request import urlopen + +import pytest +from PIL import Image +from pydantic import BaseModel +from transformers import ( + Blip2ForConditionalGeneration, + CLIPModel, + CLIPProcessor, + LlavaForConditionalGeneration, +) + +from outlines.models.transformers_vision import TransformersVision +from outlines.types import Choice, Json, Regex + +TEST_MODEL = "trl-internal-testing/tiny-LlavaForConditionalGeneration" +TEST_CLIP_MODEL = "openai/clip-vit-base-patch32" +IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg", + "https://upload.wikimedia.org/wikipedia/commons/7/71/2010-kodiak-bear-1.jpg", +] + + +@pytest.fixture +def images(): + def img_from_url(url): + img_byte_stream = BytesIO(urlopen(url).read()) + return Image.open(img_byte_stream).convert("RGB") + + return [img_from_url(url) for url in IMAGE_URLS] + + +@pytest.fixture +def model(): + return TransformersVision( + model_name=TEST_MODEL, model_class=LlavaForConditionalGeneration + ) + + +def test_transformers_vision_instantiate_simple(): + model = TransformersVision( + model_name=TEST_MODEL, + model_class=Blip2ForConditionalGeneration, + ) + assert isinstance(model, TransformersVision) + + +def test_transformers_vision_instantiate_other_processor_class(): + model = TransformersVision( + model_name=TEST_CLIP_MODEL, model_class=CLIPModel, processor_class=CLIPProcessor + ) + assert isinstance(model, TransformersVision) + + +def test_transformers_vision_simple(model, images): + result = model.generate( + {"prompts": "Describe this image in one sentence:", "images": images[0]}, + None, + ) + assert isinstance(result, str) + + +def test_transformers_vision_call(model, images): + result = model( + {"prompts": "Describe this image in one sentence:", "images": images[0]} + ) + assert isinstance(result, str) + + +def test_transformers_vision_wrong_number_images(model, images): + with pytest.raises(ValueError): + a = model( + { + "prompts": "Describe this image in one sentence:", + "images": [images[0], images[1]], + } + ) + print(a) + + +def test_transformers_vision_wrong_input_type(model): + with pytest.raises(NotImplementedError): + model.generate("invalid input", None) + + +def test_transformers_inference_kwargs(model, images): + result = model( + {"prompts": "Describe this image in one sentence:", "images": images[0]}, + max_new_tokens=100, + ) + assert isinstance(result, str) + + +def test_transformers_invalid_inference_kwargs(model, images): + with pytest.raises(ValueError): + model( + { + "prompts": "Describe this image in one sentence:", + "images": images[0], + }, + foo="bar", + ) + + +def test_transformers_several_images(model, images): + result = model( + { + "prompts": "Describe this image in one sentence:", + "images": [images[0], images[1]], + } + ) + assert isinstance(result, str) + + +def test_transformers_vision_json(model, images): + class Foo(BaseModel): + name: str + + result = model( + {"prompts": "Give a name to this animal.", "images": images[0]}, + Json(Foo), + ) + assert "name" in result + + +def test_transformers_vision_regex(model, images): + result = model( + {"prompts": "How old is it?", "images": images[0]}, Regex(r"[0-9]") + ) + + assert isinstance(result, str) + assert re.match(r"[0-9]", result) + + +def test_transformers_vision_choice(model, images): + class Foo(Enum): + cat = "cat" + dog = "dog" + + result = model( + {"prompts": "Is it a cat or a dog?", "images": images[0]}, Choice(Foo) + ) + + assert isinstance(result, str) + assert result in ["cat", "dog"] + + +def test_transformers_vision_batch_samples(model, images): + result = model( + {"prompts": "Describe this image in one sentence.", "images": images[0]}, + num_return_sequences=2, + num_beams=2, + ) + assert isinstance(result, list) + assert len(result) == 2 + result = model( + { + "prompts": [ + "Describe this image in one sentence.", + "Describe this image in one sentence.", + ], + "images": [images[0], images[1]], + } + ) + assert isinstance(result, list) + assert len(result) == 2 + result = model( + { + "prompts": [ + "Describe this image in one sentence.", + "Describe this image in one sentence.", + ], + "images": [images[0], images[1]], + }, + num_return_sequences=2, + num_beams=2, + ) + assert isinstance(result, list) + assert len(result) == 2 + for item in result: + assert isinstance(item, list) + assert len(item) == 2 + + +def test_transformers_vision_batch_samples_constrained(model, images): + class Foo(Enum): + cat = "cat" + dog = "dog" + + result = model( + {"prompts": "Describe this image in one sentence.", "images": images[0]}, + Choice(Foo), + num_return_sequences=2, + num_beams=2, + ) + assert isinstance(result, list) + assert len(result) == 2 + for item in result: + assert item in ["cat", "dog"] + result = model( + { + "prompts": [ + "Describe this image in one sentence.", + "Describe this image in one sentence.", + ], + "images": [images[0], images[1]], + }, + Choice(Foo), + ) + assert isinstance(result, list) + assert len(result) == 2 + for item in result: + assert item in ["cat", "dog"] + result = model( + { + "prompts": [ + "Describe this image in one sentence.", + "Describe this image in one sentence.", + ], + "images": [images[0], images[1]], + }, + Choice(Foo), + num_return_sequences=2, + num_beams=2, + ) + assert isinstance(result, list) + assert len(result) == 2 + for item in result: + assert isinstance(item, list) + assert len(item) == 2 + assert item[0] in ["cat", "dog"] + assert item[1] in ["cat", "dog"]