diff --git a/outlines/function.py b/outlines/function.py index d7f6dec68..cabaf1fef 100644 --- a/outlines/function.py +++ b/outlines/function.py @@ -7,7 +7,7 @@ import outlines if TYPE_CHECKING: - from outlines.generate.api import SequenceGenerator + from outlines.generate import SequenceGeneratorAdapter from outlines.templates import Template @@ -25,7 +25,7 @@ class Function: prompt_template: "Template" schema: Union[str, Callable, object] model_name: str - generator: Optional["SequenceGenerator"] = None + generator: Optional["SequenceGeneratorAdapter"] = None @classmethod def from_github(cls, program_path: str, function_name: str = "fn"): diff --git a/outlines/generate/__init__.py b/outlines/generate/__init__.py index 10bf78553..bc91b6124 100644 --- a/outlines/generate/__init__.py +++ b/outlines/generate/__init__.py @@ -5,7 +5,7 @@ from outlines.processors import CFGLogitsProcessor, RegexLogitsProcessor from outlines.types import CFG, Choice, JsonType, List, Regex -from .api import SequenceGenerator +from .api import SequenceGeneratorAdapter from .cfg import cfg from .choice import choice from .format import format diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 5d3c52a8a..d9833eb12 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -3,8 +3,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union -from outlines.generate.generator import sequence_generator - if TYPE_CHECKING: import torch @@ -13,383 +11,6 @@ ] -class SequenceGenerator: - def __init__( - self, - fsm, - model, - sampler, - device, - ): - self.fsm = fsm - self.model = model - self.sampler = sampler - self.tokenizer = model.tokenizer - self.device = device - self.num_samples = sampler.samples - - def get_generated_token_ids( - self, - prompt_token_ids: "torch.Tensor", - token_ids: "torch.Tensor", - ) -> List["torch.Tensor"]: - """Get the tokens generated so far. - - Parameters - ---------- - prompt_token_ids - Tensor that contains the token ids of the sequences' prompts. - token_ids - The generated token ids. - - Returns - ------- - A tensor that contains the token ids that have been generated so far. - - """ - prompt_lengths = [len(prompt) for prompt in prompt_token_ids] - token_ids = [ - cur_token_ids[length:] - for cur_token_ids, length in zip(token_ids, prompt_lengths) - ] - - return token_ids - - def is_stop_sequence_found( - self, generated_sequences: List[str], stop_sequences: List[str] - ) -> bool: - """Determine whether one of the stop sequences has been generated. - - Parameters - ---------- - generated_sequences - The list of sequences generated so far. - stop_sequences - The list that contains the sequence which stop the generation when - found. - - Returns - ------- - True if at least one of the stop sequences has been found in each generated - sequence. - - """ - return all( - [ - any([seq in generated for seq in stop_sequences]) - for generated in generated_sequences - ] - ) - - def strip_stop_sequences( - self, sequence: str, stop_sequences: Optional[List[str]] - ) -> str: - """Remove the stop sequences from the generated sequences. - - Parameters - ---------- - sequence - One of the generated sequences. - stop_sequences - The list that contains the sequence which stop the generation when - found. - - """ - if stop_sequences: - match_indexes = [sequence.find(seq) for seq in stop_sequences] - if any([index != -1 for index in match_indexes]): - # select the stop_sequence that is found first in the sequence - min_match_index_value = min([i for i in match_indexes if i != -1]) - min_match_index_pos = match_indexes.index(min_match_index_value) - sequence = sequence[ - : match_indexes[min_match_index_pos] - + len(stop_sequences[min_match_index_pos]) - ] - - return sequence - - def format_sequence(self, sequence: str) -> FormattedOutput: - """Translate the generated sequence to another type. - - This method is for instance overridden when generating JSON to either - return a dictionnary or a Pydantic model. - - Parameters - ---------- - sequence - A generated sequences. - - Returns - ------- - The formatted sequence. - - """ - return sequence - - def __call__( - self, - prompts: Union[str, List[str]], - max_tokens: Optional[int] = None, - stop_at: Optional[Union[str, List[str]]] = None, - rng: Optional["torch.Generator"] = None, - ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: - """Generate the full text sequence. - - Since `SequenceGenerator.stream` calls the tokenizer at every step this - method loops over the generator returned by `sequence_generator` itself - so the tokenizer is called only once after all token ids have been - generated. - - Parameters - ---------- - prompts - A string or list of strings that are passed to the model before - generating the first token. - max_tokens - An integer representing maximum number of tokens that will be generated - (per prompt) - stop_at - A string or list of strings at which the text generated will stop - rng - The random number generator. Defaults to a non-seeded `torch.Generator` - instance. - - Returns - ------- - The generation(s), potentially cast to another type. - """ - import torch - - if isinstance(prompts, str): - prompts = [prompts] - - if isinstance(stop_at, str): - stop_at = [stop_at] - - stop_sequences = stop_at - num_samples = self.num_samples - - if rng is None: - rng = torch.Generator(device=self.device) - rng.seed() - - prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) - prompt_token_ids = prompt_token_ids.to(self.device) - attention_masks = attention_masks.to(self.device) - - # To draw multiple samples we repeat the prompt as many times - # as there are samples. We copy the FSMs and initialize the - # FSM states. - num_samples = self.num_samples - batch_size = len(prompts) - - prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) - attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) - fsm_states = [0 for _ in range(batch_size * num_samples)] - fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] - weights = torch.zeros( - (batch_size * num_samples), dtype=torch.float, device=self.device - ) - - states = sequence_generator( - self.model, - self.sampler, - fsms, - prompt_token_ids, - weights, - attention_masks, - fsm_states, - rng=rng, - ) - - while True: - try: - last_state = next(states) - if max_tokens or stop_sequences: - token_ids = last_state.token_ids - generated_token_ids = self.get_generated_token_ids( - prompt_token_ids, token_ids - ) - if max_tokens and len(generated_token_ids[0]) >= max_tokens: - break - if stop_sequences and self.is_stop_sequence_found( - self.tokenizer.decode(generated_token_ids), stop_sequences - ): - break - except StopIteration: - break - - token_ids = last_state.token_ids - generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids) - - generated = self.tokenizer.decode(generated_token_ids) - stripped = [ - self.strip_stop_sequences(sequence, stop_sequences) - for sequence in generated - ] - formatted = [self.format_sequence(sequence) for sequence in stripped] - - # We reshape the output to (batch_size, sample_size) - output: List[List[FormattedOutput]] = list() - for i in range(0, batch_size * num_samples, num_samples): - output.append(formatted[i : i + num_samples]) - - # We remove leading dimensions for the output - if batch_size == 1 and num_samples == 1: - return output[0][0] - elif batch_size == 1: - return output[0] - elif num_samples == 1: - return [samples[0] for samples in output] - else: - return output - - def stream( - self, - prompts: Union[str, List[str]], - max_tokens: Optional[int] = None, - stop_at: Optional[Union[str, List[str]]] = None, - rng: Optional["torch.Generator"] = None, - ) -> Iterator[Union[List[str], str, List[List[str]]]]: - """Generate the text sequence one token at a time. - - Since `Tokenizer.decode` strips the whitespaces from the tokens we have no - choice but to decode the generated token ids at each step and compare the - current decoded strings to the previously decoded strings. - - Parameters - ---------- - prompts - A string or list of strings that are passed to the model before - generating the first token. - max_tokens - An integer representing maximum number of tokens that will be generated - (per prompt) - stop_at - A string or list of strings at which the text generated will stop - rng - The random number generator. Defaults to a non-seeded `torch.Generator` - instance. - - Returns - ------- - A string or list of strings that contain the generated text. - - """ - import torch - - if isinstance(prompts, str): - prompts = [prompts] - - if isinstance(stop_at, str): - stop_at = [stop_at] - - stop_sequences = stop_at - num_samples = self.num_samples - - prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) - prompt_token_ids = prompt_token_ids.to(self.device) - attention_masks = attention_masks.to(prompt_token_ids.device) - - # To draw multiple samples we repeat the prompt as many times - # as there are samples. We copy the FSMs and initialize the - # FSM states. - num_samples = self.num_samples - batch_size = len(prompts) - - prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) - attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) - fsm_states = [0 for _ in range(batch_size * num_samples)] - fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] - weights = torch.zeros( - (batch_size * num_samples), - dtype=torch.float, - device=prompt_token_ids.device, - ) - - if rng is None: - rng = torch.Generator(device=prompt_token_ids.device) - rng.seed() - - states = sequence_generator( - self.model, - self.sampler, - fsms, - prompt_token_ids, - weights, - attention_masks, - fsm_states, - rng=rng, - ) - - def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: - previously_generated_sequences = [ - "" for _ in range(batch_size) - ] * num_samples - num_generated = 0 - is_stop_at_reached = [False for _ in range(batch_size)] * num_samples - while True: - if (max_tokens and num_generated >= max_tokens) or all( - is_stop_at_reached - ): - return - try: - sequence = next(states) - num_generated += 1 - except StopIteration: - return - generated_token_ids = sequence.token_ids[:, -num_generated:] - generated_sequences = self.tokenizer.decode(generated_token_ids) - if stop_sequences: - is_stop_at_reached = [ - stop - or self.is_stop_sequence_found( - [generated_sequence], stop_sequences - ) - for generated_sequence, stop in zip( - generated_sequences, is_stop_at_reached - ) - ] - - generated_sequences = [ - ( - self.format_sequence( - self.strip_stop_sequences(sequence, stop_sequences) - ) - if stop - else sequence - ) - for sequence, stop in zip( - generated_sequences, is_stop_at_reached - ) - ] - next_tokens = [ - token[len(sequence) :] - for token, sequence, stop in zip( - generated_sequences, - previously_generated_sequences, - is_stop_at_reached, - ) - ] - previously_generated_sequences = generated_sequences - # We reshape the output to (batch_size, sample_size) - output: List[List[str]] = list() - for i in range(0, batch_size * num_samples, num_samples): - output.append(next_tokens[i : i + num_samples]) - - # We remove leading dimensions for the output - if batch_size == 1 and num_samples == 1: - yield output[0][0] - elif batch_size == 1: - yield output[0] - elif num_samples == 1: - yield [samples[0] for samples in output] - else: - yield output - - return token_generator() - - @dataclass(frozen=True) class GenerationParameters: """Generation parameters used in Outlines' public API.""" diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py deleted file mode 100644 index e506aa035..000000000 --- a/outlines/generate/generator.py +++ /dev/null @@ -1,312 +0,0 @@ -import dataclasses -import math -from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional, Tuple - -if TYPE_CHECKING: - import torch - - from outlines.fsm.guide import Guide - - -class ContextLengthExceededError(Exception): - pass - - -@dataclasses.dataclass(frozen=True) -class GenerationState: - token_ids: "torch.Tensor" - kv_cache: "torch.Tensor" - logits: "torch.Tensor" - weights: "torch.Tensor" - fsm_states: List[int] - - -def sequence_generator( - model: Callable, - sampler: Callable, - fsms: List["Guide"], - token_ids: "torch.Tensor", - sequence_weights: "torch.Tensor", - attention_masks: "torch.Tensor", - fsm_states: List[int], - rng: "torch.Generator", -) -> Iterator[GenerationState]: - """Generates sequences of tokens. - - Parameters - ---------- - model - A callable that generates a probability distribution over the - vocabulary when passed a tensor of token ids. - sampler - A callable that returns the next token ids, their ancestor sequence and - the updated sequence weights when passed a distribution over the - vocabulary. - token_ids - A tensor of token ids on which the sequence distribution is conditioned, of - shape ``(n_seqs, n_prompt_tokens)`` - sequence_weights - A tensor that contains the initial weights of the sequences, of shape - ``(n_seqs,)`` - attention_masks - A tensor of tensors that represent the tokens considered at the attention - layer, of shape ``(n_seqs, n_prompt_tokens)``. - fsms - List of finite-state machines that drive the text generation, - one for each sequence in the batch. - fsm_states - The initial states of the finite-state machine for each sequence in the batch. - - Yields - ------ - A new sequence. - - """ - import torch - - if rng is None: - rng = torch.Generator() - - kv_cache = None - - while True: - try: - logits, kv_cache = model(token_ids, attention_masks, kv_cache) - except IndexError: # Exceeding the context length - raise ContextLengthExceededError( - "The input length exceeds the context length of the model." - ) - - allowed_tokens = get_allowed_tokens(fsms, fsm_states) - biased_logits = bias_logits(logits, allowed_tokens) - next_token_ids, ancestors, sequence_weights = sampler( - biased_logits, sequence_weights, rng - ) - - token_ids = update_token_ids(token_ids, next_token_ids, ancestors) - attention_masks = update_attention_masks(attention_masks, ancestors) - kv_cache = reorder_kv_cache(kv_cache, ancestors) - if len(ancestors) > 1: - fsms = reorder_fsms(fsms, ancestors) - fsm_states = reorder_fsm_states(fsm_states, ancestors) - - fsm_states = get_next_fsm_states(fsms, fsm_states, next_token_ids) - is_finished = is_generation_finished(fsms, fsm_states) - - if is_finished: - yield GenerationState( - token_ids, - kv_cache, - logits, - sequence_weights, - fsm_states, - ) - return - - yield GenerationState( - token_ids, - kv_cache, - logits, - sequence_weights, - fsm_states, - ) - - -def get_next_fsm_states( - fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor" -) -> List[int]: - """ - - Parameters - ---------- - fsm - The finite-state machine used to monitor this batch. - next_token_ids - The tokens that were just generated. - - Returns - ------- - A `torch.Tensor` object that represents the next logit mask. - - """ - return [ - fsm.get_next_state(fsm_state, int(token_id[0])) - for fsm, fsm_state, token_id in zip(fsms, fsm_states, next_token_ids) - ] - - -def get_allowed_tokens( - fsms: List["Guide"], fsm_states: List[int] -) -> List[Optional[Iterable[int]]]: - """Get the new instructions for each sequence from the finite-state machine. - - Parameters - ---------- - fsm - The finite-state machine used to monitor this batch. - fsm_states - The FSM states corresponding to each sequence in the batch. - - Returns - ------- - A nested list that contains the ids of the logits to keep. - - """ - return [ - fsm.get_next_instruction(state).tokens for fsm, state in zip(fsms, fsm_states) - ] - - -def is_generation_finished(fsms: List["Guide"], fsm_states: List[int]) -> bool: - """Determine if the generation is finished. - - A generation is considered finished if the FSM of every sequence in the - batch is in a final state. - - A better solution is to return finished sequences as soon as their FSM - is in a final state. - - Parameters - ---------- - fsm - The finite-state machine used to monitor this batch. - fsm_states - The FSM states corresponding to each sequence in the batch. - - Returns - ------- - Whether all sequences are finished sampling. - - """ - return all([fsm.is_final_state(state) for fsm, state in zip(fsms, fsm_states)]) - - -def update_token_ids( - token_ids: "torch.Tensor", next_token_ids: "torch.Tensor", ancestors: "torch.Tensor" -) -> "torch.Tensor": - """Append the sampled tokens to the running sequence of tokens. - - Parameters - ---------- - token_ids - The current token sequences - next_token_ids - The tokens that were just generated and that we need to append - to the existing sequences. - ancestors - The sequences to which the token ids need to be added. - - Returns - ------- - A new sequence of token ids that contains the tokens that were - just generated. - - """ - import torch - - token_ids = torch.index_select(token_ids, 0, ancestors) - return torch.concatenate([token_ids, next_token_ids], dim=-1) - - -def update_attention_masks( - attention_masks: "torch.Tensor", ancestors: "torch.Tensor" -) -> "torch.Tensor": - """Expand the attention masks. - - Parameters - ---------- - attention_masks - The attention masks for each sequence in the batch. - ancestors - The sequences to which the token ids need to be added. - - Returns - ------- - The attention masks padded with 1s. - - """ - import torch - - attention_masks = torch.index_select(attention_masks, 0, ancestors) - return torch.concatenate( - [ - attention_masks, - torch.ones( - attention_masks.shape[:-1] + (1,), device=attention_masks.device - ), - ], - axis=-1, - ) - - -def reorder_fsms(fsms: List["Guide"], ancestors: "torch.Tensor") -> List["Guide"]: - reordered_fsms = [] - for ancestor in ancestors: - reordered_fsms.append(fsms[ancestor].copy()) - - return reordered_fsms - - -def reorder_fsm_states(fsm_states: List[int], ancestors: "torch.Tensor") -> List[int]: - reordered_states = [] - for ancestor in ancestors: - reordered_states.append(fsm_states[ancestor]) - - return reordered_states - - -def reorder_kv_cache( - kv_cache: Optional[Tuple], ancestors: "torch.Tensor" -) -> Optional[Tuple]: - """Re-order the KV-cache based on the ancestors. - - In transformers, the object that stores the KV-cache is a tuple who elements - are the key cache and the value cache. Each of these caches are tuples where - each element correpond to a layer. To each layer corresponds a tensor whose - first dimension is the batch size. - - """ - import torch - - if kv_cache is None: - return None - - new_kv_cache: Tuple = tuple() - for cache_item in kv_cache: - new_cache_item: Tuple = tuple() - for layer in cache_item: - layer = torch.index_select(layer, 0, ancestors.to(layer.device)) - new_cache_item += (layer,) - new_kv_cache += (new_cache_item,) - - return new_kv_cache - - -def bias_logits(logits: "torch.Tensor", allowed_token_ids: List) -> "torch.Tensor": - """Mask the logits. - - The function iterates over a nested list where each list corresponds to the - indices that need to be masked for each row in the array. - - Parameters - ---------- - logits - Two dimensional tensor that contains the next-token probability - distribution. - allowed_token_ids - A list that contains the tokens that can be generated by the model. - - Returns - ------- - A view of the original logits tensor where some values are masked. - - """ - import torch - - biased_logits = torch.full_like(logits, -math.inf, device=logits.device) - for i, ids in enumerate(allowed_token_ids): - if ids is not None: - biased_logits[i, ids] = logits[i, ids] - else: - biased_logits[i] = logits[i] - return biased_logits diff --git a/tests/generate/test_generator.py b/tests/generate/test_generator.py deleted file mode 100644 index 5a2edf8dc..000000000 --- a/tests/generate/test_generator.py +++ /dev/null @@ -1,497 +0,0 @@ -import math -from typing import Generator - -import pytest -import torch - -from outlines.fsm.guide import Generate -from outlines.generate.api import SequenceGenerator -from outlines.generate.generator import ( - bias_logits, - get_allowed_tokens, - get_next_fsm_states, - is_generation_finished, - sequence_generator, - update_attention_masks, - update_token_ids, -) - - -def test_sequence_generator_class(): - class MockFSM: - first_state = 0 - - def get_next_state(self, state, next_token_ids): - return 4 - - def get_next_instruction(self, *_): - return Generate([4]) - - def is_final_state(self, _): - return True - - def copy(self): - return self - - class MockTokenizer: - def encode(self, _): - # Input: "test" - return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]]) - - def decode(self, tokens): - return ["testx"[i] for i in tokens] - - class MockModel: - def __init__(self): - self.tokenizer = MockTokenizer() - - def __call__(*_): - return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None - - class sampler: - def __init__(self): - self.samples = 1 - - def __call__(self, biased_logits, *_): - return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None - - # Stream - generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu") - assert generator.device == "cpu" - assert isinstance(generator.tokenizer, MockTokenizer) - assert isinstance(generator.fsm, MockFSM) - - sequence = generator.stream("test") - assert isinstance(sequence, Generator) - - next(sequence) - - with pytest.raises(StopIteration): - next(sequence) - - # Call - generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu") - result = generator("test") - assert result == "x" - - -def test_sequence_generator_1d_single_iteration(): - class MockFSM: - def get_next_state(self, state, next_token_ids): - return 0 - - def get_next_instruction(self, _): - return Generate([0, 1, 2, 3]) - - def is_final_state(self, _): - return True - - def copy(self): - return self - - class MockTokenizer: - def encode(self, _): - return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]]) - - def decode(self, x): - return x - - class MockModel: - def __init__(self): - self.tokenizer = MockTokenizer() - - def __call__(*_): - return torch.tensor([[0, 1, 2, 3]], dtype=torch.float), None - - def sampler(biased_logits, *_): - return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None - - token_ids, attention_mask = ( - torch.tensor([[0, 1, 2, 3]]), - torch.tensor([[1, 1, 1, 1]]), - ) - init_fsm_states = [0] - sequence = sequence_generator( - MockModel(), - sampler, - [MockFSM()], - token_ids, - torch.tensor([0]), - attention_mask, - init_fsm_states, - rng=torch.Generator(), - ) - result = next(sequence) - - assert torch.equal(result.token_ids, torch.tensor([[0, 1, 2, 3, 3]])) - assert torch.equal(result.logits, torch.tensor([[0, 1, 2, 3]])) - - with pytest.raises(StopIteration): - next(sequence) - - -def test_sequence_generator_1d_several_iterations(): - class MockFSM: - def get_next_state(self, state, next_token_ids): - return state + 1 - - def get_next_instruction(self, _): - return Generate([0, 1, 2, 3]) - - def is_final_state(self, state): - if state < 2: - return False - else: - return True - - def copy(self): - return self - - class MockTokenizer: - def encode(self, _): - return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]]) - - def decode(self, x): - return x - - class MockModel: - def __init__(self): - self.tokenizer = MockTokenizer() - - def __call__(*_): - return torch.tensor([[0, 1, 2, 3]], dtype=torch.float), None - - def sampler(biased_logits, *_): - return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None - - token_ids, attention_mask = ( - torch.tensor([[0, 1, 2, 3]]), - torch.tensor([[1, 1, 1, 1]]), - ) - init_fsm_states = [0] - sequence = sequence_generator( - MockModel(), - sampler, - [MockFSM()], - token_ids, - torch.tensor([0]), - attention_mask, - init_fsm_states, - rng=torch.Generator(), - ) - - result = next(sequence) - assert torch.equal(result.token_ids, torch.tensor([[0, 1, 2, 3, 3]])) - assert torch.equal(result.logits, torch.tensor([[0, 1, 2, 3]])) - - result = next(sequence) - assert torch.equal(result.token_ids, torch.tensor([[0, 1, 2, 3, 3, 3]])) - assert torch.equal(result.logits, torch.tensor([[0, 1, 2, 3]])) - - with pytest.raises(StopIteration): - next(sequence) - - -def test_sequence_generator_2d_single_iteration(): - class MockFSM: - def get_next_state(self, state, next_token_ids): - return 0 - - def get_next_instruction(self, _): - return Generate([0, 1, 2, 3]) - - def is_final_state(self, _): - return True - - def copy(self): - return self - - class MockTokenizer: - def encode(self, _): - return torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), torch.tensor( - [[1, 1, 1, 1], [1, 1, 1, 1]] - ) - - def decode(self, x): - return x - - class MockModel: - def __init__(self): - self.tokenizer = MockTokenizer() - - def __call__(*_): - return torch.tensor([[0, 1, 2, 3], [4, 5, 7, 6]], dtype=torch.float), None - - def sampler(biased_logits, *_): - return ( - torch.argmax(biased_logits, keepdims=True, dim=-1), - torch.tensor([0, 1]), - None, - ) - - token_ids, attention_mask = ( - torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), - torch.tensor([[1, 1, 1, 1], [1, 1, 1, 1]]), - ) - init_fsm_states = [0, 0] - fsms = [MockFSM(), MockFSM()] - sequence = sequence_generator( - MockModel(), - sampler, - fsms, - token_ids, - torch.tensor([0, 0]), - attention_mask, - init_fsm_states, - rng=torch.Generator(), - ) - - result = next(sequence) - assert torch.equal( - result.token_ids, torch.tensor([[0, 1, 2, 3, 3], [4, 5, 6, 7, 2]]) - ) - assert torch.equal( - result.logits, torch.tensor([[0, 1, 2, 3], [4, 5, 7, 6]], dtype=torch.float) - ) - - with pytest.raises(StopIteration): - next(sequence) - - -def test_sequence_generator_2d_several_iterations(): - class MockFSM: - def get_next_state(self, state, next_token_ids): - return state + 1 - - def get_next_instruction(self, _): - return Generate([0, 1, 2, 3]) - - def is_final_state(self, state): - if state < 2: - return False - else: - return True - - def copy(self): - return self - - class MockTokenizer: - def encode(self, _): - return torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), torch.tensor( - [[1, 1, 1, 1], [1, 1, 1, 1]] - ) - - def decode(self, x): - return x - - class MockModel: - def __init__(self): - self.tokenizer = MockTokenizer() - - def __call__(*_): - return torch.tensor([[0, 1, 2, 3], [4, 5, 7, 6]], dtype=torch.float), None - - def sampler(biased_logits, *_): - return ( - torch.argmax(biased_logits, keepdims=True, dim=-1), - torch.tensor([0, 1]), - None, - ) - - token_ids, attention_mask = ( - torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), - torch.tensor([[1, 1, 1, 1], [1, 1, 1, 1]]), - ) - init_fsm_states = [0, 0] - fsms = [MockFSM(), MockFSM()] - sequence = sequence_generator( - MockModel(), - sampler, - fsms, - token_ids, - torch.tensor([0, 0]), - attention_mask, - init_fsm_states, - rng=torch.Generator(), - ) - - result = next(sequence) - assert torch.equal( - result.token_ids, torch.tensor([[0, 1, 2, 3, 3], [4, 5, 6, 7, 2]]) - ) - assert torch.equal( - result.logits, torch.tensor([[0, 1, 2, 3], [4, 5, 7, 6]], dtype=torch.float) - ) - - result = next(sequence) - assert torch.equal( - result.token_ids, torch.tensor([[0, 1, 2, 3, 3, 3], [4, 5, 6, 7, 2, 2]]) - ) - assert torch.equal( - result.logits, torch.tensor([[0, 1, 2, 3], [4, 5, 7, 6]], dtype=torch.float) - ) - - with pytest.raises(StopIteration): - next(sequence) - - -def test_get_next_fsm_states(): - class MockFSM: - def get_next_state(self, state, next_token_ids): - return 0 - - def copy(self): - return self - - result = get_next_fsm_states([MockFSM()], [0], torch.tensor([[0]])) - assert result == [0] - - result = get_next_fsm_states( - [MockFSM(), MockFSM()], [0, 0], torch.tensor([[0], [0]]) - ) - assert result == [0, 0] - - -def test_get_get_next_instructions(): - class MockFSM: - def get_next_instruction(self, _): - return Generate([1, 2, 3, 4]) - - result = get_allowed_tokens([MockFSM()], [0]) - assert result == [[1, 2, 3, 4]] - - result = get_allowed_tokens([MockFSM(), MockFSM()], [0, 1]) - assert result == [[1, 2, 3, 4], [1, 2, 3, 4]] - - -def test_is_generation_finished(): - class MockFSMFinished: - def is_final_state(self, _): - return True - - result = is_generation_finished([MockFSMFinished(), MockFSMFinished()], [1, 1]) - assert result is True - - class MockFSMNotFinished: - def is_final_state(self, state): - if state == 0: - return False - else: - return True - - result = is_generation_finished( - [MockFSMNotFinished(), MockFSMNotFinished()], [0, 1] - ) - assert result is False - - -@pytest.mark.parametrize( - "token_ids,next_token_ids,ancestors,expected_result", - [ - ( - torch.tensor([[1]]), - torch.tensor([[2]]), - torch.tensor([0]), - torch.tensor([[1, 2]]), - ), - ( - torch.tensor([[1], [2]]), - torch.tensor([[3], [4]]), - torch.tensor([0, 1]), - torch.tensor([[1, 3], [2, 4]]), - ), - ( - torch.tensor([[1], [2]]), - torch.tensor([[3], [4]]), - torch.tensor([0, 0]), - torch.tensor([[1, 3], [1, 4]]), - ), - ( - torch.tensor([[1], [2]]), - torch.tensor([[3], [4]]), - torch.tensor([1, 0]), - torch.tensor([[2, 3], [1, 4]]), - ), - ( - torch.tensor([[1, 2], [3, 5]]), - torch.tensor([[3], [4]]), - torch.tensor([1, 0]), - torch.tensor([[3, 5, 3], [1, 2, 4]]), - ), - ], -) -def test_update_token_ids(token_ids, next_token_ids, ancestors, expected_result): - result = update_token_ids(token_ids, next_token_ids, ancestors) - assert torch.equal(result, expected_result) - - -@pytest.mark.parametrize( - "attention_masks,ancestors,expected_result", - [ - ( - torch.tensor([[1, 1]], dtype=torch.float), - torch.tensor([0]), - torch.tensor([[1, 1, 1]], dtype=torch.float), - ), - ( - torch.tensor([[1, 1], [1, 1]], dtype=torch.float), - torch.tensor([0, 0]), - torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.float), - ), - ( - torch.tensor([[0, 1], [1, 1]], dtype=torch.float), - torch.tensor([0, 1]), - torch.tensor([[0, 1, 1], [1, 1, 1]], dtype=torch.float), - ), - ( - torch.tensor([[0, 1], [1, 1]], dtype=torch.float), - torch.tensor([1, 0]), - torch.tensor([[1, 1, 1], [0, 1, 1]], dtype=torch.float), - ), - ( - torch.tensor([[0, 1], [1, 1]], dtype=torch.float), - torch.tensor([1, 1]), - torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.float), - ), - ], -) -def test_expand_attention_masks(attention_masks, ancestors, expected_result): - result = update_attention_masks(attention_masks, ancestors) - assert torch.equal(result, expected_result) - - -@pytest.mark.parametrize( - "logits,indices_to_mask,expected", - [ - ( - torch.tensor([[1, 2, 3, 4]], dtype=torch.float), - [[0, 1, 2, 3]], - torch.tensor([[1, 2, 3, 4]], dtype=torch.float), - ), - ( - torch.tensor([[1, 2, 3, 4]], dtype=torch.float), - [[0, 2, 3]], - torch.tensor([[1, -math.inf, 3, 4]], dtype=torch.float), - ), - ( - torch.tensor([[1, 2, 3, 4]], dtype=torch.float), - [[0, 2]], - torch.tensor([[1, -math.inf, 3, -math.inf]], dtype=torch.float), - ), - ( - torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float), - [[1, 2], [0, 1]], - torch.tensor([[-math.inf, 2, 3], [4, 5, -math.inf]], dtype=torch.float), - ), - ( - torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float), - [[0, 2], [1]], - torch.tensor( - [[1, -math.inf, 3], [-math.inf, 5, -math.inf]], dtype=torch.float - ), - ), - ], -) -def test_bias_logits(logits, indices_to_mask, expected): - masked_logits = bias_logits(logits, indices_to_mask) - assert torch.equal(masked_logits, expected)