Skip to content

Commit

Permalink
Refactor the Transformers and TransformersVision model (dottxt-ai#1430)
Browse files Browse the repository at this point in the history
The objective of this commit is to adapt `Transfomers` and `TransformersVision` to the
model interface used in v1.0. The main changes are:
- make the models inherit from the base class `Model` and create a
`ModelTypeAdapter` for them
- modify the signature of the `generate` method and make it the only way of
generating text with the model.
- modify the handling of model/tokenizer optional arguments to avoid manipulating
them ourselves (giving to the user the responsibility of choosing those they need).
- delete the tests for those models in the `generate` directory and test them instead
in the `models` directory
  • Loading branch information
RobinPicard authored and rlouf committed Feb 24, 2025
1 parent d72b649 commit b0937e2
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 18 deletions.
4 changes: 3 additions & 1 deletion docs/reference/models/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

You can use `outlines.from_transformers` to load a `transformers` model and tokenizer:

You can also provide keyword arguments in an optional `model_kwargs` parameter. Those will be passed to the `from_pretrained` method of the model class. One such argument is `device_map`, which allows you to specify the device on which the model will be loaded.

For instance:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from outlines import models

model_name = "microsoft/Phi-3-mini-4k-instruct"
Expand Down
2 changes: 1 addition & 1 deletion outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def __init__(self, model: "PreTrainedModel", processor):
self.processor.padding_side = "left"
self.processor.pad_token = "[PAD]"

tokenizer: "PreTrainedTokenizer" = self.processor.tokenizer
tokenizer = self.processor.tokenizer

super().__init__(model, tokenizer)

Expand Down
2 changes: 1 addition & 1 deletion outlines/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from jsonschema import Draft202012Validator as Validator
from jsonschema.exceptions import SchemaError
from outlines_core.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel, TypeAdapter
from typing_extensions import _TypedDictMeta # type: ignore
from outlines_core.fsm.json_schema import build_regex_from_schema

from . import airports, countries, locale
from outlines.types.dsl import (
Expand Down
30 changes: 15 additions & 15 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,21 +302,21 @@ def test_generate_json(request, model_fixture, sample_schema):
generator(**get_inputs(model_fixture), max_tokens=100)


# TODO: add support for genson in the Regex type of v1.0
# def test_integrate_genson_generate_json(request):
# from genson import SchemaBuilder
#
# builder = SchemaBuilder()
# builder.add_schema({"type": "object", "properties": {}})
# builder.add_object({"name": "Toto", "age": 5})
#
# model = request.getfixturevalue("model_transformers_opt125m")
#
# generator = generate.json(model, builder)
# res = generator("Return a json of a young boy")
#
# assert "name" in res
# assert "age" in res
@pytest.mark.xfail(reason="Genson has not been added to JsonType.")
def test_integrate_genson_generate_json(request):
from genson import SchemaBuilder

builder = SchemaBuilder()
builder.add_schema({"type": "object", "properties": {}})
builder.add_object({"name": "Toto", "age": 5})

model = request.getfixturevalue("model_transformers_opt125m")

generator = generate.json(model, builder)
res = generator("Return a json of a young boy")

assert "name" in res
assert "age" in res


@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
Expand Down
11 changes: 11 additions & 0 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def test_transformers_instantiate_mamba():
assert isinstance(model, Transformers)


def test_transformers_instantiate_tokenizer_kwargs():
model = outlines.from_transformers(
transformers.AutoModelForCausalLM.from_pretrained(TEST_MODEL),
transformers.AutoTokenizer.from_pretrained(
TEST_MODEL, additional_special_tokens=["<t1>", "<t2>"]
),
)
assert "<t1>" in model.tokenizer.special_tokens
assert "<t2>" in model.tokenizer.special_tokens


@pytest.fixture
def model():
model = outlines.from_transformers(
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_transformers_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from outlines.models.transformers import TransformersVision
from outlines.types import Choice, JsonType, Regex


TEST_MODEL = "trl-internal-testing/tiny-LlavaForConditionalGeneration"
TEST_CLIP_MODEL = "openai/clip-vit-base-patch32"
IMAGE_URLS = [
Expand Down

0 comments on commit b0937e2

Please sign in to comment.