diff --git a/outlines/caching.py b/outlines/caching.py index 0831c40bb..6882bac6b 100644 --- a/outlines/caching.py +++ b/outlines/caching.py @@ -51,7 +51,7 @@ def get_cache(): """ from outlines._version import __version__ as outlines_version # type: ignore - outlines_cache_dir = os.environ.get('OUTLINES_CACHE_DIR') + outlines_cache_dir = os.environ.get("OUTLINES_CACHE_DIR") xdg_cache_home = os.environ.get("XDG_CACHE_HOME") home_dir = os.path.normpath(os.path.expanduser("~")) if outlines_cache_dir: diff --git a/outlines/fsm/types.py b/outlines/fsm/types.py index e5a1f8f47..f6409aa66 100644 --- a/outlines/fsm/types.py +++ b/outlines/fsm/types.py @@ -4,12 +4,15 @@ from outlines.types import Regex, boolean as boolean_regex, date as date_regex from outlines.types import datetime as datetime_regex -from outlines.types import integer as integer_regex, number as number_regex, time as time_regex +from outlines.types import ( + integer as integer_regex, + number as number_regex, + time as time_regex, +) class FormatFunction(Protocol): - def __call__(self, sequence: str) -> Any: - ... + def __call__(self, sequence: str) -> Any: ... def python_types_to_regex(python_type: Type) -> Tuple[Regex, FormatFunction]: diff --git a/outlines/models/anthropic.py b/outlines/models/anthropic.py index 830c0cf77..83ea5f229 100644 --- a/outlines/models/anthropic.py +++ b/outlines/models/anthropic.py @@ -1,4 +1,5 @@ """Integration with Anthropic's API.""" + from functools import singledispatchmethod from typing import Union diff --git a/outlines/models/openai.py b/outlines/models/openai.py index f082dbc4c..378bb830d 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -1,4 +1,5 @@ """Integration with OpenAI's API.""" + from functools import singledispatchmethod from types import NoneType from typing import Optional, Union diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index 583fcc98f..8ce69c32a 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -23,6 +23,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import math from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union diff --git a/outlines/samplers.py b/outlines/samplers.py index 3ab1728fc..3fef673b1 100644 --- a/outlines/samplers.py +++ b/outlines/samplers.py @@ -14,8 +14,7 @@ def __call__( next_token_logits: "torch.DoubleTensor", sequence_weights: "torch.DoubleTensor", rng: "torch.Generator", - ) -> "torch.DoubleTensor": - ... + ) -> "torch.DoubleTensor": ... @dataclass(frozen=True) diff --git a/outlines/serve/serve.py b/outlines/serve/serve.py index 998fbc459..61d3ed7af 100644 --- a/outlines/serve/serve.py +++ b/outlines/serve/serve.py @@ -78,7 +78,8 @@ async def generate(request: Request) -> Response: logits_processors = [] sampling_params = SamplingParams( - **request_dict, logits_processors=logits_processors # type: ignore + **request_dict, + logits_processors=logits_processors, # type: ignore ) request_id = random_uuid() diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index fd69f2193..0788d9dca 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -10,7 +10,16 @@ from typing_extensions import _TypedDictMeta # type: ignore from . import airports, countries, locale -from outlines.types.dsl import Regex, json_schema, one_or_more, optional, regex, repeat, zero_or_more, times +from outlines.types.dsl import ( + Regex, + json_schema, + one_or_more, + optional, + regex, + repeat, + zero_or_more, + times, +) # Python types diff --git a/outlines/types/airports.py b/outlines/types/airports.py index ec0ef72bd..6e3d011b4 100644 --- a/outlines/types/airports.py +++ b/outlines/types/airports.py @@ -1,4 +1,5 @@ """Generate valid airport codes.""" + from enum import Enum import airportsdata diff --git a/outlines/types/countries.py b/outlines/types/countries.py index c612640b4..96be735d3 100644 --- a/outlines/types/countries.py +++ b/outlines/types/countries.py @@ -1,4 +1,5 @@ """Generate valid country codes and names.""" + from enum import Enum from iso3166 import countries diff --git a/outlines/types/dsl.py b/outlines/types/dsl.py index 645b0486a..a00f4a639 100644 --- a/outlines/types/dsl.py +++ b/outlines/types/dsl.py @@ -69,7 +69,6 @@ def __get_pydantic_core_schema__( def __get_pydantic_json_schema__( self, core_schema: cs.CoreSchema, handler: GetJsonSchemaHandler ) -> JsonSchemaValue: - return {"type": "string", "pattern": to_regex(self)} def validate(self, value: str) -> str: diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index aaa4509bf..64dee6e4b 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -71,9 +71,7 @@ def model_mamba(tmp_path_factory): def model_bart(tmp_path_factory): from transformers import AutoModelForSeq2SeqLM - return models.Transformers( - "facebook/bart-base", model_class=AutoModelForSeq2SeqLM - ) + return models.Transformers("facebook/bart-base", model_class=AutoModelForSeq2SeqLM) @pytest.fixture(scope="session") @@ -281,7 +279,7 @@ def test_generate_json(request, model_fixture, sample_schema): # TODO: add support for genson in the Regex type of v1.0 -#def test_integrate_genson_generate_json(request): +# def test_integrate_genson_generate_json(request): # from genson import SchemaBuilder # # builder = SchemaBuilder() diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 249893ff9..b9f0888f5 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -40,8 +40,7 @@ def test_transformers_instantiate_mamba(): def test_transformers_instantiate_tokenizer_kwargs(): model = Transformers( - TEST_MODEL, - tokenizer_kwargs={"additional_special_tokens": ["", ""]} + TEST_MODEL, tokenizer_kwargs={"additional_special_tokens": ["", ""]} ) assert "" in model.tokenizer.special_tokens assert "" in model.tokenizer.special_tokens diff --git a/tests/types/test_to_regex.py b/tests/types/test_to_regex.py index 6cb566fc5..4b0403ac6 100644 --- a/tests/types/test_to_regex.py +++ b/tests/types/test_to_regex.py @@ -1,7 +1,22 @@ import pytest -from outlines.types.dsl import String, Regex, JsonSchema, KleeneStar, KleenePlus, QuantifyBetween, QuantifyExact, QuantifyMaximum, QuantifyMinimum, Sequence, Alternatives, Optional, Term, to_regex +from outlines.types.dsl import ( + String, + Regex, + JsonSchema, + KleeneStar, + KleenePlus, + QuantifyBetween, + QuantifyExact, + QuantifyMaximum, + QuantifyMinimum, + Sequence, + Alternatives, + Optional, + Term, + to_regex, +) def test_to_regex_simple():