Skip to content

Commit

Permalink
Fix Tokenizer for several added special tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Jan 30, 2025
1 parent 38ab055 commit d6c7419
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 21 deletions.
25 changes: 14 additions & 11 deletions src/cpp/src/make_tokenizer_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,14 @@ using namespace ov;
using namespace ov::op;

bool ov::genai::MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {

std::shared_ptr<ov::Node> combine_seg_node;
for (auto node: model->get_ordered_ops()) {
if (strcmp(node->get_type_info().name, "CombineSegments") == 0) {
combine_seg_node = node;
}
}
if (!combine_seg_node || combine_seg_node->input_value(1).get_element_type() != ov::element::i32) {
return false;
}

std::shared_ptr<v0::Constant> input_1_const = std::dynamic_pointer_cast<v0::Constant>(combine_seg_node->get_input_node_shared_ptr(1));
if (!input_1_const) {
return false;
if (!combine_seg_node) {
return false;
}

op::util::VariableInfo var_info{ov::Shape{}, ov::element::boolean, ADD_SPECIAL_TOKENS_VAR_ID};
Expand All @@ -37,11 +31,20 @@ bool ov::genai::MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr
auto default_mode_const = std::make_shared<v0::Constant>(ov::element::boolean, ov::Shape{}, std::vector{true});
auto read_value = std::make_shared<v6::ReadValue>(default_mode_const, variable);
auto zero_constant = std::make_shared<v0::Constant>(ov::element::i32, ov::Shape{}, std::vector{0});
auto select_node = std::make_shared<v1::Select>(read_value, input_1_const, zero_constant);
combine_seg_node->input(1).replace_source_output(select_node->output(0));

size_t num_special_tokens = (combine_seg_node->get_input_size() - 1) / 3;
for (size_t i = 0; i < num_special_tokens; i++) {
// If input is constant then it's special tokens, otherwise it's tokens from input text.
auto const_input = std::dynamic_pointer_cast<v0::Constant>(combine_seg_node->get_input_node_shared_ptr(3*i + 1));
if (!const_input) {
continue;
}

auto select_node = std::make_shared<v1::Select>(read_value, const_input, zero_constant);
combine_seg_node->input(3*i + 1).replace_source_output(select_node);
}

auto assign = std::make_shared<v6::Assign>(read_value, variable);

model->add_sinks({assign});
model->add_variables({variable});
return true;
Expand Down
41 changes: 31 additions & 10 deletions tests/python_tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from transformers import AutoTokenizer
from typing import Dict, Tuple, List
from pathlib import Path
import openvino_genai
import json

Expand All @@ -18,6 +19,21 @@
)


def load_hf_tokenizer(model_id: str, hf_load_params: dict = None):
hf_load_params = hf_load_params or {}
return AutoTokenizer.from_pretrained(model_id, **hf_load_params)


def convert_and_load_genai_tokenizer(hf_tokenizer : AutoTokenizer, models_path: Path):
from openvino_tokenizers import convert_tokenizer
from openvino import save_model

tokenizer, detokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=True)
save_model(tokenizer, models_path / "openvino_tokenizer.xml")
save_model(detokenizer, models_path / "openvino_detokenizer.xml")
return openvino_genai.Tokenizer(models_path)


def load_genai_tokenizer_with_configs(configs: List[Tuple], temp_path):
delete_rt_info(configs, temp_path)

Expand Down Expand Up @@ -220,22 +236,27 @@ def test_set_chat_template():
'Why is the Sun yellow?',
'What was my first question?',
['Why is the Sun yellow?'],
"若我有一亿美元,在人工智能盛行的今天,我怎样投资才能收益最大化?",
"如果您有任何疑问,请联系我们,我们将予以解答。",
"מחרוזת בדיקה",
"Multiline\nstring!\nWow!",
]
@pytest.mark.parametrize("model_id", [
"katuni4ka/tiny-random-phi3",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
("black-forest-labs/FLUX.1-dev", dict(subfolder="tokenizer")), # FLUX.1-dev has tokenizer in subfolder
])
@pytest.mark.precommit
@pytest.mark.nightly
@pytest.mark.parametrize("prompt", prompts)
def test_encode_decode_with_special_tokens_option(prompt):
import numpy as np
model_descr = get_models_list()[0]
model_id, path, hf_tokenizer, model_opt, ov_pipe = read_model((model_descr[0], model_descr[1]))
ov_tokenzier = ov_pipe.get_tokenizer()
def test_special_tokens(tmp_path, prompt, model_id):
model_id, hf_load_params = (model_id[0], model_id[1]) if isinstance(model_id, tuple) else (model_id, {})
# breakpoint()
hf_tokenizer = load_hf_tokenizer(model_id, hf_load_params)
ov_tokenizer = convert_and_load_genai_tokenizer(hf_tokenizer, tmp_path)

# Calling encode with 'add_special_tokens' will set state flag.
ov_res_add_spec = ov_tokenzier.encode(prompt, add_special_tokens=True).input_ids.data
ov_res_no_spec = ov_tokenzier.encode(prompt, add_special_tokens=False).input_ids.data
ov_res_add_spec = ov_tokenizer.encode(prompt, add_special_tokens=True).input_ids.data
ov_res_no_spec = ov_tokenizer.encode(prompt, add_special_tokens=False).input_ids.data
hf_res_add_spec = hf_tokenizer(prompt, return_tensors="np", add_special_tokens=True)["input_ids"]
hf_res_no_spec = hf_tokenizer(prompt, return_tensors="np", add_special_tokens=False)["input_ids"]
assert np.all(ov_res_add_spec == hf_res_add_spec)
Expand All @@ -246,8 +267,8 @@ def test_encode_decode_with_special_tokens_option(prompt):
assert hf_res_add_spec.size != hf_res_no_spec.size

# Decode with 'skip_special_tokens'
decoded_genai_skip_spec = ov_tokenzier.decode(hf_res_add_spec, skip_special_tokens=True)[0]
decoded_genai_no_skip = ov_tokenzier.decode(hf_res_add_spec, skip_special_tokens=False)[0]
decoded_genai_skip_spec = ov_tokenizer.decode(hf_res_add_spec, skip_special_tokens=True)[0]
decoded_genai_no_skip = ov_tokenizer.decode(hf_res_add_spec, skip_special_tokens=False)[0]
decoded_hf_skip_spec = hf_tokenizer.decode(hf_res_add_spec[0], skip_special_tokens=True)
decoded_hf_no_skip = hf_tokenizer.decode(hf_res_add_spec[0], skip_special_tokens=False)
assert decoded_genai_skip_spec == decoded_hf_skip_spec
Expand Down

0 comments on commit d6c7419

Please sign in to comment.