From 8be46280cafb3e461ec54e3852c4ab5feb9109ad Mon Sep 17 00:00:00 2001 From: Robin Picard Date: Mon, 24 Feb 2025 08:06:39 +0100 Subject: [PATCH] Rename Local/API into Steerable/BlackBox for models/generators --- outlines/generate/__init__.py | 22 +++++++++++----------- outlines/models/__init__.py | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/outlines/generate/__init__.py b/outlines/generate/__init__.py index 0d4f3c8e3..10bf78553 100644 --- a/outlines/generate/__init__.py +++ b/outlines/generate/__init__.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Any, Optional, Union, cast, get_args -from outlines.models import APIModel, LlamaCpp, LocalModel +from outlines.models import BlackBoxModel, SteerableModel from outlines.processors import CFGLogitsProcessor, RegexLogitsProcessor from outlines.types import CFG, Choice, JsonType, List, Regex @@ -16,8 +16,8 @@ @dataclass -class APIGenerator: - """Represents an API-based generator. +class BlackBoxGenerator: + """Represents a generator for which we don't control constrained generation. Attributes ---------- @@ -28,7 +28,7 @@ class APIGenerator: """ - model: APIModel + model: BlackBoxModel output_type: Optional[Union[JsonType, List, Choice, Regex]] = None def __post_init__(self): @@ -42,8 +42,8 @@ def __call__(self, prompt, **inference_kwargs): @dataclass -class LocalGenerator: - """Represents a local model-based generator. +class SteerableGenerator: + """Represents a generator for which we control constrained generation. We use this class to keep track of the logits processor which can be quite expensive to build. @@ -57,7 +57,7 @@ class LocalGenerator: """ - model: LocalModel + model: SteerableModel output_type: Optional[Union[JsonType, List, Choice, Regex]] def __post_init__(self): @@ -80,10 +80,10 @@ def __call__(self, prompt, **inference_kwargs): def Generator( - model: Union[LocalModel, APIModel], + model: Union[SteerableModel, BlackBoxModel], output_type: Optional[Union[JsonType, List, Choice, Regex, CFG]] = None, ): - if isinstance(model, APIModel): # type: ignore - return APIGenerator(model, output_type) # type: ignore + if isinstance(model, BlackBoxModel): # type: ignore + return BlackBoxGenerator(model, output_type) # type: ignore else: - return LocalGenerator(model, output_type) # type: ignore + return SteerableGenerator(model, output_type) # type: ignore diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index ee50356b0..6a9ba753f 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -29,5 +29,5 @@ Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM, Ollama ] -LocalModel = Union[LlamaCpp, Transformers, MLXLM, VLLM] -APIModel = Union[OpenAI, Anthropic, Gemini, Ollama, Dottxt] +SteerableModel = Union[LlamaCpp, Transformers, MLXLM, VLLM] +BlackBoxModel = Union[OpenAI, Anthropic, Gemini, Ollama, Dottxt]