Skip to content

Commit

Permalink
Allow generation of multiple samples for each prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jan 20, 2024
1 parent ce21732 commit 5cebac7
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 42 deletions.
81 changes: 67 additions & 14 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get_generated_token_ids(
init_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
prompts: List[str],
last_state: GenerationState,
num_samples: int,
) -> List[torch.Tensor]:
"""Get the tokens generated so far.
Expand All @@ -60,18 +61,28 @@ def get_generated_token_ids(
The prompts passed to the generator.
last_state
The current state of the generation
num_samples
The number of samples taken for each sequence
Returns
-------
A tensor that contains the token ids that have been generated so far.
"""
prompt_token_ids = init_state[0]
prompt_lengths = [len(prompt_token_ids[i]) for i in range(len(prompts))]
prompt_lengths = [
len(prompt_token_ids[i])
for _ in range(num_samples)
for i in range(len(prompts))
]

# We flatten the obtained token_ids since the tokenizer's decoder
# only accepts tensor with two dimensions
token_ids = last_state.token_ids.reshape((-1, last_state.token_ids.shape[-1]))

token_ids = [
cur_token_ids[length:]
for cur_token_ids, length in zip(last_state.token_ids, prompt_lengths)
for cur_token_ids, length in zip(token_ids, prompt_lengths)
]

return token_ids
Expand Down Expand Up @@ -150,9 +161,10 @@ def __call__(
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
num_samples: int = 1,
rng: Optional[torch.Generator] = None,
kv_cache: Optional[torch.tensor] = None,
) -> Union[str, List[str]]:
) -> Union[str, List[str], List[List[str]]]:
"""Generate the full text sequence.
Since `SequenceGenerator.stream` calls the tokenizer at every step this
Expand Down Expand Up @@ -205,15 +217,20 @@ def __call__(
init_fsm_states = [FSMState(0) for _ in range(num_sequences)]

states = sequence_generator(
self.generate_token, fsms, init_state, init_fsm_states, rng
self.generate_token,
fsms,
init_state,
init_fsm_states,
rng=rng,
num_samples=num_samples,
)

while True:
try:
last_state = next(states)
if max_tokens or stop_sequences:
generated_token_ids = self.get_generated_token_ids(
init_state, prompts, last_state
init_state, prompts, last_state, num_samples
)
if max_tokens and len(generated_token_ids[0]) >= max_tokens:
break
Expand All @@ -225,8 +242,9 @@ def __call__(
break

generated_token_ids = self.get_generated_token_ids(
init_state, prompts, last_state
init_state, prompts, last_state, num_samples
)

generated = self.tokenizer.decode(generated_token_ids)
stripped = [
self.strip_stop_sequences(sequence, stop_sequences)
Expand All @@ -242,16 +260,29 @@ def __call__(
+ " is raised nevertheless please open an issue: https://github.com/outlines-dev/outlines/issues"
)

return formatted if len(formatted) > 1 else formatted[0]
# We reshape the output to (sample_size, batch_size)
output = []
step = len(prompts)
for i in range(0, len(formatted), step):
output.append(formatted[i : i + step])

# We remove leading dimensions for the output
if len(prompts) == 1 and num_samples == 1:
return output[0][0]
elif num_samples == 1:
return output[0]
else:
return output

def stream(
self,
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
num_samples: int = 1,
rng: Optional[torch.Generator] = None,
kv_cache: Optional[torch.tensor] = None,
) -> Iterator[Union[List[str], str]]:
) -> Iterator[Union[List[str], List[List[str]], str]]:
"""Generate the text sequence one token at a time.
Since `Tokenizer.decode` strips the whitespaces from the tokens we have no
Expand Down Expand Up @@ -303,13 +334,20 @@ def stream(
init_fsm_states = [FSMState(0) for _ in range(num_sequences)]

states = sequence_generator(
self.generate_token, fsms, init_state, init_fsm_states, rng
self.generate_token,
fsms,
init_state,
init_fsm_states,
num_samples=num_samples,
rng=rng,
)

def token_generator() -> Iterator[Union[List[str], str]]:
previously_generated_sequences = ["" for _ in range(num_sequences)]
def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
previously_generated_sequences = [
"" for _ in range(num_sequences)
] * num_samples
num_generated = 0
is_stop_at_reached = [False for _ in range(num_sequences)]
is_stop_at_reached = [False for _ in range(num_sequences)] * num_samples
while True:
if (max_tokens and num_generated >= max_tokens) or all(
is_stop_at_reached
Expand All @@ -320,7 +358,10 @@ def token_generator() -> Iterator[Union[List[str], str]]:
num_generated += 1
except StopIteration:
return
generated_token_ids = sequence.token_ids[:, -num_generated:]
generated_token_ids = sequence.token_ids[:, :, -num_generated:]
generated_token_ids = generated_token_ids.reshape(
-1, generated_token_ids.shape[-1]
)
generated_sequences = self.tokenizer.decode(generated_token_ids)
next_tokens = [
token[len(sequence) :] if not stop else ""
Expand All @@ -341,7 +382,19 @@ def token_generator() -> Iterator[Union[List[str], str]]:
generated_sequences, is_stop_at_reached
)
]
yield next_tokens
# We reshape the output to (sample_size, batch_size)
output = []
step = len(prompts)
for i in range(0, len(next_tokens), step):
output.append(next_tokens[i : i + step])

# We remove leading dimensions for the output
if len(prompts) == 1 and num_samples == 1:
yield output[0][0]
elif num_samples == 1:
yield output[0]
else:
yield output

return token_generator()

Expand Down
29 changes: 25 additions & 4 deletions outlines/generate/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def sequence_generator(
fsms: List["FSM"],
init_state: Tuple,
fsm_states: List[FSMState],
rng: torch.Generator,
num_samples: int = 1,
rng: torch.Generator = torch.Generator(),
) -> Iterator[GenerationState]:
"""Generates sequences of tokens.
Expand All @@ -78,6 +79,17 @@ def sequence_generator(
"""
token_ids, attention_masks, kv_cache = init_state
batch_shape = token_ids.shape[:-1]

# To take several samples we duplicate `token_ids`, `attention_masks`
# and `fsm_states` as many times as the number of samples requested.
# The resulting tensors are of shape (num_samples * num_batches, num_tokens)
token_ids = torch.repeat_interleave(token_ids, num_samples, dim=0)
attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)

fsm_states = fsm_states * num_samples
fsms = [fsm.copy() for fsm in fsms for _ in range(num_samples)]

while True:
allowed_tokens = get_allowed_tokens(fsms, fsm_states)

Expand All @@ -88,18 +100,27 @@ def sequence_generator(
rng=rng,
allowed_tokens=allowed_tokens,
)

token_ids = update_token_ids(token_ids, next_token_ids)
attention_masks = expand_attention_masks(attention_masks)

fsm_states = get_next_fsm_states(fsms, fsm_states, next_token_ids)
is_finished = is_generation_finished(fsms, fsm_states)

if is_finished:
yield GenerationState(token_ids, kv_cache, logits, fsm_states)
yield GenerationState(
token_ids.reshape((num_samples,) + batch_shape + token_ids.shape[-1:]),
kv_cache,
logits,
fsm_states,
)
return

yield GenerationState(token_ids, kv_cache, logits, fsm_states)
yield GenerationState(
token_ids.reshape((num_samples,) + batch_shape + token_ids.shape[-1:]),
kv_cache,
logits,
fsm_states,
)


def token_generator(model, sampler: "Sampler") -> Callable:
Expand Down
93 changes: 82 additions & 11 deletions tests/generate/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def copy(self):
class MockTokenizer:
def encode(self, _):
# Input: "test"
return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1, 1]])
return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]])

def decode(self, tokens):
return ["testx"[i] for i in tokens]
Expand Down Expand Up @@ -139,11 +139,11 @@ def sampler(biased_logits, *_):
init_fsm_states = [0]
generate = token_generator(MockModel(), sampler)
sequence = sequence_generator(
generate, [MockFSM()], init_state, init_fsm_states, torch.Generator()
generate, [MockFSM()], init_state, init_fsm_states, rng=torch.Generator()
)
result = next(sequence)

assert torch.equal(result.token_ids, torch.tensor([[0, 1, 2, 3, 3]]))
assert torch.equal(result.token_ids, torch.tensor([[[0, 1, 2, 3, 3]]]))
assert torch.equal(result.logits, torch.tensor([[0, 1, 2, 3]]))

with pytest.raises(StopIteration):
Expand Down Expand Up @@ -185,15 +185,15 @@ def sampler(biased_logits, *_):
init_fsm_states = [0]
generate = token_generator(MockModel(), sampler)
sequence = sequence_generator(
generate, [MockFSM()], init_state, init_fsm_states, torch.Generator()
generate, [MockFSM()], init_state, init_fsm_states, rng=torch.Generator()
)

result = next(sequence)
assert torch.equal(result.token_ids, torch.tensor([[0, 1, 2, 3, 3]]))
assert torch.equal(result.token_ids, torch.tensor([[[0, 1, 2, 3, 3]]]))
assert torch.equal(result.logits, torch.tensor([[0, 1, 2, 3]]))

result = next(sequence)
assert torch.equal(result.token_ids, torch.tensor([[0, 1, 2, 3, 3, 3]]))
assert torch.equal(result.token_ids, torch.tensor([[[0, 1, 2, 3, 3, 3]]]))
assert torch.equal(result.logits, torch.tensor([[0, 1, 2, 3]]))

with pytest.raises(StopIteration):
Expand Down Expand Up @@ -239,12 +239,12 @@ def sampler(biased_logits, *_):
fsms = [MockFSM(), MockFSM()]
generate = token_generator(MockModel(), sampler)
sequence = sequence_generator(
generate, fsms, init_state, init_fsm_states, torch.Generator()
generate, fsms, init_state, init_fsm_states, rng=torch.Generator()
)

result = next(sequence)
assert torch.equal(
result.token_ids, torch.tensor([[0, 1, 2, 3, 3], [4, 5, 6, 7, 2]])
result.token_ids, torch.tensor([[[0, 1, 2, 3, 3], [4, 5, 6, 7, 2]]])
)
assert torch.equal(
result.logits, torch.tensor([[0, 1, 2, 3], [4, 5, 7, 6]], dtype=torch.float)
Expand All @@ -254,6 +254,77 @@ def sampler(biased_logits, *_):
next(sequence)


def test_sequence_generator_2d_single_iteration_several_samples():
class MockFSM:
def next_state(self, state, next_token_ids):
return 0

def allowed_token_ids(self, _):
return [0, 1, 2, 3]

def is_final_state(self, _):
return True

class MockTokenizer:
def encode(self, _):
return torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), torch.tensor(
[[1, 1, 1, 1], [1, 1, 1, 1]]
)

def decode(self, x):
return x

class MockModel:
def __init__(self):
self.tokenizer = MockTokenizer()

def __call__(*_):
return (
torch.tensor(
[[0, 1, 2, 3], [4, 5, 7, 6], [0, 1, 2, 3], [1, 5, 3, 4]],
dtype=torch.float,
),
None,
)

def sampler(biased_logits, *_):
return torch.argmax(biased_logits, keepdims=True, dim=-1)

init_state = (
torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
torch.tensor([[1, 1, 1, 1], [1, 1, 1, 1]]),
None,
)
init_fsm_states = [0, 0]
fsms = [MockFSM(), MockFSM()]
generate = token_generator(MockModel(), sampler)
sequence = sequence_generator(
generate,
fsms,
init_state,
init_fsm_states,
num_samples=2,
rng=torch.Generator(),
)

result = next(sequence)
expected_token_ids = torch.tensor(
[
[[0, 1, 2, 3, 3], [0, 1, 2, 3, 2]],
[[4, 5, 6, 7, 3], [4, 5, 6, 7, 1]],
]
)
assert torch.equal(result.token_ids, expected_token_ids)

expected_logits = torch.tensor(
[[0, 1, 2, 3], [4, 5, 7, 6], [0, 1, 2, 3], [1, 5, 3, 4]], dtype=torch.float
)
assert torch.equal(result.logits, expected_logits)

with pytest.raises(StopIteration):
next(sequence)


def test_sequence_generator_2d_several_iterations():
class MockFSM:
def next_state(self, state, next_token_ids):
Expand Down Expand Up @@ -296,20 +367,20 @@ def sampler(biased_logits, *_):
fsms = [MockFSM(), MockFSM()]
generate = token_generator(MockModel(), sampler)
sequence = sequence_generator(
generate, fsms, init_state, init_fsm_states, torch.Generator()
generate, fsms, init_state, init_fsm_states, rng=torch.Generator()
)

result = next(sequence)
assert torch.equal(
result.token_ids, torch.tensor([[0, 1, 2, 3, 3], [4, 5, 6, 7, 2]])
result.token_ids, torch.tensor([[[0, 1, 2, 3, 3], [4, 5, 6, 7, 2]]])
)
assert torch.equal(
result.logits, torch.tensor([[0, 1, 2, 3], [4, 5, 7, 6]], dtype=torch.float)
)

result = next(sequence)
assert torch.equal(
result.token_ids, torch.tensor([[0, 1, 2, 3, 3, 3], [4, 5, 6, 7, 2, 2]])
result.token_ids, torch.tensor([[[0, 1, 2, 3, 3, 3], [4, 5, 6, 7, 2, 2]]])
)
assert torch.equal(
result.logits, torch.tensor([[0, 1, 2, 3], [4, 5, 7, 6]], dtype=torch.float)
Expand Down
Loading

0 comments on commit 5cebac7

Please sign in to comment.