From f4c03c0966a35528df8ba1cc3342c5b1da4e209e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 24 Jun 2024 17:39:41 +0200 Subject: [PATCH] llama : add handling of byte tokens in UGM tokenizer (same as in SPM) llama : fix preventing crashes when precompiled_charsmap is not present --- llama.cpp | 76 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/llama.cpp b/llama.cpp index 5c6f80d4a1aa4..d7050894250fd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13335,7 +13335,8 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { GGML_ASSERT(llama_is_byte_token(vocab, id)); const auto & token_data = vocab.id_to_token.at(id); switch (llama_vocab_get_type(vocab)) { - case LLAMA_VOCAB_TYPE_SPM: { + case LLAMA_VOCAB_TYPE_SPM: + case LLAMA_VOCAB_TYPE_UGM: { auto buf = token_data.text.substr(3, 2); return strtol(buf.c_str(), NULL, 16); } @@ -13355,7 +13356,8 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); static const char * hex = "0123456789ABCDEF"; switch (llama_vocab_get_type(vocab)) { - case LLAMA_VOCAB_TYPE_SPM: { + case LLAMA_VOCAB_TYPE_SPM: + case LLAMA_VOCAB_TYPE_UGM: { const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; auto token = vocab.token_to_id.find(buf); if (token != vocab.token_to_id.end()) { @@ -14242,36 +14244,38 @@ struct llm_tokenizer_ugm { size_t longest_prefix_length = 0; size_t longest_prefix_offset = 0; - struct xcda_array_view xcda_view(xcda_array, xcda_array_size); - - // Find the longest normalized sequence matching the input prefix by walking - // the XOR-compressed compact double array (XCDA) starting from the root node - // We find the index of the next node by calculating BASE[s] ^ c where s is - // the index of the previous node and c is a numerical character value - uint32_t node_index = 0; - // get BASE of the root node - node_index = xcda_view.get_base(node_index); - for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) { - unsigned char c = input[prefix_offset]; - if (c == 0) { - break; - } - node_index ^= c; - // if value of LCHECK is not c it means that this is not a child of - // the previous node, so we stop matching - if (xcda_view.get_lcheck(node_index) != c) { - break; - } - bool is_leaf = xcda_view.get_leaf(node_index); - // get BASE of the current node - node_index ^= xcda_view.get_base(node_index); - // if LEAF of the current node is true, it means that its BASE points to the node - // containing index of replacement sequence for currently matched input prefix - if (is_leaf) - { - longest_prefix_length = prefix_offset - input_offset + 1; - // get index of replacement sequence for currently matched input prefix - longest_prefix_offset = xcda_view.get_value(node_index); + if (xcda_array_size > 0) { + struct xcda_array_view xcda_view(xcda_array, xcda_array_size); + + // Find the longest normalized sequence matching the input prefix by walking + // the XOR-compressed compact double array (XCDA) starting from the root node + // We find the index of the next node by calculating BASE[s] ^ c where s is + // the index of the previous node and c is a numerical character value + uint32_t node_index = 0; + // get BASE of the root node + node_index = xcda_view.get_base(node_index); + for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) { + unsigned char c = input[prefix_offset]; + if (c == 0) { + break; + } + node_index ^= c; + // if value of LCHECK is not c it means that this is not a child of + // the previous node, so we stop matching + if (xcda_view.get_lcheck(node_index) != c) { + break; + } + bool is_leaf = xcda_view.get_leaf(node_index); + // get BASE of the current node + node_index ^= xcda_view.get_base(node_index); + // if LEAF of the current node is true, it means that its BASE points to the node + // containing index of replacement sequence for currently matched input prefix + if (is_leaf) + { + longest_prefix_length = prefix_offset - input_offset + 1; + // get index of replacement sequence for currently matched input prefix + longest_prefix_offset = xcda_view.get_value(node_index); + } } } @@ -14299,11 +14303,11 @@ struct llm_tokenizer_ugm { // escaped space symbol - U+2581 (Lower One Eighth Block) const std::string escaped_space = "\xE2\x96\x81"; - char * prefix_replacements; - size_t prefix_replacements_size; + char * prefix_replacements = NULL; + size_t prefix_replacements_size = 0; - uint32_t * xcda_array; - size_t xcda_array_size; + uint32_t * xcda_array = NULL; + size_t xcda_array_size = 0; struct naive_trie user_defined_token_matcher;