From 37808feaad717148e5b2a02f94b29eed93e82f28 Mon Sep 17 00:00:00 2001 From: Robin Picard Date: Sun, 23 Feb 2025 21:48:08 +0100 Subject: [PATCH] Rename Local/API into Steerable/Enclosed for models/generators --- outlines/generate/__init__.py | 22 +++++++++++----------- outlines/models/__init__.py | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/outlines/generate/__init__.py b/outlines/generate/__init__.py index 2fb99a030..4616c744f 100644 --- a/outlines/generate/__init__.py +++ b/outlines/generate/__init__.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Any, Optional, Union, cast, get_args -from outlines.models import APIModel, LlamaCpp, LocalModel +from outlines.models import EnclosedModel, SteerableModel from outlines.processors import CFGLogitsProcessor, RegexLogitsProcessor from outlines.types import CFG, Choice, Json, List, Regex @@ -16,8 +16,8 @@ @dataclass -class APIGenerator: - """Represents an API-based generator. +class EnclosedGenerator: + """Represents a generator for which we don't control constrained generation. Attributes ---------- @@ -28,7 +28,7 @@ class APIGenerator: """ - model: APIModel + model: EnclosedModel output_type: Optional[Union[Json, List, Choice, Regex]] = None def __post_init__(self): @@ -42,8 +42,8 @@ def __call__(self, prompt, **inference_kwargs): @dataclass -class LocalGenerator: - """Represents a local model-based generator. +class SteerableGenerator: + """Represents a generator for which we control constrained generation. We use this class to keep track of the logits processor which can be quite expensive to build. @@ -57,7 +57,7 @@ class LocalGenerator: """ - model: LocalModel + model: SteerableModel output_type: Optional[Union[Json, List, Choice, Regex]] def __post_init__(self): @@ -80,10 +80,10 @@ def __call__(self, prompt, **inference_kwargs): def Generator( - model: Union[LocalModel, APIModel], + model: Union[SteerableModel, EnclosedModel], output_type: Optional[Union[Json, List, Choice, Regex, CFG]] = None, ): - if isinstance(model, APIModel): # type: ignore - return APIGenerator(model, output_type) # type: ignore + if isinstance(model, EnclosedModel): # type: ignore + return EnclosedGenerator(model, output_type) # type: ignore else: - return LocalGenerator(model, output_type) # type: ignore + return SteerableGenerator(model, output_type) # type: ignore diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index 4254e7cf6..e31c6bddb 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -24,5 +24,5 @@ Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM, Ollama ] -LocalModel = Union[LlamaCpp, Transformers] -APIModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini, Ollama] +SteerableModel = Union[LlamaCpp, Transformers] +EnclosedModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini, Ollama]