diff --git a/.gitignore b/.gitignore index 8e91982e2..66ce54fc7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ __pycache__ *_version.py docs/build .coverage +.idea/ diff --git a/docs/cookbook/index.md b/docs/cookbook/index.md index 856b47287..6e920cba5 100644 --- a/docs/cookbook/index.md +++ b/docs/cookbook/index.md @@ -1,6 +1,6 @@ # Examples -- [Classification](classification): Classify customer requests. +- [Classification](classification.md): Classify customer requests. - [Named Entity Extraction](extraction.md): Extract information from pizza orders. - [Dating Profile](dating_profiles.md): Build dating profiles from descriptions using prompt templating and JSON-guided generation. - [Chain Of Density](chain_of_density.md): Summarize documents using chain of density prompting and JSON-guided generation. diff --git a/docs/cookbook/models_playing_chess.md b/docs/cookbook/models_playing_chess.md index a6490199b..56285affd 100644 --- a/docs/cookbook/models_playing_chess.md +++ b/docs/cookbook/models_playing_chess.md @@ -1,12 +1,16 @@ # Large language models playing chess -In this example we will make a quantized version of Mistral-7B play chess against itself. On its own the model easily generates invalid move, so we will give it a little help. At each step we will generate a regex that only matches valid move, and use it to help the model only generating valid moves. +In this example we will make a Phi-2 model play chess against itself. On its own the model easily generates invalid moves, so we will give it a little help. At each step we will generate a regex that only matches valid move, and use it to help the model only generating valid moves. ## The chessboard The game will be played on a standard checkboard. We will use the `chess` [library](https://github.com/niklasf/python-chess) to track the opponents' moves, and check that the moves are valid. ```python +%pip install outlines -q +%pip install chess -q +%pip install transformers accelerate einops -q + import chess board = chess.Board("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1") @@ -14,17 +18,18 @@ board = chess.Board("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1") ## The opponents -Mistral-7B quantized will be playing against itself: +Phi-2 will be playing against itself: ```python from outlines import models -board_state = models.transformers("TheBloke/Mistral-7B-OpenOrca-AWQ", device="cuda") +model = models.transformers("microsoft/phi-2") + ``` ## A little help for the language model -To make sure Mistral-7B generates valid chess moves we will use Outline's regex-guided generation. We define a function that takes the current state of the board and returns a regex that matches all possible legal moves: +To make sure Phi-2 generates valid chess moves we will use Outline's regex-guided generation. We define a function that takes the current state of the board and returns a regex that matches all possible legal moves: ```python import re @@ -44,22 +49,22 @@ def legal_moves_regex(board): The prompt corresponds to the current state of the board, so we start with: ```python -prompt = "Score: 1-0 WhiteElo: 1600 BlackElo: 1600 Timecontrol: 1800+0 Moves: 1." +prompt = "Let's play Chess. Moves: " + ``` We update the prompt at each step so it reflects the state of the board after the previous move. -## Let's play! - +## Let's play ```python from outlines import generate - +board_state = " " turn_number = 0 while not board.is_game_over(): regex_pattern = legal_moves_regex(board) - guided = generate.regex(model, regex_pattern)(board_state) + guided = generate.regex(model, regex_pattern)(prompt + board_state) move = board.parse_san(guided) if turn_number % 2 == 0 : # It's White's turn @@ -74,7 +79,10 @@ while not board.is_game_over(): print(board_state) ``` -It turns out Mistal-7B (quantized) is not very good at playing chess: the game systematically ends because of the threefold repetition rule. +Interestingly enough, Phi-2 hates capturing. +```pgn + e4 e5 1.Nf3 Ne7 3.b4 Nf5 5.Nc3 Ne7 7.Bb5 a6 9.Na4 b6 11.c3 Nec6 13.c4 a5 15.d4 Qg5 17.Nd2 Bb7 19.dxe5 +``` *This example was originally authored by [@903124S](https://x.com/903124S) in [this gist](https://gist.github.com/903124/cfbefa24da95e2316e0d5e8ef8ed360d).* diff --git a/docs/reference/models/llamacpp.md b/docs/reference/models/llamacpp.md new file mode 100644 index 000000000..826a2cfec --- /dev/null +++ b/docs/reference/models/llamacpp.md @@ -0,0 +1,15 @@ +# Llama.cpp + +!!! Installation + + You need to install the `llama-cpp-python` library to be able to use these models in Outlines. + +Outlines provides an integration with [Llama.cpp](https://github.com/ggerganov/llama.cpp) using the [llama-cpp-python library](https://github.com/abetlen/llama-cpp-python). Llamacpp allows to run quantized models on machines with limited compute. + +Assuming [Phi2's weights](https://huggingface.co/TheBloke/phi-2-GGUF) are in the current directory: + +```python +from outlines import models, generate + +model = models.llamacpp("./phi-2.Q4_K_M.gguf", device="cpu") +``` diff --git a/docs/reference/openai_text_generation.md b/docs/reference/models/openai.md similarity index 98% rename from docs/reference/openai_text_generation.md rename to docs/reference/models/openai.md index 5eb1c8c1d..1706c1143 100644 --- a/docs/reference/openai_text_generation.md +++ b/docs/reference/models/openai.md @@ -1,5 +1,9 @@ # Generate text with the OpenAI API +!!! Installation + + You need to install the `openai` and `tiktoken` libraries to be able to use the OpenAI API in Outlines. + Outlines supports models available via the OpenAI Chat API, e.g. ChatGPT and GPT-4. The following models can be used with Outlines: ```python @@ -12,6 +16,7 @@ print(type(model)) # OpenAI ``` + It is possible to pass a system message to the model when initializing it: ```python diff --git a/docs/reference/vllm.md b/docs/reference/vllm.md index 787e3b24a..f3232ca10 100644 --- a/docs/reference/vllm.md +++ b/docs/reference/vllm.md @@ -49,9 +49,7 @@ curl http://127.0.0.1:8000/generate \ Instead of `curl`, you can also use the [requests][requests]{:target="_blank"} library from another python program. -Please consult the [vLLM documentation][vllm]{:target="_blank"} for details on additional request parameters. - -You can also [read the code](https://github.com/outlines-dev/outlines/blob/main/outlines/serve/serve.py) in case you need to customize the solution to your needs. +Please consult the [vLLM documentation][vllm]{:target="_blank"} for details on additional request parameters. You can also [read the code](https://github.com/outlines-dev/outlines/blob/main/outlines/serve/serve.py) in case you need to customize the solution to your needs. [requests]: https://requests.readthedocs.io/en/latest/ [vllm]: https://docs.vllm.ai/en/latest/index.html diff --git a/examples/cfg.py b/examples/cfg.py index c4937c46f..faca2bf3b 100644 --- a/examples/cfg.py +++ b/examples/cfg.py @@ -1,5 +1,3 @@ -from lark.exceptions import UnexpectedCharacters, UnexpectedToken - import outlines.generate as generate import outlines.models as models @@ -49,19 +47,13 @@ model = models.transformers("hf-internal-testing/tiny-random-gpt2") batch_size = 10 -max_tokens = 30 for grammar in [nlamb_grammar, calc_grammar]: generator = generate.cfg(model, grammar) - sequences = generator([" "] * batch_size, max_tokens=max_tokens) + sequences = generator([" "] * batch_size) for seq in sequences: try: parse = generator.fsm.parser.parse(seq) assert parse is not None print("SUCCESS", seq) - except (UnexpectedCharacters, UnexpectedToken): - if generator.fsm.num_tokens_generated == max_tokens: - print("MAXTOKEN", seq) - else: - print("FAILURE", seq) except Exception: print("FAILURE", seq) diff --git a/examples/llamacpp_example.py b/examples/llamacpp_example.py new file mode 100644 index 000000000..56dacca5d --- /dev/null +++ b/examples/llamacpp_example.py @@ -0,0 +1,46 @@ +from enum import Enum + +import torch +from pydantic import BaseModel, constr + +import outlines + + +class Weapon(str, Enum): + sword = "sword" + axe = "axe" + mace = "mace" + spear = "spear" + bow = "bow" + crossbow = "crossbow" + + +class Armor(str, Enum): + leather = "leather" + chainmail = "chainmail" + plate = "plate" + + +class Character(BaseModel): + name: constr(max_length=10) + age: int + armor: Armor + weapon: Weapon + strength: int + + +if __name__ == "__main__": + # Download model from https://huggingface.co/TheBloke/phi-2-GGUF + model = outlines.models.llamacpp("./phi-2.Q3_K_M.gguf", device="cpu") + + # Construct guided sequence generator + generator = outlines.generate.json(model, Character, max_tokens=512) + + # Draw a sample + rng = torch.Generator(device="cpu") + rng.manual_seed(789005) + + prompt = "Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:" + + sequence = generator(prompt, rng=rng) + print(sequence) diff --git a/mkdocs.yml b/mkdocs.yml index 59f3d7b40..f15a070cb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -126,7 +126,8 @@ nav: - Prompt templating: reference/prompting.md - Outlines functions: reference/functions.md - Models: - - OpenAI: reference/openai_text_generation.md + - OpenAI: reference/models/openai.md + - Llama.cpp: reference/models/llamacpp.md - API Reference: - api/index.md diff --git a/outlines/__init__.py b/outlines/__init__.py index 3d75d0aca..d7740496d 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -11,6 +11,7 @@ "clear_cache", "disable_cache", "get_cache", + "Function", "prompt", "vectorize", ] diff --git a/outlines/fsm/fsm.py b/outlines/fsm/fsm.py index 1c9297eb3..1d95071e0 100644 --- a/outlines/fsm/fsm.py +++ b/outlines/fsm/fsm.py @@ -14,16 +14,16 @@ class FSM(Protocol): - def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: + def allowed_token_ids(self, state: FSMState) -> List[int]: ... - def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: + def next_state(self, state: FSMState, token_id: int) -> FSMState: ... - def is_final_state(self, state: FSMState, idx: int = 0) -> bool: + def is_final_state(self, state: FSMState) -> bool: ... - def reset(self) -> None: + def copy(self) -> "FSM": ... @@ -36,17 +36,12 @@ class StopAtTokenFSM(FSM): """ - def __init__( - self, - tokenizer: "Tokenizer", - stop_token_id: int, - ): + def __init__(self, tokenizer: "Tokenizer", stop_token_id: int): self.stop_token_id = stop_token_id - self.num_tokens_generated = 0 self.vocabulary = tokenizer.vocabulary.values() self.final_states = {1} - def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: + def allowed_token_ids(self, state: FSMState) -> List[int]: """Generate a list of allowed tokens for the next step. When in the initial state we allow every token to be generated. @@ -56,8 +51,6 @@ def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: ---------- state The current state of the FSM. - idx - The index of the current input in the batch. Returns ------- @@ -69,7 +62,7 @@ def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: else: return [self.stop_token_id] - def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: + def next_state(self, state: FSMState, token_id: int) -> FSMState: """Update the state of the FSM. The FSM stays in the initial state `0` unless the specified stop token @@ -82,29 +75,24 @@ def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: The current state of the FSM. token_id The id of the token that was just generated. - idx - The index of the current input in the batch. Returns ------- The new state of the FSM. """ - if idx == 0: - self.num_tokens_generated += 1 - if token_id == self.stop_token_id: return FSMState(1) return FSMState(0) - def is_final_state(self, state: FSMState, idx: int = 0) -> bool: + def is_final_state(self, state: FSMState) -> bool: """Determine whether the current state of the FSM is a final state.""" return state in self.final_states - def reset(self) -> None: - """Reset the FSM to its initial state. Here this only resets the token counter.""" - self.num_tokens_generated = 0 + def copy(self) -> "StopAtTokenFSM": + """Create a copy of the FSM.""" + return self class RegexFSM(FSM): @@ -152,11 +140,10 @@ def create_states_mapping( ) = create_states_mapping( regex_string, tuple(sorted(tokenizer.vocabulary.items())) ) - self.num_tokens_generated = 0 self.vocabulary = tokenizer.vocabulary.values() self.end_token_id = tokenizer.eos_token_id - def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: + def allowed_token_ids(self, state: FSMState) -> List[int]: """Generate a list of allowed tokens for the next step. The initialization of the FSM builds an index which maps FSM states to a @@ -173,8 +160,6 @@ def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: ---------- state The current state of the FSM. - idx - The index of the current input in the batch. Returns ------- @@ -188,7 +173,7 @@ def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: else: return list(next_tokens_to_end_states.keys()) - def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: + def next_state(self, state: FSMState, token_id: int) -> FSMState: """Update the state of the FSM. We use the index to determine to which state the FSM should transition @@ -200,17 +185,12 @@ def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: The current state of the FSM. token_id The id of the token that was just generated. - idx - The index of the current input in the batch. Returns ------- The new state of the FSM. """ - if idx == 0: - self.num_tokens_generated += 1 - if token_id == self.end_token_id: return FSMState(-1) @@ -221,24 +201,22 @@ def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: return FSMState(next_state) - def is_final_state(self, state: FSMState, idx: int = 0) -> bool: + def is_final_state(self, state: FSMState) -> bool: """Determine whether the current state of the FSM is a final state.""" return state in self.final_states - def reset(self) -> None: - """Reset the FSM to its initial state. Here this only resets the token counter.""" - self.num_tokens_generated = 0 + def copy(self) -> "RegexFSM": + """Create a copy of the FSM.""" + return self class CFGFSM(FSM): """FSM to generate text that is in the language of a context-free grammar.""" - def __init__( - self, - cfg_string: str, - tokenizer: "Tokenizer", - ): - # self.parser = PartialLark(cfg_string, parser="lalr") + def __init__(self, cfg_string: str, tokenizer: "Tokenizer"): + self.cfg_string = cfg_string + self.tokenizer = tokenizer + self.parser = Lark( cfg_string, parser="lalr", @@ -253,59 +231,52 @@ def __init__( self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp() self.terminal_regexps["$END"] = tokenizer.eos_token - self.tokenizer = tokenizer - self.num_tokens_generated = 0 - self.generations: List[str] = [] - self.regex_fsms: List[RegexFSM] = [] - self.reset_state: List[bool] = [] - self.allow_eos: List[bool] = [] - self.done: List[bool] = [] - - def _set_next_regex_fsm(self, idx: int = 0) -> None: + self.generation = "" + self.reset_state = False + self.allow_eos = False + self.done = False + self.regex_fsm: RegexFSM + + def _set_next_regex_fsm(self) -> None: """Use the CFG incremental parser to set the next regex FSM. - Check what the CFG incremental parser proposes next. - If the only proposal is the EOS token, - we set the state to done and return. - If there are other proposals, - we set a new regex FSM and return. + Check what the CFG incremental parser proposes next: + - If the only proposal is the EOS token we set the state to done and + return. + - If there are other proposals, we set a new regex FSM and return. """ - interactive = self.parser.parse_interactive(self.generations[idx]) + interactive = self.parser.parse_interactive(self.generation) interactive.exhaust_lexer() options = {self.terminal_regexps[x] for x in interactive.accepts()} if self.terminal_regexps["$END"] in options: options.remove(self.terminal_regexps["$END"]) if len(options) == 0: - self.done[idx] = True + self.done = True return - self.allow_eos[idx] = True + self.allow_eos = True options.add("") assert len(options) > 1 regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")" - args = ( - regex_string, - self.tokenizer, - ) - if len(self.regex_fsms) <= idx: - self.regex_fsms.append(RegexFSM(*args)) - else: - self.regex_fsms[idx] = RegexFSM(*args) - self.reset_state[idx] = True + self.regex_fsm = RegexFSM(regex_string, self.tokenizer) + self.reset_state = True - def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: + def allowed_token_ids(self, state: FSMState) -> List[int]: """Generate a list of allowed tokens for the next step. - Upon initialization, the CFG incremental parser is used to determine the first regex. + Upon initialization, the CFG incremental parser is used to determine the + first regex. This regex is used for proposals until either: - - the regex is exhausted, and its only remaining option is the EOS token, - in which case we always transition to the next regex - - the regex can be exhausted, but the EOS token is not the only remaining option, - in which case we transition to the next regex with probability P (TODO) - or remove the possibility of generating the EOS token and continue with the current regex + + - The regex is exhausted, and its only remaining option is the EOS + token, in which case we always transition to the next regex + - The regex can be exhausted, but the EOS token is not the only + remaining option, in which case we transition to the next regex with + probability P (TODO) or remove the possibility of generating the EOS + token and continue with the current regex The CFG incremental parser is allowed to propose the EOS token from any final state, and once it is generated, the FSM will continue to always generate the EOS token. @@ -314,22 +285,14 @@ def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: ---------- state The current state of the FSM. - idx - The index of the current input in the batch. Returns ------- A list that contains the tokens to mask. """ - if len(self.generations) <= idx: - self.generations.append("") - self.reset_state.append(False) - self.allow_eos.append(False) - self.done.append(False) - - if len(self.regex_fsms) > idx: - proposal = self.regex_fsms[idx].allowed_token_ids(state) + if self.generation != "": + proposal = self.regex_fsm.allowed_token_ids(state) if self.tokenizer.eos_token_id not in proposal: return proposal if set(proposal) != {self.tokenizer.eos_token_id}: @@ -337,23 +300,23 @@ def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] return proposal - self._set_next_regex_fsm(idx) + self._set_next_regex_fsm() - if self.done[idx]: + if self.done: return [self.tokenizer.eos_token_id] - if self.reset_state[idx]: + if self.reset_state: state = FSMState(0) - proposal = self.regex_fsms[idx].allowed_token_ids(state) - if self.allow_eos[idx]: - self.allow_eos[idx] = False + proposal = self.regex_fsm.allowed_token_ids(state) + if self.allow_eos: + self.allow_eos = False else: proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] assert len(proposal) > 0 return proposal - def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: + def next_state(self, state: FSMState, token_id: int) -> FSMState: """Update the state of the FSM. Transitions the underlying regex FSM to its next state. @@ -366,34 +329,26 @@ def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: The current state of the FSM. token_id The id of the token that was just generated. - idx - The index of the current input in the batch. Returns ------- The new state of the FSM. """ - if idx == 0: - self.num_tokens_generated += 1 if token_id == self.tokenizer.eos_token_id: - self.done[idx] = True + self.done = True return FSMState(-1) - if self.reset_state[idx]: - self.reset_state[idx] = False + if self.reset_state: + self.reset_state = False state = FSMState(0) - self.generations[idx] += self.tokenizer.decode([token_id])[0] + self.generation += self.tokenizer.decode([token_id])[0] - return self.regex_fsms[idx].next_state(state, token_id, idx) + return self.regex_fsm.next_state(state, token_id) - def is_final_state(self, state: FSMState, idx: int = 0) -> bool: + def is_final_state(self, state: FSMState) -> bool: """Return whether the current state of the FSM is a final state.""" - return self.done[idx] - - def reset(self) -> None: - """Reset the FSM to its initial state, so it can be called on a fresh batch on inputs.""" - self.num_tokens_generated = 0 - self.generations = [] - self.regex_fsms = [] - self.reset_state = [] - self.done = [] + return self.done + + def copy(self) -> "CFGFSM": + """Create a copy of the FSM.""" + return CFGFSM(self.cfg_string, self.tokenizer) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 891f3d9c9..b1876152f 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -81,12 +81,10 @@ def to_regex(resolver: Resolver, instance: dict): ---- Many features of JSON schema are missing: - Support the fact that fields in an object are optional by default - - Handle `required` keyword - Handle `additionalProperties` keyword - Handle types defined as a list - Handle constraints on numbers - Handle special patterns: `date`, `uri`, etc. - - Handle optional fields (not in `required`) This does not support recursive definitions. @@ -102,13 +100,43 @@ def to_regex(resolver: Resolver, instance: dict): if "properties" in instance: regex = "" regex += r"\{" - for i, (name, value) in enumerate(instance["properties"].items()): - regex += f'{whitespace}"{name}"{whitespace}:{whitespace}' - regex += to_regex(resolver, value) - - # No comma after the last key-value pair in JSON - if i < len(instance["properties"]) - 1: - regex += f"{whitespace}," + properties = instance["properties"] + required_properties = instance.get("required", []) + is_required = [item in required_properties for item in properties] + # If at least one property is required, we include the one in the lastest position + # without any comma. + # For each property before it (optional or required), we add with a comma after the property. + # For each property after it (optional), we add with a comma before the property. + if any(is_required): + last_required_pos = max([i for i, value in enumerate(is_required) if value]) + for i, (name, value) in enumerate(properties.items()): + subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}' + subregex += to_regex(resolver, value) + if i < last_required_pos: + subregex = f"{subregex}{whitespace}," + elif i > last_required_pos: + subregex = f"{whitespace},{subregex}" + regex += subregex if is_required[i] else f"({subregex})?" + # If no property is required, we have to create a possible pattern for each property in which + # it's the last one necessarilly present. Then, we add the others as optional before and after + # following the same strategy as described above. + # The whole block is made optional to allow the case in which no property is returned. + else: + property_subregexes = [] + for i, (name, value) in enumerate(properties.items()): + subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}' + subregex += to_regex(resolver, value) + property_subregexes.append(subregex) + possible_patterns = [] + for i in range(len(property_subregexes)): + pattern = "" + for subregex in property_subregexes[:i]: + pattern += f"({subregex}{whitespace},)?" + pattern += property_subregexes[i] + for subregex in property_subregexes[i + 1 :]: + pattern += f"({whitespace},{subregex})?" + possible_patterns.append(pattern) + regex += f"({'|'.join(possible_patterns)})?" regex += f"{whitespace}" + r"\}" diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 3be3af790..729c90ca6 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -184,8 +184,6 @@ def __call__( """ - self.fsm.reset() - if isinstance(prompts, str): prompts = [prompts] @@ -195,6 +193,7 @@ def __call__( stop_sequences = stop_at or self.stop_sequences max_tokens = max_tokens or self.max_tokens num_sequences = len(prompts) + fsms = [self.fsm.copy() for _ in prompts] if rng is None: rng = torch.Generator(device=self.device) @@ -206,7 +205,7 @@ def __call__( init_fsm_states = [FSMState(0) for _ in range(num_sequences)] states = sequence_generator( - self.generate_token, self.fsm, init_state, init_fsm_states, rng + self.generate_token, fsms, init_state, init_fsm_states, rng ) while True: @@ -283,13 +282,16 @@ def stream( """ - self.fsm.reset() - - max_tokens = max_tokens or self.max_tokens + if isinstance(prompts, str): + prompts = [prompts] if isinstance(stop_at, str): stop_at = [stop_at] + stop_sequences = stop_at or self.stop_sequences + max_tokens = max_tokens or self.max_tokens + num_sequences = len(prompts) + fsms = [self.fsm.copy() for _ in prompts] if rng is None: rng = torch.Generator(device=self.device) @@ -298,14 +300,10 @@ def stream( init_state = init_generator_state( self.tokenizer, self.device, prompts, kv_cache ) - - token_ids = init_state[1] - num_sequences = token_ids.shape[0] - init_fsm_states = [FSMState(0) for _ in range(num_sequences)] states = sequence_generator( - self.generate_token, self.fsm, init_state, init_fsm_states, rng + self.generate_token, fsms, init_state, init_fsm_states, rng ) def token_generator() -> Iterator[Union[List[str], str]]: @@ -398,10 +396,7 @@ def cfg( def format( - model, - python_type, - max_tokens: Optional[int] = None, - sampler: Sampler = multinomial, + model, python_type, max_tokens: Optional[int] = None, sampler: Sampler = multinomial ): regex_str = python_types_to_regex(python_type) return regex(model, regex_str, max_tokens, sampler) diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index ed298d210..44b469bc4 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -52,7 +52,7 @@ def init_generator_state( def sequence_generator( token_generator: Callable, - fsm: "FSM", + fsms: List["FSM"], init_state: Tuple, fsm_states: List[FSMState], rng: torch.Generator, @@ -64,8 +64,9 @@ def sequence_generator( token_generator A callable that generate a new token given the current generation state and logits biases. - fsm - The finite-state machine that drives the text generation. + fsms + List of finite-state machines that drive the text generation, + one for each sequence in the batch. init_state The initial generation state for the batches. fsm_states @@ -78,7 +79,7 @@ def sequence_generator( """ token_ids, attention_masks, kv_cache = init_state while True: - allowed_tokens = get_allowed_tokens(fsm, fsm_states) + allowed_tokens = get_allowed_tokens(fsms, fsm_states) next_token_ids, kv_cache, logits, _ = token_generator( token_ids, @@ -91,8 +92,8 @@ def sequence_generator( token_ids = update_token_ids(token_ids, next_token_ids) attention_masks = expand_attention_masks(attention_masks) - fsm_states = get_next_fsm_states(fsm, fsm_states, next_token_ids) - is_finished = is_generation_finished(fsm, fsm_states) + fsm_states = get_next_fsm_states(fsms, fsm_states, next_token_ids) + is_finished = is_generation_finished(fsms, fsm_states) if is_finished: yield GenerationState(token_ids, kv_cache, logits, fsm_states) @@ -149,7 +150,7 @@ def generate( def get_next_fsm_states( - fsm: "FSM", fsm_states: List[FSMState], next_token_ids: torch.Tensor + fsms: List["FSM"], fsm_states: List[FSMState], next_token_ids: torch.Tensor ) -> List[FSMState]: """ @@ -166,14 +167,12 @@ def get_next_fsm_states( """ return [ - fsm.next_state(fsm_state, int(token_id[0]), idx) - for idx, fsm_state, token_id in zip( - range(len(fsm_states)), fsm_states, next_token_ids - ) + fsm.next_state(fsm_state, int(token_id[0])) + for fsm, fsm_state, token_id in zip(fsms, fsm_states, next_token_ids) ] -def get_allowed_tokens(fsm: "FSM", fsm_states: List[FSMState]) -> torch.Tensor: +def get_allowed_tokens(fsms: List["FSM"], fsm_states: List[FSMState]) -> torch.Tensor: """Get the new instructions for each sequence from the finite-state machine. Parameters @@ -188,10 +187,10 @@ def get_allowed_tokens(fsm: "FSM", fsm_states: List[FSMState]) -> torch.Tensor: A nested list that contains the ids of the logits to keep. """ - return [fsm.allowed_token_ids(state, idx) for idx, state in enumerate(fsm_states)] + return [fsm.allowed_token_ids(state) for fsm, state in zip(fsms, fsm_states)] -def is_generation_finished(fsm: "FSM", fsm_states: List[FSMState]) -> bool: +def is_generation_finished(fsms: List["FSM"], fsm_states: List[FSMState]) -> bool: """Determine if the generation is finished. A generation is considered finished if the FSM of every sequence in the @@ -212,7 +211,7 @@ def is_generation_finished(fsm: "FSM", fsm_states: List[FSMState]) -> bool: Whether all sequences are finished sampling. """ - return all([fsm.is_final_state(state, idx) for idx, state in enumerate(fsm_states)]) + return all([fsm.is_final_state(state) for fsm, state in zip(fsms, fsm_states)]) @torch.inference_mode() @@ -264,10 +263,7 @@ def expand_attention_masks(attention_masks: torch.Tensor) -> torch.Tensor: @torch.inference_mode() -def bias_logits( - logits: torch.Tensor, - allowed_token_ids: List, -) -> torch.Tensor: +def bias_logits(logits: torch.Tensor, allowed_token_ids: List) -> torch.Tensor: """Mask the logits. The function iterates over a nested list where each list corresponds to the diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index 1d82ff8f5..d8d995750 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -8,6 +8,7 @@ from .awq import awq from .exllamav2 import exl2 from .gptq import gptq +from .llamacpp import LlamaCpp, llamacpp from .mamba import Mamba, mamba from .openai import OpenAI, openai from .transformers import Transformer, transformers diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py new file mode 100644 index 000000000..c51f600f8 --- /dev/null +++ b/outlines/models/llamacpp.py @@ -0,0 +1,304 @@ +import ctypes +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from numpy.typing import NDArray + +from outlines.models.tokenizer import Tokenizer + + +class LlamaCpp: + """Represents a `llama_cpp` model.""" + + def __init__( + self, llama_instance, model, tokenizer, device, context_params, **kwargs + ): + self.device = device + self.llama_instance = llama_instance + self.tokenizer = tokenizer + + # Note: the concept of padding does not exist in llama.cpp as a batched sequence is just + # a flat array of tokens that can be assigned to one or more sequences. + # To make it compatible with the transformers inspired tokenizer interface + # we need a padding token to homogenize to token_ids tensor. + self.pad_token_id = -1 + + self.n_past = 0 + self.n_vocab = kwargs.pop("n_vocab") + + self.ctx = llama_instance.llama_new_context_with_model(model, context_params) + + def forward(self, input_ids: torch.LongTensor, *_): + """Compute a forward pass through the llama_cpp model.""" + if input_ids.ndim == 2: + seq_tensor = input_ids[:, self.n_past :] + elif input_ids.ndim == 1: + seq_tensor = input_ids.view(1, -1)[:, self.n_past :] + else: + raise Exception("Only one and two dimensional inputs allowed.") + + tokens_total = torch.numel(seq_tensor[seq_tensor != self.pad_token_id]) + batch = self.llama_instance.llama_batch_init(tokens_total, 0, 1) + + seq_token_ids = [] + for seq_idx, seq in enumerate(seq_tensor): + for token_pos, token_id in enumerate(seq): + if token_id == self.pad_token_id: + break + batch.token[batch.n_tokens] = token_id.item() + batch.pos[batch.n_tokens] = token_pos + batch.seq_id[batch.n_tokens][0] = seq_idx + batch.n_seq_id[batch.n_tokens] = 1 + batch.logits[batch.n_tokens] = False + + batch.n_tokens += 1 + self.n_past += 1 + + batch.logits[batch.n_tokens - 1] = True + seq_token_ids.append(batch.n_tokens - 1) + + if self.llama_instance.llama_decode(self.ctx, batch) != 0: + print("Error decoding") + + all_logits = [] + for seq_token in seq_token_ids: + logits = self.llama_instance.llama_get_logits_ith(self.ctx, seq_token) + logits_list = (ctypes.c_float * self.n_vocab)( + *[logits[token_id] for token_id in range(self.n_vocab)] + ) + logits_tensor = torch.tensor(logits_list) + all_logits.append(logits_tensor) + + self.llama_instance.llama_batch_free(batch) + + stacked_logits = torch.stack(all_logits) + return stacked_logits, None + + def __call__( + self, + input_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + past_key_values: Optional[Tuple] = None, + ) -> torch.FloatTensor: + logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values) + next_token_logits = logits + + return next_token_logits, kv_cache + + +class LlamaCppTokenizer(Tokenizer): + def __init__(self, llama_instance, model, model_name: str, **kwargs): + self.model_name = model_name + self.llama_instance = llama_instance + self.is_llama = False + + self.model = model + self.n_vocab = kwargs.pop("n_vocab") + + self.eos_token_id = llama_instance.llama_token_eos(model) + self.eos_token = self._get_eos_token() + self.pad_token_id = -1 + self.bos_token_id = llama_instance.llama_token_eos(model) + self.nl_token_id = llama_instance.llama_token_nl(model) + self.vocabulary = {} + self._create_vocabulary() + + self.n_past = 0 + + self.special_tokens = { + self.eos_token_id, + self.pad_token_id, + self.bos_token_id, + self.nl_token_id, + } + + def _create_vocabulary(self): + for t in range(self.n_vocab): + size = 32 + buffer = (ctypes.c_char * size)() + n = self.llama_instance.llama_token_to_piece( + self.model, self.llama_instance.llama_token(t), buffer, size + ) + + try: + token_piece = buffer[:n].decode("utf-8") + self.vocabulary[token_piece] = t + except Exception as e: + print(f"Failed to convert token ({buffer[:n]}): {e}") + continue + + def encode( + self, prompt: Union[str, List[str]] + ) -> Tuple[NDArray[np.int64], NDArray[np.int64]]: + if isinstance(prompt, list): + prompts = prompt + else: + prompts = [prompt] + + max_len = 0 + token_ids = [] + for p in prompts: + embd_inp = (self.llama_instance.llama_token * (len(p) + 1))() + + n_of_tok = self.llama_instance.llama_tokenize( + model=self.model, + text=bytes(str(p), "utf-8"), + text_len=len(embd_inp), + tokens=embd_inp, + n_max_tokens=len(embd_inp), + add_bos=self.n_past == 0, + special=False, + ) + + self.n_past += n_of_tok + + if n_of_tok > max_len: + max_len = n_of_tok + + embd_inp = embd_inp[:n_of_tok] + token_ids.append(np.array(embd_inp)) + + max_len = np.max([len(a) for a in token_ids]) + padded = np.asarray( + [ + np.pad( + a, + (0, max_len - len(a)), + "constant", + constant_values=self.pad_token_id, + ) + for a in token_ids + ] + ) + + token_ids = torch.LongTensor(padded) + return token_ids, torch.ones_like(token_ids) + + def decode(self, token_ids: NDArray[np.int64]) -> List[str]: + if isinstance(token_ids, list): + token_ids = np.array(token_ids) + if token_ids.ndim == 1: + token_ids = [token_ids] + + pieces = [] + for tokens in token_ids: + seq = [] + for id in tokens: + size = 32 + buffer = (ctypes.c_char * size)() + n = self.llama_instance.llama_token_to_piece( + self.model, self.llama_instance.llama_token(id), buffer, size + ) + + token_piece = buffer[:n].decode("utf-8") # type: ignore + + seq.append(token_piece) + + pieces.append("".join(seq)) + + return pieces + + def _get_eos_token(self): + size = 32 + buffer = (ctypes.c_char * size)() + n = self.llama_instance.llama_token_to_piece( + self.model, self.llama_instance.llama_token(self.eos_token_id), buffer, size + ) + + token_piece = buffer[:n].decode("utf-8") + + return token_piece + + def convert_token_to_string(self, token: str) -> str: + return token + + def __eq__(self, other): + if isinstance(other, type(self)): + return other.model_name == self.model_name and other.kwargs == self.kwargs + return NotImplemented + + def __hash__(self): + return hash(self.model_name) + + +def llamacpp( + model_name: str, + device: Optional[str] = None, + model_kwargs: dict = {}, + tokenizer_kwargs: dict = {}, +): + try: + import llama_cpp + except ImportError: + raise ImportError( + "The `llama-cpp-python` library needs to be installed in order to use LlamaCpp." + ) + + if device is None: + device = "cpu" + + llama_cpp.llama_backend_init(numa=False) + + model_params = llama_cpp.llama_model_default_params() + + if "cuda" in device: + model_params.n_gpu_layers = 999 + else: + model_params.n_gpu_layers = model_kwargs.pop( + "n_gpu_layers", model_params.n_gpu_layers + ) + + if "tensor_split" in model_kwargs.keys(): + tensor_split = model_kwargs.get("tensor_split") + if isinstance(tensor_split, list): + tensor_split_arr = (ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES)( + *[t for t in tensor_split] + ) + model_params.tensor_split = tensor_split_arr + + context_params = llama_cpp.llama_context_default_params() + context_params.n_batch = model_kwargs.pop("n_batch", context_params.n_batch) + context_params.n_ctx = model_kwargs.pop("n_ctx", context_params.n_ctx) + context_params.n_threads = model_kwargs.pop("n_threads", context_params.n_threads) + context_params.n_threads_batch = model_kwargs.pop( + "n_threads_batch", context_params.n_threads_batch + ) + context_params.rope_scaling_type = model_kwargs.pop( + "rope_scaling_type", context_params.rope_scaling_type + ) + context_params.rope_freq_base = model_kwargs.pop( + "rope_freq_base", context_params.rope_freq_base + ) + context_params.rope_freq_scale = model_kwargs.pop( + "rope_freq_scale", context_params.rope_freq_scale + ) + context_params.yarn_ext_factor = model_kwargs.pop( + "yarn_ext_factor", context_params.yarn_ext_factor + ) + context_params.yarn_attn_factor = model_kwargs.pop( + "yarn_attn_factor", context_params.yarn_attn_factor + ) + context_params.yarn_beta_fast = model_kwargs.pop( + "yarn_beta_fast", context_params.yarn_beta_fast + ) + context_params.yarn_beta_slow = model_kwargs.pop( + "yarn_beta_slow", context_params.yarn_beta_slow + ) + context_params.yarn_orig_ctx = model_kwargs.pop( + "yarn_orig_ctx", context_params.yarn_orig_ctx + ) + context_params.offload_kqv = model_kwargs.pop( + "offload_kqv", context_params.offload_kqv + ) + + model = llama_cpp.llama_load_model_from_file( + model_name.encode("utf-8"), model_params + ) + + model_kwargs["n_vocab"] = llama_cpp.llama_n_vocab(model) + tokenizer_kwargs["n_vocab"] = model_kwargs.get("n_vocab") + + tokenizer = LlamaCppTokenizer(llama_cpp, model, model_name, **tokenizer_kwargs) + + return LlamaCpp(llama_cpp, model, tokenizer, "cpu", context_params, **model_kwargs) diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 2e01445cb..542b6fc26 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -104,10 +104,6 @@ def __init__( parameters that cannot be set by calling this class' methods. """ - if model_name not in ["gpt-4", "gpt-3.5-turbo"]: - raise ValueError( - "Invalid model_name. It must be either 'gpt-4' or 'gpt-3.5-turbo'." - ) try: import openai @@ -125,6 +121,13 @@ def __init__( raise ValueError( "You must specify an API key to use the OpenAI API integration." ) + try: + client = openai.OpenAI() + client.models.retrieve(model_name) + except openai.NotFoundError: + raise ValueError( + "Invalid model_name. Check openai models list at https://platform.openai.com/docs/models" + ) if config is not None: self.config = replace(config, model=model_name) # type: ignore diff --git a/pyproject.toml b/pyproject.toml index 6c5d5610c..05885cb3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,8 @@ test = [ "beartype<0.16.0", "datasets", "responses", + "llama-cpp-python", + "huggingface_hub" ] serve = [ "vllm==0.2.6", @@ -113,6 +115,8 @@ module = [ "tiktoken.*", "torch.*", "transformers.*", + "llama_cpp", + "huggingface_hub", "lark.*", "interegular.*", "datasets.*", diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py index 128c76216..61ae82525 100644 --- a/tests/fsm/test_fsm.py +++ b/tests/fsm/test_fsm.py @@ -82,27 +82,27 @@ def decode(self, token_ids): assert set(fsm.allowed_token_ids(state=0)) == {1, 3, 5} state = fsm.next_state(state=0, token_id=1) - assert fsm.generations[0] == "{" + assert fsm.generation == "{" assert not fsm.is_final_state(0) assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3} state = fsm.next_state(state=state, token_id=3) - assert fsm.generations[0] == "{[" + assert fsm.generation == "{[" assert not fsm.is_final_state(0) assert set(fsm.allowed_token_ids(state=state)) == {1, 3, 4} state = fsm.next_state(state=state, token_id=4) - assert fsm.generations[0] == "{[]" + assert fsm.generation == "{[]" assert not fsm.is_final_state(0) assert set(fsm.allowed_token_ids(state=state)) == {2} state = fsm.next_state(state=state, token_id=2) - assert fsm.generations[0] == "{[]}" + assert fsm.generation == "{[]}" assert not fsm.is_final_state(0) assert set(fsm.allowed_token_ids(state=state)) == {5} state = fsm.next_state(state=state, token_id=5) - assert fsm.generations[0] == "{[]}" + assert fsm.generation == "{[]}" assert fsm.is_final_state(0) @@ -133,18 +133,18 @@ def decode(self, token_ids): assert set(fsm.allowed_token_ids(state=0)) == {1} state = fsm.next_state(state=0, token_id=1) - assert fsm.generations[0] == "(" + assert fsm.generation == "(" assert not fsm.is_final_state(0) assert set(fsm.allowed_token_ids(state=state)) == {1, 2} state = fsm.next_state(state=state, token_id=2) - assert fsm.generations[0] == "()" + assert fsm.generation == "()" assert not fsm.is_final_state(0) # possible to continue or terminate assert set(fsm.allowed_token_ids(state=state)) == {1, 3} state = fsm.next_state(state=state, token_id=3) # feed eos - assert fsm.generations[0] == "()" + assert fsm.generation == "()" assert fsm.is_final_state(0) # once eos generated, can only terminate @@ -176,21 +176,21 @@ def decode(self, token_ids): fsm = CFGFSM(cfg_str, tokenizer) assert set(fsm.allowed_token_ids(state=0)) == {1, 2} - assert fsm.reset_state[0] # starting new regex + assert fsm.reset_state # starting new regex state = fsm.next_state(state=0, token_id=1) - assert fsm.generations[0] == "a" + assert fsm.generation == "a" assert not fsm.is_final_state(0) assert set(fsm.allowed_token_ids(state=state)) == {1} - assert not fsm.reset_state[0] # continuing current regex + assert not fsm.reset_state # continuing current regex state = fsm.next_state(state=state, token_id=1) - assert fsm.generations[0] == "aa" + assert fsm.generation == "aa" assert not fsm.is_final_state(0) assert set(fsm.allowed_token_ids(state=state)) == {3} - assert not fsm.reset_state[0] # completing current regex + assert not fsm.reset_state # completing current regex state = fsm.next_state(state=state, token_id=3) - assert fsm.generations[0] == "aa" + assert fsm.generation == "aa" assert fsm.is_final_state(0) @@ -225,7 +225,7 @@ def decode(self, token_ids): assert set(fsm.allowed_token_ids(state=0)) == {1, 2} state = fsm.next_state(state=0, token_id=1) - assert fsm.generations[0] == "a" + assert fsm.generation == "a" assert not fsm.is_final_state(0) # INTENDED LOGIC @@ -240,5 +240,5 @@ def decode(self, token_ids): # This implementation is sound, and always terminates, but is not complete assert set(fsm.allowed_token_ids(state=state)) == {3} state = fsm.next_state(state=state, token_id=3) - assert fsm.generations[0] == "a" + assert fsm.generation == "a" assert fsm.is_final_state(0) diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 454d120c7..990f14215 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -182,6 +182,7 @@ def test_match_number(pattern, does_match): "title": "Foo", "type": "object", "properties": {"count": {"title": "Count", "type": "integer"}}, + "required": ["count"], }, '\\{[\\n ]*"count"[\\n ]*:[\\n ]*(0|[1-9][0-9]*)[\\n ]*\\}', [('{\n "count": 100\n}', True)], @@ -262,8 +263,10 @@ def test_match_number(pattern, does_match): "title": "Foo", "type": "object", "properties": {"spam": {"title": "Spam", "type": "integer"}}, + "required": ["spam"], } }, + "required": ["fuzz"], }, f'\\{{[\\n ]*"fuzz"[\\n ]*:[\\n ]*\\{{[\\n ]*"spam"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*\\}}[\\n ]*\\}}', [('{\n "fuzz": {\n "spam": 100\n }\n}', True)], @@ -278,7 +281,7 @@ def test_match_number(pattern, does_match): "name": {"title": "Name", "type": "string"}, "a": {"$ref": "#/properties/name"}, }, - "required": ["user_id", "name"], + "required": ["user_id", "name", "a"], }, f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"a"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', [('{"user_id": 100, "name": "John", "a": "Marc"}', True)], @@ -293,7 +296,7 @@ def test_match_number(pattern, does_match): "name": {"title": "Name", "type": "string"}, "name2": {"$ref": "#/$defs/name"}, }, - "required": ["user_id", "name"], + "required": ["user_id", "name", "name2"], }, f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"name2"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', [('{"user_id": 100, "name": "John", "name2": "Marc"}', True)], @@ -310,8 +313,10 @@ def test_match_number(pattern, does_match): "address": {"$ref": "customer#/$defs/address"}, }, "required": [ + "name", "first_name", "last_name", + "address", "shipping_address", "billing_address", ], @@ -343,6 +348,95 @@ def test_match_number(pattern, does_match): ) ], ), + # Optional properties + # Last required property in first position + ( + { + "properties": { + "name": {"type": "string"}, + "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + "weapon": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + }, + "required": ["name"], + "title": "Character", + "type": "object", + }, + f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*(({INTEGER})|(null)|({INTEGER}null)|(null{INTEGER})))?([\\n ]*,[\\n ]*"weapon"[\\n ]*:[\\n ]*(({STRING})|(null)|({STRING}null)|(null{STRING})))?[\\n ]*\\}}', + [ + ('{ "name" : "Player" }', True), + ('{ "name" : "Player", "weapon" : "sword" }', True), + ('{ "age" : 10, "weapon" : "sword" }', False), + ], + ), + # Last required property in middle position + ( + { + "properties": { + "name": {"type": "string"}, + "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + "weapon": {"type": "string"}, + "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + }, + "required": ["name", "weapon"], + "title": "Character", + "type": "object", + }, + f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"age"[\\n ]*:[\\n ]*(({INTEGER})|(null)|({INTEGER}null)|(null{INTEGER}))[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*(({INTEGER})|(null)|({INTEGER}null)|(null{INTEGER})))?[\\n ]*\\}}', + [ + ('{ "name" : "Player" , "weapon" : "sword" }', True), + ( + '{ "name" : "Player", "age" : 10, "weapon" : "sword" , "strength" : 10 }', + True, + ), + ('{ "weapon" : "sword" }', False), + ], + ), + # Last required property in last position + ( + { + "properties": { + "name": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + "age": {"type": "integer"}, + "armor": {"type": "string"}, + "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + "weapon": {"title": "Weapon", "type": "string"}, + }, + "required": ["age", "armor", "weapon"], + "title": "Character", + "type": "object", + }, + f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*(({STRING})|(null)|({STRING}null)|(null{STRING}))[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"armor"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"strength"[\\n ]*:[\\n ]*(({INTEGER})|(null)|({INTEGER}null)|(null{INTEGER}))[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + [ + ( + '{ "name" : "Player", "age" : 10, "armor" : "plate", "strength" : 11, "weapon" : "sword" }', + True, + ), + ('{ "age" : 10, "armor" : "plate", "weapon" : "sword" }', True), + ( + '{ "name" : "Kahlhanbeh", "armor" : "plate", "weapon" : "sword" }', + False, + ), + ], + ), + # All properties are optional + ( + { + "properties": { + "name": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + }, + "title": "Character", + "type": "object", + }, + f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*(({STRING})|(null)|({STRING}null)|(null{STRING}))([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*(({INTEGER})|(null)|({INTEGER}null)|(null{INTEGER})))?([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*(({INTEGER})|(null)|({INTEGER}null)|(null{INTEGER})))?|([\\n ]*"name"[\\n ]*:[\\n ]*(({STRING})|(null)|({STRING}null)|(null{STRING}))[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*(({INTEGER})|(null)|({INTEGER}null)|(null{INTEGER}))([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*(({INTEGER})|(null)|({INTEGER}null)|(null{INTEGER})))?|([\\n ]*"name"[\\n ]*:[\\n ]*(({STRING})|(null)|({STRING}null)|(null{STRING}))[\\n ]*,)?([\\n ]*"age"[\\n ]*:[\\n ]*(({INTEGER})|(null)|({INTEGER}null)|(null{INTEGER}))[\\n ]*,)?[\\n ]*"strength"[\\n ]*:[\\n ]*(({INTEGER})|(null)|({INTEGER}null)|(null{INTEGER})))?[\\n ]*\\}}', + [ + ('{ "name" : "Player" }', True), + ('{ "name" : "Player", "age" : 10, "strength" : 10 }', True), + ('{ "age" : 10, "strength" : 10 }', True), + ("{ }", True), + ], + ), ], ) def test_match(schema, regex, examples): diff --git a/tests/generate/test_generator.py b/tests/generate/test_generator.py index b1fdf2409..9a7bd1188 100644 --- a/tests/generate/test_generator.py +++ b/tests/generate/test_generator.py @@ -21,17 +21,17 @@ def test_sequence_generator_class(): class MockFSM: - def next_state(self, state, next_token_ids, _): + def next_state(self, state, next_token_ids): return 4 def allowed_token_ids(self, *_): return [4] - def is_final_state(self, _, idx=0): + def is_final_state(self, _): return True - def reset(self): - pass + def copy(self): + return self class MockTokenizer: def encode(self, _): @@ -74,18 +74,15 @@ def sampler(biased_logits, *_): def test_init_sequence_generator(): class MockFSM: - def next_state(self, state, next_token_ids, idx=0): + def next_state(self, state, next_token_ids): return 0 - def allowed_token_ids(self, _, idx=0): + def allowed_token_ids(self, _): return [] - def is_final_state(self, _, idx=0): + def is_final_state(self, _): return True - def reset(self): - pass - class MockTokenizer: def encode(self, _): return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]]) @@ -112,18 +109,15 @@ def sampler(biased_logits, *_): def test_sequence_generator_1d_single_iteration(): class MockFSM: - def next_state(self, state, next_token_ids, idx=0): + def next_state(self, state, next_token_ids): return 0 - def allowed_token_ids(self, _, idx=0): + def allowed_token_ids(self, _): return [0, 1, 2, 3] - def is_final_state(self, _, idx=0): + def is_final_state(self, _): return True - def reset(self): - pass - class MockTokenizer: def encode(self, _): return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]]) @@ -145,7 +139,7 @@ def sampler(biased_logits, *_): init_fsm_states = [0] generate = token_generator(MockModel(), sampler) sequence = sequence_generator( - generate, MockFSM(), init_state, init_fsm_states, torch.Generator() + generate, [MockFSM()], init_state, init_fsm_states, torch.Generator() ) result = next(sequence) @@ -158,21 +152,18 @@ def sampler(biased_logits, *_): def test_sequence_generator_1d_several_iterations(): class MockFSM: - def next_state(self, state, next_token_ids, idx=0): + def next_state(self, state, next_token_ids): return FSMState(state + 1) - def allowed_token_ids(self, _, idx=0): + def allowed_token_ids(self, _): return [0, 1, 2, 3] - def is_final_state(self, state, idx=0): + def is_final_state(self, state): if state < 2: return False else: return True - def reset(self): - pass - class MockTokenizer: def encode(self, _): return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]]) @@ -194,7 +185,7 @@ def sampler(biased_logits, *_): init_fsm_states = [0] generate = token_generator(MockModel(), sampler) sequence = sequence_generator( - generate, MockFSM(), init_state, init_fsm_states, torch.Generator() + generate, [MockFSM()], init_state, init_fsm_states, torch.Generator() ) result = next(sequence) @@ -211,18 +202,15 @@ def sampler(biased_logits, *_): def test_sequence_generator_2d_single_iteration(): class MockFSM: - def next_state(self, state, next_token_ids, idx=0): + def next_state(self, state, next_token_ids): return 0 - def allowed_token_ids(self, _, idx=0): + def allowed_token_ids(self, _): return [0, 1, 2, 3] - def is_final_state(self, _, idx=0): + def is_final_state(self, _): return True - def reset(self): - pass - class MockTokenizer: def encode(self, _): return torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), torch.tensor( @@ -248,9 +236,10 @@ def sampler(biased_logits, *_): None, ) init_fsm_states = [0, 0] + fsms = [MockFSM(), MockFSM()] generate = token_generator(MockModel(), sampler) sequence = sequence_generator( - generate, MockFSM(), init_state, init_fsm_states, torch.Generator() + generate, fsms, init_state, init_fsm_states, torch.Generator() ) result = next(sequence) @@ -267,21 +256,18 @@ def sampler(biased_logits, *_): def test_sequence_generator_2d_several_iterations(): class MockFSM: - def next_state(self, state, next_token_ids, idx=0): + def next_state(self, state, next_token_ids): return FSMState(state + 1) - def allowed_token_ids(self, _, idx=0): + def allowed_token_ids(self, _): return [0, 1, 2, 3] - def is_final_state(self, state, idx=0): + def is_final_state(self, state): if state < 2: return False else: return True - def reset(self): - pass - class MockTokenizer: def encode(self, _): return torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), torch.tensor( @@ -307,9 +293,10 @@ def sampler(biased_logits, *_): None, ) init_fsm_states = [0, 0] + fsms = [MockFSM(), MockFSM()] generate = token_generator(MockModel(), sampler) sequence = sequence_generator( - generate, MockFSM(), init_state, init_fsm_states, torch.Generator() + generate, fsms, init_state, init_fsm_states, torch.Generator() ) result = next(sequence) @@ -411,44 +398,48 @@ def sampler(biased_logits, *_): def test_get_next_fsm_states(): class MockFSM: - def next_state(self, state, next_token_ids, idx=0): + def next_state(self, state, next_token_ids): return 0 - result = get_next_fsm_states(MockFSM(), [0], torch.tensor([[0]])) + result = get_next_fsm_states([MockFSM()], [0], torch.tensor([[0]])) assert result == [0] - result = get_next_fsm_states(MockFSM(), [0, 0], torch.tensor([[0], [0]])) + result = get_next_fsm_states( + [MockFSM(), MockFSM()], [0, 0], torch.tensor([[0], [0]]) + ) assert result == [0, 0] def test_get_allowed_token_idss(): class MockFSM: - def allowed_token_ids(self, _, idx=0): + def allowed_token_ids(self, _): return [1, 2, 3, 4] - result = get_allowed_tokens(MockFSM(), [0]) + result = get_allowed_tokens([MockFSM()], [0]) assert result == [[1, 2, 3, 4]] - result = get_allowed_tokens(MockFSM(), [0, 1]) + result = get_allowed_tokens([MockFSM(), MockFSM()], [0, 1]) assert result == [[1, 2, 3, 4], [1, 2, 3, 4]] def test_is_generation_finished(): class MockFSMFinished: - def is_final_state(self, _, idx=0): + def is_final_state(self, _): return True - result = is_generation_finished(MockFSMFinished(), [1, 1]) + result = is_generation_finished([MockFSMFinished(), MockFSMFinished()], [1, 1]) assert result is True class MockFSMNotFinished: - def is_final_state(self, state, idx=0): + def is_final_state(self, state): if state == 0: return False else: return True - result = is_generation_finished(MockFSMNotFinished(), [0, 1]) + result = is_generation_finished( + [MockFSMNotFinished(), MockFSMNotFinished()], [0, 1] + ) assert result is False diff --git a/tests/generate/test_integration_transfomers.py b/tests/generate/test_integration_transfomers.py index cf129a8b5..0e6cfa4b4 100644 --- a/tests/generate/test_integration_transfomers.py +++ b/tests/generate/test_integration_transfomers.py @@ -310,7 +310,8 @@ def test_transformers_json_schema(): "properties": { "foo" : {"type": "integer"}, "bar": {"type": "string", "maxLength": 4} - } + }, + "required": ["foo", "bar"] } """ diff --git a/tests/models/test_llama_cpp.py b/tests/models/test_llama_cpp.py new file mode 100644 index 000000000..68e998239 --- /dev/null +++ b/tests/models/test_llama_cpp.py @@ -0,0 +1,68 @@ +import numpy as np +import pytest +import torch +from huggingface_hub import hf_hub_download + +from outlines.models.llamacpp import llamacpp + +TEST_MODEL = "./llama-test-model/llama-2-tiny-random.gguf" + + +@pytest.fixture(scope="session") +def model_download(tmp_path_factory): + tmp_path_factory.mktemp("./llama-test-model") + hf_hub_download( + repo_id="aladar/llama-2-tiny-random-GGUF", + local_dir="./llama-test-model", + local_dir_use_symlinks=False, + filename="llama-2-tiny-random.gguf", + ) + + +def test_tokenizer(model_download): + model = llamacpp(TEST_MODEL, "cpu") + + tokenizer = model.tokenizer + assert tokenizer.eos_token_id == 2 + assert tokenizer.pad_token_id == -1 + + token_ids, attention_mask = tokenizer.encode(["Test", "test bla hallo"]) + assert token_ids.ndim == 2 + assert token_ids.shape[0] == 2 + assert token_ids[0, -1] == -1 + assert token_ids[1, -1] != -1 + assert isinstance(token_ids, torch.LongTensor) + assert token_ids.shape == attention_mask.shape + + token_ids, attention_mask = tokenizer.encode("Test") + assert token_ids.ndim == 2 + assert token_ids.shape[0] == 1 + assert isinstance(token_ids, torch.LongTensor) + assert token_ids.shape == attention_mask.shape + + text = tokenizer.decode(np.array([0, 1, 2])) + assert isinstance(text, list) + + +def test_model(model_download): + model = llamacpp(TEST_MODEL) + + input_ids = torch.tensor([[0, 1, 2]]) + logits, kv_cache = model(input_ids, torch.ones_like(input_ids)) + + assert logits.ndim == 2 + assert logits.shape[0] == 1 + + model.n_past = 0 + input_ids = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + logits, kv_cache = model(input_ids, torch.ones_like(input_ids)) + + assert logits.ndim == 2 + assert logits.shape[0] == 3 + + model.n_past = 0 + input_ids = torch.tensor([[0, 1, 2], [3, -1, -1]]) + logits, kv_cache = model(input_ids, torch.ones_like(input_ids)) + + assert logits.ndim == 2 + assert logits.shape[0] == 2 diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 2801673fe..c2e885eb1 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -1,7 +1,6 @@ import pytest from outlines.models.openai import ( - OpenAI, build_optimistic_mask, find_longest_intersection, find_response_choices_intersection, @@ -49,11 +48,3 @@ def test_find_longest_common_prefix(response, choice, expected_prefix): def test_build_optimistic_mask(transposed, mask_size, expected_mask): mask = build_optimistic_mask(transposed, mask_size) assert mask == expected_mask - - -def test_model_name_validation(): - with pytest.raises(ValueError): - OpenAI(model_name="invalid_model_name") - - with pytest.raises(ValueError): - OpenAI(model_name="gpt-4-1106-preview")