diff --git a/src/cpp/src/make_tokenizer_stateful.cpp b/src/cpp/src/make_tokenizer_stateful.cpp index 547ecdac92..b028ac5b18 100644 --- a/src/cpp/src/make_tokenizer_stateful.cpp +++ b/src/cpp/src/make_tokenizer_stateful.cpp @@ -14,34 +14,43 @@ using namespace ov; using namespace ov::op; bool ov::genai::MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr& model) { - std::shared_ptr combine_seg_node; for (auto node: model->get_ordered_ops()) { if (strcmp(node->get_type_info().name, "CombineSegments") == 0) { combine_seg_node = node; } } - if (!combine_seg_node || combine_seg_node->input_value(1).get_element_type() != ov::element::i32) { - return false; + if (!combine_seg_node) { + return false; } - std::shared_ptr input_1_const = std::dynamic_pointer_cast(combine_seg_node->get_input_node_shared_ptr(1)); - if (!input_1_const) { - return false; + size_t num_segments = (combine_seg_node->get_input_size() - 1) / 3; + std::vector> const_inputs; + const_inputs.reserve(num_segments); + + for (size_t i = 0; i < num_segments; i++) { + // If input is constant then it's special tokens, otherwise it's tokens from input text. + auto const_input = std::dynamic_pointer_cast(combine_seg_node->get_input_node_shared_ptr(3*i + 1)); + if (const_input) { + const_inputs.emplace_back(combine_seg_node->input(3*i + 1)); + } + } + if (const_inputs.empty()) { + return false; } - - op::util::VariableInfo var_info{ov::Shape{}, ov::element::boolean, ADD_SPECIAL_TOKENS_VAR_ID}; - auto variable = std::make_shared(var_info); // Default mode is add_special_tokens. auto default_mode_const = std::make_shared(ov::element::boolean, ov::Shape{}, std::vector{true}); + auto variable = std::make_shared(op::util::VariableInfo{Shape{}, element::boolean, ADD_SPECIAL_TOKENS_VAR_ID}); auto read_value = std::make_shared(default_mode_const, variable); auto zero_constant = std::make_shared(ov::element::i32, ov::Shape{}, std::vector{0}); - auto select_node = std::make_shared(read_value, input_1_const, zero_constant); - combine_seg_node->input(1).replace_source_output(select_node->output(0)); + + for (size_t i = 0; i < const_inputs.size(); i++) { + auto select_node = std::make_shared(read_value, const_inputs[i].get_source_output(), zero_constant); + const_inputs[i].replace_source_output(select_node); + } auto assign = std::make_shared(read_value, variable); - model->add_sinks({assign}); model->add_variables({variable}); return true; diff --git a/tests/python_tests/common.py b/tests/python_tests/common.py index 88690e872a..64482e6fc0 100644 --- a/tests/python_tests/common.py +++ b/tests/python_tests/common.py @@ -8,7 +8,7 @@ from optimum.intel import OVModelForCausalLM from pathlib import Path -from openvino_genai import ContinuousBatchingPipeline, LLMPipeline, SchedulerConfig, GenerationResult, GenerationConfig, DecodedResults, StopCriteria, StreamerBase +from openvino_genai import ContinuousBatchingPipeline, LLMPipeline, SchedulerConfig, GenerationResult, GenerationConfig, DecodedResults, StopCriteria, StreamerBase, Tokenizer from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import GenerationConfig as HFGenerationConfig from typing import List, Tuple, Callable @@ -462,13 +462,17 @@ def convert_models(opt_model : OVModelForCausalLM, hf_tokenizer : AutoTokenizer, opt_model.generation_config.save_pretrained(models_path) # convert tokenizers as well + convert_and_save_tokenizer(hf_tokenizer, models_path) + + +def convert_and_save_tokenizer(hf_tokenizer : AutoTokenizer, models_path: Path): from openvino_tokenizers import convert_tokenizer - from openvino import serialize + from openvino import save_model tokenizer, detokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=True) - serialize(tokenizer, models_path / "openvino_tokenizer.xml") - serialize(detokenizer, models_path / "openvino_detokenizer.xml") - + save_model(tokenizer, models_path / "openvino_tokenizer.xml") + save_model(detokenizer, models_path / "openvino_detokenizer.xml") + def run_llm_pipeline_with_ref(model_id: str, prompts: List[str], diff --git a/tests/python_tests/test_sampling.py b/tests/python_tests/test_sampling.py index 28b2afd42a..fa445e96f1 100644 --- a/tests/python_tests/test_sampling.py +++ b/tests/python_tests/test_sampling.py @@ -72,7 +72,7 @@ def test_stop_strings(tmp_path, generation_config): def test_greedy(tmp_path, generation_config, prompt, use_cb): model_id : str = "katuni4ka/tiny-random-phi3" if sys.platform.startswith('win') and prompt.startswith('你'): - pytest.skip("For unknown reason this prompt fails on Win") + pytest.skip("CVS-160780 - Fails on Win with 'RuntimeError: No mapping for the Unicode character exists in the target multi-byte code page'") run_llm_pipeline_with_ref(model_id=model_id, prompts=[prompt], diff --git a/tests/python_tests/test_tokenizer.py b/tests/python_tests/test_tokenizer.py index d71534c2f1..c1122fab7f 100644 --- a/tests/python_tests/test_tokenizer.py +++ b/tests/python_tests/test_tokenizer.py @@ -2,19 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 import os +import sys import pytest import numpy as np from transformers import AutoTokenizer from typing import Dict, Tuple, List import openvino_genai import json - -from common import delete_rt_info +from common import delete_rt_info, convert_and_save_tokenizer from ov_genai_test_utils import ( get_models_list, get_chat_models_list, read_model, - model_tmp_path + model_tmp_path, ) @@ -220,22 +220,31 @@ def test_set_chat_template(): 'Why is the Sun yellow?', 'What was my first question?', ['Why is the Sun yellow?'], - "若我有一亿美元,在人工智能盛行的今天,我怎样投资才能收益最大化?", + "如果您有任何疑问,请联系我们,我们将予以解答。", "מחרוזת בדיקה", "Multiline\nstring!\nWow!", ] +@pytest.mark.parametrize("model_id", [ + "katuni4ka/tiny-random-phi3", + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + # ("black-forest-labs/FLUX.1-dev", dict(subfolder="tokenizer")), # FLUX.1-dev has tokenizer in subfolder +]) @pytest.mark.precommit @pytest.mark.nightly @pytest.mark.parametrize("prompt", prompts) -def test_encode_decode_with_special_tokens_option(prompt): - import numpy as np - model_descr = get_models_list()[0] - model_id, path, hf_tokenizer, model_opt, ov_pipe = read_model((model_descr[0], model_descr[1])) - ov_tokenzier = ov_pipe.get_tokenizer() +def test_special_tokens(tmp_path, prompt, model_id): + if sys.platform.startswith('win') and isinstance(prompt, str) and (prompt.startswith('如') or prompt.endswith('ה')): + pytest.skip("CVS-160780 - Fails on Win with 'RuntimeError: No mapping for the Unicode character exists in the target multi-byte code page'") + + model_id, hf_tok_load_params = (model_id[0], model_id[1]) if isinstance(model_id, tuple) else (model_id, {}) + + hf_tokenizer = AutoTokenizer.from_pretrained(model_id, **hf_tok_load_params) + convert_and_save_tokenizer(hf_tokenizer, tmp_path) + ov_tokenizer = openvino_genai.Tokenizer(tmp_path) # Calling encode with 'add_special_tokens' will set state flag. - ov_res_add_spec = ov_tokenzier.encode(prompt, add_special_tokens=True).input_ids.data - ov_res_no_spec = ov_tokenzier.encode(prompt, add_special_tokens=False).input_ids.data + ov_res_add_spec = ov_tokenizer.encode(prompt, add_special_tokens=True).input_ids.data + ov_res_no_spec = ov_tokenizer.encode(prompt, add_special_tokens=False).input_ids.data hf_res_add_spec = hf_tokenizer(prompt, return_tensors="np", add_special_tokens=True)["input_ids"] hf_res_no_spec = hf_tokenizer(prompt, return_tensors="np", add_special_tokens=False)["input_ids"] assert np.all(ov_res_add_spec == hf_res_add_spec) @@ -246,8 +255,8 @@ def test_encode_decode_with_special_tokens_option(prompt): assert hf_res_add_spec.size != hf_res_no_spec.size # Decode with 'skip_special_tokens' - decoded_genai_skip_spec = ov_tokenzier.decode(hf_res_add_spec, skip_special_tokens=True)[0] - decoded_genai_no_skip = ov_tokenzier.decode(hf_res_add_spec, skip_special_tokens=False)[0] + decoded_genai_skip_spec = ov_tokenizer.decode(hf_res_add_spec, skip_special_tokens=True)[0] + decoded_genai_no_skip = ov_tokenizer.decode(hf_res_add_spec, skip_special_tokens=False)[0] decoded_hf_skip_spec = hf_tokenizer.decode(hf_res_add_spec[0], skip_special_tokens=True) decoded_hf_no_skip = hf_tokenizer.decode(hf_res_add_spec[0], skip_special_tokens=False) assert decoded_genai_skip_spec == decoded_hf_skip_spec