diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 729c90ca6..1f9c54020 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -49,6 +49,7 @@ def get_generated_token_ids( init_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], prompts: List[str], last_state: GenerationState, + num_samples: int, ) -> List[torch.Tensor]: """Get the tokens generated so far. @@ -60,6 +61,8 @@ def get_generated_token_ids( The prompts passed to the generator. last_state The current state of the generation + num_samples + The number of samples taken for each sequence Returns ------- @@ -67,11 +70,19 @@ def get_generated_token_ids( """ prompt_token_ids = init_state[0] - prompt_lengths = [len(prompt_token_ids[i]) for i in range(len(prompts))] + prompt_lengths = [ + len(prompt_token_ids[i]) + for _ in range(num_samples) + for i in range(len(prompts)) + ] + + # We flatten the obtained token_ids since the tokenizer's decoder + # only accepts tensor with two dimensions + token_ids = last_state.token_ids.reshape((-1, last_state.token_ids.shape[-1])) token_ids = [ cur_token_ids[length:] - for cur_token_ids, length in zip(last_state.token_ids, prompt_lengths) + for cur_token_ids, length in zip(token_ids, prompt_lengths) ] return token_ids @@ -150,9 +161,10 @@ def __call__( prompts: Union[str, List[str]], max_tokens: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, + num_samples: int = 1, rng: Optional[torch.Generator] = None, kv_cache: Optional[torch.tensor] = None, - ) -> Union[str, List[str]]: + ) -> Union[str, List[str], List[List[str]]]: """Generate the full text sequence. Since `SequenceGenerator.stream` calls the tokenizer at every step this @@ -205,7 +217,12 @@ def __call__( init_fsm_states = [FSMState(0) for _ in range(num_sequences)] states = sequence_generator( - self.generate_token, fsms, init_state, init_fsm_states, rng + self.generate_token, + fsms, + init_state, + init_fsm_states, + rng=rng, + num_samples=num_samples, ) while True: @@ -213,7 +230,7 @@ def __call__( last_state = next(states) if max_tokens or stop_sequences: generated_token_ids = self.get_generated_token_ids( - init_state, prompts, last_state + init_state, prompts, last_state, num_samples ) if max_tokens and len(generated_token_ids[0]) >= max_tokens: break @@ -225,8 +242,9 @@ def __call__( break generated_token_ids = self.get_generated_token_ids( - init_state, prompts, last_state + init_state, prompts, last_state, num_samples ) + generated = self.tokenizer.decode(generated_token_ids) stripped = [ self.strip_stop_sequences(sequence, stop_sequences) @@ -242,16 +260,29 @@ def __call__( + " is raised nevertheless please open an issue: https://github.com/outlines-dev/outlines/issues" ) - return formatted if len(formatted) > 1 else formatted[0] + # We reshape the output to (sample_size, batch_size) + output = [] + step = len(prompts) + for i in range(0, len(formatted), step): + output.append(formatted[i : i + step]) + + # We remove leading dimensions for the output + if len(prompts) == 1 and num_samples == 1: + return output[0][0] + elif num_samples == 1: + return output[0] + else: + return output def stream( self, prompts: Union[str, List[str]], max_tokens: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, + num_samples: int = 1, rng: Optional[torch.Generator] = None, kv_cache: Optional[torch.tensor] = None, - ) -> Iterator[Union[List[str], str]]: + ) -> Iterator[Union[List[str], List[List[str]], str]]: """Generate the text sequence one token at a time. Since `Tokenizer.decode` strips the whitespaces from the tokens we have no @@ -303,13 +334,20 @@ def stream( init_fsm_states = [FSMState(0) for _ in range(num_sequences)] states = sequence_generator( - self.generate_token, fsms, init_state, init_fsm_states, rng + self.generate_token, + fsms, + init_state, + init_fsm_states, + num_samples=num_samples, + rng=rng, ) - def token_generator() -> Iterator[Union[List[str], str]]: - previously_generated_sequences = ["" for _ in range(num_sequences)] + def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: + previously_generated_sequences = [ + "" for _ in range(num_sequences) + ] * num_samples num_generated = 0 - is_stop_at_reached = [False for _ in range(num_sequences)] + is_stop_at_reached = [False for _ in range(num_sequences)] * num_samples while True: if (max_tokens and num_generated >= max_tokens) or all( is_stop_at_reached @@ -320,7 +358,10 @@ def token_generator() -> Iterator[Union[List[str], str]]: num_generated += 1 except StopIteration: return - generated_token_ids = sequence.token_ids[:, -num_generated:] + generated_token_ids = sequence.token_ids[:, :, -num_generated:] + generated_token_ids = generated_token_ids.reshape( + -1, generated_token_ids.shape[-1] + ) generated_sequences = self.tokenizer.decode(generated_token_ids) next_tokens = [ token[len(sequence) :] if not stop else "" @@ -341,7 +382,19 @@ def token_generator() -> Iterator[Union[List[str], str]]: generated_sequences, is_stop_at_reached ) ] - yield next_tokens + # We reshape the output to (sample_size, batch_size) + output = [] + step = len(prompts) + for i in range(0, len(next_tokens), step): + output.append(next_tokens[i : i + step]) + + # We remove leading dimensions for the output + if len(prompts) == 1 and num_samples == 1: + yield output[0][0] + elif num_samples == 1: + yield output[0] + else: + yield output return token_generator() diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index 44b469bc4..b53d92131 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -55,7 +55,8 @@ def sequence_generator( fsms: List["FSM"], init_state: Tuple, fsm_states: List[FSMState], - rng: torch.Generator, + num_samples: int = 1, + rng: torch.Generator = torch.Generator(), ) -> Iterator[GenerationState]: """Generates sequences of tokens. @@ -78,6 +79,17 @@ def sequence_generator( """ token_ids, attention_masks, kv_cache = init_state + batch_shape = token_ids.shape[:-1] + + # To take several samples we duplicate `token_ids`, `attention_masks` + # and `fsm_states` as many times as the number of samples requested. + # The resulting tensors are of shape (num_samples * num_batches, num_tokens) + token_ids = torch.repeat_interleave(token_ids, num_samples, dim=0) + attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) + + fsm_states = fsm_states * num_samples + fsms = [fsm.copy() for fsm in fsms for _ in range(num_samples)] + while True: allowed_tokens = get_allowed_tokens(fsms, fsm_states) @@ -88,7 +100,6 @@ def sequence_generator( rng=rng, allowed_tokens=allowed_tokens, ) - token_ids = update_token_ids(token_ids, next_token_ids) attention_masks = expand_attention_masks(attention_masks) @@ -96,10 +107,20 @@ def sequence_generator( is_finished = is_generation_finished(fsms, fsm_states) if is_finished: - yield GenerationState(token_ids, kv_cache, logits, fsm_states) + yield GenerationState( + token_ids.reshape((num_samples,) + batch_shape + token_ids.shape[-1:]), + kv_cache, + logits, + fsm_states, + ) return - yield GenerationState(token_ids, kv_cache, logits, fsm_states) + yield GenerationState( + token_ids.reshape((num_samples,) + batch_shape + token_ids.shape[-1:]), + kv_cache, + logits, + fsm_states, + ) def token_generator(model, sampler: "Sampler") -> Callable: diff --git a/tests/generate/test_generator.py b/tests/generate/test_generator.py index 9a7bd1188..93371da01 100644 --- a/tests/generate/test_generator.py +++ b/tests/generate/test_generator.py @@ -36,7 +36,7 @@ def copy(self): class MockTokenizer: def encode(self, _): # Input: "test" - return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1, 1]]) + return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]]) def decode(self, tokens): return ["testx"[i] for i in tokens] @@ -139,11 +139,11 @@ def sampler(biased_logits, *_): init_fsm_states = [0] generate = token_generator(MockModel(), sampler) sequence = sequence_generator( - generate, [MockFSM()], init_state, init_fsm_states, torch.Generator() + generate, [MockFSM()], init_state, init_fsm_states, rng=torch.Generator() ) result = next(sequence) - assert torch.equal(result.token_ids, torch.tensor([[0, 1, 2, 3, 3]])) + assert torch.equal(result.token_ids, torch.tensor([[[0, 1, 2, 3, 3]]])) assert torch.equal(result.logits, torch.tensor([[0, 1, 2, 3]])) with pytest.raises(StopIteration): @@ -185,15 +185,15 @@ def sampler(biased_logits, *_): init_fsm_states = [0] generate = token_generator(MockModel(), sampler) sequence = sequence_generator( - generate, [MockFSM()], init_state, init_fsm_states, torch.Generator() + generate, [MockFSM()], init_state, init_fsm_states, rng=torch.Generator() ) result = next(sequence) - assert torch.equal(result.token_ids, torch.tensor([[0, 1, 2, 3, 3]])) + assert torch.equal(result.token_ids, torch.tensor([[[0, 1, 2, 3, 3]]])) assert torch.equal(result.logits, torch.tensor([[0, 1, 2, 3]])) result = next(sequence) - assert torch.equal(result.token_ids, torch.tensor([[0, 1, 2, 3, 3, 3]])) + assert torch.equal(result.token_ids, torch.tensor([[[0, 1, 2, 3, 3, 3]]])) assert torch.equal(result.logits, torch.tensor([[0, 1, 2, 3]])) with pytest.raises(StopIteration): @@ -239,12 +239,12 @@ def sampler(biased_logits, *_): fsms = [MockFSM(), MockFSM()] generate = token_generator(MockModel(), sampler) sequence = sequence_generator( - generate, fsms, init_state, init_fsm_states, torch.Generator() + generate, fsms, init_state, init_fsm_states, rng=torch.Generator() ) result = next(sequence) assert torch.equal( - result.token_ids, torch.tensor([[0, 1, 2, 3, 3], [4, 5, 6, 7, 2]]) + result.token_ids, torch.tensor([[[0, 1, 2, 3, 3], [4, 5, 6, 7, 2]]]) ) assert torch.equal( result.logits, torch.tensor([[0, 1, 2, 3], [4, 5, 7, 6]], dtype=torch.float) @@ -254,6 +254,77 @@ def sampler(biased_logits, *_): next(sequence) +def test_sequence_generator_2d_single_iteration_several_samples(): + class MockFSM: + def next_state(self, state, next_token_ids): + return 0 + + def allowed_token_ids(self, _): + return [0, 1, 2, 3] + + def is_final_state(self, _): + return True + + class MockTokenizer: + def encode(self, _): + return torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), torch.tensor( + [[1, 1, 1, 1], [1, 1, 1, 1]] + ) + + def decode(self, x): + return x + + class MockModel: + def __init__(self): + self.tokenizer = MockTokenizer() + + def __call__(*_): + return ( + torch.tensor( + [[0, 1, 2, 3], [4, 5, 7, 6], [0, 1, 2, 3], [1, 5, 3, 4]], + dtype=torch.float, + ), + None, + ) + + def sampler(biased_logits, *_): + return torch.argmax(biased_logits, keepdims=True, dim=-1) + + init_state = ( + torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), + torch.tensor([[1, 1, 1, 1], [1, 1, 1, 1]]), + None, + ) + init_fsm_states = [0, 0] + fsms = [MockFSM(), MockFSM()] + generate = token_generator(MockModel(), sampler) + sequence = sequence_generator( + generate, + fsms, + init_state, + init_fsm_states, + num_samples=2, + rng=torch.Generator(), + ) + + result = next(sequence) + expected_token_ids = torch.tensor( + [ + [[0, 1, 2, 3, 3], [0, 1, 2, 3, 2]], + [[4, 5, 6, 7, 3], [4, 5, 6, 7, 1]], + ] + ) + assert torch.equal(result.token_ids, expected_token_ids) + + expected_logits = torch.tensor( + [[0, 1, 2, 3], [4, 5, 7, 6], [0, 1, 2, 3], [1, 5, 3, 4]], dtype=torch.float + ) + assert torch.equal(result.logits, expected_logits) + + with pytest.raises(StopIteration): + next(sequence) + + def test_sequence_generator_2d_several_iterations(): class MockFSM: def next_state(self, state, next_token_ids): @@ -296,12 +367,12 @@ def sampler(biased_logits, *_): fsms = [MockFSM(), MockFSM()] generate = token_generator(MockModel(), sampler) sequence = sequence_generator( - generate, fsms, init_state, init_fsm_states, torch.Generator() + generate, fsms, init_state, init_fsm_states, rng=torch.Generator() ) result = next(sequence) assert torch.equal( - result.token_ids, torch.tensor([[0, 1, 2, 3, 3], [4, 5, 6, 7, 2]]) + result.token_ids, torch.tensor([[[0, 1, 2, 3, 3], [4, 5, 6, 7, 2]]]) ) assert torch.equal( result.logits, torch.tensor([[0, 1, 2, 3], [4, 5, 7, 6]], dtype=torch.float) @@ -309,7 +380,7 @@ def sampler(biased_logits, *_): result = next(sequence) assert torch.equal( - result.token_ids, torch.tensor([[0, 1, 2, 3, 3, 3], [4, 5, 6, 7, 2, 2]]) + result.token_ids, torch.tensor([[[0, 1, 2, 3, 3, 3], [4, 5, 6, 7, 2, 2]]]) ) assert torch.equal( result.logits, torch.tensor([[0, 1, 2, 3], [4, 5, 7, 6]], dtype=torch.float) diff --git a/tests/generate/test_integration_transfomers.py b/tests/generate/test_integration_transfomers.py index 0e6cfa4b4..938d76ddb 100644 --- a/tests/generate/test_integration_transfomers.py +++ b/tests/generate/test_integration_transfomers.py @@ -85,6 +85,28 @@ def test_transformers_integration_text(): assert isinstance(sequence[0], str) +def test_transformers_integration_text_multiple_samples(): + rng = torch.Generator() + rng.manual_seed(10000) # Choosen so is generated + + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" + model = models.transformers(model_name, device="cpu") + sequence = generate.text(model)("Write a short sentence ", num_samples=2, rng=rng) + assert isinstance(sequence, list) + assert len(sequence) == 2 + assert model.tokenizer.eos_token not in sequence + + prompts = ["Write a short sentence ", "And another one "] + sequence = generate.text(model)( + prompts, max_tokens=10, num_samples=2, stop_at=[".", ","], rng=rng + ) + assert isinstance(sequence, list) + assert len(sequence) == 2 + assert isinstance(sequence[0], list) + assert len(sequence) == 2 + assert isinstance(sequence[0][0], str) + + def test_transformers_integration_streaming(): rng = torch.Generator() rng.manual_seed(10000) # Choosen so is generated @@ -96,10 +118,9 @@ def test_transformers_integration_streaming(): ) token = next(sequence) - assert isinstance(token, list) - assert isinstance(token[0], str) + assert isinstance(token, str) - remaining = "".join([token[0] for token in sequence]) + remaining = "".join([token for token in sequence]) assert isinstance(remaining, str) sequence = generate.text(model).stream( @@ -111,27 +132,37 @@ def test_transformers_integration_streaming(): assert isinstance(tokens[1], str) -@pytest.mark.xfail(reason="not implemented") -def test_transformers_integration_text_stop(): +def test_transformers_integration_streaming_batch_samples(): rng = torch.Generator() rng.manual_seed(10000) # Choosen so is generated model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" model = models.transformers(model_name, device="cpu") - prompt = "Write a short sentence " - sequence = generate.text(model)(prompt, stop_at="a", rng=rng) - assert sequence[len(prompt) :].find("a") == -1 + sequence = generate.text(model).stream( + ["Prompt1", "Prompt2"], + max_tokens=10, + stop_at=[".", ","], + num_samples=2, + rng=rng, + ) + tokens = next(sequence) + assert isinstance(tokens, list) + assert len(tokens) == 2 + assert isinstance(tokens[0], list) + assert len(tokens[0]) == 2 + assert isinstance(tokens[0], list) + assert len(tokens[1]) == 2 -@pytest.mark.xfail(reason="not implemented") -def test_transformers_integration_text_array_samples(): +def test_transformers_integration_text_stop(): rng = torch.Generator() - rng.manual_seed(0) + rng.manual_seed(10000) # Choosen so is generated model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" model = models.transformers(model_name, device="cpu") - prompts = ["Write a short sentence", "And another one"] - _ = generate.text(model)(prompts, max_tokens=10, rng=rng, samples=3) + prompt = "Write a short sentence " + sequence = generate.text(model)(prompt, stop_at="a", rng=rng) + assert sequence[len(prompt) :].find("a") == -1 def test_transformers_various_regexes(): @@ -165,6 +196,26 @@ def test_transformers_various_regexes_prompt_list(): assert re.fullmatch(regex_str, sequence[1]) is not None +def test_transformers_various_regexes_prompt_list_multiple_samples(): + rng = torch.Generator() + rng.manual_seed(0) + + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" + model = models.transformers(model_name, device="cpu") + prompt = "Write an email address" + regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" + generator = generate.regex(model, regex_str) + + # Two prompts + sequence = generator([prompt, prompt], num_samples=2, rng=rng) + assert isinstance(sequence, list) + assert len(sequence) == 2 + assert re.fullmatch(regex_str, sequence[0][0]) is not None + assert re.fullmatch(regex_str, sequence[0][1]) is not None + assert re.fullmatch(regex_str, sequence[1][0]) is not None + assert re.fullmatch(regex_str, sequence[1][1]) is not None + + def test_transformers_integration_integer(): rng = torch.Generator() rng.manual_seed(0) @@ -343,6 +394,29 @@ class Spam(BaseModel): assert isinstance(result[1], BaseModel) +def test_transformers_json_batch_multiple_samples(): + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" + model = models.transformers(model_name, device="cpu") + prompts = ["Output some JSON ", "Output more JSON"] + + class Spam(BaseModel): + foo: int + bar: float + spam: constr(max_length=10) + fuzz: bool + + rng = torch.Generator() + rng.manual_seed(0) # make sure that `bar` is not an int + + result = generate.json(model, Spam)(prompts, max_tokens=500, rng=rng, num_samples=2) + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0][0], BaseModel) + assert isinstance(result[0][1], BaseModel) + assert isinstance(result[1][0], BaseModel) + assert isinstance(result[1][1], BaseModel) + + def test_transformers_json_str_enum(): model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" model = models.transformers(model_name, device="cpu")