diff --git a/docs/api/samplers.md b/docs/api/samplers.md index c125e9a66..2b9b34234 100644 --- a/docs/api/samplers.md +++ b/docs/api/samplers.md @@ -1 +1 @@ -::: outlines.generate.samplers +::: outlines.samplers diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 9015b4e73..32db1189e 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -14,7 +14,7 @@ sequence_generator, token_generator, ) -from outlines.generate.samplers import MultinomialSampler, Sampler +from outlines.samplers import MultinomialSampler, Sampler class SequenceGenerator: diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index 71f37a2c6..1a3e72e23 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -8,8 +8,8 @@ if TYPE_CHECKING: from outlines.fsm.fsm import FSM - from outlines.generate.samplers import Sampler from outlines.models.tokenizer import Tokenizer + from outlines.samplers import Sampler @dataclasses.dataclass(frozen=True) diff --git a/outlines/generate/samplers.py b/outlines/samplers.py similarity index 79% rename from outlines/generate/samplers.py rename to outlines/samplers.py index 2323eb6fa..2bb465b28 100644 --- a/outlines/generate/samplers.py +++ b/outlines/samplers.py @@ -18,15 +18,12 @@ class GreedySampler: Greedy sampling consists in choosing the token with the largest likelihood at every step. - Attributes - ---------- - particles - The number of samples taken for each input sequence. + We don't allow more than one sample as this does not really make sense. """ - def __init__(self, samples: int = 1): - self.particles = samples + def __init__(self): + self.particles = 1 def __call__(self, logits: torch.DoubleTensor, *_) -> torch.DoubleTensor: """Call the greedy sampler. @@ -41,15 +38,10 @@ def __call__(self, logits: torch.DoubleTensor, *_) -> torch.DoubleTensor: Returns ------- - The ids of the sampled tokens having shape ``(samples, n_seqs)``. + The ids of the sampled tokens, of shape ``(n_seqs, 1)`` """ - if self.particles == 1: - next_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - else: - next_token_ids = torch.topk( - logits, self.particles, dim=-1, largest=True, sorted=True - ).indices.T + next_token_ids = torch.argmax(logits, dim=-1, keepdim=True) return next_token_ids diff --git a/outlines/text/generate/api.py b/outlines/text/generate/api.py index 613e9ea1f..d9e733b3f 100644 --- a/outlines/text/generate/api.py +++ b/outlines/text/generate/api.py @@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Union import outlines -from outlines.generate.samplers import MultinomialSampler, Sampler +from outlines.samplers import MultinomialSampler, Sampler def json( diff --git a/tests/generate/test_integration_transfomers.py b/tests/generate/test_integration_transfomers.py index 1b2139d73..8ebc3d4de 100644 --- a/tests/generate/test_integration_transfomers.py +++ b/tests/generate/test_integration_transfomers.py @@ -10,8 +10,8 @@ import outlines.generate as generate import outlines.models as models from outlines.fsm.regex import reduced_vocabulary -from outlines.generate.samplers import multinomial from outlines.models.transformers import TransformerTokenizer +from outlines.samplers import multinomial def test_deprecation(): diff --git a/tests/generate/test_samplers.py b/tests/test_samplers.py similarity index 72% rename from tests/generate/test_samplers.py rename to tests/test_samplers.py index a5a383004..5cdcc06b8 100644 --- a/tests/generate/test_samplers.py +++ b/tests/test_samplers.py @@ -2,12 +2,7 @@ import torch -from outlines.generate.samplers import ( - GreedySampler, - MultinomialSampler, - greedy, - multinomial, -) +from outlines.samplers import GreedySampler, MultinomialSampler, greedy, multinomial def test_aliases(): @@ -21,19 +16,11 @@ def test_greedy(): next_token_ids = sampler(logits) assert next_token_ids.equal(torch.tensor([[2]])) - sampler = GreedySampler(2) - next_token_ids = sampler(logits) - assert next_token_ids.equal(torch.tensor([[2], [1]])) - sampler = GreedySampler() logits = torch.tensor([[10.0, 0.0, 3.0], [-math.inf, 2.0, 5.0]]) next_token_ids = sampler(logits) assert next_token_ids.equal(torch.tensor([[0], [2]])) - sampler = GreedySampler(2) - next_token_ids = sampler(logits) - assert next_token_ids.equal(torch.tensor([[0, 2], [2, 1]])) - def test_multinomial(): rng = torch.Generator()