From 6d82c910fd8d83bf806aea293e4c610d9ad9e5dd Mon Sep 17 00:00:00 2001 From: Cameron Pfiffer Date: Fri, 7 Feb 2025 16:28:56 -0800 Subject: [PATCH] Add LogitTrackingProcessor --- docs/reference/processors.md | 247 +++++++++ examples/logit_tracking_demo.py | 313 ++++++++++++ mkdocs.yml | 1 + outlines/caching.py | 2 +- outlines/fsm/types.py | 9 +- outlines/models/openai.py | 9 +- outlines/processors/__init__.py | 12 + outlines/processors/base_logits_processor.py | 2 +- outlines/processors/structured.py | 1 + outlines/processors/tracking.py | 509 +++++++++++++++++++ outlines/samplers.py | 3 +- outlines/serve/serve.py | 3 +- outlines/types/__init__.py | 11 +- outlines/types/airports.py | 1 + outlines/types/countries.py | 1 + outlines/types/dsl.py | 1 - pyproject.toml | 1 + tests/processors/test_tracking.py | 337 ++++++++++++ tests/types/test_to_regex.py | 17 +- 19 files changed, 1468 insertions(+), 12 deletions(-) create mode 100644 docs/reference/processors.md create mode 100644 examples/logit_tracking_demo.py create mode 100644 outlines/processors/tracking.py create mode 100644 tests/processors/test_tracking.py diff --git a/docs/reference/processors.md b/docs/reference/processors.md new file mode 100644 index 000000000..6476e5112 --- /dev/null +++ b/docs/reference/processors.md @@ -0,0 +1,247 @@ +# Logit processors + +Logit processors modify token probabilities during text generation to enforce constraints or analyze the generation process. While processors can be used directly, most users will interact with them through the high-level generation APIs (see [Generating JSON](generation/json.md), [Regex Generation](generation/regex.md), and [CFG Generation](generation/cfg.md)). + +Users can track the token probabilities and logits at each step of the generation process using the `LogitTrackingProcessor`. This is useful for debugging and understanding the generation process. + +## Available Processors + +Outlines provides several specialized processors for different use cases: + +- `JSONLogitsProcessor`: Ensures generation follows a JSON schema +- `RegexLogitsProcessor`: Constrains generation to match a regex pattern +- `CFGLogitsProcessor`: Enforces a context-free grammar +- `LogitTrackingProcessor`: Tracks token probabilities and logits + +### RegexLogitsProcessor + +The `RegexLogitsProcessor` constrains generation to match a regular expression pattern: + +```python +from outlines.processors import RegexLogitsProcessor + +# Create a processor that only allows 4-digit numbers +processor = RegexLogitsProcessor(r"[0-9]{4}", tokenizer) + +# Use with a generator +generator = outlines.generate.regex(model, r"[0-9]{4}") +generator.logits_processor = processor +``` + +See [Regex Generation](generation/regex.md) for more details and examples. + +### JSONLogitsProcessor + +The `JSONLogitsProcessor` ensures generation follows a JSON schema defined using Pydantic: + +```python +from pydantic import BaseModel +from outlines.processors import JSONLogitsProcessor + +class Response(BaseModel): + name: str + age: int + city: str + +# Create processor from schema +processor = JSONLogitsProcessor(Response, tokenizer) + +# Use with a generator +generator = outlines.generate.json(model, Response) +generator.logits_processor = processor +``` + +See [Generating JSON](generation/json.md) for more details and examples. + +### CFGLogitsProcessor + +The `CFGLogitsProcessor` constrains generation to follow a context-free grammar: + +```python +from outlines.processors import CFGLogitsProcessor + +# Define a simple grammar +grammar = """ +start: NUMBER "+" NUMBER "=" NUMBER +NUMBER: /[0-9]+/ +""" + +# Create processor from grammar +processor = CFGLogitsProcessor(grammar, tokenizer) + +# Use with a generator +generator = outlines.generate.cfg(model, grammar) +generator.logits_processor = processor +``` + +See [CFG Generation](generation/cfg.md) for more details and examples. + +## Tracking logit scores and token probabilities + +The `LogitTrackingProcessor` wraps any processor to track logit scores and token probabilities before and after processing. This is useful for: + +- Debugging logit processors by analyzing how they modify token probabilities +- Visualizing the effects of logit biasing on token distributions +- Understanding how constraints affect the generation process +- Validating that processors are working as intended + +### Adding tracking to a generator + +The simplest way to add tracking is using the convenience function `track_logits`: + +```python +from outlines import generate, models +from outlines.processors import track_logits +from pydantic import BaseModel + +# Define your schema +class Person(BaseModel): + name: str + age: int + +# Create generator with tracking +model = models.transformers("HuggingFaceTB/SmolLM2-135M-Instruct") +generator = generate.json(model, Person) +generator = track_logits(generator) # Enable tracking + +# Apply templating if needed +prompt = model.tokenizer.tokenizer.apply_chat_template( + [{"role": "system", "content": "You are a helpful assistant, responding in JSON."}, + {"role": "user", "content": "Make me a person with a name and age. Return the JSON only."}], + tokenize=False, + add_bos=True, + add_generation_prompt=True, +) + +# Generate the response +response = generator(prompt) +``` + +**NOTE**: You __must__ use `generator.logits_processor.clear()` between generations, otherwise the processor will use the logits from the previous generation. You may also construct a new generator and call `track_logits` again to start tracking from scratch. + +### Analyzing generation results + +Once tracking is enabled, you can analyze the generation process in several ways: + +1. Get the logits and probabilities at each position as a matrix: + +```python +# Raw logits as a dictionary with two keys: unstructured and structured +logits = generator.logits_processor.get_logits() + +# Get a vocab_size x n_positions matrix of logits for +# structured and unstructured logits +unstructured_logits = logits['unstructured'] +structured_logits = logits['structured'] + +probabilities = generator.logits_processor.get_probabilities() + +# Get a vocab_size x n_positions matrix of probabilities +# for structured and unstructured logits +unstructured_probs = probabilities['unstructured'] +structured_probs = probabilities['structured'] +``` + +2. Get the top tokens at each position: + +```python +# Get top 5 tokens at each position +top_k = generator.logits_processor.get_top_tokens(k=5) + +# Analyze each position +for position_dict in top_k: + print(f"\nPosition {position_dict['position']}:") + print(f"Text so far: {position_dict['text_so_far']}") + + for token in position_dict['tokens']: + print(f"\nToken: {token['token']}") + print(f"Unstructured probability: {token['unstructured_prob']:.3f}") + print(f"Structured probability: {token['structured_prob']:.3f}") + print(f"Unstructured logit: {token['unstructured_logit']:.3f}") + print(f"Structured logit: {token['structured_logit']:.3f}") + print(f"Was chosen: {token['is_chosen']}") +``` + +3. Convert to a pandas DataFrame for analysis: + +```python +import pandas as pd + +# Get all tokens with probability > 1% +df = generator.logits_processor.get_probabilities_dataframe(min_value=0.01) +print(df) +# position token natural constrained chosen +# 0 0 You 0.021324 0.0 False +# 1 0 The 0.021959 0.0 False +# 2 0 Sure 0.025492 0.0 False +# 3 0 JSON 0.031045 0.0 False +# 4 0 To 0.031047 0.0 False +``` + +4. Get the generated sequence up to a position: + +```python +# Get text generated up to position 5 +text = generator.logits_processor.sequence_up_to(5) +``` + +### Memory management + +The tracking processor stores logits in memory for analysis, and offloads logits to main memory if you use a GPU. For long sequences, you have several options: + +1. Clear tracking data when no longer needed: +```python +generator.logits_processor.clear() +``` + +2. Filter data when analyzing: +```python +# Only analyze specific positions +results = generator.logits_processor.get_top_tokens(positions=[0, 1, 2]) + +# Only look at high probability tokens +df = generator.logits_processor.get_probabilities_dataframe(min_value=0.01) +``` + +### Important notes about logit tracking + +- Tracking logits is a slow operation, so do not use it in production environments +- The processor will accumulate logits if you call `generator(prompt)` multiple times, meaning that the tokens stored can be aggregated across generations. You can use `generator.logits_processor.clear()` to reset the processor, or construct a new generator and call `track_logits` again to start tracking from scratch. +- Processed logits will contain `-inf` values when structured outputs are used +- Token decoding requires the wrapped processor to have a tokenizer attribute +- Memory usage grows linearly with sequence length +- The tracking processor only supports single-batch processing +- Tracking logits can incur significant overhead -- do not use it in production environments + +## Using the tracking processor directly + +The tracking processor can be used directly with transformers pipelines: + +```python +import outlines.models as models +import transformers +from outlines.processors import RegexLogitsProcessor +from outlines.processors.tracking import LogitTracker + +model_uri = "HuggingFaceTB/SmolLM2-135M-Instruct" +model = models.transformers(model_uri) + +outlines_tokenizer = models.TransformerTokenizer( + transformers.AutoTokenizer.from_pretrained(model_uri) +) +phone_number_logits_processor = LogitTracker(RegexLogitsProcessor( + "\\+?[1-9][0-9]{7,14}", # phone number pattern + outlines_tokenizer, +)) + +generator = transformers.pipeline('text-generation', model=model_uri) + +# Perform inference +output = generator( + "Jenny gave me her number it's ", + logits_processor=transformers.LogitsProcessorList([phone_number_logits_processor]) +) + +# Retrieve the logits +phone_number_logits_processor.get_logits() +``` diff --git a/examples/logit_tracking_demo.py b/examples/logit_tracking_demo.py new file mode 100644 index 000000000..8fde00ffc --- /dev/null +++ b/examples/logit_tracking_demo.py @@ -0,0 +1,313 @@ +""" +Demo script showing how to use the LogitTrackingProcessor to analyze token probabilities. + +This script demonstrates: +1. How language models naturally choose tokens +2. How structural constraints (like JSON or regex) affect these choices +3. Visualization of probability distributions +4. Analysis of token selection patterns +""" +from typing import Literal, Optional +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from pydantic import BaseModel, Field + +import outlines.models as models +import outlines.generate as generate +from outlines.processors.tracking import track_logits, LogitTracker +from outlines.processors import RegexLogitsProcessor +import transformers + +def plot_token_distributions(tracking_processor, k=10, positions=None, prefix=""): + """Plot token probability distributions before and after applying constraints. + + Creates a horizontal bar chart showing: + - Blue bars: What tokens the model would naturally choose + - Orange bars: What tokens are allowed by structural constraints + + Parameters + ---------- + tracking_processor : LogitTrackingProcessor + The processor containing tracked probabilities + k : int, optional + Number of top tokens to show in each plot, by default 10 + positions : List[int], optional + Which positions to plot. If None, plots all positions. + prefix : str, optional + Prefix for the output filename + + Notes + ----- + - Bar height indicates probability (how likely the model thinks each token is) + - Tokens are sorted by maximum probability across both distributions + - Only probabilities > 1% show their exact values + - Grid lines help compare probabilities between tokens + """ + # Get probability matrices and vocab mapping + probs = tracking_processor.get_probabilities() + vocab = tracking_processor.get_vocab_mapping() + + # Determine positions to plot + if positions is None: + positions = list(range(probs['unstructured'].shape[1])) + n_positions = len(positions) + + # Create plot + fig, axes = plt.subplots(1, n_positions, figsize=(7 * n_positions, 10)) + if n_positions == 1: + axes = [axes] + + for idx, pos in enumerate(positions): + # Get probabilities for this position + unstructured = probs['unstructured'][:, pos] + structured = probs['structured'][:, pos] + + # Get top k tokens by maximum probability + top_indices = np.argsort(np.maximum(unstructured, structured))[-k:] + + # Create bar positions + y = np.arange(len(top_indices)) + height = 0.35 + + # Plot bars + axes[idx].barh(y - height/2, unstructured[top_indices], height, + label='Natural Choice', alpha=0.7, color='skyblue') + axes[idx].barh(y + height/2, structured[top_indices], height, + label='After Constraints', alpha=0.7, color='orange') + + # Customize plot + axes[idx].set_title(f'Token {pos+1} in Sequence') + axes[idx].set_yticks(y) + axes[idx].set_yticklabels([vocab[i] for i in top_indices]) + axes[idx].set_xlabel('Probability') + + # Add legend + axes[idx].legend(loc='upper right', bbox_to_anchor=(1, 1.1)) + axes[idx].grid(True, alpha=0.3) + + # Add probability values + for i, (v1, v2) in enumerate(zip(unstructured[top_indices], structured[top_indices])): + if v1 > 0.01: # Only show probabilities > 1% + axes[idx].text(v1 + 0.01, i - height/2, f'{v1:.1%}', va='center') + if v2 > 0.01: + axes[idx].text(v2 + 0.01, i + height/2, f'{v2:.1%}', va='center') + + plt.tight_layout() + plt.savefig(f"{prefix}token_distributions.png", dpi=300, bbox_inches='tight') + plt.close() + + +def plot_heatmap(tracking_processor, k=50, positions=None, prefix="", show_both=True, kind="logits", show_tokens=True): + """Plot a heatmap of token probabilities across sequence positions. + + Creates a heatmap visualization showing how token probabilities evolve + across different positions in the sequence. Optionally shows both + natural and constrained probabilities side by side. + + Parameters + ---------- + tracking_processor : LogitTrackingProcessor + The processor containing tracked probabilities + k : int, optional + Number of top tokens to include in the heatmap, by default 50 + positions : List[int], optional + Which positions to plot. If None, plots all positions. + prefix : str, optional + Prefix for the output filename + show_both : bool, optional + If True, shows both natural and constrained probabilities side by side. + If False, only shows natural probabilities. + kind : str, optional + Whether to plot logits or probabilities, by default "logits" + show_tokens : bool, optional + Whether to show the token strings on the y-axis, by default True + + Notes + ----- + - Brighter colors indicate higher probabilities + - Y-axis shows token strings + - X-axis shows position in sequence + - Near-zero probabilities are masked out (shown in gray) + - For constrained generation, blocked tokens appear masked + """ + # Get probability matrices and vocab mapping + if kind == "logits": + things = tracking_processor.get_logits() + # For logits, mask out very negative values + threshold = -10 # Logits below this are effectively zero probability + else: + things = tracking_processor.get_probabilities() + # For probabilities, mask out near-zero values + threshold = 0.001 # Probabilities below 0.1% are masked + + vocab = tracking_processor.get_vocab_mapping() + + # Determine positions to plot + if positions is None: + positions = list(range(things['unstructured'].shape[1])) + + # Get indices of top k tokens (by maximum probability across all positions) + max_probs = np.maximum( + things['unstructured'].max(axis=1), + things['structured'].max(axis=1) + ) + top_indices = np.argsort(max_probs)[-k:] + + # Create masked arrays for better visualization + def mask_array(arr): + if kind == "logits": + return np.ma.masked_where(arr < threshold, arr) + else: + return np.ma.masked_where(arr < threshold, arr) + + unstructured_masked = mask_array(things['unstructured'][top_indices][:, positions]) + structured_masked = mask_array(things['structured'][top_indices][:, positions]) + + # Create figure + if show_both: + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 10)) + fig.suptitle(f'Token {kind.capitalize()} Evolution', fontsize=16, y=1.05) + else: + fig, ax1 = plt.subplots(1, 1, figsize=(8, 10)) + + # Plot natural probabilities with masked array + im1 = ax1.imshow( + unstructured_masked, + aspect='auto', + cmap='viridis', + ) + ax1.set_title(f'Natural Token {kind.capitalize()}') + ax1.set_xlabel('Position in Sequence') + ax1.set_ylabel('Token') + if show_tokens: + ax1.set_yticks(range(len(top_indices))) + ax1.set_yticklabels([vocab[i][0] for i in top_indices]) + plt.colorbar(im1, ax=ax1, label=f'{kind.capitalize()}') + + # Plot constrained probabilities if requested + if show_both: + im2 = ax2.imshow( + structured_masked, + aspect='auto', + cmap='viridis', + ) + ax2.set_title(f'Constrained Token {kind.capitalize()}') + ax2.set_xlabel('Position in Sequence') + ax2.set_yticks([]) # Hide y-ticks since they're the same as ax1 + plt.colorbar(im2, ax=ax2, label=f'{kind.capitalize()}') + + plt.tight_layout() + plt.savefig(f"{prefix}{kind}_heatmap.png", dpi=300, bbox_inches='tight') + plt.close() + + +# This function applies a simple chat template to the prompt +def template(tokenizer, prompt: str, system_prompt: str = "You are a helpful assistant, responding in JSON.") -> str: + return tokenizer.apply_chat_template( + [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}], + tokenize=False, + add_bos=True, + add_generation_prompt=True, + ) + +def display_token_analysis(results, show_logits=True): + """Helper to display token analysis results in a readable format.""" + for position_data in results: + position_data['position'] + text = position_data['text_so_far'] + print(text) + print("-" * 80) + + # Print header + header = f"{'Token':<20} {'Natural Prob':<15} {'Constrained Prob':<15}" + if show_logits: + header += f"{'Natural Logit':<15} {'Constrained Logit':<15}" + print(header) + print("-" * 80) + + # Print each token's info + for token_info in position_data['tokens']: + # Add arrow for chosen token + prefix = "→" if token_info['is_chosen'] else " " + + # Format probabilities as percentages (format first, then pad) + unstructured_prob = f"{token_info['unstructured_prob']:.1%}" + structured_prob = f"{token_info['structured_prob']:.1%}" + + # Build the line piece by piece + line = f"{prefix} {repr(token_info['token']):<20} {unstructured_prob:<15} {structured_prob:<15}" + + # Add logits if requested + if show_logits: + unstructured_logit = f"{token_info['unstructured_logit']:.2f}" + structured_logit = f"{token_info['structured_logit']:.2f}" + line += f"{unstructured_logit:<15} {structured_logit:<15}" + + print(line) + +def analyze_json_generation(model): + """Analyze generation with JSON structure constraints.""" + print("\n=== Analyzing JSON-Structured Generation ===") + + # Define the required JSON structure + class Person(BaseModel): + name: str + age: int + zip_code: str = Field(pattern=r"^\d{5}$") + state: str = Field(pattern=r"^[A-Z]{2}$") + + # Create generator with tracking + generator = generate.json(model, Person) + generator = track_logits(generator) + + # Generate JSON + prompt = template(model.tokenizer.tokenizer, "Make me a person with a name, age, zip code, and state. Return the JSON only.") + print(f"\nPrompt: {prompt}") + result = generator(prompt) + print(f"Generated JSON: {result}") + + # Show how constraints affect token choices + print("\nAnalyzing token choices with JSON constraints:") + print("1. Token generation analysis (showing probabilities and logits):") + results = generator.logits_processor.get_top_tokens(k=5, positions=[0, 1, 2, 3, 4]) + display_token_analysis(results, show_logits=True) + + # Convert to dataframe + df = generator.logits_processor.get_probabilities_dataframe(min_value=0.01) + + # Retrieve only the tokens that were chosen + chosen = df[df.chosen] + print(chosen) + + # Show sequence at different points + print("\n2. Generation sequence at different points:") + for pos in [5, 10, 15, 20]: + print(f"\nFirst {pos} tokens: {repr(generator.logits_processor.sequence_up_to(pos))}") + + # Visualize how JSON structure affects probabilities + print("\nCreating visualizations:") + print("1. Bar plot comparing natural vs constrained probabilities") + plot_token_distributions(generator.logits_processor, k=30, positions=[0, 1, 2], prefix="structured_") + + print("2. Heatmap showing probability evolution with/without constraints") + plot_heatmap( + generator.logits_processor, + k=10000, + kind="logits", + prefix="structured_", + show_both=True, + show_tokens=False + ) + +def main(): + print("Loading model and tokenizer...") + + model_uri = "HuggingFaceTB/SmolLM2-135M-Instruct" + model = models.transformers(model_uri) + + # Run examples + analyze_json_generation(model) + +if __name__ == "__main__": + main() diff --git a/mkdocs.yml b/mkdocs.yml index 2c31494da..a5a154dcc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -145,6 +145,7 @@ nav: - Grammar: reference/generation/cfg.md - Creating Grammars: reference/generation/creating_grammars.md - Custom FSM operations: reference/generation/custom_fsm_ops.md + - Processors: reference/processors.md - Utilities: - Serve with vLLM: reference/serve/vllm.md - Serve with LM Studio: reference/serve/lmstudio.md diff --git a/outlines/caching.py b/outlines/caching.py index 0831c40bb..6882bac6b 100644 --- a/outlines/caching.py +++ b/outlines/caching.py @@ -51,7 +51,7 @@ def get_cache(): """ from outlines._version import __version__ as outlines_version # type: ignore - outlines_cache_dir = os.environ.get('OUTLINES_CACHE_DIR') + outlines_cache_dir = os.environ.get("OUTLINES_CACHE_DIR") xdg_cache_home = os.environ.get("XDG_CACHE_HOME") home_dir = os.path.normpath(os.path.expanduser("~")) if outlines_cache_dir: diff --git a/outlines/fsm/types.py b/outlines/fsm/types.py index e5a1f8f47..f6409aa66 100644 --- a/outlines/fsm/types.py +++ b/outlines/fsm/types.py @@ -4,12 +4,15 @@ from outlines.types import Regex, boolean as boolean_regex, date as date_regex from outlines.types import datetime as datetime_regex -from outlines.types import integer as integer_regex, number as number_regex, time as time_regex +from outlines.types import ( + integer as integer_regex, + number as number_regex, + time as time_regex, +) class FormatFunction(Protocol): - def __call__(self, sequence: str) -> Any: - ... + def __call__(self, sequence: str) -> Any: ... def python_types_to_regex(python_type: Type) -> Tuple[Regex, FormatFunction]: diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 40ade1c25..e46107bd4 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -1,4 +1,5 @@ """Integration with OpenAI's API.""" + import copy import functools from dataclasses import asdict, dataclass, field, replace @@ -139,7 +140,13 @@ def __call__( if samples is None: samples = self.config.n - config = replace(self.config, max_tokens=max_tokens, temperature=temperature, n=samples, stop=stop_at) # type: ignore + config = replace( + self.config, + max_tokens=max_tokens, + temperature=temperature, + n=samples, + stop=stop_at, + ) # type: ignore response, prompt_tokens, completion_tokens = generate_chat( prompt, system_prompt, self.client, config diff --git a/outlines/processors/__init__.py b/outlines/processors/__init__.py index f0f0f829b..0957a2343 100644 --- a/outlines/processors/__init__.py +++ b/outlines/processors/__init__.py @@ -5,3 +5,15 @@ OutlinesLogitsProcessor, RegexLogitsProcessor, ) +from .tracking import LogitTracker, track_logits + +__all__ = [ + "CFGLogitsProcessor", + "GuideLogitsProcessor", + "JSONLogitsProcessor", + "OutlinesLogitsProcessor", + "RegexLogitsProcessor", + # Logit tracking + "LogitsTracker", + "track_logits", +] diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index d6fe346e0..9f183d8f4 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -9,7 +9,7 @@ import mlx.core as mx -Array = Union[NDArray, torch.Tensor, List, "mx.array"] +Array = Union[NDArray, torch.Tensor, "mx.array"] def is_mlx_array_type(array_type): diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index 583fcc98f..8ce69c32a 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -23,6 +23,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import math from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union diff --git a/outlines/processors/tracking.py b/outlines/processors/tracking.py new file mode 100644 index 000000000..b0a0f28fd --- /dev/null +++ b/outlines/processors/tracking.py @@ -0,0 +1,509 @@ +""" +A simple logit processor that tracks probabilities for both structured and unstructured generation. + +For each token generated, we store: +- The raw logits the model would assign naturally +- The filtered logits after applying structural constraints +- A mapping from vocabulary indices to token strings +""" + +from typing import TYPE_CHECKING, Optional, Union, List, Literal, Dict, Any + +import numpy as np +import torch + +from .base_logits_processor import OutlinesLogitsProcessor, Array +from ..models.tokenizer import Tokenizer + +if TYPE_CHECKING: + from outlines.generate import SequenceGenerator + + +class LogitTracker(OutlinesLogitsProcessor): + """Tracks logits for both structured and unstructured text generation. + + For each position in the sequence, this class stores: + + - `unstructured_logits`: Raw values of the logits from the model + - `structured_logits`: Values of the logits after applying constraints + - `vocab_tokens`: Mapping from vocab indices to token strings + - `chosen_tokens`: Track actual sampled token IDs during generation + + Each logit matrix has: + - Columns: One for each position in the generated sequence + - Rows: One for each token in the vocabulary + + Attributes + ---------- + processor : Optional[OutlinesLogitsProcessor] + The processor that applies structural constraints + unstructured_logits : List[Array] + Raw logits from the model for each position + structured_logits : List[Array] + Logits after applying constraints for each position + vocab_tokens : Optional[List[str]] + Mapping from vocabulary indices to token strings + chosen_tokens : List[int] + Track actual chosen token IDs during generation. This is used + to ensure a log of the tokens the model generates, and is + used internally for various convenience functions. + _started : bool + Tracks whether to start appending chosen tokens. This is set to True + on the first call to process_logits, and remains True thereafter. + tokenizer : Optional[Tokenizer] + The tokenizer used for decoding tokens + """ + + def __init__(self, processor: Optional[OutlinesLogitsProcessor]): + """Initialize the tracking processor. + + Parameters + ---------- + processor : Optional[OutlinesLogitsProcessor] + The processor that applies structural constraints. + """ + self.processor = processor + self.unstructured_logits: List[ + Array + ] = [] # List of logit arrays, one per position + self.structured_logits: List[ + Array + ] = [] # List of logit arrays, one per position + self.vocab_tokens: Optional[List[str]] = ( + None # Will store the vocabulary mapping + ) + self.chosen_tokens: List[ + int + ] = [] # Track actual chosen tokens during generation + self._started: bool = False # Tracks whether to start appending chosen tokens + self.tokenizer: Optional[Tokenizer] = None + + # If the processor has a tokenizer, use it + if processor is not None and hasattr(processor, "tokenizer"): + self.tokenizer = processor.tokenizer + + def process_logits( + self, input_ids: List[List[int]], logits: torch.Tensor + ) -> torch.Tensor: + """Process logits and store them. + + This method: + 1. Stores the raw logits from the model + 2. Applies any structural constraints if a processor exists + 3. Stores the constrained logits + 4. Tracks the chosen token ID + + Parameters + ---------- + input_ids : List[List[int]] + The input token ids for the sequence. Must be single batch. + logits : torch.Tensor + The original logits to process, shape (1, vocab_size) + + Returns + ------- + torch.Tensor + The processed logits, shape (1, vocab_size) + + Raises + ------ + ValueError + If batch size > 1 is provided. The tracking processor currently + only supports single-batch processing. + """ + # Enforce single batch processing + if logits.shape[0] > 1: + raise ValueError( + "LogitTrackingProcessor only supports single-batch processing. " + f"Got batch size {logits.shape[0]}" + ) + if len(input_ids) > 1: + raise ValueError( + "LogitTrackingProcessor only supports single-batch processing. " + f"Got {len(input_ids)} sequences" + ) + + # Always store the raw logits as unstructured + self.unstructured_logits.append(logits[0].detach().cpu().numpy().copy()) + + # Store the actual chosen token ID if available + if self._started and len(input_ids[0]) > 1: + # Get the last token from the current sequence + self.chosen_tokens.append(input_ids[0][-1]) + + # If we haven't started tracking yet, do so now. + # this will only happen on the first call to process_logits. + else: + self._started = True + + # Apply structural constraints if we have a processor + if self.processor is not None: + processed = self.processor.process_logits(input_ids, logits) + self.structured_logits.append(processed[0].detach().cpu().numpy().copy()) + return processed + + # For unconstrained generation, structured = unstructured + self.structured_logits.append(logits[0].detach().cpu().numpy().copy()) + return logits + + def get_probabilities(self) -> Dict[str, Array]: + """Get probability distributions computed from stored logits. + + Returns + ------- + Dict[str, Union[List[Array], Array]] + Contains a dictionary with two keys: + - unstructured: Raw probability distributions + - structured: Probability distributions after constraints + Each can be either a list of arrays or a single matrix + """ + # Convert logits to probabilities + unstructured_probs = [ + torch.softmax(torch.tensor(logits), dim=-1).numpy() + for logits in self.unstructured_logits + ] + structured_probs = [ + torch.softmax(torch.tensor(logits), dim=-1).numpy() + for logits in self.structured_logits + ] + + # Stack arrays into matrices + unstructured = np.column_stack(unstructured_probs) + structured = np.column_stack(structured_probs) + + return {"unstructured": unstructured, "structured": structured} + + def get_logits(self) -> Dict[str, Array]: + """Get the stored logit values. + + Returns + ------- + Dict[str, Array] + Contains a dictionary with two keys: + + - unstructured: Raw logit values + - structured: Logit values after constraints + + Each matrix will have shape (vocab_size, n_positions), i.e. + return_value['unstructured'] is a vocab_size x n_positions matrix. + """ + unstructured = np.column_stack(self.unstructured_logits) + structured = np.column_stack(self.structured_logits) + + return {"unstructured": unstructured, "structured": structured} + + def get_top_tokens( + self, + k: int = 10, + positions: Optional[Union[int, List[int]]] = None, + include_logits: bool = True, + ) -> List[Dict[str, Any]]: + """Get the top k tokens at specified positions with their probabilities and logits. + + Parameters + ---------- + k : int, optional + Number of top tokens to return, by default 10 + positions : Union[int, List[int]], optional + Position(s) to analyze. Can be a single position or list of positions. + By default analyzes all positions. + include_logits : bool, optional + Whether to include raw logit values in addition to probabilities + + Returns + ------- + List[Dict[str, Any]] + List of dictionaries, one per position, containing: + - position: Position in sequence + - text_so_far: Text generated up to this position + - tokens: List of top k token dictionaries, each containing: + - token: The token string + - natural_prob: Unconstrained probability + - constrained_prob: Probability after constraints + - natural_logit: Raw logit value (if include_logits=True) + - constrained_logit: Constrained logit value (if include_logits=True) + - is_chosen: Whether this token was actually chosen + """ + # Convert single position to list + if positions is None: + positions = list(range(len(self.structured_logits))) + elif isinstance(positions, int): + positions = [positions] + + # Get probabilities and logits + probs = self.get_probabilities() + logits = self.get_logits() if include_logits else None + + # Get vocab mapping + vocab = self.get_vocab_mapping() + + results = [] + for pos in positions: + if pos >= len(self.unstructured_logits): + continue + + # Get text generated so far + text_so_far = self.sequence_up_to(pos) + + # Get values for this position + u_probs = probs["unstructured"][:, pos] + s_probs = probs["structured"][:, pos] + + if logits is not None: + u_logits = logits["unstructured"][:, pos] + s_logits = logits["structured"][:, pos] + + # Get top k indices by maximum probability + max_probs = np.maximum(u_probs, s_probs) + # Ensure we get exactly k indices by setting k to min(k, vocab_size) + k_actual = min(k, len(max_probs)) + top_indices = np.argsort(max_probs)[-k_actual:][::-1] + + # Get the actual next token for comparison + next_token = ( + self.sequence_up_to(pos + 1)[len(text_so_far) :] + if pos < len(self.structured_logits) - 1 + else "" + ) + + # Build token info list + tokens = [] + for idx in top_indices: + token = vocab[idx] + token_info = { + "token": token, + "unstructured_prob": float(u_probs[idx]), + "structured_prob": float(s_probs[idx]), + "is_chosen": token == next_token, + } + + if include_logits: + token_info.update( + { + "unstructured_logit": float(u_logits[idx]), + "structured_logit": float(s_logits[idx]), + } + ) + + tokens.append(token_info) + + results.append( + {"position": pos, "text_so_far": text_so_far, "tokens": tokens} + ) + + return results + + def get_vocab_mapping(self) -> List[str]: + """Get the mapping from vocabulary indices to token strings. + + Returns + ------- + List[str] + List of token strings, where index matches vocabulary index + + Raises + ------ + AttributeError + If no tokenizer is available + """ + if not hasattr(self, "tokenizer") or self.tokenizer is None: + raise AttributeError("No tokenizer available for mapping tokens") + + if self.vocab_tokens is None and self.tokenizer is not None: + # Create the mapping if we haven't yet + vocab_size = len(self.unstructured_logits[0]) + self.vocab_tokens = [ + self.tokenizer.decode([i])[0] for i in range(vocab_size) + ] + + return self.vocab_tokens + + def clear(self): + """Clear all stored logits.""" + self.unstructured_logits = [] + self.structured_logits = [] + self.chosen_tokens = [] + + def get_probabilities_dataframe(self, min_value: Optional[float] = None): + values = self.get_probabilities() + return self._to_dataframe(values, min_value) + + def get_logits_dataframe(self, min_value: Optional[float] = None): + values = self.get_logits() + return self._to_dataframe(values, min_value) + + def _to_dataframe( + self, + values, + min_value: Optional[float] = None, + ): + """Convert tracking data to a pandas DataFrame for analysis. + + Parameters + ---------- + values: + Logits or probabilities values. + min_value : Optional[float], optional + If provided, only include tokens with values >= min_value + in either structured or unstructured distribution + + Returns + ------- + pd.DataFrame + DataFrame with columns: + - position: Token position in sequence + - token: String representation of token + - natural: Raw model values (probs/logits) + - constrained: Values after constraints + - chosen: Whether this token was chosen (True/False) + + Raises + ------ + ImportError + If pandas is not installed + """ + try: + import pandas as pd + except ImportError: + raise ImportError( + "The `pandas` library is required to convert values to a DataFrame." + " Install `pandas` with: pip install pandas" + ) + + # Get vocab mapping + vocab = self.get_vocab_mapping() + + # Create lists to store data + rows = [] + + # Process each position + for pos in range( + values["unstructured"].shape[1] + ): # Use shape[1] for number of positions + u_vals = values["unstructured"][:, pos] + s_vals = values["structured"][:, pos] + + # Get the chosen token at this position if available + chosen_token = ( + vocab[self.chosen_tokens[pos]] + if pos < len(self.chosen_tokens) + else None + ) + + # Get indices to include based on filters + if min_value is not None: + # Get maximum value between structured/unstructured for sorting + max_vals = np.maximum(u_vals, s_vals) + + # Both filters: get top k among values >= min_value + valid_indices = np.where(max_vals >= min_value)[0] + if len(valid_indices) > 0: # Only sort if we have valid indices + valid_indices = valid_indices[ + np.argsort(max_vals[valid_indices])[-10:] + ] + else: + # No filters: include all tokens + valid_indices = range(len(vocab)) + + # Add rows for valid indices + for idx in valid_indices: + token = vocab[idx] + rows.append( + { + "position": pos, + "token": token, + "natural": float( + u_vals[idx] + ), # Convert to float to avoid numpy type issues + "constrained": float(s_vals[idx]), + "chosen": token == chosen_token, + } + ) + + return pd.DataFrame(rows) + + def sequence_up_to(self, pos: Optional[int] = None) -> str: + """Get the sequence of tokens generated up to a position. + + Parameters + ---------- + pos : Optional[int], optional + Position to reconstruct up to (exclusive). + If None, returns the entire sequence. + + Returns + ------- + str + The concatenated string of chosen tokens + + Raises + ------ + AttributeError + If no tokenizer is available for decoding + """ + if not self.chosen_tokens: + return "" + + if not hasattr(self, "tokenizer"): + raise AttributeError("No tokenizer available for decoding sequence") + + # Get the tokenizer + if self.processor is not None and hasattr(self.processor, "tokenizer"): + tokenizer = self.processor.tokenizer + else: + tokenizer = self.tokenizer + + # Get tokens up to the specified position + end_pos = len(self.chosen_tokens) if pos is None else pos + tokens_to_decode = self.chosen_tokens[:end_pos] + + # Decode the sequence + return "".join(tokenizer.decode(tokens_to_decode)) + + +def track_logits(generator: "SequenceGenerator") -> "SequenceGenerator": + """Add probability tracking to any generator. + + This is a convenience function that wraps a generator's logits processor + with a LogitTrackingProcessor, enabling analysis of token probabilities + and logits during generation. + + Currently only works with structured generators, outlines.generate.text + is not supported. + + Parameters + ---------- + generator : SequenceGenerator + The generator to add tracking to + + Returns + ------- + SequenceGenerator + The same generator with tracking enabled + + Examples + -------- + >>> # Track probabilities for unconstrained text generation + >>> generator = generate.text(model) + >>> generator = track_logits(generator) + >>> + >>> # Track probabilities for JSON generation + >>> generator = generate.json(model, schema) + >>> generator = track_logits(generator) + """ + # If there's no logits_processor, throw an error. Logit tracking + # is currently only supported for structured generators. + if not hasattr(generator, "logits_processor"): + raise ValueError("Logit tracking is not supported for this generator") + + # Create tracking processor, wrapping any existing processor + tracking = LogitTracker(generator.logits_processor) + + # Add tokenizer for token mapping + if hasattr(generator.logits_processor, "tokenizer"): + tracking.tokenizer = generator.logits_processor.tokenizer + + # Set as the generator's processor + generator.logits_processor = tracking + + return generator diff --git a/outlines/samplers.py b/outlines/samplers.py index 3ab1728fc..3fef673b1 100644 --- a/outlines/samplers.py +++ b/outlines/samplers.py @@ -14,8 +14,7 @@ def __call__( next_token_logits: "torch.DoubleTensor", sequence_weights: "torch.DoubleTensor", rng: "torch.Generator", - ) -> "torch.DoubleTensor": - ... + ) -> "torch.DoubleTensor": ... @dataclass(frozen=True) diff --git a/outlines/serve/serve.py b/outlines/serve/serve.py index 998fbc459..61d3ed7af 100644 --- a/outlines/serve/serve.py +++ b/outlines/serve/serve.py @@ -78,7 +78,8 @@ async def generate(request: Request) -> Response: logits_processors = [] sampling_params = SamplingParams( - **request_dict, logits_processors=logits_processors # type: ignore + **request_dict, + logits_processors=logits_processors, # type: ignore ) request_id = random_uuid() diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index 9511720af..dbd8f507b 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -1,7 +1,16 @@ from enum import Enum from . import airports, countries, locale -from outlines.types.dsl import Regex, json_schema, one_or_more, optional, regex, repeat, zero_or_more, times +from outlines.types.dsl import ( + Regex, + json_schema, + one_or_more, + optional, + regex, + repeat, + zero_or_more, + times, +) # Python types integer = Regex(r"[+-]?(0|[1-9][0-9]*)") diff --git a/outlines/types/airports.py b/outlines/types/airports.py index ec0ef72bd..6e3d011b4 100644 --- a/outlines/types/airports.py +++ b/outlines/types/airports.py @@ -1,4 +1,5 @@ """Generate valid airport codes.""" + from enum import Enum import airportsdata diff --git a/outlines/types/countries.py b/outlines/types/countries.py index c612640b4..96be735d3 100644 --- a/outlines/types/countries.py +++ b/outlines/types/countries.py @@ -1,4 +1,5 @@ """Generate valid country codes and names.""" + from enum import Enum from iso3166 import countries diff --git a/outlines/types/dsl.py b/outlines/types/dsl.py index 86feabca0..94b4d26d9 100644 --- a/outlines/types/dsl.py +++ b/outlines/types/dsl.py @@ -69,7 +69,6 @@ def __get_pydantic_core_schema__( def __get_pydantic_json_schema__( self, core_schema: cs.CoreSchema, handler: GetJsonSchemaHandler ) -> JsonSchemaValue: - return {"type": "string", "pattern": to_regex(self)} def validate(self, value: str) -> str: diff --git a/pyproject.toml b/pyproject.toml index 969290b07..d62b5d664 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,7 @@ module = [ "jaxlib", "jax.numpy", "jinja2", + "pandas.*", "jsonschema.*", "openai.*", "mamba_ssm.*", diff --git a/tests/processors/test_tracking.py b/tests/processors/test_tracking.py new file mode 100644 index 000000000..3077509e3 --- /dev/null +++ b/tests/processors/test_tracking.py @@ -0,0 +1,337 @@ +import pytest +import torch +import numpy as np +import pandas as pd +from typing import List + +from outlines.processors.tracking import LogitTracker, track_logits +from outlines.processors.base_logits_processor import OutlinesLogitsProcessor + + +class MockProcessor(OutlinesLogitsProcessor): + """Mock processor that modifies logits in a predictable way.""" + + def __init__(self): + self.tokenizer = MockTokenizer() + + def process_logits( + self, input_ids: List[List[int]], logits: torch.Tensor + ) -> torch.Tensor: + # For testing purposes, set every other logit to -inf + processed = logits.clone() + processed[:, ::2] = float("-inf") + return processed + + +class MockTokenizer: + """Mock tokenizer for testing.""" + + def decode(self, token_ids): + if not token_ids: # Handle empty list case + return "" + # Concatenate all tokens + return "".join(f"token_{tid}" for tid in token_ids) + + +@pytest.fixture +def processor(): + """Fixture for creating a tracking processor with a mock base processor.""" + base = MockProcessor() + processor = LogitTracker(base) + processor.tokenizer = base.tokenizer # Ensure tokenizer is available + return processor + + +def test_initialization(): + """Test initialization with various parameters.""" + base = MockProcessor() + + # Basic initialization + processor = LogitTracker(base) + assert processor.processor == base + assert len(processor.unstructured_logits) == 0 + assert len(processor.structured_logits) == 0 + assert processor.vocab_tokens is None + assert len(processor.chosen_tokens) == 0 + assert not processor._started + + # Without processor + processor = LogitTracker(None) + assert processor.processor is None + + +@pytest.mark.parametrize("vocab_size", [10, 100]) +def test_logit_processing(processor, vocab_size): + """Test logit processing with different vocab sizes.""" + input_ids = [[0]] # Single batch + logits = torch.ones(1, vocab_size) # Single batch + + processed = processor.process_logits(input_ids, logits) + + # Check tracking + assert len(processor.unstructured_logits) == 1 + assert len(processor.structured_logits) == 1 + assert processor.unstructured_logits[0].shape == (vocab_size,) + assert processor.structured_logits[0].shape == (vocab_size,) + + # Check original logits preserved + assert torch.allclose(torch.tensor(processor.unstructured_logits[0]), logits[0]) + + # Check processing (every other logit should be -inf) + assert torch.all(torch.isinf(processed[:, ::2])) + assert not torch.any(torch.isinf(processed[:, 1::2])) + + # Check chosen token tracking + assert processor._started + assert len(processor.chosen_tokens) == 0 # First call doesn't add token + + +def test_batch_size_validation(processor): + """Test that multi-batch processing raises an error.""" + # Test with multiple sequences in input_ids + with pytest.raises(ValueError, match="only supports single-batch processing"): + processor.process_logits([[0], [0]], torch.ones(2, 10)) + + # Test with multiple batches in logits + with pytest.raises(ValueError, match="only supports single-batch processing"): + processor.process_logits([[0]], torch.ones(2, 10)) + + +def test_chosen_token_tracking(processor): + """Test tracking of chosen tokens during generation.""" + # First token + processor.process_logits([[0]], torch.ones(1, 10)) + assert len(processor.chosen_tokens) == 0 + assert processor._started + + # Second token - should track the previous choice + processor.process_logits([[0, 1]], torch.ones(1, 10)) + assert len(processor.chosen_tokens) == 1 + assert processor.chosen_tokens[0] == 1 + + # Third token + processor.process_logits([[0, 1, 2]], torch.ones(1, 10)) + assert len(processor.chosen_tokens) == 2 + assert processor.chosen_tokens[1] == 2 + + +def test_get_probabilities(processor): + """Test probability distribution computation.""" + # Process a few positions + for i in range(3): + # Create logits that will result in valid probability distributions + logits = torch.full((1, 10), -100.0) # Very negative but not -inf + logits[0, i] = 0.0 # Make one token dominate the probability mass + print(f"\nPosition {i} logits:") + print(f"Raw logits: {logits[0]}") + processor.process_logits([[j for j in range(i + 1)]], logits) + + # Print the softmax of these logits to debug + probs = torch.softmax(logits[0], dim=-1) + print(f"Raw probabilities: {probs}") + print(f"Probability sum: {probs.sum()}") + + probs = processor.get_probabilities() + print("\nProbabilities:") + print(f"Unstructured shape: {probs['unstructured'].shape}") + + # For matrix form, we need to check each position (column) separately + for pos in range(probs["unstructured"].shape[1]): + dist = probs["unstructured"][:, pos] + print(f"Position {pos} sum: {np.sum(dist)}") + print(f"Position {pos} distribution: {dist}") + assert np.allclose(np.sum(dist), 1.0, rtol=1e-5) + + # Check structured probabilities + dist = probs["structured"][:, pos] + valid_probs = dist[~np.isinf(dist)] + if len(valid_probs) > 0: + assert np.allclose(np.sum(valid_probs), 1.0, rtol=1e-5) + + +def test_get_logits(processor): + """Test logit value retrieval.""" + # Process a few positions with known values + for i in range(3): + logits = torch.full((1, 10), float(i)) + processor.process_logits([[j for j in range(i + 1)]], logits) + + logits = processor.get_logits() + + assert set(logits.keys()) == {"unstructured", "structured"} + + assert isinstance(logits["unstructured"], np.ndarray) + assert logits["unstructured"].shape == (10, 3) + assert logits["structured"].shape == (10, 3) + # Check values match what we put in + for i in range(3): + assert np.allclose(logits["unstructured"][:, i], i) + + +def test_get_top_tokens(processor): + """Test top token retrieval with various parameters.""" + # Process some logits with known values + logits = torch.tensor([[2.0, -1.0, 1.0, 0.0, 3.0]]) + processor.process_logits([[0]], logits) + + # Test with different k values and explicitly disable logits + results = processor.get_top_tokens(k=2, include_logits=False) + assert len(results) == 1 # One position + assert len(results[0]["tokens"]) == 2 # k=2 tokens + + # Check token info structure + token_info = results[0]["tokens"][0] + assert set(token_info.keys()) == { + "token", + "unstructured_prob", + "structured_prob", + "is_chosen", + } + + # Test position filtering + results = processor.get_top_tokens(positions=[0]) + assert len(results) == 1 + + # Test invalid position + results = processor.get_top_tokens(positions=[100]) + assert len(results) == 0 + + +def test_sequence_reconstruction(processor): + """Test sequence reconstruction from chosen tokens.""" + # Process a sequence + tokens = [[0], [0, 1], [0, 1, 2]] + for ids in tokens: + print(f"\nProcessing tokens: {ids}") + processor.process_logits([ids], torch.ones(1, 10)) + + print(f"\nFinal chosen_tokens: {processor.chosen_tokens}") + print(f"sequence(0): '{processor.sequence_up_to(0)}'") + print(f"sequence(1): '{processor.sequence_up_to(1)}'") + print(f"sequence(2): '{processor.sequence_up_to(2)}'") + + # Test different positions + assert processor.sequence_up_to(0) == "" # No tokens yet + assert processor.sequence_up_to(1) == "token_1" # First token + assert processor.sequence_up_to(2) == "token_1token_2" # Two tokens + assert processor.sequence_up_to() == "token_1token_2" # Full sequence + + # Test position beyond current sequence + assert processor.sequence_up_to(100) == "token_1token_2" + + +def test_to_dataframe(processor): + """Test DataFrame conversion with various parameters.""" + # Skip if pandas not available + pytest.importorskip("pandas") + + # Process some logits + logits = torch.tensor([[2.0, -1.0, 1.0, 0.0, 3.0]]) + processor.process_logits([[0]], logits) + + # Test probabilities + df = processor.get_probabilities_dataframe() + assert isinstance(df, pd.DataFrame) + assert set(df.columns) == {"position", "token", "natural", "constrained", "chosen"} + assert df["position"].nunique() == 1 + assert (df["natural"] >= 0).all() and (df["natural"] <= 1).all() + + # Test logits + df = processor.get_logits_dataframe() + assert not ( + (df["natural"] >= 0) & (df["natural"] <= 1) + ).all() # Logits can be any value + + # Test min_value filter + df = processor.get_probabilities_dataframe(min_value=0.1) + assert len(df) > 0 + assert ((df["natural"].abs() >= 0.1) | (df["constrained"].abs() >= 0.1)).all() + + +def test_clear(processor): + """Test clearing tracked data.""" + # Add some data + processor.process_logits([[0]], torch.ones(1, 10)) + processor.process_logits([[0, 1]], torch.ones(1, 10)) + + assert len(processor.unstructured_logits) > 0 + assert len(processor.structured_logits) > 0 + assert len(processor.chosen_tokens) > 0 + assert processor._started # Should be True after processing + + # Clear + processor.clear() + + assert len(processor.unstructured_logits) == 0 + assert len(processor.structured_logits) == 0 + assert len(processor.chosen_tokens) == 0 + assert processor._started # Should remain True after clear + + +def test_track_logits_helper(): + """Test the track_logits convenience function.""" + + class MockGenerator: + def __init__(self): + self.logits_processor = MockProcessor() + + generator = MockGenerator() + tracked = track_logits(generator) + + assert isinstance(tracked.logits_processor, LogitTracker) + assert isinstance(tracked.logits_processor.processor, MockProcessor) + + +@pytest.mark.parametrize( + "invalid_value", + [ + "not a processor", # Invalid processor type + None, # No processor + ], +) +def test_invalid_inputs(invalid_value): + """Test handling of invalid inputs.""" + if isinstance(invalid_value, str): + processor = LogitTracker(invalid_value) + with pytest.raises(AttributeError): + processor.process_logits([[0]], torch.ones(1, 10)) + else: + # None processor should work but not modify logits + processor = LogitTracker(invalid_value) + logits = torch.ones(1, 10) + result = processor.process_logits([[0]], logits) + assert torch.allclose(result, logits) + + +def test_missing_tokenizer(): + """Test behavior when processor has no tokenizer.""" + + class ProcessorWithoutTokenizer(OutlinesLogitsProcessor): + def process_logits(self, input_ids, logits): + return logits + + processor = LogitTracker(ProcessorWithoutTokenizer()) + processor.process_logits([[0]], torch.ones(1, 10)) + + with pytest.raises(AttributeError): + processor.get_vocab_mapping() + + +def test_shape_mismatch(): + """Test error handling for shape mismatch between input_ids and logits.""" + processor = LogitTracker(MockProcessor()) + + # Test batch size mismatch + input_ids = [[0]] # batch_size=1 + logits = torch.ones(2, 10) # batch_size=2 + + print("\nShape mismatch test:") + print(f"input_ids shape: {len(input_ids)}x{len(input_ids[0])}") + print(f"logits shape: {logits.shape}") + print(f"logits[0]: {logits[0]}") # Print first batch + print(f"logits: {logits}") # Print full tensor + + # We need to ensure the processor validates batch sizes + # This should fail because logits has batch_size=2 but input_ids has batch_size=1 + with pytest.raises(ValueError, match=r"only supports single-batch processing"): + processor.process_logits(input_ids, logits) diff --git a/tests/types/test_to_regex.py b/tests/types/test_to_regex.py index 6cb566fc5..4b0403ac6 100644 --- a/tests/types/test_to_regex.py +++ b/tests/types/test_to_regex.py @@ -1,7 +1,22 @@ import pytest -from outlines.types.dsl import String, Regex, JsonSchema, KleeneStar, KleenePlus, QuantifyBetween, QuantifyExact, QuantifyMaximum, QuantifyMinimum, Sequence, Alternatives, Optional, Term, to_regex +from outlines.types.dsl import ( + String, + Regex, + JsonSchema, + KleeneStar, + KleenePlus, + QuantifyBetween, + QuantifyExact, + QuantifyMaximum, + QuantifyMinimum, + Sequence, + Alternatives, + Optional, + Term, + to_regex, +) def test_to_regex_simple():