Skip to content

Commit

Permalink
Move samplers module to root module
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jan 23, 2024
1 parent 69b2af5 commit 35998fe
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 32 deletions.
2 changes: 1 addition & 1 deletion docs/api/samplers.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
::: outlines.generate.samplers
::: outlines.samplers
2 changes: 1 addition & 1 deletion outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
sequence_generator,
token_generator,
)
from outlines.generate.samplers import MultinomialSampler, Sampler
from outlines.samplers import MultinomialSampler, Sampler


class SequenceGenerator:
Expand Down
2 changes: 1 addition & 1 deletion outlines/generate/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

if TYPE_CHECKING:
from outlines.fsm.fsm import FSM
from outlines.generate.samplers import Sampler
from outlines.models.tokenizer import Tokenizer
from outlines.samplers import Sampler


@dataclasses.dataclass(frozen=True)
Expand Down
18 changes: 5 additions & 13 deletions outlines/generate/samplers.py → outlines/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@ class GreedySampler:
Greedy sampling consists in choosing the token with the largest
likelihood at every step.
Attributes
----------
particles
The number of samples taken for each input sequence.
We don't allow more than one sample as this does not really make sense.
"""

def __init__(self, samples: int = 1):
self.particles = samples
def __init__(self):
self.particles = 1

def __call__(self, logits: torch.DoubleTensor, *_) -> torch.DoubleTensor:
"""Call the greedy sampler.
Expand All @@ -41,15 +38,10 @@ def __call__(self, logits: torch.DoubleTensor, *_) -> torch.DoubleTensor:
Returns
-------
The ids of the sampled tokens having shape ``(samples, n_seqs)``.
The ids of the sampled tokens, of shape ``(n_seqs, 1)``
"""
if self.particles == 1:
next_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
else:
next_token_ids = torch.topk(
logits, self.particles, dim=-1, largest=True, sorted=True
).indices.T
next_token_ids = torch.argmax(logits, dim=-1, keepdim=True)

return next_token_ids

Expand Down
2 changes: 1 addition & 1 deletion outlines/text/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Callable, List, Optional, Union

import outlines
from outlines.generate.samplers import MultinomialSampler, Sampler
from outlines.samplers import MultinomialSampler, Sampler


def json(
Expand Down
2 changes: 1 addition & 1 deletion tests/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import outlines.generate as generate
import outlines.models as models
from outlines.fsm.regex import reduced_vocabulary
from outlines.generate.samplers import multinomial
from outlines.models.transformers import TransformerTokenizer
from outlines.samplers import multinomial


def test_deprecation():
Expand Down
15 changes: 1 addition & 14 deletions tests/generate/test_samplers.py → tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@

import torch

from outlines.generate.samplers import (
GreedySampler,
MultinomialSampler,
greedy,
multinomial,
)
from outlines.samplers import GreedySampler, MultinomialSampler, greedy, multinomial


def test_aliases():
Expand All @@ -21,19 +16,11 @@ def test_greedy():
next_token_ids = sampler(logits)
assert next_token_ids.equal(torch.tensor([[2]]))

sampler = GreedySampler(2)
next_token_ids = sampler(logits)
assert next_token_ids.equal(torch.tensor([[2], [1]]))

sampler = GreedySampler()
logits = torch.tensor([[10.0, 0.0, 3.0], [-math.inf, 2.0, 5.0]])
next_token_ids = sampler(logits)
assert next_token_ids.equal(torch.tensor([[0], [2]]))

sampler = GreedySampler(2)
next_token_ids = sampler(logits)
assert next_token_ids.equal(torch.tensor([[0, 2], [2, 1]]))


def test_multinomial():
rng = torch.Generator()
Expand Down

0 comments on commit 35998fe

Please sign in to comment.