From f208928b49e3f952fc5d9aae205b8f2c7524a810 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 21 Oct 2024 11:37:04 +0200 Subject: [PATCH] Add tokenizer data build time to `lm-format-enforcer` timings Indeed, this needs to be run every time one starts a new process to perform structured generation. This is equivalent to `outlines`'s compilation step. --- src/benchmark_lfe.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/benchmark_lfe.py b/src/benchmark_lfe.py index dddfc75..680f285 100644 --- a/src/benchmark_lfe.py +++ b/src/benchmark_lfe.py @@ -13,7 +13,6 @@ def do_setup(self, model, samples): self.tokenizer = AutoTokenizer.from_pretrained( model, clean_up_tokenization_spaces=True ) - self.tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer) self.all_tokenized_samples = [ self.tokenizer.encode(sample) for sample in samples ] @@ -27,9 +26,6 @@ def _get_first_token(self, token_enforcer): """Get first token to verify lazy index is fully warmed up""" _ = token_enforcer.get_allowed_tokens(self.all_tokenized_samples[0][:1]) - def teardown(self, *args): - del self.tokenizer_data - class LMFormatEnforcerRegex(LMFormatEnforcerBenchmark): params = [models, regex_cases.keys()] @@ -43,7 +39,8 @@ def setup(self, model, regex_name): def _get_enforcer(self, regex_name): pattern = regex_cases[regex_name]["regex"] parser = RegexParser(pattern) - return TokenEnforcer(self.tokenizer_data, parser) + tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer) + return TokenEnforcer(tokenizer_data, parser) def time_lfe_total(self, _, regex_name): enforcer = self._get_enforcer(regex_name) @@ -87,7 +84,8 @@ def setup(self, model, json_schema_name): def _get_enforcer(self, json_schema_name): schema = json_cases[json_schema_name]["schema"] parser = JsonSchemaParser(schema) - return TokenEnforcer(self.tokenizer_data, parser) + tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer) + return TokenEnforcer(tokenizer_data, parser) def time_lfe_total(self, _, json_schema_name): enforcer = self._get_enforcer(json_schema_name)