Skip to content

Commit

Permalink
Duplicate outlines timing code in outlines-core
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 21, 2024
1 parent 0e02ffb commit 245fc5e
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 27 deletions.
12 changes: 6 additions & 6 deletions src/benchmark_lfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def _get_enforcer(self, regex_name):
tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)
return TokenEnforcer(tokenizer_data, parser)

def time_lfe_total(self, _, regex_name):
def time_total(self, _, regex_name):
enforcer = self._get_enforcer(regex_name)
self._exhaust_samples(enforcer)

def time_lfe_first_token(self, _, regex_name):
def time_first_token(self, _, regex_name):
enforcer = self._get_enforcer(regex_name)
self._get_first_token(enforcer)

Expand All @@ -68,7 +68,7 @@ def setup(self, model, regex_name):
self.enforcer = self._get_enforcer(regex_name)
self._get_first_token(self.enforcer)

def time_lfe_runtime(self, *args):
def time_runtime(self, *args):
self._exhaust_samples(self.enforcer)


Expand All @@ -87,11 +87,11 @@ def _get_enforcer(self, json_schema_name):
tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer)
return TokenEnforcer(tokenizer_data, parser)

def time_lfe_total(self, _, json_schema_name):
def time_total(self, _, json_schema_name):
enforcer = self._get_enforcer(json_schema_name)
self._exhaust_samples(enforcer)

def time_lfe_first_token(self, _, json_schema_name):
def time_first_token(self, _, json_schema_name):
enforcer = self._get_enforcer(json_schema_name)
self._get_first_token(enforcer)

Expand All @@ -113,5 +113,5 @@ def setup(self, model, json_schema_name):
self.enforcer = self._get_enforcer(json_schema_name)
self._get_first_token(self.enforcer)

def time_lfe_runtime(self, *args):
def time_runtime(self, *args):
self._exhaust_samples(self.enforcer)
12 changes: 6 additions & 6 deletions src/benchmark_outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ def setup(self, model, regex_name):
samples = regex_cases[regex_name]["samples"]
self.do_setup(model, samples)

def time_outlines_total(self, _, regex_name):
def time_total(self, _, regex_name):
regex_string = regex_cases[regex_name]["regex"]
guide = self.guide_class(regex_string, self.tokenizer)
self._exhaust_samples(guide)

def time_outlines_first_token(self, _, regex_name):
def time_first_token(self, _, regex_name):
regex_string = regex_cases[regex_name]["regex"]
guide = self.guide_class(regex_string, self.tokenizer)
self._get_first_token(guide)
Expand All @@ -87,7 +87,7 @@ def setup(self, model, regex_name):
self.guide = self.guide_class(regex_string, self.tokenizer)
self._get_first_token(self.guide)

def time_outlines_runtime(self, *args):
def time_runtime(self, *args):
self._exhaust_samples(self.guide)


Expand All @@ -102,13 +102,13 @@ def setup(self, model, json_schema_name):
samples = json_cases[json_schema_name]["samples"]
self.do_setup(model, samples)

def time_outlines_total(self, _, json_schema_name):
def time_total(self, _, json_schema_name):
json_string = json_cases[json_schema_name]["schema"]
regex_string = self.json_from_regex_fn(json.dumps(json_string))
guide = self.guide_class(regex_string, self.tokenizer)
self._exhaust_samples(guide)

def time_outlines_first_token(self, _, json_schema_name):
def time_first_token(self, _, json_schema_name):
json_string = json_cases[json_schema_name]["schema"]
regex_string = self.json_from_regex_fn(json.dumps(json_string))
guide = self.guide_class(regex_string, self.tokenizer)
Expand All @@ -134,5 +134,5 @@ def setup(self, model, json_schema_name):
self.guide = self.guide_class(regex_string, self.tokenizer)
self._get_first_token(self.guide)

def time_outlines_runtime(self, *args):
def time_runtime(self, *args):
self._exhaust_samples(self.guide)
133 changes: 118 additions & 15 deletions src/benchmark_outlines_core.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from outlines.caching import cache
import json

import outlines.caching as caching
import torch
from outlines.models.transformers import TransformerTokenizer
from outlines_core.fsm.guide import RegexGuide, create_states_mapping
from outlines_core.fsm.json_schema import build_regex_from_schema
from transformers import AutoTokenizer

from .benchmark_outlines import (
OutlinesJsonSchema,
OutlinesJsonSchemaRunTime,
OutlinesRegex,
OutlinesRegexRunTime,
)
from .data import json_cases, models, regex_cases


@cache()
@caching.cache()
def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs):
return create_states_mapping(regex_string, tokenizer, *args, **kwargs)

Expand All @@ -36,19 +36,122 @@ def from_regex(
)


class OutlinesCoreRegex(OutlinesRegex):
class OutlinesCoreBenchmark:
guide_class = CachedOutlinesCoreRegexGuide.from_regex

def do_setup(self, model, samples):
"""Set up the benchmark."""
self.tokenizer = AutoTokenizer.from_pretrained(
model, clean_up_tokenization_spaces=True
)
self.tokenizer = TransformerTokenizer(self.tokenizer)

class OutlinesCoreRegexRunTime(OutlinesRegexRunTime):
guide_class = CachedOutlinesCoreRegexGuide.from_regex
self.all_tokenized_samples = [
self.tokenizer.encode(sample)[0][0] for sample in samples
]

def _exhaust_samples(self, guide):
state = guide.initial_state
for sample_tokens in self.all_tokenized_samples:
for token in sample_tokens:
if isinstance(token, torch.Tensor):
token = token.item()
state = guide.get_next_state(state, token)
_ = guide.get_next_instruction(state)

class OutlinesCoreJsonSchema(OutlinesJsonSchema):
guide_class = CachedOutlinesCoreRegexGuide.from_regex
def _get_first_token(self, guide):
"""Get first token to verify lazy index is fully warmed up"""
state = guide.get_next_state(
guide.initial_state, self.all_tokenized_samples[0][0]
)
_ = guide.get_next_instruction(state)

def teardown(self, *args):
caching.clear_cache()


class OutlinesCoreRegex(OutlinesCoreBenchmark):
params = [models, regex_cases.keys()]
param_names = ["model", "regex_name"]
timeout = 1200

def setup(self, model, regex_name):
samples = regex_cases[regex_name]["samples"]
self.do_setup(model, samples)

def time_total(self, _, regex_name):
regex_string = regex_cases[regex_name]["regex"]
guide = self.guide_class(regex_string, self.tokenizer)
self._exhaust_samples(guide)

def time_first_token(self, _, regex_name):
regex_string = regex_cases[regex_name]["regex"]
guide = self.guide_class(regex_string, self.tokenizer)
self._get_first_token(guide)


class OutlinesCoreRegexRunTime(OutlinesCoreBenchmark):
"""Class which warms-up Guide in setup steps"""

params = [models, regex_cases.keys()]
param_names = ["model", "regex_name"]
timeout = 1200

def setup(self, model, regex_name):
samples = regex_cases[regex_name]["samples"]
self.do_setup(model, samples)

# ensure warmed up so we're only measuring runtime
regex_string = regex_cases[regex_name]["regex"]
self.guide = self.guide_class(regex_string, self.tokenizer)
self._get_first_token(self.guide)

def time_runtime(self, *args):
self._exhaust_samples(self.guide)


class OutlinesCoreJsonSchema(OutlinesCoreBenchmark):
json_from_regex_fn = lambda self, schema: build_regex_from_schema(schema)

params = [models, json_cases.keys()]
param_names = ["model", "json_schema_name"]
timeout = 1200

def setup(self, model, json_schema_name):
samples = json_cases[json_schema_name]["samples"]
self.do_setup(model, samples)

def time_total(self, _, json_schema_name):
json_string = json_cases[json_schema_name]["schema"]
regex_string = self.json_from_regex_fn(json.dumps(json_string))
guide = self.guide_class(regex_string, self.tokenizer)
self._exhaust_samples(guide)

def time_first_token(self, _, json_schema_name):
json_string = json_cases[json_schema_name]["schema"]
regex_string = self.json_from_regex_fn(json.dumps(json_string))
guide = self.guide_class(regex_string, self.tokenizer)
self._get_first_token(guide)


class OutlinesCoreJsonSchemaRunTime(OutlinesCoreBenchmark):
"""Class which warms-up Guide in setup steps"""

class OutlinesCoreJsonSchemaRunTime(OutlinesJsonSchemaRunTime):
guide_class = CachedOutlinesCoreRegexGuide.from_regex
json_from_regex_fn = lambda self, schema: build_regex_from_schema(schema)

params = [models, json_cases.keys()]
param_names = ["model", "json_schema_name"]
timeout = 1200

def setup(self, model, json_schema_name):
samples = json_cases[json_schema_name]["samples"]
self.do_setup(model, samples)

# ensure warmed up so we're only measuring runtime
json_string = json_cases[json_schema_name]["schema"]
regex_string = self.json_from_regex_fn(json.dumps(json_string))
self.guide = self.guide_class(regex_string, self.tokenizer)
self._get_first_token(self.guide)

def time_runtime(self, *args):
self._exhaust_samples(self.guide)

0 comments on commit 245fc5e

Please sign in to comment.