Skip to content

Commit

Permalink
Add llama logits processor
Browse files Browse the repository at this point in the history
  • Loading branch information
dtiarks committed Jan 18, 2024
1 parent 98271d9 commit 79bd994
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ __pycache__
docs/build
.coverage
.idea/
*.gguf
68 changes: 68 additions & 0 deletions examples/llamacpp_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from enum import Enum

import numpy as np
from llama_cpp import Llama, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
from numpy.typing import NDArray
from pydantic import BaseModel, constr

from outlines.generate.processors.llamacpp import JSONLogitsProcessor


class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"


class Armor(str, Enum):
leather = "leather"
chainmail = "chainmail"
plate = "plate"


class Character(BaseModel):
name: constr(max_length=10)
age: int
armor: Armor
weapon: Weapon
strength: int


# TODO: why do we need this?
class EosCriteria(StoppingCriteria):
def __init__(self, eos_token_id):
self.eos_token_id = eos_token_id

def __call__(self, input_ids: NDArray[np.intc], logits: NDArray[np.single]):
if self.eos_token_id in input_ids[1:]:
return True


if __name__ == "__main__":
llama = Llama("./phi-2.Q4_K_M.gguf")

prompt = b"Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:"

logits_processor = JSONLogitsProcessor(Character, llama)
stopping_criteria_list = StoppingCriteriaList([EosCriteria(llama.token_eos())])

json_str = ""
tokens = llama.tokenize(prompt)
for token in llama.generate(
tokens,
top_k=40,
top_p=0.95,
temp=0.7,
logits_processor=LogitsProcessorList([logits_processor]),
stopping_criteria=stopping_criteria_list,
):
d = llama.detokenize([token])
try:
json_str += d.decode("utf-8")
except UnicodeDecodeError:
continue

print(json_str)
Empty file.
133 changes: 133 additions & 0 deletions outlines/generate/processors/llamacpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import json
import math
from collections import defaultdict
from typing import DefaultDict, List, Tuple, Union

import numpy as np
import torch
from numpy.typing import NDArray

from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
from outlines.models.tokenizer import Tokenizer


class LlamaCppTokenizer(Tokenizer):
def __init__(self, llama_instance, **kwargs):
self.model_name = "llama"
self.llama_instance = llama_instance
self.is_llama = False

self.n_vocab = llama_instance.n_vocab()

self.eos_token_id = llama_instance.token_eos()
self.eos_token = llama_instance.detokenize([self.eos_token_id])
self.pad_token_id = -1
self.bos_token_id = llama_instance.token_bos()
self.nl_token_id = 0
self.vocabulary = {}
self._create_vocabulary()

self.special_tokens = {}

def _create_vocabulary(self):
for t in range(self.n_vocab):
token_piece = ""
try:
token_piece = self.llama_instance.detokenize([t]).decode("utf-8")
self.vocabulary[token_piece] = t
except Exception as e:
print(f"Failed to convert token ({token_piece}): {e}")
continue

def encode(
self, prompt: Union[str, List[str]]
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
token_ids = self.llama_instance.tokenize(prompt)
return token_ids, torch.ones_like(token_ids)

def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
if isinstance(token_ids, list):
token_ids = np.array(token_ids)
if token_ids.ndim == 1:
token_ids = [token_ids]

decoded = self.llama_instance.detokenize(token_ids)

return decoded

def convert_token_to_string(self, token: str) -> str:
return token

def __eq__(self, other):
if isinstance(other, type(self)):
return other.model_name == self.model_name and other.kwargs == self.kwargs
return NotImplemented

def __hash__(self):
return hash(self.model_name)


class RegexLogitsProcessor:
def __init__(self, regex_string, llama):
"""Compile the FSM that drives the regex-guided generation.
Parameters
----------
regex_string
A string that represents a regular expression
llm
An instance of `vllm.LLM`
"""

self.tokenizer = LlamaCppTokenizer(llama)

fsm = RegexFSM(regex_string, self.tokenizer)
self.fsm = fsm

self.fsm_state = None

def __call__(
self, input_ids: NDArray[np.int64], scores: NDArray[np.float32]
) -> NDArray[np.float32]:
"""Use the FSM to bias the logits before sampling the next token."""

# TODO: sequence id handling
seq_id = 0

if len(input_ids) == 0 or self.fsm_state is None: # Initialize the fsm states
self.fsm_state: DefaultDict[int, int] = defaultdict(int) # type: ignore
else:
last_token = input_ids[-1]
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[seq_id], last_token
)

allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])

mask = torch.full((scores.shape[-1],), -math.inf, device="cpu").numpy()
mask[allowed_tokens] = 0
biased_scores = scores + mask

biased_scores[self.tokenizer.eos_token_id] = 0

return biased_scores


class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema, llm):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
----------
schema
A JSON schema that encodes the structure we want the model to generate
llm
An instance of `vllm.LLM`
"""
if isinstance(schema, dict):
schema = json.dumps(schema)
regex_string = build_regex_from_object(schema)
super().__init__(regex_string, llm)

0 comments on commit 79bd994

Please sign in to comment.