Skip to content

Commit

Permalink
Misc type fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 23, 2025
1 parent dbff27e commit 1af6d3f
Show file tree
Hide file tree
Showing 14 changed files with 44 additions and 16 deletions.
2 changes: 1 addition & 1 deletion outlines/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_cache():
"""
from outlines._version import __version__ as outlines_version # type: ignore

outlines_cache_dir = os.environ.get('OUTLINES_CACHE_DIR')
outlines_cache_dir = os.environ.get("OUTLINES_CACHE_DIR")
xdg_cache_home = os.environ.get("XDG_CACHE_HOME")
home_dir = os.path.normpath(os.path.expanduser("~"))
if outlines_cache_dir:
Expand Down
9 changes: 6 additions & 3 deletions outlines/fsm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@

from outlines.types import Regex, boolean as boolean_regex, date as date_regex
from outlines.types import datetime as datetime_regex
from outlines.types import integer as integer_regex, number as number_regex, time as time_regex
from outlines.types import (
integer as integer_regex,
number as number_regex,
time as time_regex,
)


class FormatFunction(Protocol):
def __call__(self, sequence: str) -> Any:
...
def __call__(self, sequence: str) -> Any: ...


def python_types_to_regex(python_type: Type) -> Tuple[Regex, FormatFunction]:
Expand Down
1 change: 1 addition & 0 deletions outlines/models/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Integration with Anthropic's API."""

from functools import singledispatchmethod
from typing import Union

Expand Down
1 change: 1 addition & 0 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Integration with OpenAI's API."""

from functools import singledispatchmethod
from types import NoneType
from typing import Optional, Union
Expand Down
1 change: 1 addition & 0 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import math
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union

Expand Down
3 changes: 1 addition & 2 deletions outlines/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def __call__(
next_token_logits: "torch.DoubleTensor",
sequence_weights: "torch.DoubleTensor",
rng: "torch.Generator",
) -> "torch.DoubleTensor":
...
) -> "torch.DoubleTensor": ...


@dataclass(frozen=True)
Expand Down
3 changes: 2 additions & 1 deletion outlines/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ async def generate(request: Request) -> Response:
logits_processors = []

sampling_params = SamplingParams(
**request_dict, logits_processors=logits_processors # type: ignore
**request_dict,
logits_processors=logits_processors, # type: ignore
)
request_id = random_uuid()

Expand Down
11 changes: 10 additions & 1 deletion outlines/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@
from typing_extensions import _TypedDictMeta # type: ignore

from . import airports, countries, locale
from outlines.types.dsl import Regex, json_schema, one_or_more, optional, regex, repeat, zero_or_more, times
from outlines.types.dsl import (
Regex,
json_schema,
one_or_more,
optional,
regex,
repeat,
zero_or_more,
times,
)


# Python types
Expand Down
1 change: 1 addition & 0 deletions outlines/types/airports.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generate valid airport codes."""

from enum import Enum

import airportsdata
Expand Down
1 change: 1 addition & 0 deletions outlines/types/countries.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generate valid country codes and names."""

from enum import Enum

from iso3166 import countries
Expand Down
1 change: 0 additions & 1 deletion outlines/types/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def __get_pydantic_core_schema__(
def __get_pydantic_json_schema__(
self, core_schema: cs.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:

return {"type": "string", "pattern": to_regex(self)}

def validate(self, value: str) -> str:
Expand Down
6 changes: 2 additions & 4 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def model_mamba(tmp_path_factory):
def model_bart(tmp_path_factory):
from transformers import AutoModelForSeq2SeqLM

return models.Transformers(
"facebook/bart-base", model_class=AutoModelForSeq2SeqLM
)
return models.Transformers("facebook/bart-base", model_class=AutoModelForSeq2SeqLM)


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -281,7 +279,7 @@ def test_generate_json(request, model_fixture, sample_schema):


# TODO: add support for genson in the Regex type of v1.0
#def test_integrate_genson_generate_json(request):
# def test_integrate_genson_generate_json(request):
# from genson import SchemaBuilder
#
# builder = SchemaBuilder()
Expand Down
3 changes: 1 addition & 2 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def test_transformers_instantiate_mamba():

def test_transformers_instantiate_tokenizer_kwargs():
model = Transformers(
TEST_MODEL,
tokenizer_kwargs={"additional_special_tokens": ["<t1>", "<t2>"]}
TEST_MODEL, tokenizer_kwargs={"additional_special_tokens": ["<t1>", "<t2>"]}
)
assert "<t1>" in model.tokenizer.special_tokens
assert "<t2>" in model.tokenizer.special_tokens
Expand Down
17 changes: 16 additions & 1 deletion tests/types/test_to_regex.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
import pytest


from outlines.types.dsl import String, Regex, JsonSchema, KleeneStar, KleenePlus, QuantifyBetween, QuantifyExact, QuantifyMaximum, QuantifyMinimum, Sequence, Alternatives, Optional, Term, to_regex
from outlines.types.dsl import (
String,
Regex,
JsonSchema,
KleeneStar,
KleenePlus,
QuantifyBetween,
QuantifyExact,
QuantifyMaximum,
QuantifyMinimum,
Sequence,
Alternatives,
Optional,
Term,
to_regex,
)


def test_to_regex_simple():
Expand Down

0 comments on commit 1af6d3f

Please sign in to comment.