Skip to content

Commit

Permalink
Pass number of samples via sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jan 23, 2024
1 parent 6a8b9fb commit 69b2af5
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 89 deletions.
31 changes: 17 additions & 14 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
sequence_generator,
token_generator,
)
from outlines.generate.samplers import Sampler, multinomial
from outlines.generate.samplers import MultinomialSampler, Sampler


class SequenceGenerator:
Expand All @@ -25,7 +25,6 @@ def __init__(
sampler,
device,
*,
num_samples: int = 1,
max_tokens=None,
stop_at=None,
):
Expand All @@ -34,7 +33,7 @@ def __init__(
self.tokenizer = model.tokenizer
self.device = device
self.max_tokens = max_tokens
self.num_samples = num_samples
self.num_particles = sampler.particles

if isinstance(stop_at, str):
stop_at = [stop_at]
Expand Down Expand Up @@ -172,7 +171,6 @@ def __call__(
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
num_samples: int = 1,
rng: Optional[torch.Generator] = None,
kv_cache: Optional[torch.tensor] = None,
) -> Union[str, List[str], List[List[str]]]:
Expand Down Expand Up @@ -215,6 +213,7 @@ def __call__(

stop_sequences = stop_at or self.stop_sequences
max_tokens = max_tokens or self.max_tokens
num_samples = self.num_particles
num_sequences = len(prompts)
fsms = [self.fsm.copy() for _ in prompts]

Expand Down Expand Up @@ -282,7 +281,6 @@ def stream(
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
num_samples: int = 1,
rng: Optional[torch.Generator] = None,
kv_cache: Optional[torch.tensor] = None,
) -> Iterator[Union[List[str], List[List[str]], str]]:
Expand Down Expand Up @@ -324,6 +322,7 @@ def stream(

stop_sequences = stop_at or self.stop_sequences
max_tokens = max_tokens or self.max_tokens
num_samples = self.num_particles
num_sequences = len(prompts)
fsms = [self.fsm.copy() for _ in prompts]

Expand Down Expand Up @@ -407,7 +406,7 @@ def text(
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
*,
sampler: Sampler = multinomial,
sampler: Sampler = MultinomialSampler(),
):
eos_token = model.tokenizer.eos_token_id
fsm = StopAtTokenFSM(model.tokenizer, eos_token)
Expand All @@ -424,7 +423,7 @@ def regex(
model,
regex_str: str,
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
sampler: Sampler = MultinomialSampler(),
):
fsm = RegexFSM(regex_str, model.tokenizer)

Expand All @@ -439,7 +438,7 @@ def cfg(
cfg_str: str,
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
sampler: Sampler = multinomial,
sampler: Sampler = MultinomialSampler(),
):
fsm = CFGFSM(cfg_str, model.tokenizer)

Expand All @@ -452,7 +451,10 @@ def cfg(


def format(
model, python_type, max_tokens: Optional[int] = None, sampler: Sampler = multinomial
model,
python_type,
max_tokens: Optional[int] = None,
sampler: Sampler = MultinomialSampler(),
):
regex_str = python_types_to_regex(python_type)
return regex(model, regex_str, max_tokens, sampler)
Expand All @@ -462,7 +464,7 @@ def choice(
model,
choices: List[str],
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
sampler: Sampler = MultinomialSampler(),
):
regex_str = r"(" + r"|".join(choices) + r")"
return regex(model, regex_str, max_tokens, sampler)
Expand All @@ -472,30 +474,31 @@ def json(
model,
schema_object: Union[str, object, Callable],
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
sampler: Sampler = MultinomialSampler(),
):
def handle_json_decoder_error(fn):
"""Handle errors during JSON decoding gracefully."""

def wrapper(*args):
try:
fn(*args)
return fn(*args)
except pyjson.decoder.JSONDecodeError:
raise TypeError(
"Could not format the output of the model into a dictionary or a Pydantic model."
+ " The model has likely exceeded its context length. Please try again using `constr` (for Pydantic)"
+ " and `maxLength` (for JSON Schema) to limit the length of the string fields. If this exception"
+ " is raised nevertheless please open an issue: https://github.com/outlines-dev/outlines/issues"
)
return wrapper

return wrapper

@handle_json_decoder_error
def format_output_to_pydantic(output):
return schema_object.parse_raw(output)

@handle_json_decoder_error
def format_output_to_dict(output):
return pyjson.load(output)
return pyjson.loads(output)

if isinstance(schema_object, type(BaseModel)):
schema = pyjson.dumps(schema_object.model_json_schema())
Expand Down
2 changes: 1 addition & 1 deletion outlines/generate/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def generate(
)

biased_logits = bias_logits(logits, allowed_tokens)
next_token_ids = sampler(biased_logits, 1, rng)
next_token_ids = sampler(biased_logits, rng)

return next_token_ids, new_kv_cache, logits, biased_logits

Expand Down
110 changes: 70 additions & 40 deletions outlines/generate/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,68 +4,98 @@


class Sampler(Protocol):
particles: int

def __call__(
self, logits: torch.DoubleTensor, samples: int, rng: torch.Generator
self, logits: torch.DoubleTensor, rng: torch.Generator
) -> torch.DoubleTensor:
...


def greedy(logits: torch.DoubleTensor, samples: int, *_) -> torch.DoubleTensor:
class GreedySampler:
"""Greedy Sampling algorithm.
Greedy sampling consists in choosing the token with the largest
likelihood at every step.
Parameters
Attributes
----------
logits
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
samples
The number of sequences to produce. In this case, the top-`samples`
logit values are returned.
rng
A random number generator.
Returns
-------
The ids of the sampled tokens having shape ``(samples, n_seqs)``.
particles
The number of samples taken for each input sequence.
"""
if samples == 1:
next_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
else:
next_token_ids = torch.topk(
logits, samples, dim=-1, largest=True, sorted=True
).indices.T

return next_token_ids
def __init__(self, samples: int = 1):
self.particles = samples

def __call__(self, logits: torch.DoubleTensor, *_) -> torch.DoubleTensor:
"""Call the greedy sampler.
Parameters
----------
logits
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
rng
A random number generator.
Returns
-------
The ids of the sampled tokens having shape ``(samples, n_seqs)``.
"""
if self.particles == 1:
next_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
else:
next_token_ids = torch.topk(
logits, self.particles, dim=-1, largest=True, sorted=True
).indices.T

return next_token_ids

def multinomial(
logits: torch.DoubleTensor, samples: int, rng: torch.Generator
) -> torch.DoubleTensor:

greedy = GreedySampler


class MultinomialSampler:
"""Multinomial sampling algorithm.
Multinomial sampling consists in randomly sampling the next token assuming
its distribution is a Categorical distribution parametrized by the
next-token logits.
Parameters
Attributes
----------
logits
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
samples
The number of sequences to sample.
rng
A random number generator.
Returns
-------
The ids of the sampled tokens having shape ``(samples, n_seqs)``.
particles
The number of samples taken for each input sequence.
"""
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=samples, generator=rng)
return next_token_ids

def __init__(self, samples: int = 1):
self.particles = samples

def __call__(
self, logits: torch.DoubleTensor, rng: torch.Generator
) -> torch.DoubleTensor:
"""Call the multinomial sampler.
Parameters
----------
logits
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
rng
A random number generator.
Returns
-------
The ids of the sampled tokens having shape ``(samples, n_seqs)``.
"""
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng)
return next_token_ids


multinomial = MultinomialSampler
12 changes: 6 additions & 6 deletions outlines/text/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from typing import Callable, List, Optional, Union

import outlines
from outlines.generate.samplers import Sampler, multinomial
from outlines.generate.samplers import MultinomialSampler, Sampler


def json(
model,
schema_object: Union[str, object, Callable],
max_tokens: Optional[int] = None,
*,
sampler: Sampler = multinomial,
sampler: Sampler = MultinomialSampler(),
):
warnings.warn(
"`outlines.text.generate.json` is deprecated, please use `outlines.generate.json` instead. "
Expand All @@ -25,7 +25,7 @@ def regex(
regex_str: str,
max_tokens: Optional[int] = None,
*,
sampler: Sampler = multinomial,
sampler: Sampler = MultinomialSampler(),
):
warnings.warn(
"`outlines.text.generate.regex` is deprecated, please use `outlines.generate.regex` instead. "
Expand All @@ -39,7 +39,7 @@ def format(
model,
python_type,
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
sampler: Sampler = MultinomialSampler(),
):
warnings.warn(
"`outlines.text.generate.format` is deprecated, please use `outlines.generate.format` instead. "
Expand All @@ -54,7 +54,7 @@ def continuation(
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
*,
sampler: Sampler = multinomial,
sampler: Sampler = MultinomialSampler(),
):
warnings.warn(
"`outlines.text.generate.continuation` is deprecated, please use `outlines.generate.text` instead. "
Expand All @@ -70,7 +70,7 @@ def choice(
choices: List[str],
max_tokens: Optional[int] = None,
*,
sampler: Sampler = multinomial,
sampler: Sampler = MultinomialSampler(),
):
warnings.warn(
"`outlines.text.generate.choice` is deprecated, please use `outlines.generate.choice` instead. "
Expand Down
12 changes: 8 additions & 4 deletions tests/generate/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,15 @@ def __init__(self):
def __call__(*_):
return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None

def sampler(biased_logits, *_):
return torch.argmax(biased_logits, keepdims=True)
class sampler:
def __init__(self):
self.particles = 1

def __call__(self, biased_logits, *_):
return torch.argmax(biased_logits, keepdims=True)

# Stream
generator = SequenceGenerator(MockFSM(), MockModel(), sampler, "cpu")
generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu")
assert generator.device == "cpu"
assert isinstance(generator.tokenizer, MockTokenizer)
assert isinstance(generator.fsm, MockFSM)
Expand All @@ -67,7 +71,7 @@ def sampler(biased_logits, *_):
next(sequence)

# Call
generator = SequenceGenerator(MockFSM(), MockModel(), sampler, "cpu")
generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu")
result = generator("test")
assert result == "x"

Expand Down
Loading

0 comments on commit 69b2af5

Please sign in to comment.