diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py index 778c27c6f..ea7eb19ee 100644 --- a/outlines/models/vllm.py +++ b/outlines/models/vllm.py @@ -208,20 +208,31 @@ def adapt_tokenizer(tokenizer: "PreTrainedTokenizerBase") -> "PreTrainedTokenize tokenizer.vocabulary = tokenizer.get_vocab() tokenizer.special_tokens = set(tokenizer.all_special_tokens) + tokenizer.convert_token_to_string = convert_token_to_string + return tokenizer + - def convert_token_to_string(token: Union[str, bytes]) -> str: - string = tokenizer.convert_tokens_to_string([token]) +def convert_token_to_string( + token: Union[str, bytes], tokenizer: PreTrainedTokenizerBase +) -> str: + """Convert a token to a string. - # A hack to handle missing spaces to HF's Llama tokenizers - if ( - type(token) is str - and token.startswith(SPIECE_UNDERLINE) - or token == "<0x20>" - ): - return " " + string + Parameters + ---------- + token + The token to convert. + tokenizer + The tokenizer of the model. - return string + Returns + ------- + str + The string representation of the token. + """ + string = tokenizer.convert_tokens_to_string([token]) - tokenizer.convert_token_to_string = convert_token_to_string + # A hack to handle missing spaces to HF's Llama tokenizers + if type(token) is str and token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string - return tokenizer + return string diff --git a/tests/models/test_vllm.py b/tests/models/test_vllm.py new file mode 100644 index 000000000..1e2de4aae --- /dev/null +++ b/tests/models/test_vllm.py @@ -0,0 +1,30 @@ +"""Tests for the `vllm` module.""" + +import pytest +from transformers import SPIECE_UNDERLINE, AutoTokenizer + +from outlines.models.vllm import adapt_tokenizer, convert_token_to_string + +TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM" + + +def test_adapt_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL, padding_side="left") + adapted_tokenizer = adapt_tokenizer(tokenizer=tokenizer) + assert hasattr(adapted_tokenizer, "vocabulary") + assert hasattr(adapted_tokenizer, "special_tokens") + assert adapted_tokenizer.convert_token_to_string == convert_token_to_string + + +@pytest.mark.parametrize( + "token, expected", + [ + ("baz", "baz"), + ("<0x20>", " <0x20>"), + (SPIECE_UNDERLINE, f" {SPIECE_UNDERLINE}"), + ], +) +def test_convert_token_to_string(token, expected): + tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL, padding_side="left") + output = convert_token_to_string(token=token, tokenizer=tokenizer) + assert output == expected