Skip to content

Commit

Permalink
Create the CFG type to be used as an output type
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin Picard committed Feb 13, 2025
1 parent b352c62 commit f2cd4b7
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 11 deletions.
32 changes: 21 additions & 11 deletions outlines/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union, cast, get_args
from typing import Any, Optional, Union, cast, get_args

from outlines.models import APIModel, LocalModel
from outlines.types import Choice, Json, List, Regex
from outlines.models import APIModel, LlamaCpp, LocalModel
from outlines.processors import CFGLogitsProcessor, RegexLogitsProcessor
from outlines.types import CFG, Choice, Json, List, Regex

from .api import SequenceGenerator
from .cfg import cfg
Expand All @@ -13,9 +14,6 @@
from .regex import regex
from .text import text

if TYPE_CHECKING:
from outlines.processors import RegexLogitsProcessor


@dataclass
class APIGenerator:
Expand All @@ -33,6 +31,12 @@ class APIGenerator:
model: APIModel
output_type: Optional[Union[Json, List, Choice, Regex]] = None

def __post_init__(self):
if isinstance(self.output_type, CFG):
raise NotImplementedError(
"CFG generation is not supported for API-based models"
)

def __call__(self, prompt, **inference_kwargs):
return self.model.generate(prompt, self.output_type, **inference_kwargs)

Expand Down Expand Up @@ -60,18 +64,24 @@ def __post_init__(self):
if self.output_type is None:
self.logits_processor = None
else:
regex_string = self.output_type.to_regex()
self.logits_processor = RegexLogitsProcessor(
regex_string, self.model.tokenizer
)
if isinstance(self.output_type, CFG):
cfg_string = self.output_type.definition
self.logits_processor = CFGLogitsProcessor(
cfg_string, self.model.tokenizer
)
else:
regex_string = self.output_type.to_regex()
self.logits_processor = RegexLogitsProcessor(
regex_string, self.model.tokenizer
)

def __call__(self, prompt, **inference_kwargs):
return self.model.generate(prompt, self.logits_processor, **inference_kwargs)


def Generator(
model: Union[LocalModel, APIModel],
output_type: Optional[Union[Json, List, Choice, Regex]] = None,
output_type: Optional[Union[Json, List, Choice, Regex, CFG]] = None,
):
if isinstance(model, APIModel): # type: ignore
return APIGenerator(model, output_type) # type: ignore
Expand Down
11 changes: 11 additions & 0 deletions outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Dict, Iterator, List, Set, Tuple, Union

from outlines.models.tokenizer import Tokenizer
from outlines.processors import CFGLogitsProcessor

if TYPE_CHECKING:
from llama_cpp import Llama
Expand Down Expand Up @@ -159,6 +160,11 @@ def generate(self, prompt: str, logits_processor, **inference_kwargs) -> str:
"The `llama-cpp-python` library does not support batch inference."
)

if isinstance(logits_processor, CFGLogitsProcessor):
raise NotImplementedError(
"CFG generation is not supported for LlamaCpp due to bug in the llama_cpp tokenizer"
)

completion = self.model(
prompt,
logits_processor=LogitsProcessorList([logits_processor]),
Expand Down Expand Up @@ -197,6 +203,11 @@ def stream(
"The `llama-cpp-python` library does not support batch inference."
)

if isinstance(logits_processor, CFGLogitsProcessor):
raise NotImplementedError(
"CFG generation is not supported for LlamaCpp due to bug in the llama_cpp tokenizer"
)

generator = self.model(
prompt,
logits_processor=LogitsProcessorList([logits_processor]),
Expand Down
7 changes: 7 additions & 0 deletions outlines/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,10 @@ class Regex:

def to_regex(self):
return self.definition


@dataclass
class CFG:
"""Represents a Context-Free Grammar as a string."""

definition: str
6 changes: 6 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,9 @@ class Model(BaseModel):
assert isinstance(format_fn(1), str)
does_match = re.match(regex_str, test_string) is not None
assert does_match is should_match


def test_type_cfg():
cfg_str = "?start: expression"
cfg_type = types.CFG(cfg_str)
assert cfg_type.definition == cfg_str

0 comments on commit f2cd4b7

Please sign in to comment.