Skip to content

Commit

Permalink
make recursion work and tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
tscholak committed Jan 2, 2024
1 parent a468fa2 commit b6bcc4b
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 61 deletions.
82 changes: 68 additions & 14 deletions outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,15 @@ def to_regex(resolver: None | Resolver, instance: Schema) -> str:
The instance to translate
"""

definitions = DEFINITIONS
class Path(str):
pass

class Regex(str):
pass

definitions: dict[str, Path | Regex] = {
name: Regex(regex) for name, regex in DEFINITIONS.items()
}

def go(instance: Schema) -> str:
if isinstance(instance, bool):
Expand Down Expand Up @@ -209,7 +217,7 @@ def go(instance: Schema) -> str:
name = re.escape(path.replace("/", "_").replace("#", "").replace("$", "_"))
assert resolver is not None, "Cannot resolve references without a resolver"
if name not in definitions:
definitions[name] = go(resolver.lookup(path).contents)
definitions[name] = Path(path)
return f"(?&{name})"

# The type keyword may either be a string or an array:
Expand Down Expand Up @@ -247,16 +255,8 @@ def go(instance: Schema) -> str:
return type_to_regex["integer"]

elif instance_type == "array":
min_items = instance.get("minItems", "0")
max_items = instance.get("maxItems", "")
if min_items == max_items:
num_repeats = "{" + str(int(min_items) - 1) + "}"
else:
num_repeats = "*"

if "items" in instance:
items_regex = go(instance["items"])
return rf"\[({items_regex})(,({items_regex})){num_repeats}\]"
else:
# Here we need to make the choice to exclude generating list of objects
# if the specification of the object is not given, even though a JSON
Expand All @@ -268,8 +268,52 @@ def go(instance: Schema) -> str:
{"type": "integer"},
{"type": "string"},
]
regexes = [go(t) for t in types]
return rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)})){num_repeats}\]"
items_regex = rf"({'|'.join(go(t) for t in types)})"

min_items = instance.get("minItems")
min_items = int(min_items) if min_items is not None else 0
max_items = instance.get("maxItems")
max_items = int(max_items) if max_items is not None else None

if min_items == 0 and max_items is None:
middle = rf"({items_regex}(,{items_regex})*)?"

elif min_items > 0 and max_items is None:
middle = (
rf"{items_regex}(,{items_regex})"
+ r"{"
+ rf"{min_items-1},"
+ r"}"
)

elif min_items == 0 and max_items is not None:
if max_items == 0:
middle = r""
else:
middle = (
rf"({items_regex}(,{items_regex})"
+ r"{"
+ rf"0,{max_items-1}"
+ r"})?"
)

elif min_items > 0 and max_items is not None:
if max_items >= min_items:
middle = (
rf"{items_regex}(,{items_regex})"
+ r"{"
+ rf"{min_items-1},{max_items-1}"
+ r"}"
)
else:
raise ValueError(
"max_items must be greater than or equal to min_items"
)

else:
raise ValueError("min_items must be greater than or equal to 0")

return rf"\[{middle}\]"

elif instance_type == "boolean":
return type_to_regex["boolean"]
Expand All @@ -290,12 +334,22 @@ def go(instance: Schema) -> str:
it is, please open an issue on the Outlines repository"""
)

_regex = go(instance)
definitions["__self__"] = Regex(go(instance))

while any(isinstance(v, Path) for v in definitions.values()):
for name, value in definitions.items():
if isinstance(value, Path):
assert (
resolver is not None
), "Cannot resolve references without a resolver"
definitions[name] = Regex(go(resolver.lookup(value).contents))

regex = r"(?:"
for name, value in definitions.items():
assert isinstance(value, Regex)
regex += rf"(?P<{name}>{value})"
regex += r"){0}"
regex += _regex
regex += r"(?&__self__)"

return regex

Expand Down
Loading

0 comments on commit b6bcc4b

Please sign in to comment.