diff --git a/.devops/musa.Dockerfile b/.devops/musa.Dockerfile index bfd7fc1c1740f..1e87737abfb71 100644 --- a/.devops/musa.Dockerfile +++ b/.devops/musa.Dockerfile @@ -1,6 +1,6 @@ ARG UBUNTU_VERSION=22.04 # This needs to generally match the container host's environment. -ARG MUSA_VERSION=rc3.1.0 +ARG MUSA_VERSION=rc3.1.1 # Target the MUSA build image ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION} diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6841ba5897809..62f4ed8742778 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -403,6 +403,34 @@ jobs: # This is using llvmpipe and runs slower than other backends ctest -L main --verbose --timeout 1800 + - name: Determine tag name + id: tag + shell: bash + run: | + BUILD_NUMBER="$(git rev-list --count HEAD)" + SHORT_HASH="$(git rev-parse --short=7 HEAD)" + if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then + echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT + else + SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') + echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT + fi + + - name: Pack artifacts + id: pack_artifacts + if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} + run: | + cp LICENSE ./build/bin/ + cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp + zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/* + + - name: Upload artifacts + if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip + name: llama-bin-ubuntu-vulkan-x64.zip + ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 container: rocm/dev-ubuntu-22.04:6.0.2 @@ -443,7 +471,7 @@ jobs: ubuntu-22-cmake-musa: runs-on: ubuntu-22.04 - container: mthreads/musa:rc3.1.0-devel-ubuntu22.04 + container: mthreads/musa:rc3.1.1-devel-ubuntu22.04 steps: - name: Clone diff --git a/README.md b/README.md index 11f3d0286aa2b..7629647d7e420 100644 --- a/README.md +++ b/README.md @@ -235,6 +235,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo | [HIP](docs/build.md#hip) | AMD GPU | | [Vulkan](docs/build.md#vulkan) | GPU | | [CANN](docs/build.md#cann) | Ascend NPU | +| [OpenCL](docs/backend/OPENCL.md) | Adreno GPU | ## Building the project @@ -518,5 +519,18 @@ If your issue is with model generation quality, then please at least scan the fo - [Aligning language models to follow instructions](https://openai.com/research/instruction-following) - [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155) -#### References - +## Completions +Command-line completion is available for some environments. + +#### Bash Completion +```bash +$ build/bin/llama-cli --completion-bash > ~/.llama-completion.bash +$ source ~/.llama-completion.bash +``` +Optionally this can be added to your `.bashrc` or `.bash_profile` to load it +automatically. For example: +```console +$ echo "source ~/.llama-completion.bash" >> ~/.bashrc +``` + +## References diff --git a/common/arg.cpp b/common/arg.cpp index f7af09c7166cd..144e93fa9b867 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -365,6 +365,108 @@ static void common_params_print_usage(common_params_context & ctx_arg) { print_options(specific_options); } +static void common_params_print_completion(common_params_context & ctx_arg) { + std::vector common_options; + std::vector sparam_options; + std::vector specific_options; + + for (auto & opt : ctx_arg.options) { + if (opt.is_sparam) { + sparam_options.push_back(&opt); + } else if (opt.in_example(ctx_arg.ex)) { + specific_options.push_back(&opt); + } else { + common_options.push_back(&opt); + } + } + + printf("_llama_completions() {\n"); + printf(" local cur prev opts\n"); + printf(" COMPREPLY=()\n"); + printf(" cur=\"${COMP_WORDS[COMP_CWORD]}\"\n"); + printf(" prev=\"${COMP_WORDS[COMP_CWORD-1]}\"\n\n"); + + printf(" opts=\""); + auto print_options = [](const std::vector & options) { + for (const common_arg * opt : options) { + for (const char * arg : opt->args) { + printf("%s ", arg); + } + } + }; + + print_options(common_options); + print_options(sparam_options); + print_options(specific_options); + printf("\"\n\n"); + + printf(" case \"$prev\" in\n"); + printf(" --model)\n"); + printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n"); + printf(" return 0\n"); + printf(" ;;\n"); + printf(" --grammar-file)\n"); + printf(" COMPREPLY=( $(compgen -f -X '!*.gbnf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n"); + printf(" return 0\n"); + printf(" ;;\n"); + printf(" *)\n"); + printf(" COMPREPLY=( $(compgen -W \"${opts}\" -- \"$cur\") )\n"); + printf(" return 0\n"); + printf(" ;;\n"); + printf(" esac\n"); + printf("}\n\n"); + + std::set executables = { + "llama-batched", + "llama-batched-bench", + "llama-bench", + "llama-cli", + "llama-convert-llama2c-to-ggml", + "llama-cvector-generator", + "llama-embedding", + "llama-eval-callback", + "llama-export-lora", + "llama-gbnf-validator", + "llama-gen-docs", + "llama-gguf", + "llama-gguf-hash", + "llama-gguf-split", + "llama-gritlm", + "llama-imatrix", + "llama-infill", + "llama-llava-cli", + "llama-llava-clip-quantize-cli", + "llama-lookahead", + "llama-lookup", + "llama-lookup-create", + "llama-lookup-merge", + "llama-lookup-stats", + "llama-minicpmv-cli", + "llama-parallel", + "llama-passkey", + "llama-perplexity", + "llama-q8dot", + "llama-quantize", + "llama-quantize-stats", + "llama-qwen2vl-cli", + "llama-retrieval", + "llama-run", + "llama-save-load-state", + "llama-server", + "llama-simple", + "llama-simple-chat", + "llama-speculative", + "llama-speculative-simple", + "llama-tokenize", + "llama-tts", + "llama-vdot" + }; + + for (const auto& exe : executables) { + printf("complete -F _llama_completions %s\n", exe.c_str()); + } +} + static std::vector parse_device_list(const std::string & value) { std::vector devices; auto dev_names = string_split(value, ','); @@ -426,6 +528,10 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e } exit(0); } + if (ctx_arg.params.completion) { + common_params_print_completion(ctx_arg); + exit(0); + } } catch (const std::invalid_argument & ex) { fprintf(stderr, "%s\n", ex.what()); ctx_arg.params = params_org; @@ -494,6 +600,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex exit(0); } )); + add_opt(common_arg( + {"--completion-bash"}, + "print source-able bash completion script for llama.cpp", + [](common_params & params) { + params.completion = true; + } + )); add_opt(common_arg( {"--verbose-prompt"}, string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"), @@ -946,6 +1059,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.min_p = std::stof(value); } ).set_sparam()); + add_opt(common_arg( + {"--top-nsigma"}, "N", + string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma), + [](common_params & params, const std::string & value) { + params.sampling.top_n_sigma = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_MAIN}).set_sparam()); add_opt(common_arg( {"--xtc-probability"}, "N", string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), @@ -1975,6 +2095,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.use_jinja = true; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA")); + add_opt(common_arg( + {"--reasoning-format"}, "FORMAT", + "reasoning format (default: deepseek; allowed values: deepseek, none)\n" + "controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).\n" + "only supported for non-streamed responses", + [](common_params & params, const std::string & value) { + /**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; } + else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; } + else { std::invalid_argument("invalid value"); } + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( diff --git a/common/chat.cpp b/common/chat.cpp index ef1c6fb3d3978..5b8e280aae341 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -12,11 +12,13 @@ std::string common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; + case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)"; case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; + case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)"; default: throw std::runtime_error("Unknown chat format"); } @@ -105,7 +107,6 @@ static common_chat_msg parse_json_tool_calls( std::sregex_iterator rend; std::sregex_iterator rit(it, end, function_regex); if (rit == rend) { - fprintf(stderr, "No more tool calls found\n"); result.content += std::string(it, end); break; } @@ -115,14 +116,21 @@ static common_chat_msg parse_json_tool_calls( json arguments; if (!parse_json(it, end, arguments)) { - throw std::runtime_error("Failed to parse json tool call arguments"); + throw std::runtime_error("Failed to parse json tool call arguments: " + input); } if (!std::regex_search(it, end, match, close_regex)) { - throw std::runtime_error("Malformed input, missing closing pattern"); + throw std::runtime_error("Malformed input, missing closing pattern: " + input); } it = match.suffix().first; result.tool_calls.push_back({name, arguments.is_string() ? arguments.get() : arguments.dump(), /* id= */ ""}); } + + if (!result.tool_calls.empty()) { + if (!string_strip(result.content).empty()) { + LOG_WRN("Content found with tool calls: %s\n", result.content.c_str()); + } + result.content = ""; + } return result; } @@ -134,11 +142,11 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in result.role = "assistant"; const auto process_tool_calls = [&](const json & tool_calls) { for (const auto & tool_call : tool_calls) { - const auto & arguments = tool_call["arguments"]; + const auto & arguments = tool_call.at("arguments"); result.tool_calls.push_back({ - tool_call["name"], + tool_call.at("name"), arguments.is_string() ? arguments.get() : arguments.dump(), - tool_call.contains("id") ? tool_call["id"] : "", + tool_call.contains("id") ? tool_call.at("id") : "", }); } }; @@ -155,7 +163,7 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in static void foreach_function(const json & tools, const std::function & fn) { for (const auto & tool : tools) { - if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) { + if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str()); continue; } @@ -190,27 +198,27 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp auto tool_call_schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; + const auto & function = tool.at("function"); auto tool_schema = json { {"type", "object"}, {"properties", { {"name", { {"type", "string"}, - {"const", function["name"]}, + {"const", function.at("name")}, }}, - {"arguments", function["parameters"]}, + {"arguments", function.at("parameters")}, }}, {"required", json::array({"name", "arguments"})}, }; if (function.contains("description")) { - tool_schema["description"] = function["description"]; + tool_schema["description"] = function.at("description"); } if (inputs.parallel_tool_calls) { - tool_schema["properties"]["id"] = { + tool_schema.at("properties")["id"] = { {"type", "string"}, {"minLength", 4}, }; - tool_schema["required"].push_back("id"); + tool_schema.at("required").push_back("id"); } tool_call_schemas.emplace_back(tool_schema); }); @@ -275,21 +283,21 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) { common_chat_msg result; result.role = "assistant"; if (data.contains("tool_calls")) { - for (const auto & tool_call : data["tool_calls"]) { + for (const auto & tool_call : data.at("tool_calls")) { result.tool_calls.push_back({ - tool_call["name"], - tool_call["arguments"].dump(), - tool_call.contains("id") ? tool_call["id"] : "", + tool_call.at("name"), + tool_call.at("arguments").dump(), + tool_call.contains("id") ? tool_call.at("id") : "", }); } } else if (data.contains("tool_call")) { result.tool_calls.push_back({ - data["tool_call"]["name"], - data["tool_call"]["arguments"].dump(), + data.at("tool_call").at("name"), + data.at("tool_call").at("arguments").dump(), /* id= */ "", }); } else if (data.contains("response")) { - const auto & response = data["response"]; + const auto & response = data.at("response"); result.content = response.is_string() ? response.get() : response.dump(2); } return result; @@ -301,7 +309,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; + const auto & function = tool.at("function"); schemas.push_back({ {"type", "object"}, {"properties", { @@ -309,9 +317,9 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. {"name", { {"type", "string"}, - {"const", function["name"]}, + {"const", function.at("name")}, }}, - {"arguments", function["parameters"]}, + {"arguments", function.at("parameters")}, {"id", { {"type", "string"}, // Nemo's template expects a 9-character alphanumeric ID. @@ -346,7 +354,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; + const auto & function = tool.at("function"); schemas.push_back({ {"type", "object"}, {"properties", { @@ -357,9 +365,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ }}, {"tool_name", { {"type", "string"}, - {"const", function["name"]}, + {"const", function.at("name")}, }}, - {"parameters", function["parameters"]}, + {"parameters", function.at("parameters")}, }}, {"required", json::array({"tool_call_id", "tool_name", "parameters"})}, }); @@ -382,39 +390,65 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ "<|END_THINKING|>", "<|END_ACTION|>", }; - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); - data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; + auto adjusted_messages = json::array(); + for (const auto & msg : inputs.messages) { + auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); + auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); + if (has_reasoning_content && has_tool_calls) { + auto adjusted_message = msg; + adjusted_message["tool_plan"] = msg.at("reasoning_content"); + adjusted_message.erase("reasoning_content"); + adjusted_messages.push_back(adjusted_message); + } else { + adjusted_messages.push_back(msg); + } + } + data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); + data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B; return data; } -static common_chat_msg common_chat_parse_command_r7b(const std::string & input) { - static std::regex response_regex("<\\|START_RESPONSE\\|>([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>"); - static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>"); +static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) { + static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)"); + static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>"); + static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>"); + std::smatch match; common_chat_msg result; result.role = "assistant"; - if (std::regex_match(input, match, response_regex)) { - result.content = match[1].str(); - } else if (std::regex_match(input, match, thought_action_regex)) { - result.tool_plan = match[1].str(); - auto actions_str = match[2].str(); + + std::string rest = input; + + if (std::regex_match(rest, match, thought_regex)) { + if (extract_reasoning) { + result.reasoning_content = match[2].str(); + } else if (!match[2].str().empty()) { + // Let the unparsed thinking tags through in content only if their insides aren't empty. + result.content = match[1].str(); + } + rest = match[3].str(); + } + if (std::regex_match(rest, match, action_regex)) { + auto actions_str = match[1].str(); auto actions = json::parse(actions_str); for (const auto & action : actions) { result.tool_calls.push_back({ - /* .name = */ action["tool_name"], - /* .arguments = */ action["parameters"].dump(), - /* .id = */ action["tool_call_id"], + /* .name = */ action.at("tool_name"), + /* .arguments = */ action.at("parameters").dump(), + /* .id = */ action.at("tool_call_id"), }); } + } else if (std::regex_match(rest, match, response_regex)) { + auto response = match[1].str(); + result.content += response; } else { - LOG_ERR("Failed to parse command_r output"); - result.content = input; + result.content += rest; } return result; } static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { - if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) { + if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) { throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); } const auto & parameters_properties = parameters.at("properties"); @@ -468,9 +502,9 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com }; foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); builder.resolve_refs(parameters); // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime @@ -546,34 +580,90 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { common_chat_params data; - data.grammar_lazy = inputs.tool_choice != "required"; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - auto args_rule = builder.add_schema(name + "-args", parameters); - tool_rules.push_back(builder.add_rule(name + "-call", - "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); - }); - data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); - data.preserved_tokens = { - "<|tool▁sep|>", - "<|tool▁call▁end|>", - }; - builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space"); - }, grammar_options); + if (inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null(); + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + auto args_rule = builder.add_schema(name + "-args", parameters); + tool_rules.push_back(builder.add_rule(name + "-call", + "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n" + "```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); + }); + // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, + // so we accept common variants (then it's all constrained) + builder.add_rule("root", + "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) " + "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " + "\"<|tool▁calls▁end|>\"" + " space"); + data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false}); + data.preserved_tokens = { + "", + "", + "<|tool▁sep|>", + "<|tool▁calls▁end|", + "<|tool▁call▁end|>", + }; + }, grammar_options); + } auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + + // Hacks to fix the official (broken) prompt. + // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead, + // until the official template is fixed. + if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) { + // Don't leave the chat dangling after tool results + if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) { + prompt += "<|end▁of▁sentence|>"; + if (inputs.add_generation_prompt) { + prompt += "<|Assistant|>"; + } + } + // Fix up tool call delta example added by Minja + prompt = std::regex_replace( + prompt, + std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"), + "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2"); + } data.prompt = prompt; - data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; + data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1; return data; } -static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) { - static std::regex trigger_regex("<|tool▁calls▁begin|>"); +static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) { static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); - static std::regex close_regex("```<|tool▁call▁end|>"); - return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex); + static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); + static std::regex reasoning_content_regex("((?:)?([\\s\\S\\r\\n]*?))?([\\s\\S\\r\\n]*)"); + static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>"); + common_chat_msg msg; + msg.role = "assistant"; + std::smatch match; + if (std::regex_match(input, match, reasoning_content_regex)) { + std::string rest; + if (extract_reasoning) { + msg.reasoning_content = string_strip(match[2].str()); + } else { + msg.content = match[1].str(); + } + rest = match[3].str(); + + if (std::regex_search(rest, match, tool_calls_regex)) { + auto tool_calls = match[1].str(); + auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex); + msg.tool_calls = std::move(msg2.tool_calls); + } else { + msg.content += std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end()); + } + } else { + msg.content = input; + } + return msg; } static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { @@ -583,20 +673,20 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c {"datetime", "Jan 29 2025 13:00:00 GMT"}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, }); - if (!inputs.tools.is_null() && !inputs.tools.empty()) { + if (inputs.tools.is_array() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; + const auto & function = tool.at("function"); schemas.push_back({ {"type", "object"}, {"properties", { {"name", { {"type", "string"}, - {"const", function["name"]}, + {"const", function.at("name")}, }}, - {"arguments", function["parameters"]}, + {"arguments", function.at("parameters")}, }}, {"required", json::array({"name", "arguments", "id"})}, }); @@ -628,15 +718,15 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ common_chat_params data; data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; - if (!inputs.tools.is_null() && !inputs.tools.empty()) { + if (inputs.tools.is_array() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector first_tool_rules; std::vector subsequent_tool_rules; foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); auto args_rule = builder.add_schema(name + "-args", parameters); first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); @@ -716,9 +806,9 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; - const auto & parameters = function["parameters"]; - std::string name = function["name"]; + const auto & function = tool.at("function"); + const auto & parameters = function.at("parameters"); + std::string name = function.at("name"); if (name == "python" || name == "ipython") { if (!parameters.contains("type")) { throw std::runtime_error("Missing type in python tool"); @@ -789,9 +879,9 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); builder.resolve_refs(parameters); tool_rules.push_back(builder.add_schema(name + "-call", { {"type", "object"}, @@ -839,9 +929,9 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) if (!parse_json(it, end, call)) { throw std::runtime_error("Failed to parse json tool call"); } - const auto & arguments = call["arguments"]; + const auto & arguments = call.at("arguments"); result.tool_calls.push_back({ - call["name"], + call.at("name"), arguments.dump(), // arguments.is_string() ? arguments.get() : arguments.dump(), /* id= */ "", @@ -884,47 +974,72 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha } common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { - auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none"; - LOG_DBG("[%s] has_tools=%s\n", __func__, has_tools ? "true" : "false"); + const auto & src = tmpl.source(); + const auto & caps = tmpl.original_caps(); - if (has_tools && !inputs.grammar.empty()) { - throw std::runtime_error("Cannot specify grammar with tools"); + if (inputs.tools.is_array()) { + if (inputs.tool_choice != "none" && !inputs.grammar.empty()) { + throw std::runtime_error("Cannot specify grammar with tools"); + } + if (caps.supports_tool_calls && !caps.supports_tools) { + LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); + } } - const auto & src = tmpl.source(); + // DeepSeek R1: use handler in all cases except json schema (thinking / tools). + if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) { + return common_chat_params_init_deepseek_r1(tmpl, inputs); + } + + // Command R7B: : use handler in all cases except json schema (thinking / tools). + if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) { + return common_chat_params_init_command_r7b(tmpl, inputs); + } + + // Use generic handler when mixing tools + JSON schema. + // TODO: support that mix in handlers below. + if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) { + return common_chat_params_init_generic(tmpl, inputs); + } + + // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. if (src.find(">>>all") != std::string::npos) { - // Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when return common_chat_params_init_functionary_v3_2(tmpl, inputs); } + + // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. if (src.find(" functools[") != std::string::npos) { - // Firefunction v2 requires datetime and functions in the context, even w/o tools. return common_chat_params_init_firefunction_v2(tmpl, inputs); } - if (!has_tools) { + // Plain handler (no tools) + if (inputs.tools.is_null() || inputs.tool_choice == "none") { return common_chat_params_init_without_tools(tmpl, inputs); } + // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) if (src.find("") != std::string::npos) { return common_chat_params_init_hermes_2_pro(tmpl, inputs); } + + // Functionary v3.1 (w/ tools) if (src.find("<|start_header_id|>") != std::string::npos && src.find("ipython<|end_header_id|>") != std::string::npos) { auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools); } - if (src.find("<|tool▁calls▁begin|>") != std::string::npos) { - return common_chat_params_init_deepseek_r1(tmpl, inputs); - } + + // Mistral Nemo (w/ tools) if (src.find("[TOOL_CALLS]") != std::string::npos) { return common_chat_params_init_mistral_nemo(tmpl, inputs); } - if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) { - return common_chat_params_init_command_r7b(tmpl, inputs); - } + + // Generic fallback return common_chat_params_init_generic(tmpl, inputs); } @@ -949,7 +1064,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true); case COMMON_CHAT_FORMAT_DEEPSEEK_R1: - return common_chat_parse_deepseek_r1(input); + return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false); + case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: + return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true); case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return common_chat_parse_functionary_v3_2(input); case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: @@ -959,7 +1076,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return common_chat_parse_firefunction_v2(input); case COMMON_CHAT_FORMAT_COMMAND_R7B: - return common_chat_parse_command_r7b(input); + return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false); + case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: + return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true); default: throw std::runtime_error("Unsupported format: " + common_chat_format_name(format)); } diff --git a/common/chat.hpp b/common/chat.hpp index 33e64a430d51e..ba1632f669cf7 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -19,6 +19,7 @@ struct common_chat_inputs { bool stream; std::string grammar; bool add_generation_prompt = true; + bool extract_reasoning = true; }; enum common_chat_format { @@ -28,11 +29,13 @@ enum common_chat_format { COMMON_CHAT_FORMAT_LLAMA_3_X, COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, COMMON_CHAT_FORMAT_DEEPSEEK_R1, + COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, COMMON_CHAT_FORMAT_FIREFUNCTION_V2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_HERMES_2_PRO, COMMON_CHAT_FORMAT_COMMAND_R7B, + COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; diff --git a/common/common.h b/common/common.h index b208d0c7ece59..98b9a4464787a 100644 --- a/common/common.h +++ b/common/common.h @@ -140,6 +140,7 @@ struct common_params_sampling { int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float top_n_sigma = -1.00f;// -1.0 = disabled float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; @@ -202,6 +203,11 @@ struct common_params_vocoder { bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT }; +enum common_reasoning_format { + COMMON_REASONING_FORMAT_NONE, + COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content` +}; + struct common_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size @@ -292,6 +298,7 @@ struct common_params { bool kl_divergence = false; // compute KL divergence bool usage = false; // print usage + bool completion = false; // print source-able completion script bool use_color = false; // use color to distinguish generations and inputs bool special = false; // enable special token output bool interactive = false; // interactive mode @@ -346,6 +353,7 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; std::vector api_keys; @@ -424,13 +432,13 @@ bool set_process_priority(enum ggml_sched_priority prio); // #ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) -#else -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif +# if defined(__MINGW32__) && !defined(__clang__) +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# else +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +# endif #else -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) #endif LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) @@ -623,7 +631,7 @@ struct common_chat_msg { std::string role; std::string content; std::vector tool_calls; - std::string tool_plan = ""; + std::string reasoning_content = ""; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid diff --git a/common/log.cpp b/common/log.cpp index 4bfbecf15e314..52b31470c46bd 100644 --- a/common/log.cpp +++ b/common/log.cpp @@ -1,5 +1,6 @@ #include "log.h" +#include #include #include #include diff --git a/common/log.h b/common/log.h index 4ebc6314b25d6..c56bb50d95db0 100644 --- a/common/log.h +++ b/common/log.h @@ -15,7 +15,7 @@ #ifndef __GNUC__ # define LOG_ATTRIBUTE_FORMAT(...) -#elif defined(__MINGW32__) +#elif defined(__MINGW32__) && !defined(__clang__) # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) #else # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) diff --git a/common/sampling.cpp b/common/sampling.cpp index e4b21ca1011dd..37a0d9c85ae30 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -134,11 +134,11 @@ std::string common_params_sampling::print() const { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" - "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" + "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, - top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, + top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, mirostat, mirostat_eta, mirostat_tau); return std::string(result); @@ -151,12 +151,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co lparams.no_perf = params.no_perf; - std::vector trigger_words; - trigger_words.reserve(params.grammar_trigger_words.size()); - for (const auto & str : params.grammar_trigger_words) { - trigger_words.push_back(str.word.c_str()); - } - struct llama_sampler * grmr; if (params.grammar.compare(0, 11, "%llguidance") == 0) { #ifdef LLAMA_USE_LLGUIDANCE @@ -165,6 +159,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE } else { + std::vector trigger_words; + trigger_words.reserve(params.grammar_trigger_words.size()); + for (const auto & str : params.grammar_trigger_words) { + trigger_words.push_back(str.word.c_str()); + } + grmr = params.grammar_lazy ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", trigger_words.data(), trigger_words.size(), @@ -188,45 +188,51 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co params.logit_bias.data())); if (params.mirostat == 0) { - for (const auto & cnstr : params.samplers) { - switch (cnstr) { - case COMMON_SAMPLER_TYPE_DRY: - { - std::vector c_breakers; - c_breakers.reserve(params.dry_sequence_breakers.size()); - for (const auto & str : params.dry_sequence_breakers) { - c_breakers.push_back(str.c_str()); + if (params.top_n_sigma >= 0) { + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); + } else { + for (const auto & cnstr : params.samplers) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: + { + std::vector c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto & str : params.dry_sequence_breakers) { + c_breakers.push_back(str.c_str()); + } + + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } - - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); - } - break; - case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); - break; - case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); - break; - case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); - break; - case COMMON_SAMPLER_TYPE_PENALTIES: - llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); - break; - default: - GGML_ASSERT(false && "unknown sampler type"); + break; + case COMMON_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + break; + case COMMON_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_XTC: + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + case COMMON_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); + break; + case COMMON_SAMPLER_TYPE_PENALTIES: + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + default: + GGML_ASSERT(false && "unknown sampler type"); + } } } llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); diff --git a/docs/backend/OPENCL.md b/docs/backend/OPENCL.md new file mode 100644 index 0000000000000..a604058cbeb97 --- /dev/null +++ b/docs/backend/OPENCL.md @@ -0,0 +1,205 @@ +# llama.cpp for OpenCL + +- [Background](#background) +- [OS](#os) +- [Hardware](#hardware) +- [DataType Supports](#datatype-supports) +- [Model Preparation](#model-preparation) +- [CMake Options](#cmake-options) +- [Android](#android) +- [Windows 11 Arm64](#windows-11-arm64) +- [Known Issue](#known-issues) +- [TODO](#todo) + +## Background + +OpenCL (Open Computing Language) is an open, royalty-free standard for cross-platform, parallel programming of diverse accelerators found in supercomputers, cloud servers, personal computers, mobile devices and embedded platforms. OpenCL specifies a programming language (based on C99) for programming these devices and application programming interfaces (APIs) to control the platform and execute programs on the compute devices. Similar to CUDA, OpenCL has been widely used to program GPUs and is supported by most GPU vendors. + +### Llama.cpp + OpenCL + +The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adreno GPU** firstly via OpenCL. Thanks to the portabilty of OpenCL, the OpenCL backend can also run on certain Intel GPUs although the performance is not optimal. + +## OS + +| OS | Status | Verified | +|---------|---------|------------------------------------------------| +| Android | Support | Snapdragon 8 Gen 3, Snapdragon 8 Elite | +| Windows | Support | Windows 11 Arm64 with Snapdragon X Elite | +| Linux | Support | Ubuntu 22.04 WSL2 with Intel 12700H | + +## Hardware + +### Adreno GPU + +**Verified devices** + +| Adreno GPU | Status | +|:------------------------------------:|:-------:| +| Adreno 750 (Snapdragon 8 Gen 3) | Support | +| Adreno 830 (Snapdragon 8 Elite) | Support | +| Adreno X85 (Snapdragon X Elite) | Support | + +## DataType Supports + +| DataType | Status | +|:----------------------:|:--------------------------:| +| Q4_0 | Support | +| Q6_K | Support, but not optimized | + +## Model Preparation + +You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model prepration. + +Currently we support `Q4_0` quantization and have optimize for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize`. For example, + +```sh +./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0 +``` + +Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization. + +## CMake Options + +The OpenCL backend has the following CMake options that control the behavior of the backend. + +| CMake options | Default value | Description | +|:---------------------------------:|:--------------:|:------------------------------------------| +| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. | +| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. | + +## Android + +Ubuntu 22.04 is used for targeting Android. Make sure the following tools are accessible from command line, + +* Git +* CMake 3.29 +* Ninja +* Python3 + +### I. Setup Environment + +1. **Install NDK** + +```sh +cd ~ +wget https://dl.google.com/android/repository/commandlinetools-linux-8512546_latest.zip && \ +unzip commandlinetools-linux-8512546_latest.zip && \ +mkdir -p ~/android-sdk/cmdline-tools && \ +mv cmdline-tools latest && \ +mv latest ~/android-sdk/cmdline-tools/ && \ +rm -rf commandlinetools-linux-8512546_latest.zip + +yes | ~/android-sdk/cmdline-tools/latest/bin/sdkmanager "ndk;26.3.11579264" +``` + +2. **Install OpenCL Headers and Library** + +```sh +mkdir -p ~/dev/llm +cd ~/dev/llm + +git clone https://github.com/KhronosGroup/OpenCL-Headers && \ +cd OpenCL-Headers && \ +cp -r CL ~/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include + +cd ~/dev/llm + +git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && \ +cd OpenCL-ICD-Loader && \ +mkdir build_ndk26 && cd build_ndk26 && \ +cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$HOME/android-sdk/ndk/26.3.11579264/build/cmake/android.toolchain.cmake \ + -DOPENCL_ICD_LOADER_HEADERS_DIR=$HOME/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=24 \ + -DANDROID_STL=c++_shared && \ +ninja && \ +cp libOpenCL.so ~/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android +``` + +### II. Build llama.cpp + +```sh +cd ~/dev/llm + +git clone https://github.com/ggerganov/llama.cpp && \ +cd llama.cpp && \ +mkdir build-android && cd build-android + +cmake .. -G Ninja \ + -DCMAKE_TOOLCHAIN_FILE=$HOME/android-sdk/ndk/26.3.11579264/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DBUILD_SHARED_LIBS=OFF \ + -DGGML_OPENCL=ON + +ninja +``` + +## Windows 11 Arm64 + +A Snapdragon X Elite device with Windows 11 Arm64 is used. Make sure the following tools are accessible from command line, + +* Git +* CMake 3.29 +* Clang 19 +* Ninja +* Visual Studio 2022 + +Powershell is used for the following instructions. + +### I. Setup Environment + +1. **Install OpenCL Headers and Library** + +```powershell +mkdir -p ~/dev/llm + +cd ~/dev/llm +git clone https://github.com/KhronosGroup/OpenCL-Headers && cd OpenCL-Headers +mkdir build && cd build +cmake .. -G Ninja ` + -DBUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF ` + -DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl" +cmake --build . --target install + +cd ~/dev/llm +git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && cd OpenCL-ICD-Loader +mkdir build && cd build +cmake .. -G Ninja ` + -DCMAKE_BUILD_TYPE=Release ` + -DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" ` + -DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl" +cmake --build . --target install +``` + +### II. Build llama.cpp + +```powershell + +mkdir -p ~/dev/llm +cd ~/dev/llm + +git clone https://github.com/ggerganov/llama.cpp && cd llama.cpp +mkdir build && cd build + +cmake .. -G Ninja ` + -DCMAKE_TOOLCHAIN_FILE="$HOME/dev/llm/llama.cpp/cmake/arm64-windows-llvm.cmake" ` + -DCMAKE_BUILD_TYPE=Release ` + -DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" ` + -DBUILD_SHARED_LIBS=OFF ` + -DGGML_OPENCL=ON +ninja +``` + +## Known Issues + +- Qwen2.5 0.5B model produces gibberish output with Adreno kernels. + +## TODO + +- Fix Qwen2.5 0.5B +- Optimization for Q6_K +- Support and optimization for Q4_K diff --git a/docs/docker.md b/docs/docker.md index dac9a9ec164ff..cab5ae9572349 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -104,7 +104,7 @@ You may want to pass in some different `ARGS`, depending on the MUSA environment The defaults are: -- `MUSA_VERSION` set to `rc3.1.0` +- `MUSA_VERSION` set to `rc3.1.1` The resulting images, are essentially the same as the non-MUSA images: diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index b5f3feb9f82e6..395e2aa47c301 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -3,6 +3,7 @@ #include "log.h" #include "llama.h" +#include #include #include #include diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 4ac19ca86ec56..f518d02d38689 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -876,8 +876,8 @@ static std::vector get_cmd_params_instances(const cmd_param struct test { static const std::string build_commit; static const int build_number; - static const std::string cpu_info; - static const std::string gpu_info; + const std::string cpu_info; + const std::string gpu_info; std::string model_filename; std::string model_type; uint64_t model_size; @@ -903,7 +903,10 @@ struct test { std::string test_time; std::vector samples_ns; - test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) { + test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) : + cpu_info(get_cpu_info()), + gpu_info(get_gpu_info()) { + model_filename = inst.model; char buf[128]; llama_model_desc(lmodel, buf, sizeof(buf)); @@ -1058,8 +1061,6 @@ struct test { const std::string test::build_commit = LLAMA_COMMIT; const int test::build_number = LLAMA_BUILD_NUMBER; -const std::string test::cpu_info = get_cpu_info(); -const std::string test::gpu_info = get_gpu_info(); struct printer { virtual ~printer() {} diff --git a/examples/main/README.md b/examples/main/README.md index ceaed42f649df..ea71591bd939d 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -265,6 +265,14 @@ Being experimental and unique, XTC is disabled by default. The recommended combi Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1` +### Top-nσ Sampling + +- `--top-nsigma N`: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1, -1 = disabled). + +Top-nσ sampling is a text generation method that selects tokens based on a statistical threshold in pre-softmax logits. It works by only sampling from tokens with logits that are within n * σ of the maximum logit. This method helps maintain a stable sampling space regardless of temperature scaling, allowing it to perform well on reasoning tasks even in high temperatures. Without complex probability manipulation, it efficiently filters tokens directly on the pre-softmax logits. A higher value for top-nsigma (e.g., 5) will take more noisy tokens into consideration, while a lower value (e.g., 1) will focous on the more informative region of the sampling space. + +Example usage: `--top-nsigma 1` + ### Logit Bias - `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion. diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9bf6c57433ab2..5d07421e827d1 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -3,6 +3,7 @@ #include "log.h" #include "llama.h" +#include #include #include #include diff --git a/examples/server/README.md b/examples/server/README.md index d0b262f0e1f75..751d4db9ee9d7 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -127,6 +127,7 @@ The project is under active development, and we are [looking for feedback and co | `--grammar-file FNAME` | file to read grammar from | | `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | | `--jinja` | Enable experimental Jinja templating engine (required for tool use) | +| `--reasoning-format FORMAT` | Controls extraction of model thinking traces and the format / field in which they are returned (default: `deepseek`; allowed values: `deepseek`, `none`; requires `--jinja`). `none` will leave thinking traces inline in `message.content` in a model-specific format, while `deepseek` will return them separately under `message.reasoning_content` | **Example-specific params** @@ -1136,61 +1137,252 @@ curl http://localhost:8080/v1/chat/completions \ | Template | Format | |----------|--------| - | CohereForAI-c4ai-command-r-plus-default.jinja | generic tool calls | - | CohereForAI-c4ai-command-r-plus-rag.jinja | generic tool calls | - | CohereForAI-c4ai-command-r-plus-tool_use.jinja | generic tool calls | - | MiniMaxAI-MiniMax-Text-01.jinja | generic tool calls | - | NexaAIDev-Octopus-v2.jinja | generic tool calls | - | NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | generic tool calls | - | NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | hermes 2 pro tool calls | - | NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | generic tool calls | - | NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | hermes 2 pro tool calls | - | NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | generic tool calls | - | NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | hermes 2 pro tool calls | - | OrionStarAI-Orion-14B-Chat.jinja | generic tool calls | - | Qwen-QwQ-32B-Preview.jinja | hermes 2 pro tool calls | - | Qwen-Qwen2-7B-Instruct.jinja | generic tool calls | - | Qwen-Qwen2-VL-7B-Instruct.jinja | generic tool calls | - | Qwen-Qwen2.5-7B-Instruct.jinja | hermes 2 pro tool calls | - | Qwen-Qwen2.5-Math-7B-Instruct.jinja | hermes 2 pro tool calls | - | TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | generic tool calls | - | abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | generic tool calls | - | bofenghuang-vigogne-2-70b-chat.jinja | generic tool calls | - | databricks-dbrx-instruct.jinja | generic tool calls | - | deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | generic tool calls | - | deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | deepseek r1 tool calls | - | deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | deepseek r1 tool calls | - | deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | deepseek r1 tool calls | - | deepseek-ai-DeepSeek-V2.5.jinja | deepseek r1 tool calls | - | deepseek-ai-deepseek-coder-33b-instruct.jinja | generic tool calls | - | google-gemma-2-2b-it.jinja | generic tool calls | - | google-gemma-7b-it.jinja | generic tool calls | - | indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | generic tool calls | - | mattshumer-Reflection-Llama-3.1-70B.jinja | generic tool calls | - | meetkai-functionary-medium-v3.2.jinja | functionary v3.2 tool calls | - | meta-llama-Llama-3.1-8B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) | - | meta-llama-Llama-3.2-3B-Instruct.jinja | llama 3.x tool calls | - | meta-llama-Llama-3.3-70B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) | - | meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) | - | microsoft-Phi-3-medium-4k-instruct.jinja | generic tool calls | - | microsoft-Phi-3-mini-4k-instruct.jinja | generic tool calls | - | microsoft-Phi-3-small-8k-instruct.jinja | generic tool calls | - | microsoft-Phi-3.5-mini-instruct.jinja | generic tool calls | - | microsoft-Phi-3.5-vision-instruct.jinja | generic tool calls | - | mistralai-Mistral-7B-Instruct-v0.2.jinja | generic tool calls | - | mistralai-Mistral-Large-Instruct-2407.jinja | mistral nemo tool calls | - | mistralai-Mistral-Large-Instruct-2411.jinja | generic tool calls | - | mistralai-Mistral-Nemo-Instruct-2407.jinja | mistral nemo tool calls | - | mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | generic tool calls | - | mlabonne-AlphaMonarch-7B.jinja | generic tool calls | - | nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | llama 3.x tool calls (w/ builtin tools) | - | openchat-openchat-3.5-0106.jinja | generic tool calls | - | teknium-OpenHermes-2.5-Mistral-7B.jinja | generic tool calls | + | Almawave-Velvet-14B.jinja | Hermes 2 Pro | + | AtlaAI-Selene-1-Mini-Llama-3.1-8B.jinja | Llama 3.x | + | CohereForAI-aya-expanse-8b.jinja | Generic | + | CohereForAI-c4ai-command-r-plus-default.jinja | Generic | + | CohereForAI-c4ai-command-r-plus-rag.jinja | Generic | + | CohereForAI-c4ai-command-r-plus-tool_use.jinja | Generic | + | CohereForAI-c4ai-command-r7b-12-2024-default.jinja | Command R7B (extract reasoning) | + | CohereForAI-c4ai-command-r7b-12-2024-rag.jinja | Command R7B (extract reasoning) | + | CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja | Command R7B (extract reasoning) | + | CohereForAI-c4ai-command-r7b-12-2024.jinja | Generic | + | DavieLion-Llama-3.2-1B-SPIN-iter3.jinja | Generic | + | Delta-Vector-Rei-12B.jinja | Mistral Nemo | + | EpistemeAI-Mistral-Nemo-Instruct-12B-Philosophy-Math.jinja | Mistral Nemo | + | FlofloB-83k_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit.jinja | Hermes 2 Pro | + | FlofloB-test_continued_pretraining_Phi-3-mini-4k-instruct_Unsloth_merged_16bit.jinja | Generic | + | HelpingAI-HAI-SER.jinja | Generic | + | HuggingFaceTB-SmolLM2-1.7B-Instruct.jinja | Generic | + | HuggingFaceTB-SmolLM2-135M-Instruct.jinja | Generic | + | HuggingFaceTB-SmolLM2-360M-Instruct.jinja | Generic | + | INSAIT-Institute-BgGPT-Gemma-2-27B-IT-v1.0.jinja | Generic | + | Ihor-Text2Graph-R1-Qwen2.5-0.5b.jinja | Hermes 2 Pro | + | Infinigence-Megrez-3B-Instruct.jinja | Generic | + | Josephgflowers-TinyLlama_v1.1_math_code-world-test-1.jinja | Generic | + | LGAI-EXAONE-EXAONE-3.5-2.4B-Instruct.jinja | Generic | + | LGAI-EXAONE-EXAONE-3.5-7.8B-Instruct.jinja | Generic | + | LatitudeGames-Wayfarer-12B.jinja | Generic | + | Magpie-Align-Llama-3-8B-Magpie-Align-v0.1.jinja | Generic | + | Magpie-Align-Llama-3.1-8B-Magpie-Align-v0.1.jinja | Generic | + | MaziyarPanahi-calme-3.2-instruct-78b.jinja | Generic | + | MiniMaxAI-MiniMax-Text-01.jinja | Generic | + | MiniMaxAI-MiniMax-VL-01.jinja | Generic | + | NaniDAO-deepseek-r1-qwen-2.5-32B-ablated.jinja | DeepSeek R1 (extract reasoning) | + | NexaAIDev-Octopus-v2.jinja | Generic | + | NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | Generic | + | NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | Hermes 2 Pro | + | NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | Generic | + | NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | Hermes 2 Pro | + | NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | Generic | + | NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | Hermes 2 Pro | + | NovaSky-AI-Sky-T1-32B-Flash.jinja | Hermes 2 Pro | + | NovaSky-AI-Sky-T1-32B-Preview.jinja | Hermes 2 Pro | + | OnlyCheeini-greesychat-turbo.jinja | Generic | + | Orenguteng-Llama-3.1-8B-Lexi-Uncensored-V2.jinja | Llama 3.x | + | OrionStarAI-Orion-14B-Chat.jinja | Generic | + | PowerInfer-SmallThinker-3B-Preview.jinja | Generic | + | PrimeIntellect-INTELLECT-1-Instruct.jinja | Generic | + | Qwen-QVQ-72B-Preview.jinja | Generic | + | Qwen-QwQ-32B-Preview.jinja | Hermes 2 Pro | + | Qwen-Qwen1.5-7B-Chat.jinja | Generic | + | Qwen-Qwen2-7B-Instruct.jinja | Generic | + | Qwen-Qwen2-VL-72B-Instruct.jinja | Generic | + | Qwen-Qwen2-VL-7B-Instruct.jinja | Generic | + | Qwen-Qwen2.5-0.5B.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-1.5B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-14B-Instruct-1M.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-14B.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-32B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-32B.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-3B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-72B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-7B-Instruct-1M.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-7B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-7B.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-Coder-32B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-Coder-7B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-Math-1.5B.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-Math-7B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-VL-3B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-VL-72B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-VL-7B-Instruct.jinja | Hermes 2 Pro | + | RWKV-Red-Team-ARWKV-7B-Preview-0.1.jinja | Hermes 2 Pro | + | SakanaAI-TinySwallow-1.5B-Instruct.jinja | Hermes 2 Pro | + | SakanaAI-TinySwallow-1.5B.jinja | Hermes 2 Pro | + | Sao10K-70B-L3.3-Cirrus-x1.jinja | Llama 3.x | + | SentientAGI-Dobby-Mini-Leashed-Llama-3.1-8B.jinja | Llama 3.x | + | SentientAGI-Dobby-Mini-Unhinged-Llama-3.1-8B.jinja | Llama 3.x | + | Steelskull-L3.3-Damascus-R1.jinja | Llama 3.x | + | Steelskull-L3.3-MS-Nevoria-70b.jinja | Llama 3.x | + | Steelskull-L3.3-Nevoria-R1-70b.jinja | Llama 3.x | + | THUDM-glm-4-9b-chat.jinja | Generic | + | THUDM-glm-edge-1.5b-chat.jinja | Generic | + | Tarek07-Progenitor-V1.1-LLaMa-70B.jinja | Llama 3.x | + | TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | Generic | + | TinyLlama-TinyLlama-1.1B-Chat-v1.0.jinja | Generic | + | UCLA-AGI-Mistral7B-PairRM-SPPO-Iter3.jinja | Generic | + | ValiantLabs-Llama3.1-8B-Enigma.jinja | Llama 3.x | + | abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | Generic | + | ai21labs-AI21-Jamba-1.5-Large.jinja | Generic | + | allenai-Llama-3.1-Tulu-3-405B-SFT.jinja | Generic | + | allenai-Llama-3.1-Tulu-3-405B.jinja | Generic | + | allenai-Llama-3.1-Tulu-3-8B.jinja | Generic | + | arcee-ai-Virtuoso-Lite.jinja | Hermes 2 Pro | + | arcee-ai-Virtuoso-Medium-v2.jinja | Hermes 2 Pro | + | arcee-ai-Virtuoso-Small-v2.jinja | Hermes 2 Pro | + | avemio-GRAG-NEMO-12B-ORPO-HESSIAN-AI.jinja | Generic | + | bespokelabs-Bespoke-Stratos-7B.jinja | Hermes 2 Pro | + | bfuzzy1-acheron-m1a-llama.jinja | Generic | + | bofenghuang-vigogne-2-70b-chat.jinja | Generic | + | bytedance-research-UI-TARS-72B-DPO.jinja | Generic | + | bytedance-research-UI-TARS-7B-DPO.jinja | Generic | + | bytedance-research-UI-TARS-7B-SFT.jinja | Generic | + | carsenk-phi3.5_mini_exp_825_uncensored.jinja | Generic | + | cyberagent-DeepSeek-R1-Distill-Qwen-14B-Japanese.jinja | DeepSeek R1 (extract reasoning) | + | cyberagent-DeepSeek-R1-Distill-Qwen-32B-Japanese.jinja | DeepSeek R1 (extract reasoning) | + | databricks-dbrx-instruct.jinja | Generic | + | deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | Generic | + | deepseek-ai-DeepSeek-Coder-V2-Lite-Base.jinja | Generic | + | deepseek-ai-DeepSeek-Coder-V2-Lite-Instruct.jinja | Generic | + | deepseek-ai-DeepSeek-R1-Distill-Llama-70B.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1-Distill-Qwen-1.5B.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1-Distill-Qwen-14B.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1-Zero.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-V2-Lite.jinja | Generic | + | deepseek-ai-DeepSeek-V2.5.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-V3.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-deepseek-coder-33b-instruct.jinja | Generic | + | deepseek-ai-deepseek-coder-6.7b-instruct.jinja | Generic | + | deepseek-ai-deepseek-coder-7b-instruct-v1.5.jinja | Generic | + | deepseek-ai-deepseek-llm-67b-chat.jinja | Generic | + | deepseek-ai-deepseek-llm-7b-chat.jinja | Generic | + | dicta-il-dictalm2.0-instruct.jinja | Generic | + | ehristoforu-Falcon3-8B-Franken-Basestruct.jinja | Hermes 2 Pro | + | fireworks-ai-llama-3-firefunction-v2.jinja | FireFunction v2 | + | godlikehhd-alpaca_data_sampled_ifd_new_5200.jinja | Hermes 2 Pro | + | godlikehhd-alpaca_data_score_max_0.7_2600.jinja | Hermes 2 Pro | + | google-gemma-2-27b-it.jinja | Generic | + | google-gemma-2-2b-it.jinja | Generic | + | google-gemma-2-2b-jpn-it.jinja | Generic | + | google-gemma-7b-it.jinja | Generic | + | huihui-ai-DeepSeek-R1-Distill-Llama-70B-abliterated.jinja | DeepSeek R1 (extract reasoning) | + | huihui-ai-DeepSeek-R1-Distill-Llama-8B-abliterated.jinja | DeepSeek R1 (extract reasoning) | + | huihui-ai-DeepSeek-R1-Distill-Qwen-14B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) | + | huihui-ai-DeepSeek-R1-Distill-Qwen-32B-abliterated.jinja | DeepSeek R1 (extract reasoning) | + | huihui-ai-DeepSeek-R1-Distill-Qwen-7B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) | + | huihui-ai-Qwen2.5-14B-Instruct-1M-abliterated.jinja | Hermes 2 Pro | + | ibm-granite-granite-3.1-8b-instruct.jinja | Generic | + | indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | Generic | + | inflatebot-MN-12B-Mag-Mell-R1.jinja | Generic | + | jinaai-ReaderLM-v2.jinja | Generic | + | kms7530-chemeng_qwen-math-7b_24_1_100_1_nonmath.jinja | Hermes 2 Pro | + | knifeayumu-Cydonia-v1.3-Magnum-v4-22B.jinja | Mistral Nemo | + | langgptai-qwen1.5-7b-chat-sa-v0.1.jinja | Generic | + | lightblue-DeepSeek-R1-Distill-Qwen-7B-Japanese.jinja | DeepSeek R1 (extract reasoning) | + | mattshumer-Reflection-Llama-3.1-70B.jinja | Generic | + | meetkai-functionary-medium-v3.1.jinja | Functionary v3.1 Llama 3.1 | + | meetkai-functionary-medium-v3.2.jinja | Functionary v3.2 | + | meta-llama-Llama-2-7b-chat-hf.jinja | Generic | + | meta-llama-Llama-3.1-8B-Instruct.jinja | Llama 3.x | + | meta-llama-Llama-3.2-11B-Vision-Instruct.jinja | Llama 3.x | + | meta-llama-Llama-3.2-1B-Instruct.jinja | Llama 3.x | + | meta-llama-Llama-3.2-3B-Instruct.jinja | Llama 3.x | + | meta-llama-Llama-3.3-70B-Instruct.jinja | Llama 3.x | + | meta-llama-Meta-Llama-3-8B-Instruct.jinja | Generic | + | meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | Llama 3.x | + | microsoft-Phi-3-medium-4k-instruct.jinja | Generic | + | microsoft-Phi-3-mini-4k-instruct.jinja | Generic | + | microsoft-Phi-3-small-8k-instruct.jinja | Generic | + | microsoft-Phi-3.5-mini-instruct.jinja | Generic | + | microsoft-Phi-3.5-vision-instruct.jinja | Generic | + | microsoft-phi-4.jinja | Generic | + | migtissera-Tess-3-Mistral-Nemo-12B.jinja | Generic | + | ministral-Ministral-3b-instruct.jinja | Generic | + | mistralai-Codestral-22B-v0.1.jinja | Generic | + | mistralai-Mistral-7B-Instruct-v0.1.jinja | Generic | + | mistralai-Mistral-7B-Instruct-v0.2.jinja | Generic | + | mistralai-Mistral-7B-Instruct-v0.3.jinja | Mistral Nemo | + | mistralai-Mistral-Large-Instruct-2407.jinja | Mistral Nemo | + | mistralai-Mistral-Large-Instruct-2411.jinja | Generic | + | mistralai-Mistral-Nemo-Instruct-2407.jinja | Mistral Nemo | + | mistralai-Mistral-Small-24B-Instruct-2501.jinja | Generic | + | mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | Generic | + | mkurman-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | + | mlabonne-AlphaMonarch-7B.jinja | Generic | + | mlx-community-Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32.jinja | Hermes 2 Pro | + | mlx-community-Qwen2.5-VL-7B-Instruct-8bit.jinja | Hermes 2 Pro | + | mobiuslabsgmbh-DeepSeek-R1-ReDistill-Qwen-1.5B-v1.1.jinja | DeepSeek R1 (extract reasoning) | + | netcat420-MFANNv0.20.jinja | Generic | + | netcat420-MFANNv0.24.jinja | Generic | + | netease-youdao-Confucius-o1-14B.jinja | Hermes 2 Pro | + | nvidia-AceMath-7B-RM.jinja | Hermes 2 Pro | + | nvidia-Eagle2-1B.jinja | Hermes 2 Pro | + | nvidia-Eagle2-9B.jinja | Hermes 2 Pro | + | nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | Llama 3.x | + | onnx-community-DeepSeek-R1-Distill-Qwen-1.5B-ONNX.jinja | DeepSeek R1 (extract reasoning) | + | open-thoughts-OpenThinker-7B.jinja | Hermes 2 Pro | + | openchat-openchat-3.5-0106.jinja | Generic | + | pankajmathur-orca_mini_v6_8b.jinja | Generic | + | princeton-nlp-Mistral-7B-Base-SFT-RDPO.jinja | Generic | + | princeton-nlp-Mistral-7B-Instruct-DPO.jinja | Generic | + | princeton-nlp-Mistral-7B-Instruct-RDPO.jinja | Generic | + | prithivMLmods-Bellatrix-Tiny-1.5B-R1.jinja | Hermes 2 Pro | + | prithivMLmods-Bellatrix-Tiny-1B-R1.jinja | Llama 3.x | + | prithivMLmods-Bellatrix-Tiny-1B-v3.jinja | Generic | + | prithivMLmods-Bellatrix-Tiny-3B-R1.jinja | Llama 3.x | + | prithivMLmods-Blaze-14B-xElite.jinja | Generic | + | prithivMLmods-Calcium-Opus-14B-Elite2-R1.jinja | Hermes 2 Pro | + | prithivMLmods-Calme-Ties-78B.jinja | Generic | + | prithivMLmods-Calme-Ties2-78B.jinja | Generic | + | prithivMLmods-Calme-Ties3-78B.jinja | Generic | + | prithivMLmods-ChemQwen2-vL.jinja | Generic | + | prithivMLmods-GWQ2b.jinja | Generic | + | prithivMLmods-LatexMind-2B-Codec.jinja | Generic | + | prithivMLmods-Llama-3.2-6B-AlgoCode.jinja | Llama 3.x | + | prithivMLmods-Megatron-Opus-14B-Exp.jinja | Hermes 2 Pro | + | prithivMLmods-Megatron-Opus-14B-Stock.jinja | Hermes 2 Pro | + | prithivMLmods-Megatron-Opus-7B-Exp.jinja | Hermes 2 Pro | + | prithivMLmods-Omni-Reasoner-Merged.jinja | Hermes 2 Pro | + | prithivMLmods-Omni-Reasoner4-Merged.jinja | Hermes 2 Pro | + | prithivMLmods-Primal-Opus-14B-Optimus-v1.jinja | Hermes 2 Pro | + | prithivMLmods-QwQ-Math-IO-500M.jinja | Hermes 2 Pro | + | prithivMLmods-Qwen-7B-Distill-Reasoner.jinja | DeepSeek R1 (extract reasoning) | + | prithivMLmods-Qwen2.5-1.5B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro | + | prithivMLmods-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | + | prithivMLmods-Qwen2.5-32B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro | + | prithivMLmods-Qwen2.5-7B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | + | prithivMLmods-Triangulum-v2-10B.jinja | Hermes 2 Pro | + | qingy2024-Falcon3-2x10B-MoE-Instruct.jinja | Hermes 2 Pro | + | rubenroy-Zurich-14B-GCv2-5m.jinja | Hermes 2 Pro | + | rubenroy-Zurich-7B-GCv2-5m.jinja | Hermes 2 Pro | + | silma-ai-SILMA-Kashif-2B-Instruct-v1.0.jinja | Generic | + | simplescaling-s1-32B.jinja | Hermes 2 Pro | + | sometimesanotion-Lamarck-14B-v0.7.jinja | Hermes 2 Pro | + | sonthenguyen-zephyr-sft-bnb-4bit-DPO-mtbr-180steps.jinja | Generic | + | sthenno-tempesthenno-icy-0130.jinja | Generic | + | sumink-qwft.jinja | Hermes 2 Pro | + | teknium-OpenHermes-2.5-Mistral-7B.jinja | Generic | + | thirdeyeai-elevate360m.jinja | Generic | + | tiiuae-Falcon3-10B-Instruct.jinja | Hermes 2 Pro | + | unsloth-DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit.jinja | DeepSeek R1 (extract reasoning) | + | unsloth-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) | + | unsloth-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) | + | unsloth-Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit.jinja | Generic | + | upstage-solar-pro-preview-instruct.jinja | Generic | + | whyhow-ai-PatientSeek.jinja | Generic | + | xwen-team-Xwen-72B-Chat.jinja | Hermes 2 Pro | + | xwen-team-Xwen-7B-Chat.jinja | Hermes 2 Pro | This table can be generated with: ```bash ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null + ``` @@ -1202,11 +1394,20 @@ curl http://localhost:8080/v1/chat/completions \ ```shell # Native support: + llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M + # Native support for DeepSeek R1 works best w/ our own template (official template buggy) + + llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \ + --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja + + llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M \ + --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja + # Native support requires the right template for these GGUFs: llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \ @@ -1236,17 +1437,17 @@ curl http://localhost:8080/v1/chat/completions \ { "type":"function", "function":{ - "name":"get_current_weather", - "description":"Get the current weather in a given location", + "name":"python", + "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", "parameters":{ "type":"object", "properties":{ - "location":{ + "code":{ "type":"string", - "description":"The city and state, e.g. San Francisco, CA" + "description":"The code to run in the ipython interpreter." } }, - "required":["location"] + "required":["code"] } } } @@ -1254,7 +1455,7 @@ curl http://localhost:8080/v1/chat/completions \ "messages": [ { "role": "user", - "content": "What is the weather like in Istanbul?." + "content": "Print a hello world message with python." } ] }' diff --git a/examples/server/public/index.html.gz b/examples/server/public/index.html.gz index 9311410e1c815..1925b334b463f 100644 Binary files a/examples/server/public/index.html.gz and b/examples/server/public/index.html.gz differ diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5da3e187f740a..71151183b81da 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -173,6 +173,7 @@ struct slot_params { {"grammar_trigger_words", grammar_trigger_words}, {"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, {"preserved_tokens", sampling.preserved_tokens}, + {"chat_format", common_chat_format_name(oaicompat_chat_format)}, {"samplers", samplers}, {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, @@ -724,9 +725,19 @@ struct server_task_result_cmpl_final : server_task_result { msg.content = content; } - json tool_calls; + json message { + {"role", "assistant"}, + }; + if (!msg.reasoning_content.empty()) { + message["reasoning_content"] = msg.reasoning_content; + } + if (msg.content.empty() && !msg.tool_calls.empty()) { + message["content"] = json(); + } else { + message["content"] = msg.content; + } if (!msg.tool_calls.empty()) { - tool_calls = json::array(); + auto tool_calls = json::array(); for (const auto & tc : msg.tool_calls) { tool_calls.push_back({ {"type", "function"}, @@ -737,15 +748,7 @@ struct server_task_result_cmpl_final : server_task_result { {"id", tc.id}, }); } - } - - json message { - {"content", msg.content}, - {"tool_calls", tool_calls}, - {"role", "assistant"}, - }; - if (!msg.tool_plan.empty()) { - message["tool_plan"] = msg.tool_plan; + message["tool_calls"] = tool_calls; } json choice { @@ -2073,8 +2076,8 @@ struct server_context { if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { // Might be better to reject the request with a 400 ? + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict); slot.params.n_predict = slot.n_predict; - SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict); } if (slot.params.ignore_eos && has_eos_token) { @@ -2279,7 +2282,7 @@ struct server_context { for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { result.probs.push_back({ cur_p->data[i].id, - common_detokenize(ctx, {cur_p->data[i].id}, special), + common_token_to_piece(ctx, cur_p->data[i].id, special), cur_p->data[i].p }); } @@ -2301,7 +2304,7 @@ struct server_context { for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { result.probs.push_back({ cur[i].id, - common_detokenize(ctx, {cur[i].id}, special), + common_token_to_piece(ctx, cur[i].id, special), cur[i].p }); } @@ -4060,7 +4063,7 @@ int main(int argc, char ** argv) { } auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates); + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, @@ -4073,7 +4076,7 @@ int main(int argc, char ** argv) { // same with handle_chat_completions, but without inference part const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates); + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index 4a551404f22a9..ba3367b4f332d 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -92,6 +92,7 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] + assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] assert expected_function_name == tool_call["function"]["name"] actual_arguments = tool_call["function"]["arguments"] @@ -155,11 +156,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + # (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), - (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + # (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), @@ -175,7 +176,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), - (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), + # (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), # TODO: fix these # (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), # (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), @@ -214,6 +215,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] + assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] assert expected_function_name == tool_call["function"]["name"] actual_arguments = tool_call["function"]["arguments"] @@ -273,7 +275,6 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t @pytest.mark.slow @pytest.mark.parametrize("hf_repo,template_override", [ - ("bartowski/c4ai-command-r7b-12-2024-GGUF:Q4_K_M", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")), ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), @@ -298,13 +299,16 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")), + + ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), - # ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ]) -def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None): +def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server n_predict = 512 server.n_slots = 1 @@ -323,6 +327,7 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None) res = server.make_request("POST", "/chat/completions", data={ "max_tokens": n_predict, "messages": [ + {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."}, {"role": "user", "content": "What is the weather in Istanbul?"}, ], "tools": [WEATHER_TOOL], @@ -332,6 +337,7 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None) tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] + assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"] actual_arguments = json.loads(tool_call["function"]["arguments"]) assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" @@ -340,22 +346,166 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None) assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' +@pytest.mark.slow +@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [ + (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + ("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + + # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value) + ("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + ("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), +]) +def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None): + global server + # n_predict = 512 + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 * 2 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if isinstance(template_override, tuple): + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things, and provide very concise answers. Do not explain your reasoning to the user. Provide any numerical values back to the user with at most two decimals."}, + {"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_6789", + "type": "function", + "function": { + "name": "calculate", + "arguments": "{\"expression\":\"sin(30 * pi / 180)\"}" + } + } + ] + }, + { + "role": "tool", + "name": "calculate", + "content": 0.55644242476, + "tool_call_id": "call_6789" + } + ], + "tools": [ + { + "type":"function", + "function":{ + "name":"calculate", + "description":"A calculator function that computes values of arithmetic expressions in the Python syntax", + "parameters":{ + "type":"object", + "properties":{ + "expression":{ + "type":"string", + "description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)" + } + }, + "required":["expression"] + } + } + } + ] + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls is None, f'Expected no tool call in {choice["message"]}' + content = choice["message"].get("content") + assert content is not None, f'Expected content in {choice["message"]}' + if result_override is not None: + assert re.match(result_override, content), f'Expected {result_override}, got {content}' + else: + assert re.match('^[\\s\\S]*?The (y[ -])?coordinate [\\s\\S]*?is (approximately )?0\\.56\\b|^0\\.56$', content), \ + f'Expected something like "The y coordinate is 0.56.", got {content}' + + +@pytest.mark.slow +@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [ + (128, 'deepseek', "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + + (1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (1024, 'none', "\n?I need[\\s\\S]*?\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + + (1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), +]) +def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): + global server + server.n_slots = 1 + server.reasoning_format = reasoning_format + server.jinja = True + server.n_ctx = 8192 * 2 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if isinstance(template_override, tuple): + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "user", "content": "What's the sum of 102 and 7?"}, + ] + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' + + content = choice["message"].get("content") + if expect_content is None: + assert content is None, f'Expected no content in {choice["message"]}' + else: + assert re.match(expect_content, content), f'Expected {expect_content}, got {content}' + + reasoning_content = choice["message"].get("reasoning_content") + if expect_reasoning_content is None: + assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}' + else: + assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}' + + @pytest.mark.slow @pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ + (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"), + (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), - (None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), - ('{"code":"print("}', "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), - ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + (None, "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), @@ -371,15 +521,13 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None) # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. (None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), - - # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ]) -def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): +def test_hello_world(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server server.n_slots = 1 server.jinja = True server.n_ctx = 8192 - server.n_predict = 128 + server.n_predict = 512 # High because of DeepSeek R1 server.model_hf_repo = hf_repo server.model_hf_file = None if isinstance(template_override, tuple): @@ -406,6 +554,7 @@ def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] + assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] actual_arguments = tool_call["function"]["arguments"] if expected_arguments_override is not None: diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index ce06806620c0b..a82504235ff54 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -78,6 +78,7 @@ class ServerProcess: draft_max: int | None = None no_webui: bool | None = None jinja: bool | None = None + reasoning_format: Literal['deepseek', 'none'] | None = None chat_template: str | None = None chat_template_file: str | None = None @@ -172,6 +173,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.append("--no-webui") if self.jinja: server_args.append("--jinja") + if self.reasoning_format is not None: + server_args.extend(("--reasoning-format", self.reasoning_format)) if self.chat_template: server_args.extend(["--chat-template", self.chat_template]) if self.chat_template_file: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 5f97df5fde639..86de0e6d78977 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -578,6 +578,7 @@ static json oaicompat_completion_params_parse(const json & body) { static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ bool use_jinja, + common_reasoning_format reasoning_format, const common_chat_templates & chat_templates) { json llama_params; @@ -633,9 +634,10 @@ static json oaicompat_completion_params_parse( throw std::runtime_error("Cannot use custom grammar constraints with tools."); } common_chat_inputs inputs; - inputs.messages = body.at("messages"); - inputs.tools = tools; - inputs.tool_choice = tool_choice; + inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; + inputs.messages = body.at("messages"); + inputs.tools = tools; + inputs.tool_choice = tool_choice; inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); diff --git a/examples/server/webui/src/components/ChatMessage.tsx b/examples/server/webui/src/components/ChatMessage.tsx index 2ffe08b371e78..68be7c751db54 100644 --- a/examples/server/webui/src/components/ChatMessage.tsx +++ b/examples/server/webui/src/components/ChatMessage.tsx @@ -254,12 +254,12 @@ export default function ChatMessage({ 🔄 Regenerate )} - )} + )} diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5bd8d9c8b5023..dd0c6a96eaee8 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -198,7 +198,7 @@ #ifndef __GNUC__ # define GGML_ATTRIBUTE_FORMAT(...) -#elif defined(__MINGW32__) +#elif defined(__MINGW32__) && !defined(__clang__) # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) #else # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 27ec149356543..1b4bd66e80c68 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -742,7 +742,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); } } -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ for (int i = 0; i < nb; i++) { v128_t srcv [8]; v128_t asrcv[8]; @@ -1030,7 +1030,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); } -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ for (int i = 0; i < nb; i++) { v128_t srcv [8]; v128_t asrcv[8]; @@ -1644,7 +1644,87 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1 //===================================== Q8_K ============================================== void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { +#ifdef __wasm_simd128__ + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + block_q8_K * restrict yc = y; // Cast to proper type + + for (int i = 0; i < nb; i++) { + const float * x_block = x + i * QK_K; + + v128_t min_vec = wasm_v128_load(x_block); + v128_t max_vec = min_vec; + + for (int j = 4; j < QK_K; j += 4) { + v128_t x_vec = wasm_v128_load(x_block + j); + max_vec = wasm_f32x4_pmax(max_vec, x_vec); + min_vec = wasm_f32x4_pmin(min_vec, x_vec); + } + max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1)); + max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2)); + min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1)); + min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2)); + float max = wasm_f32x4_extract_lane(max_vec, 0); + float min = wasm_f32x4_extract_lane(min_vec, 0); + float amax = -min > max ? min : max; + + if (amax == 0.0f) { + yc[i].d = 0.0f; + const v128_t zero = wasm_i8x16_splat(0); + for (int j = 0; j < QK_K; j += 16) { + wasm_v128_store(yc[i].qs + j, zero); + } + continue; + } + + const float iscale = -127.0f / amax; + const v128_t scale_vec = wasm_f32x4_splat(iscale); + + // Process 16 elements per iteration + for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) { + // Load and quantize 16 floats + v128_t x0 = wasm_v128_load(x_block + j); + v128_t x1 = wasm_v128_load(x_block + j + 4); + v128_t x2 = wasm_v128_load(x_block + j + 8); + v128_t x3 = wasm_v128_load(x_block + j + 12); + + v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec)); + v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec)); + v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec)); + v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec)); + + // Convert to i32 with saturation + v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0); + v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1); + v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2); + v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3); + + // Pack into 16 i8 values + v128_t i8 = wasm_i8x16_narrow_i16x8( + wasm_i16x8_narrow_i32x4(i0, i1), + wasm_i16x8_narrow_i32x4(i2, i3) + ); + wasm_v128_store(yc[i].qs + j, i8); + + // Calculate bsums using SIMD + v128_t sum16 = wasm_i16x8_add( + wasm_i16x8_extend_low_i8x16(i8), + wasm_i16x8_extend_high_i8x16(i8) + ); + v128_t sum32 = wasm_i32x4_add( + wasm_i32x4_extend_low_i16x8(sum16), + wasm_i32x4_extend_high_i16x8(sum16) + ); + sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1)); + sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2)); + yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0); + } + + yc[i].d = 1.0f / iscale; + } +#else quantize_row_q8_K_ref(x, y, k); +#endif } //===================================== Dot products ================================= @@ -2002,6 +2082,94 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + + const v128_t m4b = wasm_i8x16_splat(0x0F); + const v128_t s8b = wasm_i8x16_splat(0x8); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // Load and process x0 + v128_t v0_0 = wasm_v128_load(x0->qs); + v128_t v0_0l = wasm_v128_and(v0_0, m4b); + v128_t v0_0h = wasm_u8x16_shr(v0_0, 4); + v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b); + v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b); + + // Load y0 vectors + v128_t y0_l = wasm_v128_load(y0->qs); + v128_t y0_h = wasm_v128_load(y0->qs + 16); + + // Extend to i16x8 and compute dot products + v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls); + v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls); + v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs); + v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs); + + v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l); + v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l); + v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h); + v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h); + + v128_t dp0 = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx0l, dy0ll), + wasm_i32x4_dot_i16x8(dx0h, dy0lh) + ), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx0hl, dy0hl), + wasm_i32x4_dot_i16x8(dx0hh, dy0hh) + ) + ); + + // Load and process x1 + v128_t v0_1 = wasm_v128_load(x1->qs); + v128_t v0_1l = wasm_v128_and(v0_1, m4b); + v128_t v0_1h = wasm_u8x16_shr(v0_1, 4); + v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b); + v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b); + + // Load y1 vectors + v128_t y1_l = wasm_v128_load(y1->qs); + v128_t y1_h = wasm_v128_load(y1->qs + 16); + + // Extend to i16x8 and compute dot products + v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls); + v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls); + v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs); + v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs); + + v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l); + v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l); + v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h); + v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h); + + v128_t dp1 = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx1l, dy1ll), + wasm_i32x4_dot_i16x8(dx1h, dy1lh) + ), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx1hl, dy1hl), + wasm_i32x4_dot_i16x8(dx1hh, dy1hh) + ) + ); + + // Accumulate results with scaling + float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); + float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d); + + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0))); + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2688,10 +2856,10 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ v128_t sumv = wasm_f32x4_splat(0.0f); - uint32_t qh; + uint32_t qh_; uint64_t tmp[4]; // TODO: check if unrolling this is better @@ -2702,12 +2870,12 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r const v128_t m4b = wasm_i8x16_splat(0x0F); // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); + memcpy(&qh_, x0->qh, sizeof(qh_)); - tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_1[(qh >> 24) ]; + tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh_ >> 24) ]; const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhh = wasm_v128_load(tmp + 2); @@ -3049,12 +3217,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ v128_t sumv = wasm_f32x4_splat(0.0f); float summs = 0.0f; - uint32_t qh; + uint32_t qh_; uint64_t tmp[4]; // TODO: check if unrolling this is better @@ -3067,12 +3235,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r const v128_t m4b = wasm_i8x16_splat(0x0F); // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); + memcpy(&qh_, x0->qh, sizeof(qh_)); - tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_0[(qh >> 24) ]; + tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh_ >> 24) ]; const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhh = wasm_v128_load(tmp + 2); @@ -3565,6 +3733,45 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + + for (; ib < nb; ++ib) { + const block_q8_0 * restrict x0 = &x[ib]; + const block_q8_0 * restrict y0 = &y[ib]; + + const v128_t x0_0 = wasm_v128_load(x0->qs); + const v128_t x0_1 = wasm_v128_load(x0->qs + 16); + const v128_t y0_0 = wasm_v128_load(y0->qs); + const v128_t y0_1 = wasm_v128_load(y0->qs + 16); + + // Extend 8-bit to 16-bit + const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0); + const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0); + const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1); + const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1); + + const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0); + const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0); + const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1); + const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1); + + // Compute dot products + const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l); + const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h); + const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l); + const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h); + + // Sum all dot products + const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1)); + + // Convert to float and accumulate + const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -4439,6 +4646,106 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc); +#elif defined __wasm_simd128__ + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + // Vectorized summs calculation + v128_t summs_vec = wasm_i32x4_splat(0); + { + v128_t sc_vec = wasm_v128_load(sc); + v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4); + + v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper); + v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper); + + v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]); + v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]); + + summs_vec = wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1), + wasm_i32x4_dot_i16x8(sc_high, bsums2)), + summs_vec + ); + + summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1)); + summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2)); + } + int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0); + + // Vectorized isum calculation + int32_t isum = 0; + const uint8_t * sc_ptr = sc; + const int k_iters = QK_K/128; + + for (int k = 0; k < k_iters; ++k) { + v128_t isum_vec = wasm_i32x4_splat(0); + int shift = 0; + + for (int j = 0; j < 4; ++j) { + const int d0 = (sc_ptr[0] & 0xF); + const int d1 = (sc_ptr[1] & 0xF); + sc_ptr += 2; + + // Process first 16 elements + v128_t q2_0 = wasm_v128_load(q2); + v128_t q8_0 = wasm_v128_load(q8); + v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift); + v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03)); + + // Process next 16 elements + v128_t q2_1 = wasm_v128_load(q2 + 16); + v128_t q8_1 = wasm_v128_load(q8 + 16); + v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift); + v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03)); + + // Calculate dot products + v128_t p0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q8_0), + wasm_i16x8_extend_low_i8x16(q2_bits_0) + ); + v128_t p1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q8_0), + wasm_i16x8_extend_high_i8x16(q2_bits_0) + ); + v128_t p2 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q8_1), + wasm_i16x8_extend_low_i8x16(q2_bits_1) + ); + v128_t p3 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q8_1), + wasm_i16x8_extend_high_i8x16(q2_bits_1) + ); + + // Accumulate scaled results + v128_t scaled = wasm_i32x4_add( + wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)), + wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1)) + ); + + isum_vec = wasm_i32x4_add(isum_vec, scaled); + q8 += 32; + shift += 2; + } + q2 += 32; + + // Horizontal sum of isum_vec + isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1)); + isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2)); + isum += wasm_i32x4_extract_lane(isum_vec, 0); + } + + const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf += dall * isum - dmin * summs; + } + + *s = sumf; + #elif defined __riscv_v_intrinsic float sumf = 0; @@ -5121,6 +5428,94 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc); +#elif defined __wasm_simd128__ + int8_t aux8[QK_K]; + float sums[8] = {0}; + uint32_t auxs[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + // Process blocks with SIMD + int8_t * a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int shift = 0; shift <= 6; shift += 2) { + v128_t v_m = wasm_i8x16_splat(m); + for (int l = 0; l < 32; l += 16) { + v128_t v_q3 = wasm_v128_load(q3 + l); + v128_t v_shift = wasm_i8x16_shr(v_q3, shift); + v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03)); + + v128_t v_hm = wasm_v128_load(hm + l); + v128_t v_mask = wasm_v128_and(v_hm, v_m); + v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0)); + + v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask))); + wasm_v128_store(a + l, v_low2); + } + a += 32; + m <<= 1; + } + q3 += 32; + } + + // Extract scales + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + const int8_t * scales = (const int8_t *)auxs; + + // SIMD dot product with register accumulators + v128_t v_acc0 = wasm_i32x4_splat(0); + v128_t v_acc1 = wasm_i32x4_splat(0); + a = aux8; + for (int j = 0; j < QK_K/16; ++j) { + const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32); + + // Process 16 elements per iteration + for (int k = 0; k < 2; ++k) { + const v128_t v_q8 = wasm_i16x8_load8x8(q8); + const v128_t v_a = wasm_i16x8_load8x8(a); + + v128_t v_prod = wasm_i16x8_mul(v_q8, v_a); + v_prod = wasm_i16x8_mul(v_prod, v_scale); + + v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod)); + v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod)); + + q8 += 8; + a += 8; + } + } + + // Accumulate results + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const v128_t v_d = wasm_f32x4_splat(d); + v128_t v_sum = wasm_f32x4_add( + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d), + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d) + ); + + // Accumulate into sums vector + wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum)); + } + + // Horizontal sum + v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4)); + sumf = wasm_f32x4_extract_lane(v_sum, 0) + + wasm_f32x4_extract_lane(v_sum, 1) + + wasm_f32x4_extract_lane(v_sum, 2) + + wasm_f32x4_extract_lane(v_sum, 3); + + *s = sumf; + #elif defined __riscv_v_intrinsic uint32_t aux[3]; @@ -5646,7 +6041,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r } } *s = sumf; -#elif __ARM_NEON +#elif defined __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); const int32x4_t mzero = vdupq_n_s32(0); @@ -5709,6 +6104,107 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined __wasm_simd128__ + const uint8_t * scales = (const uint8_t*)&utmp[0]; + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Process scales and mins + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + // Sum mins * q8sums + int32_t sumi = 0; + const int16_t * restrict q8sums = y[i].bsums; + const uint8_t * m = (const uint8_t *)&utmp[2]; + for (int j = 0; j < 16; j += 2) { + sumi += (q8sums[j] + q8sums[j+1]) * m[j/2]; + } + sumf -= dmin * sumi; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + // Load 64 4-bit weights (32 bytes) + const v128_t q4x0 = wasm_v128_load(q4); + const v128_t q4x1 = wasm_v128_load(q4 + 16); + q4 += 32; + + // Split into low/high nibbles + const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F)); + const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4); + const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F)); + const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4); + + // Load 64 8-bit values (64 bytes) + const v128_t q8x0 = wasm_v128_load(q8); + const v128_t q8x1 = wasm_v128_load(q8 + 16); + const v128_t q8x2 = wasm_v128_load(q8 + 32); + const v128_t q8x3 = wasm_v128_load(q8 + 48); + q8 += 64; + + // Low nibble products + v128_t vacc1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4l0), + wasm_i16x8_extend_low_i8x16(q8x0) + ); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4l0), + wasm_i16x8_extend_high_i8x16(q8x0) + )); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4l1), + wasm_i16x8_extend_low_i8x16(q8x1) + )); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4l1), + wasm_i16x8_extend_high_i8x16(q8x1) + )); + + // High nibble products + v128_t vacc2 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4h0), + wasm_i16x8_extend_low_i8x16(q8x2) + ); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4h0), + wasm_i16x8_extend_high_i8x16(q8x2) + )); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4h1), + wasm_i16x8_extend_low_i8x16(q8x3) + )); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4h1), + wasm_i16x8_extend_high_i8x16(q8x3) + )); + + // Accumulate scaled results + int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) + + wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3); + sumi1 += vacc1_sum * scales[2*j]; + + int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) + + wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3); + sumi2 += vacc2_sum * scales[2*j+1]; + } + + sumf += d * (sumi1 + sumi2); + } + + *s = sumf; + #elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); @@ -6459,6 +6955,118 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc) + summs; +#elif defined __wasm_simd128__ + //const uint8_t * scales = (const uint8_t*)&utmp[0]; + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + // Process scales and mins + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + // Sum mins * q8sums + int32_t sumi_mins = 0; + const int16_t * restrict q8sums = y[i].bsums; + const uint8_t * m = (const uint8_t *)&utmp[2]; + for (int j = 0; j < 16; j += 2) { + sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2]; + } + sumf -= dmin * sumi_mins; // Correct subtraction + + v128_t qh0 = wasm_v128_load(qh); + v128_t qh1 = wasm_v128_load(qh + 16); + const uint8_t * sc = (const uint8_t *)utmp; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + const int shift = j * 2; + v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift); + v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift); + + v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4); + v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3); + v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4); + v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3); + + v128_t q5_0 = wasm_v128_load(q5); + v128_t q5_1 = wasm_v128_load(q5 + 16); + q5 += 32; + + v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0); + v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0); + v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1); + v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1); + + v128_t q8_0 = wasm_v128_load(q8); + v128_t q8_1 = wasm_v128_load(q8 + 16); + v128_t q8_2 = wasm_v128_load(q8 + 32); + v128_t q8_3 = wasm_v128_load(q8 + 48); + q8 += 64; + + // Process low quants + v128_t pl0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5l_0), + wasm_i16x8_extend_low_i8x16(q8_0) + ); + pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5l_0), + wasm_i16x8_extend_high_i8x16(q8_0) + )); + v128_t pl1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5l_1), + wasm_i16x8_extend_low_i8x16(q8_1) + ); + pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5l_1), + wasm_i16x8_extend_high_i8x16(q8_1) + )); + v128_t sum_low = wasm_i32x4_add(pl0, pl1); + + // Process high quants + v128_t ph0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5h_0), + wasm_i16x8_extend_low_i8x16(q8_2) + ); + ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5h_0), + wasm_i16x8_extend_high_i8x16(q8_2) + )); + v128_t ph1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5h_1), + wasm_i16x8_extend_low_i8x16(q8_3) + ); + ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5h_1), + wasm_i16x8_extend_high_i8x16(q8_3) + )); + v128_t sum_high = wasm_i32x4_add(ph0, ph1); + + // Accumulate with scale factors + int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) + + wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3); + int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) + + wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3); + + sumi += sl * sc[2*j] + sh * sc[2*j+1]; + } + + sumf += d * sumi; + } + + *s = sumf; + #elif defined __riscv_v_intrinsic const uint8_t * scales = (const uint8_t*)&utmp[0]; @@ -7122,6 +7730,85 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc); +#elif defined __wasm_simd128__ + int8_t aux8[QK_K] __attribute__((aligned(16))); + int32_t aux32[8] __attribute__((aligned(16))) = {0}; + float sums[8] __attribute__((aligned(16))) = {0}; + + for (int i = 0; i < nb; ++i) { + // Unpack 6-bit quantized data into aux8 (unchanged) + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + int8_t * a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + + const int8_t * restrict a_ptr = aux8; + const int8_t * restrict q8 = y[i].qs; + v128_t acc0 = wasm_i32x4_splat(0); + v128_t acc1 = wasm_i32x4_splat(0); + + for (int j = 0; j < QK_K/16; ++j) { + const int scale = x[i].scales[j]; + const v128_t vscale = wasm_i32x4_splat(scale); + + // Load 16 elements from a and q8 + const v128_t a_vec = wasm_v128_load(a_ptr); + const v128_t q8_vec = wasm_v128_load(q8); + + // Process low 8 elements + v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec); + v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec); + v128_t prod_low = wasm_i16x8_mul(a_low, q8_low); + v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low); + v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low); + + // Process high 8 elements + v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec); + v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec); + v128_t prod_high = wasm_i16x8_mul(a_high, q8_high); + v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high); + v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high); + + // Scale and accumulate + prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale); + prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale); + prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale); + prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale); + + acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo)); + acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi)); + + a_ptr += 16; + q8 += 16; + } + + // Store accumulated results + wasm_v128_store(&aux32[0], acc0); + wasm_v128_store(&aux32[4], acc1); + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) { + sums[l] += d * aux32[l]; + } + } + + // Sum final results + float sumf = 0; + for (int l = 0; l < 8; ++l) { + sumf += sums[l]; + } + *s = sumf; + #elif defined __riscv_v_intrinsic float sumf = 0; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index fdb430a43cc21..0cbf8318bedb9 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -7,10 +7,8 @@ #include "ggml-cpu-impl.h" #include "ggml-cpu.h" #include "ggml-impl.h" -#include "ggml-quants.h" #include "ggml-cpu-quants.h" #include "ggml-threading.h" -#include "amx/amx.h" #include "ggml.h" #if defined(_MSC_VER) || defined(__MINGW32__) @@ -1291,7 +1289,7 @@ struct ggml_threadpool { atomic_int n_graph; // incremented when there is work to be done (i.e each graph) atomic_int GGML_CACHE_ALIGN n_barrier; atomic_int GGML_CACHE_ALIGN n_barrier_passed; - atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. + atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. // these are atomic as an annotation for thread-sanitizer atomic_bool stop; // Used for stopping the threadpool altogether @@ -7490,6 +7488,7 @@ UseGgmlGemm1:; if (src1->type != vec_dot_type) { char * wdata = params->wdata; + const size_t nbw0 = ggml_type_size(vec_dot_type); const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw2 = nbw1*ne11; const size_t nbw3 = nbw2*ne12; @@ -7497,6 +7496,7 @@ UseGgmlGemm1:; assert(params->wsize >= ne13*nbw3); GGML_ASSERT(src1->type == GGML_TYPE_F32); + #if 0 for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = ith; i11 < ne11; i11 += nth) { @@ -7506,6 +7506,20 @@ UseGgmlGemm1:; } } } + #else + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + size_t bs = ggml_blck_size(vec_dot_type); + int64_t ne10_block_start = (ith * ne10/bs) / nth; + int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth; + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0), + (ne10_block_end - ne10_block_start) * bs); + } + } + } + #endif } if (ith == 0) { @@ -7593,7 +7607,6 @@ UseGgmlGemm2:; if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) { num_rows_per_vec_dot = 1; } - ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); if (nth >= nchunk0 * nchunk1) { @@ -7606,6 +7619,84 @@ UseGgmlGemm2:; // ggml_compute_forward_mul_mat_id +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)] + +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +static void ggml_compute_forward_mul_mat_id_one_chunk( + struct ggml_tensor * dst, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * ids, + const int64_t cur_a, + const int64_t ir0_start, + const int64_t ir0_end, + const int64_t ir1_start, + const int64_t ir1_end, + const char * src0_cur, + const struct mmid_row_mapping * matrix_rows, + const size_t row_size, + const bool src1_cont, + const void * wdata) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const enum ggml_type type = src0->type; + + ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; + + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + float tmp[16]; + + for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) { + const int64_t _i12 = ir1; // logical row index for this expert + + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12*ne11)*row_size + : (i11*nb11 + i12*nb12)); + + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { + vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); + } + + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float)); + } + } + } +} + +static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { + + void * ptr = *p; + ptr = (void *) GGML_PAD((uintptr_t) ptr, align); + *p = (void *) ((char *) ptr + size); + return ptr; +} + static void ggml_compute_forward_mul_mat_id( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -7623,7 +7714,6 @@ static void ggml_compute_forward_mul_mat_id( const bool src1_cont = ggml_is_contiguous(src1); - ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; @@ -7641,21 +7731,27 @@ static void ggml_compute_forward_mul_mat_id( const int n_ids = ids->ne[0]; // n_expert_used const int n_as = ne02; // n_expert - char * wdata_src1_end = (src1->type == vec_dot_type) ? - (char *) params->wdata : - (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + void * wdata_cur = params->wdata; - struct mmid_row_mapping { - int32_t i1; - int32_t i2; - }; + if (src1->type != vec_dot_type) { + incr_ptr_aligned(&wdata_cur, ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + } + + int64_t * matrix_row_counts = // [n_as] + incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t)); + + struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]] + incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t)); - int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] - struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] + char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as] + incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE); + + GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata)); if (src1->type != vec_dot_type) { char * wdata = params->wdata; + const size_t nbw0 = ggml_type_size(vec_dot_type); const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw2 = nbw1*ne11; const size_t nbw3 = nbw2*ne12; @@ -7663,19 +7759,32 @@ static void ggml_compute_forward_mul_mat_id( assert(params->wsize >= ne13*nbw3); GGML_ASSERT(src1->type == GGML_TYPE_F32); +#if 0 for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + for (int64_t i12 = ith; i12 < ne12; i12 += nth) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), ne10); } } } +#else + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + size_t bs = ggml_blck_size(vec_dot_type); + int64_t ne10_block_start = (ith * ne10/bs) / nth; + int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth; + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0), + (ne10_block_end - ne10_block_start) * bs); + } + } + } +#endif } -#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] - if (ith == 0) { // initialize matrix_row_counts memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); @@ -7693,9 +7802,14 @@ static void ggml_compute_forward_mul_mat_id( } } + // reset current_chunk + for (int cur_a = ith; cur_a < n_as; cur_a += nth) { + atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a); + *current_chunk_ctr = nth; + } + ggml_barrier(params->threadpool); - // compute each matrix multiplication in sequence for (int cur_a = 0; cur_a < n_as; ++cur_a) { const int64_t cne1 = matrix_row_counts[cur_a]; @@ -7703,84 +7817,64 @@ static void ggml_compute_forward_mul_mat_id( continue; } - const char * src0_cur = (const char *) src0->data + cur_a*nb02; - - const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const char * src0_cur = (const char *) src0->data + cur_a * nb02; + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = cne1; // src1 rows - - // distribute the thread work across the inner or outer loop based on which one is larger - - const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - - const int64_t ith0 = ith % nth0; - const int64_t ith1 = ith / nth0; - - const int64_t dr0 = (nr0 + nth0 - 1)/nth0; - const int64_t dr1 = (nr1 + nth1 - 1)/nth1; + const int64_t nr0 = ne01; + const int64_t nr1 = cne1; - const int64_t ir010 = dr0*ith0; - const int64_t ir011 = MIN(ir010 + dr0, nr0); - - const int64_t ir110 = dr1*ith1; - const int64_t ir111 = MIN(ir110 + dr1, nr1); + int chunk_size = 16; + if (nr0 == 1 || nr1 == 1) { + chunk_size = 64; + } - // threads with no work simply yield (not sure if it helps) - //if (ir010 >= ir011 || ir110 >= ir111) { - // sched_yield(); - // continue; - //} +#if defined(__aarch64__) + // disable for ARM + const bool disable_chunking = true; +#else + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); +#endif // defined(__aarch64__) - // block-tiling attempt - const int64_t blck_0 = 16; - const int64_t blck_1 = 16; + int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; + int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; - // attempt to reduce false-sharing (does not seem to make a difference) - float tmp[16]; + if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) { + nchunk0 = nr0 > nr1 ? nth : 1; + nchunk1 = nr0 > nr1 ? 1 : nth; + } - for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { - for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { - for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { - const int64_t _i12 = ir1; // logical row index for this expert + const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); - const int id = row_mapping.i1; // selected expert index + int current_chunk = ith; - const int64_t i11 = id % ne11; - const int64_t i12 = row_mapping.i2; // row index in src1 + atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a); - const int64_t i1 = id; // selected expert index - const int64_t i2 = i12; // row + while (current_chunk < nchunk0 * nchunk1) { + const int64_t ith0 = current_chunk % nchunk0; + const int64_t ith1 = current_chunk / nchunk0; - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char *) wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12*ne11)*row_size - : (i11*nb11 + i12*nb12)); + const int64_t ir0_start = dr0 * ith0; + const int64_t ir0_end = MIN(ir0_start + dr0, nr0); - float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); + const int64_t ir1_start = dr1 * ith1; + const int64_t ir1_end = MIN(ir1_start + dr1, nr1); - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} + ggml_compute_forward_mul_mat_id_one_chunk( + dst, src0, src1, ids, cur_a, + ir0_start, ir0_end, ir1_start, ir1_end, + src0_cur, matrix_rows, row_size, src1_cont, wdata + ); - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); - } - - memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); - } + if (nth >= nchunk0 * nchunk1) { + break; } + + current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed); } } - -#undef MMID_MATRIX_ROW } // ggml_compute_forward_out_prod @@ -9074,10 +9168,6 @@ static void ggml_compute_forward_clamp_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->ith != 0) { - return; - } - float min; float max; memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); @@ -13717,14 +13807,19 @@ struct ggml_cplan ggml_graph_plan( cur = 0; const struct ggml_tensor * src0 = node->src[0]; const struct ggml_tensor * src1 = node->src[1]; + const struct ggml_tensor * ids = node->src[2]; const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type; + const int n_as = src0->ne[2]; + // src1 if (src1->type != vec_dot_type) { - cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); + cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t); } - const int n_as = src0->ne[2]; - cur += GGML_PAD(cur, sizeof(int64_t)); // align - cur += n_as * sizeof(int64_t); // matrix_row_counts - cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows + // matrix_row_counts + cur += n_as * sizeof(int64_t) + sizeof(int64_t); + // matrix_rows + cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t); + // atomic_current_chunk + cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE; } break; case GGML_OP_OUT_PROD: { diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 2ccb4b472d612..be4eadcd021f6 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -284,14 +284,14 @@ struct ggml_backend_cpu_device_context { &hKey) == ERROR_SUCCESS) { DWORD cpu_brand_size = 0; if (RegQueryValueExA(hKey, - TEXT("ProcessorNameString"), + "ProcessorNameString", NULL, NULL, NULL, &cpu_brand_size) == ERROR_SUCCESS) { description.resize(cpu_brand_size); if (RegQueryValueExA(hKey, - TEXT("ProcessorNameString"), + "ProcessorNameString", NULL, NULL, (LPBYTE)&description[0], // NOLINT @@ -534,9 +534,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r if (ggml_cpu_has_dotprod()) { features.push_back({ "DOTPROD", "1" }); } - if (ggml_cpu_has_matmul_int8()) { - features.push_back({ "MATMUL_INT8", "1" }); - } if (ggml_cpu_get_sve_cnt() > 0) { static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt()); features.push_back({ "SVE_CNT", sve_cnt.c_str() }); diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index c22a662876c4a..e0482c59377fd 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -280,14 +280,6 @@ template <> inline __m256bh load(const float *p) { } #endif -//////////////////////////////////////////////////////////////////////////////////////////////////// -// CONSTANTS - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; -static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl); -#endif - //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION @@ -614,6 +606,14 @@ class tinyBLAS_Q0_AVX { TC *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + const int8_t kvalues_iq4nl[16] = { + -127, -104, -83, -65, + -49, -35, -22, -10, + 1, 13, 25, 38, + 53, 69, 89, 113 + }; + + iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); } void matmul(int64_t m, int64_t n) { @@ -1038,6 +1038,7 @@ class tinyBLAS_Q0_AVX { const int64_t ldc; const int ith; const int nth; + __m128i iq4nlt; }; #endif // __AVX__ diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2a3244428521b..fd4dcfa941d4b 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -165,11 +165,11 @@ static const char * cu_get_error_str(CUresult err) { #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #endif -#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) #define GGML_CUDA_ASSUME(x) __builtin_assume(x) #else #define GGML_CUDA_ASSUME(x) -#endif // CUDART_VERSION >= 11100 +#endif // CUDART_VERSION >= 11010 #ifdef GGML_CUDA_F16 typedef half dfloat; // dequantize float diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c95728b08bfe8..093ad70991b5a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -178,11 +178,11 @@ static ggml_cuda_device_info ggml_cuda_init() { int major_version = 0; size_t version_length = 0; if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) { - std::string version(version_length, '\0'); + std::vector version(version_length+1, '\0'); if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) { - version.resize(::strlen(version.c_str())); + version.resize(::strlen(version.data())); int parsed_value = 0; - if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) { + if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) { major_version = parsed_value; } } @@ -1480,12 +1480,7 @@ static void ggml_cuda_op_mul_mat( const size_t nbytes_data = ggml_nbytes(src0); const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding); - // TODO: remove this for MUSA once the Guilty Lockup issue is resolved -#ifndef GGML_USE_MUSA CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream)); -#else // GGML_USE_MUSA - CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream)); -#endif // !GGML_USE_MUSA } // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared: @@ -2840,7 +2835,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { return false; } -#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error @@ -2852,8 +2847,10 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { } return true; #else + GGML_UNUSED(buffer); + GGML_UNUSED(size); return false; -#endif +#endif // CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) } void ggml_backend_cuda_unregister_host_buffer(void * buffer) { diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 5dacd131ed55c..10f2ebb1cb43d 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } - return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/sum.cu b/ggml/src/ggml-cuda/sum.cu index e0dafc1d2049d..f9589080a0c3b 100644 --- a/ggml/src/ggml-cuda/sum.cu +++ b/ggml/src/ggml-cuda/sum.cu @@ -1,6 +1,6 @@ -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700 +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 #define USE_CUB -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700 +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 #ifdef USE_CUB #include diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index bffe95086af7d..99d50afda2d44 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1430,6 +1430,7 @@ static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); // some shaders have a minimum subgroup size + const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); @@ -1492,13 +1493,13 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; - l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; - m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; - s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; - l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; - m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; - s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; + l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; diff --git a/include/llama.h b/include/llama.h index 3784f7d3950e5..1f5f3a09b311e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1172,6 +1172,9 @@ extern "C" { /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); + /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 + LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. diff --git a/models/templates/README.md b/models/templates/README.md new file mode 100644 index 0000000000000..72c30d1e1e08e --- /dev/null +++ b/models/templates/README.md @@ -0,0 +1,22 @@ +These templates can be updated with the following commands: + +```bash +./scripts/get_chat_template.py CohereForAI/c4ai-command-r-plus tool_use > models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja +./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 default > models/templates/CohereForAI-c4ai-command-r7b-12-2024-default.jinja +./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 rag > models/templates/CohereForAI-c4ai-command-r7b-12-2024-rag.jinja +./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use > models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja +./scripts/get_chat_template.py deepseek-ai/DeepSeek-R1-Distill-Llama-8B > models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja +./scripts/get_chat_template.py deepseek-ai/DeepSeek-R1-Distill-Qwen-32B > models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja +./scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 > models/templates/fireworks-ai-llama-3-firefunction-v2.jinja +./scripts/get_chat_template.py google/gemma-2-2b-it > models/templates/google-gemma-2-2b-it.jinja +./scripts/get_chat_template.py meetkai/functionary-medium-v3. > models/templates/meetkai-functionary-medium-v3.jinja +./scripts/get_chat_template.py meetkai/functionary-medium-v3.2 > models/templates/meetkai-functionary-medium-v3.2.jinja +./scripts/get_chat_template.py meta-llama/Llama-3.1-8B-Instruct > models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja +./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct > models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja +./scripts/get_chat_template.py meta-llama/Llama-3.3-70B-Instruct > models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja +./scripts/get_chat_template.py microsoft/Phi-3.5-mini-instruct > models/templates/microsoft-Phi-3.5-mini-instruct.jinja +./scripts/get_chat_template.py mistralai/Mistral-Nemo-Instruct-2407 > models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja +./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja +./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja +./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja +``` \ No newline at end of file diff --git a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja index 02a1c3bce33f4..c2066bd7391c2 100644 --- a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja +++ b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja @@ -1 +1 @@ -{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %} \ No newline at end of file +{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>\n'}}{% endif %} \ No newline at end of file diff --git a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja index 2ebfe7c1e32ab..c2066bd7391c2 100644 --- a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja +++ b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja @@ -1,56 +1 @@ -{% if not add_generation_prompt is defined %} -{% set add_generation_prompt = false %} -{% endif %} -{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %} -{%- for message in messages %} -{%- if message['role'] == 'system' %} -{% set ns.system_prompt = message['content'] %} -{%- endif %} -{%- endfor %} -{{bos_token}} -{{ns.system_prompt}} -{%- for message in messages %} -{%- if message['role'] == 'user' %} -{%- set ns.is_tool = false -%} -{{'<|User|>' + message['content']}} -{%- endif %} -{%- if message['role'] == 'assistant' and message['content'] is none %} -{%- set ns.is_tool = false -%} -{%- for tool in message['tool_calls']%} -{%- if not ns.is_first %} -{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} -{%- set ns.is_first = true -%} -{%- else %} -{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} -{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} -{%- endif %} -{%- endfor %} -{%- endif %} -{%- if message['role'] == 'assistant' and message['content'] is not none %} -{%- if ns.is_tool %} -{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}} -{%- set ns.is_tool = false -%} -{%- else %} -{% set content = message['content'] %} -{% if '' in content %} -{% set content = content.split('')[-1] %} -{% endif %} -{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}} -{%- endif %} -{%- endif %} -{%- if message['role'] == 'tool' %} -{%- set ns.is_tool = true -%} -{%- if ns.is_output_first %} -{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} -{%- set ns.is_output_first = false %} -{%- else %} -{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} -{%- endif %} -{%- endif %} -{%- endfor -%} -{% if ns.is_tool %} -{{'<|tool▁outputs▁end|>'}} -{% endif %} -{% if add_generation_prompt and not ns.is_tool %} -{{'<|Assistant|>'}} -{% endif %} \ No newline at end of file +{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>\n'}}{% endif %} \ No newline at end of file diff --git a/models/templates/llama-cpp-deepseek-r1.jinja b/models/templates/llama-cpp-deepseek-r1.jinja new file mode 100644 index 0000000000000..fcb1732eb8fe7 --- /dev/null +++ b/models/templates/llama-cpp-deepseek-r1.jinja @@ -0,0 +1,76 @@ +{%- if not add_generation_prompt is defined -%} + {%- set add_generation_prompt = false -%} +{%- endif -%} +{%- set ns = namespace(is_first=false, is_tool_outputs=false, is_output_first=true, system_prompt='') -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- set ns.system_prompt = message['content'] -%} + {%- endif -%} +{%- endfor -%} +{{bos_token}} +{%- if tools %} +You can call any of the following function tools to satisfy the user's requests: {{tools | map(attribute='function') | tojson(indent=2)}} + +Example function tool call syntax: + +<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>example_function_name +```json +{ + "arg1": "some_value" + ... +} +``` +<|tool▁call▁end|><|tool▁calls▁end|> + +{% endif -%} +{{ns.system_prompt}} +{%- macro flush_tool_outputs() -%} + {%- if ns.is_tool_outputs -%} + {{- '<|tool▁outputs▁end|><|end▁of▁sentence|>' -}} + {%- set ns.is_tool_outputs = false -%} + {%- endif -%} +{%- endmacro -%} +{{- flush_tool_outputs() -}} +{%- for message in messages -%} + {%- if message['role'] != 'tool' -%} + {{- flush_tool_outputs() -}} + {%- endif -%} + {%- if message['role'] == 'user' -%} + {{- '<|User|>' + message['content'] + '<|end▁of▁sentence|>' -}} + {%- endif -%} + {%- if message['role'] == 'assistant' and message['content'] is none -%} + {{- '<|Assistant|><|tool▁calls▁begin|>' -}} + {%- set ns.is_first = true -%} + {%- for tc in message['tool_calls'] -%} + {%- if ns.is_first -%} + {%- set ns.is_first = false -%} + {%- else -%} + {{- '\n' -}} + {%- endif -%} + {%- set tool_name = tc['function']['name'] -%} + {%- set tool_args = tc['function']['arguments'] -%} + {{- '<|tool▁call▁begin|>' + tc['type'] + '<|tool▁sep|>' + tool_name + '\n' + '```json' + '\n' + tool_args + '\n' + '```' + '<|tool▁call▁end|>' -}} + {%- endfor -%} + {{- '<|tool▁calls▁end|><|end▁of▁sentence|>' -}} + {%- endif -%} + {%- if message['role'] == 'assistant' and message['content'] is not none -%} + {{- flush_tool_outputs() -}} + {%- set content = message['content'] -%} + {%- if '' in content -%} + {%- set content = content.split('')[-1] -%} + {%- endif -%} + {{- '<|Assistant|>' + content + '<|end▁of▁sentence|>' -}} + {%- endif -%} + {%- if message['role'] == 'tool' -%} + {%- set ns.is_tool_outputs = true -%} + {%- if ns.is_output_first -%} + {{- '<|tool▁outputs▁begin|>' -}} + {%- set ns.is_output_first = false -%} + {%- endif -%} + {{- '\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>' -}} + {%- endif -%} +{%- endfor -%} +{{- flush_tool_outputs() -}} +{%- if add_generation_prompt and not ns.is_tool_outputs -%} + {{- '<|Assistant|>\n' -}} +{%- endif -%} \ No newline at end of file diff --git a/scripts/get_chat_template.py b/scripts/get_chat_template.py old mode 100644 new mode 100755 index e8982d11ad7ba..d8143e4005dec --- a/scripts/get_chat_template.py +++ b/scripts/get_chat_template.py @@ -7,9 +7,8 @@ ./scripts/get_chat_template.py model_id [variant] Examples: - ./scripts/get_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct - ./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use - ./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct + ./scripts/get_chat_template.py CohereForAI/c4ai-command-r-plus tool_use + ./scripts/get_chat_template.py microsoft/Phi-3.5-mini-instruct ''' import json diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 255109ab71e0c..f99625dca0396 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -08b538031f7f944e84f472483ef5d26bf5190ead +98a61a0d0b43cba06c3ac1c603813639552a0701 diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 9b518d1ac64a5..46e27a96ed728 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1186,7 +1186,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token return; } } - LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str()); + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str()); return; } } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 252d54d4c9f45..b143d834cfabe 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -116,7 +116,7 @@ struct llama_grammar { llama_partial_utf8 partial_utf8; // lazy grammars wait for trigger words or tokens before constraining the sampling. - // we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens. + // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens. // (useful e.g. for tool_choice=required) bool lazy = false; bool awaiting_trigger = false; // Initialized to true for lazy grammars only diff --git a/src/llama-impl.h b/src/llama-impl.h index 12d1fb0828d3a..02b1d07f8400d 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -6,13 +6,13 @@ #include #ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# if defined(__MINGW32__) && !defined(__clang__) +# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# else +# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +# endif #else -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif -#else -#define LLAMA_ATTRIBUTE_FORMAT(...) +# define LLAMA_ATTRIBUTE_FORMAT(...) #endif // diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index dca6f3998c645..1ed688e3b5b7e 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -37,7 +37,7 @@ struct llama_kv_cache { bool can_shift = false; // Note: The value of head isn't only used to optimize searching - // for a free KV slot. llama_decode_internal also uses it, so it + // for a free KV slot. llama_decode_impl also uses it, so it // cannot be freely changed after a slot has been allocated. uint32_t head = 0; uint32_t size = 0; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 990b6129746de..f40bf2db83a80 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1698,6 +1698,73 @@ struct llama_sampler * llama_sampler_init_penalties( ); } +// top-n-sigma + +struct llama_sampler_top_n_sigma { + const float n; +}; + +static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { + return "top-n-sigma"; +} + +static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; + + // find max logit and calculate mean + float max = cur_p->data[0].logit; + float logits_sum = 0; + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit > max) { + max = cur_p->data[i].logit; + } + logits_sum += cur_p->data[i].logit; + } + float mean = logits_sum/cur_p->size; + + // calculate standard deviation + float acc = 0; + for (size_t i = 0; i < cur_p->size; ++i) { + acc += pow(cur_p->data[i].logit - mean, 2); + } + float std = sqrt(acc/cur_p->size); + + //apply mask + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit < max - (ctx->n * std)) { + cur_p->data[i].logit = -INFINITY; + } + } + llama_sampler_softmax_impl(cur_p); +} + +static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx; + return llama_sampler_init_top_n_sigma(ctx->n); +} + +static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_n_sigma *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_top_n_sigma_i = { + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, +}; + +struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_top_n_sigma_i, + /* .ctx = */ new llama_sampler_top_n_sigma { + /* .n = */ n, + } + ); +} + // DRY struct llama_sampler_dry { diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index b78da2cdb48ba..2836caf6a71a3 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -24,7 +24,10 @@ static common_chat_msg msg_from_json(const json & message) { ret.content = message.at("content"); } if (message.contains("tool_plan")) { - ret.tool_plan = message.at("tool_plan"); + ret.reasoning_content = message.at("tool_plan"); + } + if (message.contains("reasoning_content")) { + ret.reasoning_content = message.at("reasoning_content"); } auto has_tool_calls = message.contains("tool_calls"); if (has_tool_calls) { @@ -105,6 +108,7 @@ static std::string dump(const json & j) { static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) { assert_equals(expected.role, actual.role); assert_equals(expected.content, actual.content); + assert_equals(expected.reasoning_content, actual.reasoning_content); assert_equals(expected.tool_calls.size(), actual.tool_calls.size()); for (size_t i = 0; i < expected.tool_calls.size(); i++) { const auto & expected_tool_call = expected.tool_calls[i]; @@ -176,13 +180,15 @@ struct delta_data { static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools, - const json & tool_choice) { + const json & tool_choice, + bool think = false) { common_chat_inputs inputs; inputs.parallel_tool_calls = true; inputs.messages = json::array(); inputs.messages.push_back(user_message); inputs.tools = tools; inputs.tool_choice = tool_choice; + inputs.extract_reasoning = think; auto params_prefix = common_chat_params_init(tmpl, inputs); inputs.messages.push_back(delta_message); @@ -192,17 +198,24 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto std::string prefix = params_prefix.prompt; std::string full = params_full.prompt; - // Check full starts with prefix - if (full.find(prefix) != 0) { - fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str()); - throw std::runtime_error("Full message does not start with prefix"); - } - if (full == prefix) { throw std::runtime_error("Full message is the same as the prefix"); } - auto delta = full.substr(prefix.size()); + size_t common_prefix_length = 0; + for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) { + if (prefix[i] != full[i]) { + break; + } + if (prefix[i] == '<') { + // DeepSeek R1's template (as of 20250209) adds a trailing if add_generation_prompt, + // but it removes thinking tags for past messages. + // The prefix and full strings diverge at vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. + continue; + } + common_prefix_length = i + 1; + } + auto delta = full.substr(common_prefix_length); // Strip end tokens for (const auto & end_token : end_tokens) { @@ -223,7 +236,9 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto */ static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, const json & test_message, const json & tools = {}, const std::string & expected_delta = "", - bool expect_grammar_triggered = true) { + bool expect_grammar_triggered = true, + bool test_grammar_if_triggered = true, + bool think = false) { common_chat_msg expected_msg = msg_from_json(test_message); auto user_message = json{ @@ -232,7 +247,7 @@ static void test_template(const common_chat_template & tmpl, const std::vectorI'm thinkingHello, world!\nWhat's up?" }, + }; + json message_assist_thoughts_unparsed_r7b { + { "role", "assistant" }, + { "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" }, + }; + json message_assist_thoughts { { "role", "assistant" }, { "content", "Hello, world!\nWhat's up?" }, + { "reasoning_content", "I'm thinking" }, }; json tool_calls = json::array({{ { "type", "function" }, { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } }, }}); - json tool_call_message { + json message_assist_call { { "role", "assistant"}, { "content", {}}, { "tool_calls", { @@ -305,7 +337,34 @@ static void test_template_output_parsers() { }, }}, }; - json tool_call_message_with_id { + json message_assist_call_thoughts = { + { "role", "assistant" }, + { "content", nullptr }, + { "reasoning_content", "I'm\nthinking" }, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + }, + }}, + }; + json message_assist_call_thoughts_unparsed = { + { "role", "assistant" }, + { "content", "I'm\nthinking" }, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + }, + }}, + }; + json message_assist_call_id { { "role", "assistant"}, { "content", {}}, { "tool_calls", { @@ -322,10 +381,9 @@ static void test_template_output_parsers() { { "content", {} }, { "tool_calls", tool_calls } }; - json tool_call_plan_message_with_idx { + json message_assist_call_idx { { "role", "assistant"}, { "content", {}}, - { "tool_plan", "I'm not so sure"}, { "tool_calls", { { { "type", "function" }, @@ -341,8 +399,10 @@ static void test_template_output_parsers() { { "content", {} }, { "tool_calls", tool_calls } }; + json message_assist_call_tool_plan_idx = message_assist_call_idx; + message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking"; - auto python_tool_call_message = json{ + auto python_message_assist_call = json{ { "role", "assistant" }, { "content", {} }, { "tool_calls", json{ { @@ -357,7 +417,7 @@ static void test_template_output_parsers() { } }, } } } }; - auto code_interpreter_tool_call_message = json{ + auto code_interpreter_message_assist_call = json{ { "role", "assistant" }, { "content", {} }, { "tool_calls", json{ { @@ -374,17 +434,27 @@ static void test_template_output_parsers() { }; common_chat_inputs inputs_no_tools; - inputs_no_tools.messages = { - { { "role", "user" }, { "content", "Hey\nThere" } } - }; + inputs_no_tools.messages = json::array({message_user}); + inputs_no_tools.extract_reasoning = false; - common_chat_inputs inputs_tools = inputs_no_tools; - inputs_tools.tools = json::array(); - inputs_tools.tools.push_back(special_function_tool); + common_chat_inputs inputs_no_tools_think; + inputs_no_tools_think.messages = json::array({message_user}); + inputs_no_tools_think.extract_reasoning = true; - common_chat_inputs inputs_tools_builtin = inputs_no_tools; - inputs_tools_builtin.tools = json::array(); - inputs_tools_builtin.tools.push_back(python_tool); + common_chat_inputs inputs_tools; + inputs_tools.messages = json::array({message_user}); + inputs_tools.tools = json::array({special_function_tool}); + inputs_tools.extract_reasoning = false; + + common_chat_inputs inputs_tools_think; + inputs_tools_think.messages = json::array({message_user}); + inputs_tools_think.tools = json::array({special_function_tool}); + inputs_tools_think.extract_reasoning = true; + + common_chat_inputs inputs_tools_builtin; + inputs_tools_builtin.messages = json::array({message_user}); + inputs_tools_builtin.tools = json::array({python_tool}); + inputs_tools_builtin.extract_reasoning = false; { // Not supported yet @@ -395,15 +465,53 @@ static void test_template_output_parsers() { const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "", ""); std::vector end_tokens{ "<|END_OF_TURN_TOKEN|>" }; - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format); - - test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools, - "<|START_THINKING|>I'm not so sure<|END_THINKING|>" + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + + assert_msg_equals(msg_from_json(message_assist), + common_chat_parse( + "Hello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_COMMAND_R7B)); + assert_msg_equals(msg_from_json(message_assist), + common_chat_parse( + "Hello, world!\nWhat's up?<|END_RESPONSE|>", + COMMON_CHAT_FORMAT_COMMAND_R7B)); + assert_msg_equals(msg_from_json(message_assist), + common_chat_parse( + "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", + COMMON_CHAT_FORMAT_COMMAND_R7B)); + assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b), + common_chat_parse( + "<|START_THINKING|>I'm thinking<|END_THINKING|>" + "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", + COMMON_CHAT_FORMAT_COMMAND_R7B)); + assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b), + common_chat_parse( + "<|START_THINKING|>I'm thinking<|END_THINKING|>" + "Hello, world!\nWhat's up?<|END_RESPONSE|>", + COMMON_CHAT_FORMAT_COMMAND_R7B)); + + assert_msg_equals(msg_from_json(message_assist_thoughts), + common_chat_parse( + "<|START_THINKING|>I'm thinking<|END_THINKING|>" + "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", + COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING)); + + test_template(tmpl, end_tokens, message_assist_call_idx, tools, + "<|START_THINKING|><|END_THINKING|>" "<|START_ACTION|>[\n" " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" "]<|END_ACTION|>"); - test_template(tmpl, end_tokens, text_message, tools, + test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools, + "<|START_THINKING|>I'm thinking<|END_THINKING|>" + "<|START_ACTION|>[\n" + " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" + "]<|END_ACTION|>", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ true, + /* think= */ true); + test_template(tmpl, end_tokens, message_assist, tools, "<|START_RESPONSE|>Hello, world!\n" "What's up?<|END_RESPONSE|>", /* expect_grammar_triggered= */ false); @@ -423,12 +531,12 @@ static void test_template_output_parsers() { // Generic tool calls doesn't generate / parse content-only messages symmetrically. - assert_msg_equals(msg_from_json(text_message), + assert_msg_equals(msg_from_json(message_assist), common_chat_parse("{\n" " \"response\": \"Hello, world!\\nWhat's up?\"\n" "}", common_chat_params_init(tmpl, inputs_tools).format)); - test_template(tmpl, end_tokens, tool_call_message_with_id, tools, + test_template(tmpl, end_tokens, message_assist_call_id, tools, "{\n" " \"tool_calls\": [\n" " {\n" @@ -448,9 +556,9 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_template( - tmpl, end_tokens, tool_call_message_with_id, tools, + tmpl, end_tokens, message_assist_call_id, tools, "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]"); } { @@ -473,12 +581,12 @@ static void test_template_output_parsers() { inputs_tools) .format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist_call, tools, "\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" ""); - test_template(tmpl, end_tokens, python_tool_call_message, tools, + test_template(tmpl, end_tokens, python_message_assist_call, tools, "\n" "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n" ""); @@ -498,12 +606,12 @@ static void test_template_output_parsers() { inputs_tools_builtin) .format); - // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools, + // test_template(tmpl, end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, code_interpreter_message_assist_call, llama_3_1_tools, "<|python_tag|>code_interpreter.call(code=\"print('hey')\")"); - test_template(tmpl, end_tokens, python_tool_call_message, tools, + test_template(tmpl, end_tokens, python_message_assist_call, tools, "<|python_tag|>python.call(code=\"print('hey')\")"); - test_template(tmpl, end_tokens, tool_call_message, tools, + test_template(tmpl, end_tokens, message_assist_call, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { @@ -513,8 +621,8 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist_call, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { @@ -525,8 +633,8 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist_call, tools, "{\"arg1\": 1}"); } { @@ -537,12 +645,12 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format); assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, {}, + test_template(tmpl, end_tokens, message_assist, {}, "all\n" "Hello, world!\n" "What's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, + test_template(tmpl, end_tokens, message_assist_call, tools, "special_function\n" "{\"arg1\": 1}"); } @@ -553,23 +661,79 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist_call, tools, " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]"); } { + // Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt. const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "", ""); std::vector end_tokens{ "<|end▁of▁sentence|>" }; - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think), + common_chat_parse("I'm thinkingHello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_DEEPSEEK_R1)); + assert_msg_equals(msg_from_json(message_assist_thoughts), + common_chat_parse("I'm thinkingHello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); + assert_msg_equals(msg_from_json(message_assist_thoughts), + // Latest template update (ast of 20250209) adds a trailing \n if add_generation_prompt is true. + common_chat_parse("I'm thinkingHello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); + // test_template(tmpl, end_tokens, message_assist_call, tools, + // "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + // "```json\n" + // "{\"arg1\": 1}\n" + // // Look what's not here: <|tool▁calls▁end|> (also missing the <|end▁of▁sentence|>, but that is removed lazily by the test's delta logic) + // "```<|tool▁call▁end|>", + // /* expect_grammar_triggered= */ true, + // /* test_grammar_if_triggered= */ false); + } + { + // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all. + const common_chat_template tmpl(read_file("models/templates/llama-cpp-deepseek-r1.jinja"), + "", ""); + std::vector end_tokens{ "<|end▁of▁sentence|>" }; - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, - "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" - "```json\n" - "{\"arg1\": 1}\n" - "```<|tool▁call▁end|>"); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think), + common_chat_parse("I'm thinkingHello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_DEEPSEEK_R1)); + assert_msg_equals(msg_from_json(message_assist_thoughts), + common_chat_parse("I'm thinkingHello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); + + assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed), + common_chat_parse( + "I'm\nthinking\n\n" + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + "```json\n" + "{\"arg1\": 1}\n" + "```<|tool▁call▁end|><|tool▁calls▁end|>", + COMMON_CHAT_FORMAT_DEEPSEEK_R1)); + assert_msg_equals(msg_from_json(message_assist_call_thoughts), + common_chat_parse( + "I'm\nthinking\n\n" + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + "```json\n" + "{\"arg1\": 1}\n" + "```<|tool▁call▁end|><|tool▁calls▁end|>", + COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); + test_template(tmpl, end_tokens, message_assist_call, tools, + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + "```json\n" + "{\"arg1\": 1}\n" + "```<|tool▁call▁end|><|tool▁calls▁end|>"); } } @@ -586,16 +750,20 @@ int main(int argc, char ** argv) { std::cout << "|----------|--------|\n"; for (int i = 1; i < argc; i++) { - std::string path = argv[i]; - if (path.rfind(".jinja") != path.size() - 6) { - std::cerr << "Skipping non-jinja file: " << path << std::endl; - continue; + try { + std::string path = argv[i]; + if (path.rfind(".jinja") != path.size() - 6) { + std::cerr << "Skipping non-jinja file: " << path << std::endl; + continue; + } + common_chat_template tmpl(read_file(path), "", ""); + auto parts = string_split(path, "/"); + auto name = parts[parts.size() - 1]; + auto format = common_chat_format_name(common_chat_params_init(tmpl, inputs).format); + std::cout << "| " << name << " | " << format << " |\n"; + } catch (const std::exception & e) { + std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl; } - common_chat_template tmpl(read_file(path), "", ""); - auto parts = string_split(path, "/"); - auto name = parts[parts.size() - 1]; - std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format) - << " |\n"; } } else #endif diff --git a/tests/test-gguf.cpp b/tests/test-gguf.cpp index 6ed696328d71a..eaf572c666410 100644 --- a/tests/test-gguf.cpp +++ b/tests/test-gguf.cpp @@ -697,8 +697,8 @@ static std::pair test_handcrafted_file(const unsigned int seed) { #ifdef _WIN32 if (!file) { - printf("%s: failed to create tmpfile(), needs elevated privileges on Windows"); - printf("%s: skipping tests"); + printf("failed to create tmpfile(), needs elevated privileges on Windows"); + printf("skipping tests"); continue; } #else @@ -1086,8 +1086,8 @@ static std::pair test_roundtrip(ggml_backend_dev_t dev, const unsigned #ifdef _WIN32 if (!file) { - printf("%s: failed to create tmpfile(), needs elevated privileges on Windows"); - printf("%s: skipping tests"); + printf("failed to create tmpfile(), needs elevated privileges on Windows"); + printf("skipping tests"); return std::make_pair(0, 0); } #else diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 61bd6785044ea..f1f87d454d464 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -181,6 +181,17 @@ static void test_dry( tester.check(); } +static void test_top_n_sigma(const std::vector & probs, const std::vector & probs_expected, int n) { + sampler_tester tester(probs, probs_expected); + + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_top_n_sigma(n)); + tester.apply(llama_sampler_init_dist (0)); + DUMP(&tester.cur_p); + + tester.check(); +} + static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { sampler_tester tester(n_vocab); @@ -348,6 +359,10 @@ int main(void) { test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {}); test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {}); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.00f); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f); + test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f); test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);