Skip to content

Commit

Permalink
Rename Local/API into Steerable/BlackBox for models/generators
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin Picard authored and rlouf committed Feb 24, 2025
1 parent b0937e2 commit 6d2378d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
22 changes: 11 additions & 11 deletions outlines/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Any, Optional, Union, cast, get_args

from outlines.models import APIModel, LlamaCpp, LocalModel
from outlines.models import BlackBoxModel, SteerableModel
from outlines.processors import CFGLogitsProcessor, RegexLogitsProcessor
from outlines.types import CFG, Choice, JsonType, List, Regex

Expand All @@ -16,8 +16,8 @@


@dataclass
class APIGenerator:
"""Represents an API-based generator.
class BlackBoxGenerator:
"""Represents a generator for which we don't control constrained generation.
Attributes
----------
Expand All @@ -28,7 +28,7 @@ class APIGenerator:
"""

model: APIModel
model: BlackBoxModel
output_type: Optional[Union[JsonType, List, Choice, Regex]] = None

def __post_init__(self):
Expand All @@ -42,8 +42,8 @@ def __call__(self, prompt, **inference_kwargs):


@dataclass
class LocalGenerator:
"""Represents a local model-based generator.
class SteerableGenerator:
"""Represents a generator for which we control constrained generation.
We use this class to keep track of the logits processor which can be quite
expensive to build.
Expand All @@ -57,7 +57,7 @@ class LocalGenerator:
"""

model: LocalModel
model: SteerableModel
output_type: Optional[Union[JsonType, List, Choice, Regex]]

def __post_init__(self):
Expand All @@ -80,10 +80,10 @@ def __call__(self, prompt, **inference_kwargs):


def Generator(
model: Union[LocalModel, APIModel],
model: Union[SteerableModel, BlackBoxModel],
output_type: Optional[Union[JsonType, List, Choice, Regex, CFG]] = None,
):
if isinstance(model, APIModel): # type: ignore
return APIGenerator(model, output_type) # type: ignore
if isinstance(model, BlackBoxModel): # type: ignore
return BlackBoxGenerator(model, output_type) # type: ignore
else:
return LocalGenerator(model, output_type) # type: ignore
return SteerableGenerator(model, output_type) # type: ignore
6 changes: 3 additions & 3 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from .vllm import VLLM, from_vllm

LogitsGenerator = Union[
Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM, Ollama
Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM
]

LocalModel = Union[LlamaCpp, Transformers, MLXLM, VLLM]
APIModel = Union[OpenAI, Anthropic, Gemini, Ollama, Dottxt]
SteerableModel = Union[LlamaCpp, Transformers, VLLM, MLXLM]
BlackBoxModel = Union[OpenAI, Anthropic, Gemini, Ollama, Dottxt]

0 comments on commit 6d2378d

Please sign in to comment.