Skip to content

Commit

Permalink
Address few rounds of GitHub PR reviews ;)
Browse files Browse the repository at this point in the history
  • Loading branch information
yvan-sraka committed Jan 15, 2025
1 parent 6e35c07 commit 4454ab3
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 13 deletions.
13 changes: 9 additions & 4 deletions outlines/function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib.util
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union

Expand All @@ -10,9 +11,6 @@
from outlines.generate.api import SequenceGenerator
from outlines.prompts import Prompt

# Raising a warning here caused all the tests to fail…
print("The 'function' module is deprecated and will be removed in a future release.")


@dataclass
class Function:
Expand All @@ -25,7 +23,8 @@ class Function:
Note:
This class is part of the deprecated 'function' module and will be removed
in a future release.
in a future release (1.0.0).
Please pin your version to <1.0.0 if you need to continue using it.
"""

Expand All @@ -34,6 +33,12 @@ class Function:
model_name: str
generator: Optional["SequenceGenerator"] = None

def __post_init__(self):
warnings.warn(
"The 'function' module is deprecated and will be removed in a future release (1.0.0).",
DeprecationWarning,
)

@classmethod
def from_github(cls, program_path: str, function_name: str = "fn"):
"""Load a function stored on GitHub"""
Expand Down
17 changes: 12 additions & 5 deletions outlines/outline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from dataclasses import dataclass

import jsonschema
from pydantic import BaseModel

from outlines import generate
Expand Down Expand Up @@ -41,21 +42,27 @@ def template(a: int) -> str:
"""

def __init__(self, model, template, output_type):
if not (isinstance(output_type, str) or issubclass(output_type, BaseModel)):
if isinstance(output_type, str):
try:
jsonschema.Draft7Validator.check_schema(json.loads(output_type))
except jsonschema.exceptions.SchemaError as e:
raise TypeError(f"Invalid JSON Schema: {e.message}")
elif not issubclass(output_type, BaseModel):
raise TypeError(
"output_type must be a Pydantic model or a JSON Schema string"
"output_type must be a Pydantic model or a valid JSON Schema string"
)

self.template = template
self.output_type = output_type
self.generator = generate.json(model, output_type)

def __call__(self, *args):
prompt = self.template(*args)
def __call__(self, *args, **kwargs):
prompt = self.template(*args, **kwargs)
response = self.generator(prompt)
try:
if isinstance(self.output_type, str):
return json.loads(response)
return self.output_type.model_validate_json(response)
return self.output_type.parse_raw(response)
except (ValueError, SyntaxError):
# If `outlines.generate.json` works as intended, this error should never be raised.
raise ValueError(f"Unable to parse response: {response.strip()}")
31 changes: 27 additions & 4 deletions tests/test_outline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import Mock, patch

import pytest
from pydantic import BaseModel

from outlines import Outline
Expand All @@ -17,25 +18,47 @@ def test_outline():
mock_model = Mock()
mock_generator = Mock()
mock_generator.return_value = '{"result": 6}'

with patch("outlines.generate.json", return_value=mock_generator):
outline_instance = Outline(mock_model, template, OutputModel)
assert issubclass(outline_instance.output_type, BaseModel)
result = outline_instance(3)

assert result.result == 6


def test_outline_with_json_schema():
mock_model = Mock()
mock_generator = Mock()
mock_generator.return_value = '{"result": 6}'

with patch("outlines.generate.json", return_value=mock_generator):
outline_instance = Outline(
mock_model,
template,
'{"type": "object", "properties": {"result": {"type": "integer"}}}',
)
result = outline_instance(3)

assert result["result"] == 6


def test_invalid_output_type():
mock_model = Mock()
with pytest.raises(TypeError):
Outline(mock_model, template, int)


def test_invalid_json_response():
mock_model = Mock()
mock_generator = Mock()
mock_generator.return_value = "invalid json"
with patch("outlines.generate.json", return_value=mock_generator):
outline_instance = Outline(mock_model, template, OutputModel)
with pytest.raises(ValueError, match="Unable to parse response"):
outline_instance(3)


def test_invalid_json_schema():
mock_model = Mock()
invalid_json_schema = (
'{"type": "object", "properties": {"result": {"type": "invalid_type"}}}'
)
with pytest.raises(TypeError, match="Invalid JSON Schema"):
Outline(mock_model, template, invalid_json_schema)

0 comments on commit 4454ab3

Please sign in to comment.