Skip to content

Commit

Permalink
Add support for Transformers audio models
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin Picard committed Feb 26, 2025
1 parent 9cd81ed commit 49c8669
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 189 deletions.
2 changes: 1 addition & 1 deletion docs/reference/models/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ model = models.from_transformers(
)
```

The model name must be a valid `transformers` model name. You can find a list of all in the HuggingFace library [here](https://huggingface.co/models). We currently support `CausalLM`, `Seq2Seq`, `Mamba` and vision models.
The model name must be a valid `transformers` model name. You can find a list of all in the HuggingFace library [here](https://huggingface.co/models). We currently support `CausalLM`, `Seq2Seq`, `Mamba`, vision and audio models.

Be cautious with model selection though, some models such as `t5-base` don't include certain characters (`{`) and you may get an error when trying to perform structured generation.

Expand Down
196 changes: 196 additions & 0 deletions docs/reference/models/transformers_multimodal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Transformers Multimodal

Outlines allows seamless use of [multimodal models](https://huggingface.co/learn/multimodal-models-course/en/unit1/introduction).

`outlines.models.TransformersMultimodal` shares interfaces with, and is based on [outlines.models.Transformers](./transformers.md).

## Create a `TransformersMultimodal` model

`TransformersMultimodal` models inherit from `Transformers` and accept similar initialization parameters, with the exception that you should provide a processor instead of a tokenizer.

You can use `outlines.from_transformers` to load a `transformers` model and processor:

```python
from transformers import CLIPModel, CLIPProcessor
from outlines import models

model_name = "openai/clip-vit-base-patch32"
model = models.from_transformers(
CLIPModel.from_pretrained(model_name),
CLIPProcessor.from_pretrained(model_name)
)
```

You must provider a processor adapted to your model. We currently only support vision and audio models.

## Use the model to generate text from prompts and images/audios

When calling the model, the prompt argument you provide must be a dictionary using the following format:
```
{
"text": Union[str, List[str]],
"images": Optional[Union[Any, List[Any]]],
"audios": Optional[Union[Any, List[Any]]],
}
```

The `text` key is required. Either the `images` or `audios` key must be provided.
The format and values of those keys must correspond to what you would provide to the `__call__` method of the processor if you were to call it directly. Thus, the text prompt must include the appropriate tags to indicate where the images or audios should be inserted (e.g. `<image>` or `<|AUDIO|>`).

### Vision models

For easier use, we recommend you to create a convenience function to load a `PIL.Image` from URL.
```python
from PIL import Image
from io import BytesIO
from urllib.request import urlopen

def img_from_url(url):
img_byte_stream = BytesIO(urlopen(url).read())
return Image.open(img_byte_stream).convert("RGB")
```

You can then call the model with your prompts and images to generate text.
```python
from transformers import LlavaForConditionalGeneration, LlavaProcessor
from outlines import models

MODEL_NAME = "trl-internal-testing/tiny-LlavaForConditionalGeneration"

model = models.from_transformers(
LlavaForConditionalGeneration.from_pretrained(MODEL_NAME),
LlavaProcessor.from_pretrained(MODEL_NAME)
)
prompt = {
"text": "<image> detailed description:",
"images": img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg")
}
model(prompt)
```

You can include several images per prompt by adding more `<image>` tags to the prompt. Batching is also supported.
```python
from transformers import LlavaForConditionalGeneration
from outlines import models

MODEL_NAME = "trl-internal-testing/tiny-LlavaForConditionalGeneration"

model = models.from_transformers(
LlavaForConditionalGeneration.from_pretrained(MODEL_NAME),
LlavaProcessor.from_pretrained(MODEL_NAME)
)
prompt = {
"text": ["<image><image>detailed description:", "<image><image>. What animals are present?"],
"images": [
img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"),
img_from_url("https://upload.wikimedia.org/wikipedia/commons/7/71/2010-kodiak-bear-1.jpg"),
img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"),
img_from_url("https://upload.wikimedia.org/wikipedia/commons/7/71/2010-kodiak-bear-1.jpg"),
]
}
model(prompt)
```

Here we have two prompts, each expecting two images. We correspondingly provide four images. This will generate two descriptions, one for each prompt.

### Audio models

For easier use, we recommend you to create a convenience function to turn a `.wav` file from a URL into an audio array.
```python
import soundfile
from io import BytesIO
from urllib.request import urlopen

def audio_from_url(url):
audio_data = BytesIO(urlopen(url).read())
audio_array, _ = soundfile.read(audio_data)
return audio_array
```

You can then call the model with your prompt and audio files to generate text.
```python
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
from outlines import models

MODEL_NAME = "Qwen/Qwen2-Audio-7B-Instruct"
URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav"

tf_model = Qwen2AudioForConditionalGeneration.from_pretrained(MODEL_NAME)
tf_processor = AutoProcessor.from_pretrained(MODEL_NAME)

model = models.TransformersMultiModal(tf_model, tf_processor)

text = """<|im_start|>user
Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>
<|im_end|>
<|im_start|>user
Write a sentence about what you hear in the audio.<|im_end|>
<|im_start|>assistant
"""
prompt = {
"text": text,
"audios": audio_from_url(URL)
}
model(prompt)
```

You can include multiple audio files by providing a list for the `audios` argument and adding more of the appropriate tags to the prompt. Batching is also supported.

## Use the model for structured generation

You can use the model to generate structured data by providing a value for the parameter `output_type` (the second positional argument of the `generate` method, right after the prompt).

Supported types include `Json`, `Choice`, `Regex` and `CFG`.

For instance to do classification, you can use the `Regex` type:
```python
from outlines import models
from outlines.types import Regex
from transformers import LlavaForConditionalGeneration, LlavaProcessor

MODEL_NAME = "trl-internal-testing/tiny-LlavaForConditionalGeneration"

model = models.from_transformers(
LlavaForConditionalGeneration.from_pretrained(MODEL_NAME),
LlavaProcessor.from_pretrained(MODEL_NAME)
)
pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto"
prompt = {
"text": "<image>detailed description:",
"images": img_from_url("https://upload.wikimedia.org/wikipedia/commons/e/e3/Saturn_from_Cassini_Orbiter_%282004-10-06%29.jpg"),
}
model(prompt, Regex(pattern))
```

Another example could be to generated a structured description of an image using the `Json` type:
```python
from outlines import models
from pydantic import BaseModel
from transformers import LlavaForConditionalGeneration, LlavaProcessor
from typing import List, Optional

MODEL_NAME = "trl-internal-testing/tiny-LlavaForConditionalGeneration"

class ImageData(BaseModel):
caption: str
tags_list: List[str]
object_list: List[str]
is_photo: bool

model = models.from_transformers(
LlavaForConditionalGeneration.from_pretrained(MODEL_NAME),
LlavaProcessor.from_pretrained(MODEL_NAME)
)
pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto"
prompt = {
"text": "<image>detailed description:",
"images": img_from_url("https://upload.wikimedia.org/wikipedia/commons/e/e3/Saturn_from_Cassini_Orbiter_%282004-10-06%29.jpg"),
}
model(prompt, Json(ImageData))
```

## Resources

### Choosing a model
- https://mmbench.opencompass.org.cn/leaderboard
- https://huggingface.co/spaces/WildVision/vision-arena
117 changes: 0 additions & 117 deletions docs/reference/models/transformers_vision.md

This file was deleted.

4 changes: 2 additions & 2 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import LlamaCpp, OpenAI, TransformersVision
from outlines.models import LlamaCpp, OpenAI, TransformersMultiModal
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -33,7 +33,7 @@ def cfg(
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@cfg.register(TransformersVision)
@cfg.register(TransformersMultiModal)
def cfg_vision(model, cfg_str: str, sampler: Sampler = multinomial()):
from outlines.processors import CFGLogitsProcessor

Expand Down
4 changes: 2 additions & 2 deletions outlines/generate/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import TransformersVision
from outlines.models import TransformersMultiModal
from outlines.samplers import Sampler, multinomial


Expand All @@ -22,7 +22,7 @@ def fsm(
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@fsm.register(TransformersVision)
@fsm.register(TransformersMultiModal)
def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()):
from outlines.processors import GuideLogitsProcessor

Expand Down
4 changes: 2 additions & 2 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import OpenAI, TransformersVision
from outlines.models import OpenAI, TransformersMultiModal
from outlines.samplers import Sampler, multinomial
from outlines.types import Regex

Expand Down Expand Up @@ -39,7 +39,7 @@ def regex(model, regex_str: str | Regex, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(TransformersVision)
@regex.register(TransformersMultiModal)
def regex_vision(
model,
regex_str: str | Regex,
Expand Down
4 changes: 2 additions & 2 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import OpenAI, TransformersVision
from outlines.models import OpenAI, TransformersMultiModal
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -34,7 +34,7 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGeneratorAdapter:
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(TransformersVision)
@text.register(TransformersMultiModal)
def text_vision(model, sampler: Sampler = multinomial()):
return VisionSequenceGeneratorAdapter(model, None, sampler)

Expand Down
Loading

0 comments on commit 49c8669

Please sign in to comment.