diff --git a/.gitignore b/.gitignore index 7b302d698b24b..55efc1fbf6d2f 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,8 @@ lcov-report/ tags .build/ build* +release +debug !build-info.cmake !build-info.cpp.in !build-info.sh diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8fab0de6f14db..e68ff92445828 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -39,7 +39,7 @@ _(NOTE: this guideline is yet to be applied to the `llama.cpp` codebase. New code should follow this guideline.)_ -- Try to follow the existing patterns in the code (indentation, spaces, etc.). In case of doubt use `clang-format` to format the added code +- Try to follow the existing patterns in the code (indentation, spaces, etc.). In case of doubt use `clang-format` (from clang-tools v15+) to format the added code - For anything not covered in the current guidelines, refer to the [C++ Core Guidelines](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines) - Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices - Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggml-org/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$ diff --git a/README.md b/README.md index f70c8ae1ed851..f7f1a521e40a8 100644 --- a/README.md +++ b/README.md @@ -219,7 +219,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [llama_cpp_canister](https://github.com/onicai/llama_cpp_canister) - llama.cpp as a smart contract on the Internet Computer, using WebAssembly - [llama-swap](https://github.com/mostlygeek/llama-swap) - transparent proxy that adds automatic model switching with llama-server - [Kalavai](https://github.com/kalavai-net/kalavai-client) - Crowdsource end to end LLM deployment at any scale - +- [llmaz](https://github.com/InftyAI/llmaz) - ☸️ Easy, advanced inference platform for large language models on Kubernetes.
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8b7c75d85a6f5..6358a94e9b55f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -699,6 +699,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "b3f499bb4255f8ca19fccd664443283318f2fd2414d5e0b040fbdd0cc195d6c5": # ref: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B res = "deepseek-r1-qwen" + if chkhsh == "ccc2ef013c104be7bae2965776d611e1d7a8a2a9c547dd93a682c9a9fc80352e": + # ref: https://huggingface.co/Xenova/gpt-4o + res = "gpt-4o" if res is None: logger.warning("\n") @@ -2512,7 +2515,8 @@ def set_gguf_parameters(self): rms_eps = self.find_hparam(["rms_norm_eps"]) max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) - rope_dims = n_embd // n_head + rot_pct = self.hparams.get("partial_rotary_factor", 1.0) + rope_dims = int(rot_pct * n_embd) // n_head self.gguf_writer.add_context_length(max_pos_embds) self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds) @@ -2536,7 +2540,8 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: n_head = self.find_hparam(["num_attention_heads", "n_head"]) max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) - rope_dims = n_embd // n_head + rot_pct = self.hparams.get("partial_rotary_factor", 1.0) + rope_dims = int(rot_pct * n_embd) // n_head # write rope scaling for long context (128k) model rope_scaling = self.find_hparam(['rope_scaling'], True) @@ -2565,7 +2570,7 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: - raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') + raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}. long_factors = {len(long_factors)}, short_factors = {len(short_factors)}.') yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index fa4989a80c544..07d3ce0e4eb78 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -109,6 +109,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, {"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"}, + {"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", }, ] @@ -131,6 +132,10 @@ def download_model(model): files = ["config.json", "tokenizer.json", "tokenizer_config.json"] + if name == "gpt-4o": + # Xenova/gpt-4o is tokenizer-only, it does not contain config.json + files = ["tokenizer.json", "tokenizer_config.json"] + if tokt == TOKENIZER_TYPE.SPM: files.append("tokenizer.model") diff --git a/docs/function-calling.md b/docs/function-calling.md new file mode 100644 index 0000000000000..92cb6531ab34a --- /dev/null +++ b/docs/function-calling.md @@ -0,0 +1,390 @@ +# Function Calling + +[chat.h](../common/chat.h) (https://github.com/ggml-org/llama.cpp/pull/9639) adds support for [OpenAI-style function calling](https://platform.openai.com/docs/guides/function-calling) and is used in: +- `llama-server` when started w/ `--jinja` flag +- `llama-cli` (WIP: https://github.com/ggml-org/llama.cpp/pull/11556) + +## Universal support w/ Native & Generic handlers + +Function calling is supported for all models (see https://github.com/ggml-org/llama.cpp/pull/9639): + +- Native tool call formats supported: + - Llama 3.1 / 3.3 (including builtin tools support - tool names for `wolfram_alpha`, `web_search` / `brave_search`, `code_interpreter`), Llama 3.2 + - Functionary v3.1 / v3.2 + - Hermes 2/3, Qwen 2.5 + - Qwen 2.5 Coder (WIP: https://github.com/ggml-org/llama.cpp/pull/12034) + - Mistral Nemo + - Firefunction v2 + - Command R7B + - DeepSeek R1 (WIP / seems reluctant to call any tools?) + +- Generic tool call is supported when the template isn't recognized by native format handlers (you'll see `Chat format: Generic` in the logs). + - Use `--chat-template-file` to override the template when appropriate (see examples below) + - Generic support may consume more tokens and be less efficient than a model's native format. + +
+Show some common templates and which format handler they use + +| Template | Format | +|----------|--------| +| Almawave-Velvet-14B.jinja | Hermes 2 Pro | +| AtlaAI-Selene-1-Mini-Llama-3.1-8B.jinja | Llama 3.x | +| CohereForAI-aya-expanse-8b.jinja | Generic | +| CohereForAI-c4ai-command-r-plus-default.jinja | Generic | +| CohereForAI-c4ai-command-r-plus-rag.jinja | Generic | +| CohereForAI-c4ai-command-r-plus-tool_use.jinja | Generic | +| CohereForAI-c4ai-command-r7b-12-2024-default.jinja | Command R7B (extract reasoning) | +| CohereForAI-c4ai-command-r7b-12-2024-rag.jinja | Command R7B (extract reasoning) | +| CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja | Command R7B (extract reasoning) | +| CohereForAI-c4ai-command-r7b-12-2024.jinja | Generic | +| DavieLion-Llama-3.2-1B-SPIN-iter3.jinja | Generic | +| Delta-Vector-Rei-12B.jinja | Mistral Nemo | +| EpistemeAI-Mistral-Nemo-Instruct-12B-Philosophy-Math.jinja | Mistral Nemo | +| FlofloB-83k_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit.jinja | Hermes 2 Pro | +| FlofloB-test_continued_pretraining_Phi-3-mini-4k-instruct_Unsloth_merged_16bit.jinja | Generic | +| HelpingAI-HAI-SER.jinja | Generic | +| HuggingFaceTB-SmolLM2-1.7B-Instruct.jinja | Generic | +| HuggingFaceTB-SmolLM2-135M-Instruct.jinja | Generic | +| HuggingFaceTB-SmolLM2-360M-Instruct.jinja | Generic | +| INSAIT-Institute-BgGPT-Gemma-2-27B-IT-v1.0.jinja | Generic | +| Ihor-Text2Graph-R1-Qwen2.5-0.5b.jinja | Hermes 2 Pro | +| Infinigence-Megrez-3B-Instruct.jinja | Generic | +| Josephgflowers-TinyLlama_v1.1_math_code-world-test-1.jinja | Generic | +| LGAI-EXAONE-EXAONE-3.5-2.4B-Instruct.jinja | Generic | +| LGAI-EXAONE-EXAONE-3.5-7.8B-Instruct.jinja | Generic | +| LatitudeGames-Wayfarer-12B.jinja | Generic | +| Magpie-Align-Llama-3-8B-Magpie-Align-v0.1.jinja | Generic | +| Magpie-Align-Llama-3.1-8B-Magpie-Align-v0.1.jinja | Generic | +| MaziyarPanahi-calme-3.2-instruct-78b.jinja | Generic | +| MiniMaxAI-MiniMax-Text-01.jinja | Generic | +| MiniMaxAI-MiniMax-VL-01.jinja | Generic | +| NaniDAO-deepseek-r1-qwen-2.5-32B-ablated.jinja | DeepSeek R1 (extract reasoning) | +| NexaAIDev-Octopus-v2.jinja | Generic | +| NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | Generic | +| NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | Hermes 2 Pro | +| NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | Generic | +| NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | Hermes 2 Pro | +| NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | Generic | +| NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | Hermes 2 Pro | +| NovaSky-AI-Sky-T1-32B-Flash.jinja | Hermes 2 Pro | +| NovaSky-AI-Sky-T1-32B-Preview.jinja | Hermes 2 Pro | +| OnlyCheeini-greesychat-turbo.jinja | Generic | +| Orenguteng-Llama-3.1-8B-Lexi-Uncensored-V2.jinja | Llama 3.x | +| OrionStarAI-Orion-14B-Chat.jinja | Generic | +| PowerInfer-SmallThinker-3B-Preview.jinja | Generic | +| PrimeIntellect-INTELLECT-1-Instruct.jinja | Generic | +| Qwen-QVQ-72B-Preview.jinja | Generic | +| Qwen-QwQ-32B-Preview.jinja | Hermes 2 Pro | +| Qwen-Qwen1.5-7B-Chat.jinja | Generic | +| Qwen-Qwen2-7B-Instruct.jinja | Generic | +| Qwen-Qwen2-VL-72B-Instruct.jinja | Generic | +| Qwen-Qwen2-VL-7B-Instruct.jinja | Generic | +| Qwen-Qwen2.5-0.5B.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-1.5B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-14B-Instruct-1M.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-14B.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-32B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-32B.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-3B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-72B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-7B-Instruct-1M.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-7B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-7B.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-Coder-32B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-Coder-7B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-Math-1.5B.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-Math-7B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-VL-3B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-VL-72B-Instruct.jinja | Hermes 2 Pro | +| Qwen-Qwen2.5-VL-7B-Instruct.jinja | Hermes 2 Pro | +| RWKV-Red-Team-ARWKV-7B-Preview-0.1.jinja | Hermes 2 Pro | +| SakanaAI-TinySwallow-1.5B-Instruct.jinja | Hermes 2 Pro | +| SakanaAI-TinySwallow-1.5B.jinja | Hermes 2 Pro | +| Sao10K-70B-L3.3-Cirrus-x1.jinja | Llama 3.x | +| SentientAGI-Dobby-Mini-Leashed-Llama-3.1-8B.jinja | Llama 3.x | +| SentientAGI-Dobby-Mini-Unhinged-Llama-3.1-8B.jinja | Llama 3.x | +| Steelskull-L3.3-Damascus-R1.jinja | Llama 3.x | +| Steelskull-L3.3-MS-Nevoria-70b.jinja | Llama 3.x | +| Steelskull-L3.3-Nevoria-R1-70b.jinja | Llama 3.x | +| THUDM-glm-4-9b-chat.jinja | Generic | +| THUDM-glm-edge-1.5b-chat.jinja | Generic | +| Tarek07-Progenitor-V1.1-LLaMa-70B.jinja | Llama 3.x | +| TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | Generic | +| TinyLlama-TinyLlama-1.1B-Chat-v1.0.jinja | Generic | +| UCLA-AGI-Mistral7B-PairRM-SPPO-Iter3.jinja | Generic | +| ValiantLabs-Llama3.1-8B-Enigma.jinja | Llama 3.x | +| abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | Generic | +| ai21labs-AI21-Jamba-1.5-Large.jinja | Generic | +| allenai-Llama-3.1-Tulu-3-405B-SFT.jinja | Generic | +| allenai-Llama-3.1-Tulu-3-405B.jinja | Generic | +| allenai-Llama-3.1-Tulu-3-8B.jinja | Generic | +| arcee-ai-Virtuoso-Lite.jinja | Hermes 2 Pro | +| arcee-ai-Virtuoso-Medium-v2.jinja | Hermes 2 Pro | +| arcee-ai-Virtuoso-Small-v2.jinja | Hermes 2 Pro | +| avemio-GRAG-NEMO-12B-ORPO-HESSIAN-AI.jinja | Generic | +| bespokelabs-Bespoke-Stratos-7B.jinja | Hermes 2 Pro | +| bfuzzy1-acheron-m1a-llama.jinja | Generic | +| bofenghuang-vigogne-2-70b-chat.jinja | Generic | +| bytedance-research-UI-TARS-72B-DPO.jinja | Generic | +| bytedance-research-UI-TARS-7B-DPO.jinja | Generic | +| bytedance-research-UI-TARS-7B-SFT.jinja | Generic | +| carsenk-phi3.5_mini_exp_825_uncensored.jinja | Generic | +| cyberagent-DeepSeek-R1-Distill-Qwen-14B-Japanese.jinja | DeepSeek R1 (extract reasoning) | +| cyberagent-DeepSeek-R1-Distill-Qwen-32B-Japanese.jinja | DeepSeek R1 (extract reasoning) | +| databricks-dbrx-instruct.jinja | Generic | +| deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | Generic | +| deepseek-ai-DeepSeek-Coder-V2-Lite-Base.jinja | Generic | +| deepseek-ai-DeepSeek-Coder-V2-Lite-Instruct.jinja | Generic | +| deepseek-ai-DeepSeek-R1-Distill-Llama-70B.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1-Distill-Qwen-1.5B.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1-Distill-Qwen-14B.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1-Zero.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-V2-Lite.jinja | Generic | +| deepseek-ai-DeepSeek-V2.5.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-DeepSeek-V3.jinja | DeepSeek R1 (extract reasoning) | +| deepseek-ai-deepseek-coder-33b-instruct.jinja | Generic | +| deepseek-ai-deepseek-coder-6.7b-instruct.jinja | Generic | +| deepseek-ai-deepseek-coder-7b-instruct-v1.5.jinja | Generic | +| deepseek-ai-deepseek-llm-67b-chat.jinja | Generic | +| deepseek-ai-deepseek-llm-7b-chat.jinja | Generic | +| dicta-il-dictalm2.0-instruct.jinja | Generic | +| ehristoforu-Falcon3-8B-Franken-Basestruct.jinja | Hermes 2 Pro | +| fireworks-ai-llama-3-firefunction-v2.jinja | FireFunction v2 | +| godlikehhd-alpaca_data_sampled_ifd_new_5200.jinja | Hermes 2 Pro | +| godlikehhd-alpaca_data_score_max_0.7_2600.jinja | Hermes 2 Pro | +| google-gemma-2-27b-it.jinja | Generic | +| google-gemma-2-2b-it.jinja | Generic | +| google-gemma-2-2b-jpn-it.jinja | Generic | +| google-gemma-7b-it.jinja | Generic | +| huihui-ai-DeepSeek-R1-Distill-Llama-70B-abliterated.jinja | DeepSeek R1 (extract reasoning) | +| huihui-ai-DeepSeek-R1-Distill-Llama-8B-abliterated.jinja | DeepSeek R1 (extract reasoning) | +| huihui-ai-DeepSeek-R1-Distill-Qwen-14B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) | +| huihui-ai-DeepSeek-R1-Distill-Qwen-32B-abliterated.jinja | DeepSeek R1 (extract reasoning) | +| huihui-ai-DeepSeek-R1-Distill-Qwen-7B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) | +| huihui-ai-Qwen2.5-14B-Instruct-1M-abliterated.jinja | Hermes 2 Pro | +| ibm-granite-granite-3.1-8b-instruct.jinja | Generic | +| indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | Generic | +| inflatebot-MN-12B-Mag-Mell-R1.jinja | Generic | +| jinaai-ReaderLM-v2.jinja | Generic | +| kms7530-chemeng_qwen-math-7b_24_1_100_1_nonmath.jinja | Hermes 2 Pro | +| knifeayumu-Cydonia-v1.3-Magnum-v4-22B.jinja | Mistral Nemo | +| langgptai-qwen1.5-7b-chat-sa-v0.1.jinja | Generic | +| lightblue-DeepSeek-R1-Distill-Qwen-7B-Japanese.jinja | DeepSeek R1 (extract reasoning) | +| mattshumer-Reflection-Llama-3.1-70B.jinja | Generic | +| meetkai-functionary-medium-v3.1.jinja | Functionary v3.1 Llama 3.1 | +| meetkai-functionary-medium-v3.2.jinja | Functionary v3.2 | +| meta-llama-Llama-2-7b-chat-hf.jinja | Generic | +| meta-llama-Llama-3.1-8B-Instruct.jinja | Llama 3.x | +| meta-llama-Llama-3.2-11B-Vision-Instruct.jinja | Llama 3.x | +| meta-llama-Llama-3.2-1B-Instruct.jinja | Llama 3.x | +| meta-llama-Llama-3.2-3B-Instruct.jinja | Llama 3.x | +| meta-llama-Llama-3.3-70B-Instruct.jinja | Llama 3.x | +| meta-llama-Meta-Llama-3-8B-Instruct.jinja | Generic | +| meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | Llama 3.x | +| microsoft-Phi-3-medium-4k-instruct.jinja | Generic | +| microsoft-Phi-3-mini-4k-instruct.jinja | Generic | +| microsoft-Phi-3-small-8k-instruct.jinja | Generic | +| microsoft-Phi-3.5-mini-instruct.jinja | Generic | +| microsoft-Phi-3.5-vision-instruct.jinja | Generic | +| microsoft-phi-4.jinja | Generic | +| migtissera-Tess-3-Mistral-Nemo-12B.jinja | Generic | +| ministral-Ministral-3b-instruct.jinja | Generic | +| mistralai-Codestral-22B-v0.1.jinja | Generic | +| mistralai-Mistral-7B-Instruct-v0.1.jinja | Generic | +| mistralai-Mistral-7B-Instruct-v0.2.jinja | Generic | +| mistralai-Mistral-7B-Instruct-v0.3.jinja | Mistral Nemo | +| mistralai-Mistral-Large-Instruct-2407.jinja | Mistral Nemo | +| mistralai-Mistral-Large-Instruct-2411.jinja | Generic | +| mistralai-Mistral-Nemo-Instruct-2407.jinja | Mistral Nemo | +| mistralai-Mistral-Small-24B-Instruct-2501.jinja | Generic | +| mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | Generic | +| mkurman-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | +| mlabonne-AlphaMonarch-7B.jinja | Generic | +| mlx-community-Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32.jinja | Hermes 2 Pro | +| mlx-community-Qwen2.5-VL-7B-Instruct-8bit.jinja | Hermes 2 Pro | +| mobiuslabsgmbh-DeepSeek-R1-ReDistill-Qwen-1.5B-v1.1.jinja | DeepSeek R1 (extract reasoning) | +| netcat420-MFANNv0.20.jinja | Generic | +| netcat420-MFANNv0.24.jinja | Generic | +| netease-youdao-Confucius-o1-14B.jinja | Hermes 2 Pro | +| nvidia-AceMath-7B-RM.jinja | Hermes 2 Pro | +| nvidia-Eagle2-1B.jinja | Hermes 2 Pro | +| nvidia-Eagle2-9B.jinja | Hermes 2 Pro | +| nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | Llama 3.x | +| onnx-community-DeepSeek-R1-Distill-Qwen-1.5B-ONNX.jinja | DeepSeek R1 (extract reasoning) | +| open-thoughts-OpenThinker-7B.jinja | Hermes 2 Pro | +| openchat-openchat-3.5-0106.jinja | Generic | +| pankajmathur-orca_mini_v6_8b.jinja | Generic | +| princeton-nlp-Mistral-7B-Base-SFT-RDPO.jinja | Generic | +| princeton-nlp-Mistral-7B-Instruct-DPO.jinja | Generic | +| princeton-nlp-Mistral-7B-Instruct-RDPO.jinja | Generic | +| prithivMLmods-Bellatrix-Tiny-1.5B-R1.jinja | Hermes 2 Pro | +| prithivMLmods-Bellatrix-Tiny-1B-R1.jinja | Llama 3.x | +| prithivMLmods-Bellatrix-Tiny-1B-v3.jinja | Generic | +| prithivMLmods-Bellatrix-Tiny-3B-R1.jinja | Llama 3.x | +| prithivMLmods-Blaze-14B-xElite.jinja | Generic | +| prithivMLmods-Calcium-Opus-14B-Elite2-R1.jinja | Hermes 2 Pro | +| prithivMLmods-Calme-Ties-78B.jinja | Generic | +| prithivMLmods-Calme-Ties2-78B.jinja | Generic | +| prithivMLmods-Calme-Ties3-78B.jinja | Generic | +| prithivMLmods-ChemQwen2-vL.jinja | Generic | +| prithivMLmods-GWQ2b.jinja | Generic | +| prithivMLmods-LatexMind-2B-Codec.jinja | Generic | +| prithivMLmods-Llama-3.2-6B-AlgoCode.jinja | Llama 3.x | +| prithivMLmods-Megatron-Opus-14B-Exp.jinja | Hermes 2 Pro | +| prithivMLmods-Megatron-Opus-14B-Stock.jinja | Hermes 2 Pro | +| prithivMLmods-Megatron-Opus-7B-Exp.jinja | Hermes 2 Pro | +| prithivMLmods-Omni-Reasoner-Merged.jinja | Hermes 2 Pro | +| prithivMLmods-Omni-Reasoner4-Merged.jinja | Hermes 2 Pro | +| prithivMLmods-Primal-Opus-14B-Optimus-v1.jinja | Hermes 2 Pro | +| prithivMLmods-QwQ-Math-IO-500M.jinja | Hermes 2 Pro | +| prithivMLmods-Qwen-7B-Distill-Reasoner.jinja | DeepSeek R1 (extract reasoning) | +| prithivMLmods-Qwen2.5-1.5B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro | +| prithivMLmods-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | +| prithivMLmods-Qwen2.5-32B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro | +| prithivMLmods-Qwen2.5-7B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | +| prithivMLmods-Triangulum-v2-10B.jinja | Hermes 2 Pro | +| qingy2024-Falcon3-2x10B-MoE-Instruct.jinja | Hermes 2 Pro | +| rubenroy-Zurich-14B-GCv2-5m.jinja | Hermes 2 Pro | +| rubenroy-Zurich-7B-GCv2-5m.jinja | Hermes 2 Pro | +| silma-ai-SILMA-Kashif-2B-Instruct-v1.0.jinja | Generic | +| simplescaling-s1-32B.jinja | Hermes 2 Pro | +| sometimesanotion-Lamarck-14B-v0.7.jinja | Hermes 2 Pro | +| sonthenguyen-zephyr-sft-bnb-4bit-DPO-mtbr-180steps.jinja | Generic | +| sthenno-tempesthenno-icy-0130.jinja | Generic | +| sumink-qwft.jinja | Hermes 2 Pro | +| teknium-OpenHermes-2.5-Mistral-7B.jinja | Generic | +| thirdeyeai-elevate360m.jinja | Generic | +| tiiuae-Falcon3-10B-Instruct.jinja | Hermes 2 Pro | +| unsloth-DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit.jinja | DeepSeek R1 (extract reasoning) | +| unsloth-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) | +| unsloth-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) | +| unsloth-Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit.jinja | Generic | +| upstage-solar-pro-preview-instruct.jinja | Generic | +| whyhow-ai-PatientSeek.jinja | Generic | +| xwen-team-Xwen-72B-Chat.jinja | Hermes 2 Pro | +| xwen-team-Xwen-7B-Chat.jinja | Hermes 2 Pro | + +This table can be generated with: + +```bash +./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null +``` + +
+ +# Usage - need tool-aware Jinja template + +First, start a server with any model, but make sure it has a tools-enabled template: you can verify this by inspecting the `chat_template` or `chat_template_tool_use` properties in `http://localhost:8080/props`). + +Here are some models known to work (w/ chat template override when needed): + +```shell +# Native support: + +llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M +llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L +llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M +llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M + +# Native support for DeepSeek R1 works best w/ our own template (official template buggy) + +llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \ +--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja + +llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M \ +--chat-template-file models/templates/llama-cpp-deepseek-r1.jinja + +# Native support requires the right template for these GGUFs: + +llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \ +--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use ) + +llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \ +--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use ) + +llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \ +--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use ) + +llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \ +--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use ) + +# Generic format support +llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0 +llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0 +llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K +``` + +> [!TIP] +> If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills) + +Test in CLI (or with any library / software that can use OpenAI-compatible API backends): + +```bash +curl http://localhost:8080/v1/chat/completions -d '{ +"model": "gpt-3.5-turbo", +"tools": [ + { + "type":"function", + "function":{ + "name":"python", + "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "parameters":{ + "type":"object", + "properties":{ + "code":{ + "type":"string", + "description":"The code to run in the ipython interpreter." + } + }, + "required":["code"] + } + } + } +], +"messages": [ + { + "role": "user", + "content": "Print a hello world message with python." + } +] +}' +``` + +
+Show output + +```json +{ +"choices": [ + { + "finish_reason": "tool", + "index": 0, + "message": { + "content": null, + "tool_calls": [ + { + "name": "python", + "arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}" + } + ], + "role": "assistant" + } + } +], +"created": 1727287211, +"model": "gpt-3.5-turbo", +"object": "chat.completion", +"usage": { + "completion_tokens": 16, + "prompt_tokens": 44, + "total_tokens": 60 +}, +"id": "chatcmpl-Htbgh9feMmGM0LEH2hmQvwsCxq3c6Ni8" +} +``` + +
diff --git a/examples/llava/README-granitevision.md b/examples/llava/README-granitevision.md new file mode 100644 index 0000000000000..f08a21cc175b4 --- /dev/null +++ b/examples/llava/README-granitevision.md @@ -0,0 +1,190 @@ +# Granite Vision + +Download the model and point your `GRANITE_MODEL` environment variable to the path. + +```bash +$ git clone https://huggingface.co/ibm-granite/granite-vision-3.2-2b +$ export GRANITE_MODEL=./granite-vision-3.2-2b +``` + + +### 1. Running llava surgery v2. +First, we need to run the llava surgery script as shown below: + +`python llava_surgery_v2.py -C -m $GRANITE_MODEL` + +You should see two new files (`llava.clip` and `llava.projector`) written into your model's directory, as shown below. + +```bash +$ ls $GRANITE_MODEL | grep -i llava +llava.clip +llava.projector +``` + +We should see that the projector and visual encoder get split out into the llava files. Quick check to make sure they aren't empty: +```python +import os +import torch + +MODEL_PATH = os.getenv("GRANITE_MODEL") +if not MODEL_PATH: + raise ValueError("env var GRANITE_MODEL is unset!") + +encoder_tensors = torch.load(os.path.join(MODEL_PATH, "llava.clip")) +projector_tensors = torch.load(os.path.join(MODEL_PATH, "llava.projector")) + +assert len(encoder_tensors) > 0 +assert len(projector_tensors) > 0 +``` + +If you actually inspect the `.keys()` of the loaded tensors, you should see a lot of `vision_model` tensors in the `encoder_tensors`, and 5 tensors (`'multi_modal_projector.linear_1.bias'`, `'multi_modal_projector.linear_1.weight'`, `'multi_modal_projector.linear_2.bias'`, `'multi_modal_projector.linear_2.weight'`, `'image_newline'`) in the multimodal `projector_tensors`. + + +### 2. Creating the Visual Component GGUF +Next, create a new directory to hold the visual components, and copy the llava.clip/projector files, as shown below. + +```bash +$ ENCODER_PATH=$PWD/visual_encoder +$ mkdir $ENCODER_PATH + +$ cp $GRANITE_MODEL/llava.clip $ENCODER_PATH/pytorch_model.bin +$ cp $GRANITE_MODEL/llava.projector $ENCODER_PATH/ +``` + +Now, we need to write a config for the visual encoder. In order to convert the model, be sure to use the correct `image_grid_pinpoints`, as these may vary based on the model. You can find the `image_grid_pinpoints` in `$GRANITE_MODEL/config.json`. + +```json +{ + "_name_or_path": "siglip-model", + "architectures": [ + "SiglipVisionModel" + ], + "image_grid_pinpoints": [ + [384,384], + [384,768], + [384,1152], + [384,1536], + [384,1920], + [384,2304], + [384,2688], + [384,3072], + [384,3456], + [384,3840], + [768,384], + [768,768], + [768,1152], + [768,1536], + [768,1920], + [1152,384], + [1152,768], + [1152,1152], + [1536,384], + [1536,768], + [1920,384], + [1920,768], + [2304,384], + [2688,384], + [3072,384], + [3456,384], + [3840,384] + ], + "mm_patch_merge_type": "spatial_unpad", + "hidden_size": 1152, + "image_size": 384, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + "layer_norm_eps": 1e-6, + "hidden_act": "gelu_pytorch_tanh", + "projection_dim": 0, + "vision_feature_layer": [-24, -20, -12, -1] +} +``` + +At this point you should have something like this: +```bash +$ ls $ENCODER_PATH +config.json llava.projector pytorch_model.bin +``` + +Now convert the components to GGUF; Note that we also override the image mean/std dev to `[.5,.5,.5]` since we use the SigLIP visual encoder - in the transformers model, you can find these numbers in the `preprocessor_config.json`. +```bash +$ python convert_image_encoder_to_gguf.py \ + -m $ENCODER_PATH \ + --llava-projector $ENCODER_PATH/llava.projector \ + --output-dir $ENCODER_PATH \ + --clip-model-is-vision \ + --clip-model-is-siglip \ + --image-mean 0.5 0.5 0.5 \ + --image-std 0.5 0.5 0.5 +``` + +This will create the first GGUF file at `$ENCODER_PATH/mmproj-model-f16.gguf`; we will refer to the absolute path of this file as the `$VISUAL_GGUF_PATH.` + + +### 3. Creating the LLM GGUF. +The granite vision model contains a granite LLM as its language model. For now, the easiest way to get the GGUF for LLM is by loading the composite model in `transformers` and exporting the LLM so that it can be directly converted with the normal conversion path. + +First, set the `LLM_EXPORT_PATH` to the path to export the `transformers` LLM to. +```bash +$ export LLM_EXPORT_PATH=$PWD/granite_vision_llm +``` + +```python +import os +import transformers + +MODEL_PATH = os.getenv("GRANITE_MODEL") +if not MODEL_PATH: + raise ValueError("env var GRANITE_MODEL is unset!") + +LLM_EXPORT_PATH = os.getenv("LLM_EXPORT_PATH") +if not LLM_EXPORT_PATH: + raise ValueError("env var LLM_EXPORT_PATH is unset!") + +tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_PATH) + +# NOTE: granite vision support was added to transformers very recently (4.49); +# if you get size mismatches, your version is too old. +# If you are running with an older version, set `ignore_mismatched_sizes=True` +# as shown below; it won't be loaded correctly, but the LLM part of the model that +# we are exporting will be loaded correctly. +model = transformers.AutoModelForImageTextToText.from_pretrained(MODEL_PATH, ignore_mismatched_sizes=True) + +tokenizer.save_pretrained(LLM_EXPORT_PATH) +model.language_model.save_pretrained(LLM_EXPORT_PATH) +``` + +Now you can convert the exported LLM to GGUF with the normal converter in the root of the llama cpp project. +```bash +$ LLM_GGUF_PATH=$LLM_EXPORT_PATH/granite_llm.gguf +... +$ python convert_hf_to_gguf.py --outfile $LLM_GGUF_PATH $LLM_EXPORT_PATH +``` + + +### 4. Quantization +If you want to quantize the LLM, you can do so with `llama-quantize` as you would any other LLM. For example: +```bash +$ ./build/bin/llama-quantize $LLM_EXPORT_PATH/granite_llm.gguf $LLM_EXPORT_PATH/granite_llm_q4_k_m.gguf Q4_K_M +$ LLM_GGUF_PATH=$LLM_EXPORT_PATH/granite_llm_q4_k_m.gguf +``` + +Note that currently you cannot quantize the visual encoder because granite vision models use SigLIP as the visual encoder, which has tensor dimensions that are not divisible by 32. + + +### 5. Running the Model in Llama cpp +Build llama cpp normally; you should have a target binary named `llama-llava-cli`, which you can pass two binaries to. As an example, we pass the the llama.cpp banner. + +```bash +$ ./build/bin/llama-llava-cli -m $LLM_GGUF_PATH \ + --mmproj $VISUAL_GGUF_PATH \ + --image ./media/llama0-banner.png \ + -c 16384 \ + -p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|user|>\n\\nWhat does the text in this image say?\n<|assistant|>\n" \ + --temp 0 +``` + +Sample output: `The text in the image reads "LLAMA C++ Can it run DOOM Llama?"` diff --git a/examples/llava/README.md b/examples/llava/README.md index 012451361763c..0e3c32032055b 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -101,8 +101,27 @@ python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow ``` **note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096) + **note** llava-1.6 greatly benefits from batched prompt processing (defaults work) +**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model. + +```python +import os +import transformers + +model_path = ... +llm_export_path = ... + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) +model = transformers.AutoModelForImageTextToText.from_pretrained(model_path) + +tokenizer.save_pretrained(llm_export_path) +model.language_model.save_pretrained(llm_export_path) +``` + +Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures. + ## llava-cli templating and llava-1.6 prompting llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."` diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index e2150c06d3204..76d4a78520575 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -40,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -120,6 +121,7 @@ static std::string format(const char * fmt, ...) { #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" #define KEY_PROJ_TYPE "clip.projector_type" +#define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" @@ -444,8 +446,9 @@ struct clip_hparams { char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default) - int32_t image_grid_pinpoints[32]; + std::vector image_grid_pinpoints; int32_t image_crop_resolution; + std::unordered_set vision_feature_layer; }; struct clip_layer { @@ -585,6 +588,7 @@ struct clip_ctx { struct clip_vision_model vision_model; projector_type proj_type = PROJECTOR_TYPE_MLP; + int32_t max_feature_layer; float image_mean[3]; float image_std[3]; bool use_gelu = false; @@ -651,7 +655,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; - int n_layer = hparams.n_layer; const float eps = hparams.eps; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; @@ -752,13 +755,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); } + std::vector embedding_stack; + const auto & vision_feature_layer = hparams.vision_feature_layer; + // loop over layers - if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) { - n_layer += 1; - } - for (int il = 0; il < n_layer - 1; il++) { + for (int il = 0; il < ctx->max_feature_layer; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states + // If this is an embedding feature layer, save the output. + // NOTE: 0 index here refers to the input to the encoder. + if (vision_feature_layer.find(il) != vision_feature_layer.end()) { + embedding_stack.push_back(embeddings); + } + //const size_t nb_q_w = model.layers[il].q_w->nb[0]; // layernorm1 @@ -846,7 +855,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 cur = ggml_add(ctx0, embeddings, cur); embeddings = cur; - } // post-layernorm @@ -857,6 +865,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); } + // final layer is a vision feature layer + if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) { + embedding_stack.push_back(embeddings); + } + + // If feature layers are explicitly set, stack them (if we have multiple) + if (!embedding_stack.empty()) { + embeddings = embedding_stack[0]; + for (size_t i = 1; i < embedding_stack.size(); i++) { + embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0); + } + } + // llava projector if (ctx->has_llava_projector) { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); @@ -1443,14 +1464,26 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS); int n = gguf_get_arr_n(ctx, idx); const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx); - for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) { - hparams.image_grid_pinpoints[i] = pinpoints[i]; + for (int i = 0; i < n; ++i) { + hparams.image_grid_pinpoints.push_back(pinpoints[i]); } - if (n < 32) - hparams.image_grid_pinpoints[n] = 0; - } catch (std::runtime_error & /*e*/) { - hparams.image_grid_pinpoints[0]=0; - } + } catch (std::runtime_error & /*e*/) { } + + // Load the vision feature layer indices if they are explicitly provided; + // if multiple vision feature layers are present, the values will be concatenated + // to form the final visual features. + // NOTE: gguf conversions should standardize the values of the vision feature layer to + // be non-negative, since we use -1 to mark values as unset here. + try { + int idx = get_key_idx(ctx, KEY_FEATURE_LAYER); + int n = gguf_get_arr_n(ctx, idx); + + const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx); + + for (int i = 0; i < n; ++i) { + hparams.vision_feature_layer.insert(vision_feature_layer[i]); + } + } catch (std::runtime_error & /*e*/) { } try { int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE); @@ -1476,6 +1509,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->image_std[i] = std_data[i]; } + // Calculate the deepest feature layer based on hparams and projector type + new_clip->max_feature_layer = get_deepest_feature_layer(new_clip); + if (verbosity >= 2) { LOG_INF("\n%s: vision model hparams\n", __func__); LOG_INF("image_size %d\n", hparams.image_size); @@ -1489,8 +1525,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]); LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]); LOG_INF("v_image_grid_pinpoints: "); - for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) { - LOG_INF("%d ", hparams.image_grid_pinpoints[i]); + for (const auto & pp : hparams.image_grid_pinpoints) { + LOG_INF("%d ", pp); + } + LOG_INF("\n"); + LOG_INF("v_vision_feature_layer: "); + for (const auto & feature_layer: hparams.vision_feature_layer) { + LOG_INF("%d ", feature_layer); } LOG_INF("\n"); LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type); @@ -2235,10 +2276,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli } } } else { - if (params.image_grid_pinpoints[0] != 0) { + if (!params.image_grid_pinpoints.empty()) { // "spatial_unpad" with "anyres" processing for llava-1.6 std::vector> possible_resolutions; - for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) { + for (size_t i = 0; i < params.image_grid_pinpoints.size(); i+=2) { possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); } std::pair best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions); @@ -2404,7 +2445,14 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) { } const int32_t * clip_image_grid(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.image_grid_pinpoints; + if (ctx->vision_model.hparams.image_grid_pinpoints.size()) { + return &ctx->vision_model.hparams.image_grid_pinpoints.front(); + } + return nullptr; +} + +size_t get_clip_image_grid_size(const struct clip_ctx * ctx) { + return ctx->vision_model.hparams.image_grid_pinpoints.size(); } int clip_n_patches(const struct clip_ctx * ctx) { @@ -2929,6 +2977,28 @@ bool clip_is_qwen2vl(const struct clip_ctx * ctx) { return ctx->has_qwen2vl_merger; } +// Determine the number of encoder layers to iterate over +int get_deepest_feature_layer(const struct clip_ctx * ctx) { + // Get the index of the second to last layer; this is the + // default for models that have a llava projector + const auto & hparams = ctx->vision_model.hparams; + int n_layer = hparams.n_layer - 1; + int deepest_feature_layer = -1; + + // Handle other projectors; incrementing here indicates that we + // should use the last encoder layer for the vision features. + if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) { + n_layer += 1; + } + + // If we set explicit vision feature layers, only go up to the deepest one + for (const auto & feature_layer : hparams.vision_feature_layer) { + if (feature_layer > deepest_feature_layer) { + deepest_feature_layer = feature_layer; + } + } + return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer; +} bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { clip_image_f32 clip_img; diff --git a/examples/llava/clip.h b/examples/llava/clip.h index f557122008ff1..002c419653a01 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -55,6 +55,7 @@ CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx); CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx); +CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx); CLIP_API int clip_n_patches (const struct clip_ctx * ctx); CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img); @@ -73,8 +74,11 @@ CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); -/** build image from pixels decoded by other libraries instead of stb_image.h for better performance. The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes */ -CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img); +/** + * Build image from pixels decoded by other libraries instead of stb_image.h for better performance. + * The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes + */ +CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img); CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); @@ -92,11 +96,13 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); +CLIP_API bool clip_is_glm(const struct clip_ctx * ctx); CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); +CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx); + CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); -CLIP_API bool clip_is_glm(const struct clip_ctx * ctx); #ifdef __cplusplus } diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index 4fa1d6ceae1bb..de29687ec9236 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -6,7 +6,7 @@ import torch import numpy as np from gguf import * -from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel +from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, SiglipVisionModel TEXT = "clip.text" VISION = "clip.vision" @@ -37,6 +37,18 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b def get_tensor_name(name: str) -> str: + # Standardize the transformers llava next keys for + # image newline / mm projector with the classes in haotian-liu LLaVA + if name == "image_newline": + return "model.image_newline" + if name.startswith("multi_modal_projector"): + name = name.replace("multi_modal_projector", "mm") + if "linear_1" in name: + name = name.replace("linear_1", "0") + if "linear_2" in name: + name = name.replace("linear_2", "2") + return name + if "projection" in name: return name if "mm_projector" in name: @@ -83,8 +95,14 @@ def bytes_to_unicode(): help="Save a vision-only model. It can't be used to encode texts") ap.add_argument("--clip-model-is-vision", action="store_true", required=False, help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") -ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, + +# Selectable visual encoders that are compatible with this script +encoder_group = ap.add_mutually_exclusive_group() +encoder_group.add_argument("--clip-model-is-openclip", action="store_true", required=False, help="The clip model is from openclip (for ViT-SO400M type))") +encoder_group.add_argument("--clip-model-is-siglip", action="store_true", required=False, + help="the visual encoder is Siglip.") + ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) @@ -109,7 +127,12 @@ def bytes_to_unicode(): # output in the same directory as the model if output_dir is None dir_model = args.model_dir -if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: +if ( + args.clip_model_is_vision or + not os.path.exists(dir_model + "/vocab.json") or + args.clip_model_is_openclip or + args.clip_model_is_siglip +): vocab = None tokens = None else: @@ -137,7 +160,10 @@ def bytes_to_unicode(): if args.use_f32: ftype = 0 -if args.clip_model_is_vision or args.clip_model_is_openclip: +if args.clip_model_is_siglip: + model = SiglipVisionModel.from_pretrained(dir_model) + processor = None +elif args.clip_model_is_vision or args.clip_model_is_openclip: model = CLIPVisionModel.from_pretrained(dir_model) processor = None else: @@ -187,26 +213,71 @@ def bytes_to_unicode(): if has_text_encoder: assert t_hparams is not None assert tokens is not None + if args.clip_model_is_siglip: + text_projection_dim = 0 + else: + text_projection_dim = t_hparams.get("projection_dim", config["projection_dim"]) # text_model hparams fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"]) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"]) fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"]) - fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"])) + fout.add_uint32("clip.text.projection_dim", text_projection_dim) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"]) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"]) fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"]) fout.add_token_list(tokens) + + +def get_non_negative_vision_feature_layers(v_hparams): + """ + Determine the vision feature layer(s) for the llava model, which are indices into the + hidden states of the visual encoder. Note that the hidden states array generally takes the + form: + + [, , ... ] + + so feature indices should be offset as n+1 to get the output of encoder block n. + We convert all vision feature layers to non-negative so that -1 can be used in + the model as an unset value. If no vision feature layer is found, we leave it unset. + """ + num_hidden_layers = v_hparams["num_hidden_layers"] + to_non_negative = lambda layer_idx: layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1 + feature_layers_key = None + # Key used for llava models in transformers + if "vision_feature_layer" in config: + feature_layers_key = "vision_feature_layer" + # Key used for llava models in the original format + elif "mm_vision_select_layer" in config: + feature_layers_key = "mm_vision_select_layer" + if feature_layers_key is not None: + feature_layers = config[feature_layers_key] + if isinstance(feature_layers, int): + feature_layers = [feature_layers] + return [to_non_negative(feature_layer) for feature_layer in feature_layers] + +# Determine if we have explicitly specified vision feature layers in our config +feature_layers = get_non_negative_vision_feature_layers(v_hparams) + if has_vision_encoder: - # vision_model hparams + # Siglip does not have a visual projector; set projection dim to 0 + if args.clip_model_is_siglip: + visual_projection_dim = 0 + else: + visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"]) + + # set vision_model hparams fout.add_uint32("clip.vision.image_size", v_hparams["image_size"]) fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"]) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"]) fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"]) - fout.add_uint32("clip.vision.projection_dim", v_hparams.get("projection_dim", config["projection_dim"])) + fout.add_uint32("clip.vision.projection_dim", visual_projection_dim) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"]) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"]) - block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] + if feature_layers: + block_count = max(feature_layers) + else: + block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) # /** # "image_grid_pinpoints": [ @@ -258,7 +329,8 @@ def bytes_to_unicode(): fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"]) if "mm_projector_type" in v_hparams: fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"]) - + if feature_layers: + fout.add_array("clip.vision.feature_layer", feature_layers) if processor is not None: image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue] @@ -274,7 +346,13 @@ def bytes_to_unicode(): if has_llava_projector: - model.vision_model.encoder.layers.pop(-1) + # By default, we drop the last layer for llava projector + # models unless we have explicitly set vision feature layers + if feature_layers is None: + model.vision_model.encoder.layers.pop(-1) + else: + model.vision_model.encoder.layers = model.vision_model.encoder.layers[:max(feature_layers)] + projector = torch.load(args.llava_projector) for name, data in projector.items(): name = get_tensor_name(name) diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 0979413cdde22..7710ae73073f7 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -354,9 +354,10 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); const int32_t * image_grid = clip_image_grid(ctx_clip); + const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip); std::vector> grid_pinpoints; - for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) { + for (size_t i = 0; i < num_gridpoints; i += 2) { grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); } @@ -406,7 +407,8 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * } bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { - int num_max_patches = 6; + // Granite vision uses up to 10 patches + base patch + int num_max_patches = 11; if (clip_is_minicpmv(ctx_clip)) { num_max_patches = 10; } diff --git a/examples/llava/llava_surgery_v2.py b/examples/llava/llava_surgery_v2.py index 2d5b32fe6b236..b07c3e323c4c6 100644 --- a/examples/llava/llava_surgery_v2.py +++ b/examples/llava/llava_surgery_v2.py @@ -33,6 +33,33 @@ def save_model(model, file_path, file_type): else: torch.save(model, file_path) +# Helpers to match weight names from specific components or +# determine if a saved shard contains that component +def is_vision_tower(weight_name): + return ( + weight_name.startswith("model.vision_tower") or + weight_name.startswith("vit.") or + weight_name.startswith("vision_tower") + ) + +def is_newline(weight_name): + return ( + weight_name.startswith("model.image_newline") or + weight_name.startswith("image_newline") + ) + +def is_mm_projector(weight_name): + return ( + weight_name.startswith("model.mm_projector") or + weight_name.startswith("vision_proj.") or + weight_name.startswith("multi_modal_projector") + ) + +def newline_criteria(checkpoint): + return any(is_newline(k) for k in checkpoint.keys()) + +def proj_criteria(checkpoint): + return any(is_mm_projector(k) for k in checkpoint.keys()) # Adapted function to clean vision tower from checkpoint def clean_vision_tower_from_checkpoint(checkpoint_path): @@ -40,7 +67,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): # file_type = 'pytorch' model_path = os.path.dirname(checkpoint_path) print(f"Searching for vision tower tensors in {checkpoint_path}") - clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))] + clip_tensors = [k for k, v in checkpoint.items() if is_vision_tower(k)] if len(clip_tensors) > 0: print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}") @@ -84,12 +111,6 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): return newline_checkpoint_path, projector_checkpoint_path -def newline_criteria(checkpoint): - return any(k.startswith("model.image_newline") for k in checkpoint.keys()) - -def proj_criteria(checkpoint): - return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys()) - # Command-line interface setup ap = argparse.ArgumentParser() @@ -123,14 +144,14 @@ def proj_criteria(checkpoint): if newline_checkpoint_path is not None: print(f"Taking newline from {newline_checkpoint_path}") first_checkpoint, file_type = load_model(newline_checkpoint_path) - first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")] + first_mm_tensors = [k for k, v in first_checkpoint.items() if is_newline(k)] # Load the checkpoint mm_tensors = [] last_checkpoint = None if projector_checkpoint_path is not None: last_checkpoint, file_type = load_model(projector_checkpoint_path) - mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")] + mm_tensors = [k for k, v in last_checkpoint.items() if is_mm_projector(k)] if len(mm_tensors) == 0: if last_checkpoint is not None: @@ -155,5 +176,5 @@ def proj_criteria(checkpoint): save_model(projector, f"{args.model}/llava.projector", 'pytorch') print("Done!") -print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.") +print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.") diff --git a/examples/server/README.md b/examples/server/README.md index a2ae614d7c489..a2a0903261e31 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -13,6 +13,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. * Multimodal (wip) * Monitoring endpoints * Schema-constrained JSON response format + * [Function calling](../../docs/function-calling.md) / tool use for ~any model The project is under active development, and we are [looking for feedback and contributors](https://github.com/ggml-org/llama.cpp/issues/4216). @@ -1120,381 +1121,9 @@ curl http://localhost:8080/v1/chat/completions \ *Tool call support* -[Function calling](https://platform.openai.com/docs/guides/function-calling) is supported for all models (see https://github.com/ggml-org/llama.cpp/pull/9639): - -- Requires `--jinja` flag -- Native tool call formats supported: - - Llama 3.1 / 3.3 (including builtin tools support - tool names for `wolfram_alpha`, `web_search` / `brave_search`, `code_interpreter`), Llama 3.2 - - Functionary v3.1 / v3.2 - - Hermes 2/3, Qwen 2.5 - - Mistral Nemo - - Firefunction v2 - - Command R7B - - DeepSeek R1 (WIP / seems reluctant to call any tools?) - -
- Show some common templates and which format handler they use - - | Template | Format | - |----------|--------| - | Almawave-Velvet-14B.jinja | Hermes 2 Pro | - | AtlaAI-Selene-1-Mini-Llama-3.1-8B.jinja | Llama 3.x | - | CohereForAI-aya-expanse-8b.jinja | Generic | - | CohereForAI-c4ai-command-r-plus-default.jinja | Generic | - | CohereForAI-c4ai-command-r-plus-rag.jinja | Generic | - | CohereForAI-c4ai-command-r-plus-tool_use.jinja | Generic | - | CohereForAI-c4ai-command-r7b-12-2024-default.jinja | Command R7B (extract reasoning) | - | CohereForAI-c4ai-command-r7b-12-2024-rag.jinja | Command R7B (extract reasoning) | - | CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja | Command R7B (extract reasoning) | - | CohereForAI-c4ai-command-r7b-12-2024.jinja | Generic | - | DavieLion-Llama-3.2-1B-SPIN-iter3.jinja | Generic | - | Delta-Vector-Rei-12B.jinja | Mistral Nemo | - | EpistemeAI-Mistral-Nemo-Instruct-12B-Philosophy-Math.jinja | Mistral Nemo | - | FlofloB-83k_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit.jinja | Hermes 2 Pro | - | FlofloB-test_continued_pretraining_Phi-3-mini-4k-instruct_Unsloth_merged_16bit.jinja | Generic | - | HelpingAI-HAI-SER.jinja | Generic | - | HuggingFaceTB-SmolLM2-1.7B-Instruct.jinja | Generic | - | HuggingFaceTB-SmolLM2-135M-Instruct.jinja | Generic | - | HuggingFaceTB-SmolLM2-360M-Instruct.jinja | Generic | - | INSAIT-Institute-BgGPT-Gemma-2-27B-IT-v1.0.jinja | Generic | - | Ihor-Text2Graph-R1-Qwen2.5-0.5b.jinja | Hermes 2 Pro | - | Infinigence-Megrez-3B-Instruct.jinja | Generic | - | Josephgflowers-TinyLlama_v1.1_math_code-world-test-1.jinja | Generic | - | LGAI-EXAONE-EXAONE-3.5-2.4B-Instruct.jinja | Generic | - | LGAI-EXAONE-EXAONE-3.5-7.8B-Instruct.jinja | Generic | - | LatitudeGames-Wayfarer-12B.jinja | Generic | - | Magpie-Align-Llama-3-8B-Magpie-Align-v0.1.jinja | Generic | - | Magpie-Align-Llama-3.1-8B-Magpie-Align-v0.1.jinja | Generic | - | MaziyarPanahi-calme-3.2-instruct-78b.jinja | Generic | - | MiniMaxAI-MiniMax-Text-01.jinja | Generic | - | MiniMaxAI-MiniMax-VL-01.jinja | Generic | - | NaniDAO-deepseek-r1-qwen-2.5-32B-ablated.jinja | DeepSeek R1 (extract reasoning) | - | NexaAIDev-Octopus-v2.jinja | Generic | - | NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | Generic | - | NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | Hermes 2 Pro | - | NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | Generic | - | NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | Hermes 2 Pro | - | NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | Generic | - | NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | Hermes 2 Pro | - | NovaSky-AI-Sky-T1-32B-Flash.jinja | Hermes 2 Pro | - | NovaSky-AI-Sky-T1-32B-Preview.jinja | Hermes 2 Pro | - | OnlyCheeini-greesychat-turbo.jinja | Generic | - | Orenguteng-Llama-3.1-8B-Lexi-Uncensored-V2.jinja | Llama 3.x | - | OrionStarAI-Orion-14B-Chat.jinja | Generic | - | PowerInfer-SmallThinker-3B-Preview.jinja | Generic | - | PrimeIntellect-INTELLECT-1-Instruct.jinja | Generic | - | Qwen-QVQ-72B-Preview.jinja | Generic | - | Qwen-QwQ-32B-Preview.jinja | Hermes 2 Pro | - | Qwen-Qwen1.5-7B-Chat.jinja | Generic | - | Qwen-Qwen2-7B-Instruct.jinja | Generic | - | Qwen-Qwen2-VL-72B-Instruct.jinja | Generic | - | Qwen-Qwen2-VL-7B-Instruct.jinja | Generic | - | Qwen-Qwen2.5-0.5B.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-1.5B-Instruct.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-14B-Instruct-1M.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-14B.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-32B-Instruct.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-32B.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-3B-Instruct.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-72B-Instruct.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-7B-Instruct-1M.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-7B-Instruct.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-7B.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-Coder-32B-Instruct.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-Coder-7B-Instruct.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-Math-1.5B.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-Math-7B-Instruct.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-VL-3B-Instruct.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-VL-72B-Instruct.jinja | Hermes 2 Pro | - | Qwen-Qwen2.5-VL-7B-Instruct.jinja | Hermes 2 Pro | - | RWKV-Red-Team-ARWKV-7B-Preview-0.1.jinja | Hermes 2 Pro | - | SakanaAI-TinySwallow-1.5B-Instruct.jinja | Hermes 2 Pro | - | SakanaAI-TinySwallow-1.5B.jinja | Hermes 2 Pro | - | Sao10K-70B-L3.3-Cirrus-x1.jinja | Llama 3.x | - | SentientAGI-Dobby-Mini-Leashed-Llama-3.1-8B.jinja | Llama 3.x | - | SentientAGI-Dobby-Mini-Unhinged-Llama-3.1-8B.jinja | Llama 3.x | - | Steelskull-L3.3-Damascus-R1.jinja | Llama 3.x | - | Steelskull-L3.3-MS-Nevoria-70b.jinja | Llama 3.x | - | Steelskull-L3.3-Nevoria-R1-70b.jinja | Llama 3.x | - | THUDM-glm-4-9b-chat.jinja | Generic | - | THUDM-glm-edge-1.5b-chat.jinja | Generic | - | Tarek07-Progenitor-V1.1-LLaMa-70B.jinja | Llama 3.x | - | TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | Generic | - | TinyLlama-TinyLlama-1.1B-Chat-v1.0.jinja | Generic | - | UCLA-AGI-Mistral7B-PairRM-SPPO-Iter3.jinja | Generic | - | ValiantLabs-Llama3.1-8B-Enigma.jinja | Llama 3.x | - | abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | Generic | - | ai21labs-AI21-Jamba-1.5-Large.jinja | Generic | - | allenai-Llama-3.1-Tulu-3-405B-SFT.jinja | Generic | - | allenai-Llama-3.1-Tulu-3-405B.jinja | Generic | - | allenai-Llama-3.1-Tulu-3-8B.jinja | Generic | - | arcee-ai-Virtuoso-Lite.jinja | Hermes 2 Pro | - | arcee-ai-Virtuoso-Medium-v2.jinja | Hermes 2 Pro | - | arcee-ai-Virtuoso-Small-v2.jinja | Hermes 2 Pro | - | avemio-GRAG-NEMO-12B-ORPO-HESSIAN-AI.jinja | Generic | - | bespokelabs-Bespoke-Stratos-7B.jinja | Hermes 2 Pro | - | bfuzzy1-acheron-m1a-llama.jinja | Generic | - | bofenghuang-vigogne-2-70b-chat.jinja | Generic | - | bytedance-research-UI-TARS-72B-DPO.jinja | Generic | - | bytedance-research-UI-TARS-7B-DPO.jinja | Generic | - | bytedance-research-UI-TARS-7B-SFT.jinja | Generic | - | carsenk-phi3.5_mini_exp_825_uncensored.jinja | Generic | - | cyberagent-DeepSeek-R1-Distill-Qwen-14B-Japanese.jinja | DeepSeek R1 (extract reasoning) | - | cyberagent-DeepSeek-R1-Distill-Qwen-32B-Japanese.jinja | DeepSeek R1 (extract reasoning) | - | databricks-dbrx-instruct.jinja | Generic | - | deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | Generic | - | deepseek-ai-DeepSeek-Coder-V2-Lite-Base.jinja | Generic | - | deepseek-ai-DeepSeek-Coder-V2-Lite-Instruct.jinja | Generic | - | deepseek-ai-DeepSeek-R1-Distill-Llama-70B.jinja | DeepSeek R1 (extract reasoning) | - | deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) | - | deepseek-ai-DeepSeek-R1-Distill-Qwen-1.5B.jinja | DeepSeek R1 (extract reasoning) | - | deepseek-ai-DeepSeek-R1-Distill-Qwen-14B.jinja | DeepSeek R1 (extract reasoning) | - | deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | DeepSeek R1 (extract reasoning) | - | deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | DeepSeek R1 (extract reasoning) | - | deepseek-ai-DeepSeek-R1-Zero.jinja | DeepSeek R1 (extract reasoning) | - | deepseek-ai-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) | - | deepseek-ai-DeepSeek-V2-Lite.jinja | Generic | - | deepseek-ai-DeepSeek-V2.5.jinja | DeepSeek R1 (extract reasoning) | - | deepseek-ai-DeepSeek-V3.jinja | DeepSeek R1 (extract reasoning) | - | deepseek-ai-deepseek-coder-33b-instruct.jinja | Generic | - | deepseek-ai-deepseek-coder-6.7b-instruct.jinja | Generic | - | deepseek-ai-deepseek-coder-7b-instruct-v1.5.jinja | Generic | - | deepseek-ai-deepseek-llm-67b-chat.jinja | Generic | - | deepseek-ai-deepseek-llm-7b-chat.jinja | Generic | - | dicta-il-dictalm2.0-instruct.jinja | Generic | - | ehristoforu-Falcon3-8B-Franken-Basestruct.jinja | Hermes 2 Pro | - | fireworks-ai-llama-3-firefunction-v2.jinja | FireFunction v2 | - | godlikehhd-alpaca_data_sampled_ifd_new_5200.jinja | Hermes 2 Pro | - | godlikehhd-alpaca_data_score_max_0.7_2600.jinja | Hermes 2 Pro | - | google-gemma-2-27b-it.jinja | Generic | - | google-gemma-2-2b-it.jinja | Generic | - | google-gemma-2-2b-jpn-it.jinja | Generic | - | google-gemma-7b-it.jinja | Generic | - | huihui-ai-DeepSeek-R1-Distill-Llama-70B-abliterated.jinja | DeepSeek R1 (extract reasoning) | - | huihui-ai-DeepSeek-R1-Distill-Llama-8B-abliterated.jinja | DeepSeek R1 (extract reasoning) | - | huihui-ai-DeepSeek-R1-Distill-Qwen-14B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) | - | huihui-ai-DeepSeek-R1-Distill-Qwen-32B-abliterated.jinja | DeepSeek R1 (extract reasoning) | - | huihui-ai-DeepSeek-R1-Distill-Qwen-7B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) | - | huihui-ai-Qwen2.5-14B-Instruct-1M-abliterated.jinja | Hermes 2 Pro | - | ibm-granite-granite-3.1-8b-instruct.jinja | Generic | - | indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | Generic | - | inflatebot-MN-12B-Mag-Mell-R1.jinja | Generic | - | jinaai-ReaderLM-v2.jinja | Generic | - | kms7530-chemeng_qwen-math-7b_24_1_100_1_nonmath.jinja | Hermes 2 Pro | - | knifeayumu-Cydonia-v1.3-Magnum-v4-22B.jinja | Mistral Nemo | - | langgptai-qwen1.5-7b-chat-sa-v0.1.jinja | Generic | - | lightblue-DeepSeek-R1-Distill-Qwen-7B-Japanese.jinja | DeepSeek R1 (extract reasoning) | - | mattshumer-Reflection-Llama-3.1-70B.jinja | Generic | - | meetkai-functionary-medium-v3.1.jinja | Functionary v3.1 Llama 3.1 | - | meetkai-functionary-medium-v3.2.jinja | Functionary v3.2 | - | meta-llama-Llama-2-7b-chat-hf.jinja | Generic | - | meta-llama-Llama-3.1-8B-Instruct.jinja | Llama 3.x | - | meta-llama-Llama-3.2-11B-Vision-Instruct.jinja | Llama 3.x | - | meta-llama-Llama-3.2-1B-Instruct.jinja | Llama 3.x | - | meta-llama-Llama-3.2-3B-Instruct.jinja | Llama 3.x | - | meta-llama-Llama-3.3-70B-Instruct.jinja | Llama 3.x | - | meta-llama-Meta-Llama-3-8B-Instruct.jinja | Generic | - | meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | Llama 3.x | - | microsoft-Phi-3-medium-4k-instruct.jinja | Generic | - | microsoft-Phi-3-mini-4k-instruct.jinja | Generic | - | microsoft-Phi-3-small-8k-instruct.jinja | Generic | - | microsoft-Phi-3.5-mini-instruct.jinja | Generic | - | microsoft-Phi-3.5-vision-instruct.jinja | Generic | - | microsoft-phi-4.jinja | Generic | - | migtissera-Tess-3-Mistral-Nemo-12B.jinja | Generic | - | ministral-Ministral-3b-instruct.jinja | Generic | - | mistralai-Codestral-22B-v0.1.jinja | Generic | - | mistralai-Mistral-7B-Instruct-v0.1.jinja | Generic | - | mistralai-Mistral-7B-Instruct-v0.2.jinja | Generic | - | mistralai-Mistral-7B-Instruct-v0.3.jinja | Mistral Nemo | - | mistralai-Mistral-Large-Instruct-2407.jinja | Mistral Nemo | - | mistralai-Mistral-Large-Instruct-2411.jinja | Generic | - | mistralai-Mistral-Nemo-Instruct-2407.jinja | Mistral Nemo | - | mistralai-Mistral-Small-24B-Instruct-2501.jinja | Generic | - | mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | Generic | - | mkurman-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | - | mlabonne-AlphaMonarch-7B.jinja | Generic | - | mlx-community-Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32.jinja | Hermes 2 Pro | - | mlx-community-Qwen2.5-VL-7B-Instruct-8bit.jinja | Hermes 2 Pro | - | mobiuslabsgmbh-DeepSeek-R1-ReDistill-Qwen-1.5B-v1.1.jinja | DeepSeek R1 (extract reasoning) | - | netcat420-MFANNv0.20.jinja | Generic | - | netcat420-MFANNv0.24.jinja | Generic | - | netease-youdao-Confucius-o1-14B.jinja | Hermes 2 Pro | - | nvidia-AceMath-7B-RM.jinja | Hermes 2 Pro | - | nvidia-Eagle2-1B.jinja | Hermes 2 Pro | - | nvidia-Eagle2-9B.jinja | Hermes 2 Pro | - | nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | Llama 3.x | - | onnx-community-DeepSeek-R1-Distill-Qwen-1.5B-ONNX.jinja | DeepSeek R1 (extract reasoning) | - | open-thoughts-OpenThinker-7B.jinja | Hermes 2 Pro | - | openchat-openchat-3.5-0106.jinja | Generic | - | pankajmathur-orca_mini_v6_8b.jinja | Generic | - | princeton-nlp-Mistral-7B-Base-SFT-RDPO.jinja | Generic | - | princeton-nlp-Mistral-7B-Instruct-DPO.jinja | Generic | - | princeton-nlp-Mistral-7B-Instruct-RDPO.jinja | Generic | - | prithivMLmods-Bellatrix-Tiny-1.5B-R1.jinja | Hermes 2 Pro | - | prithivMLmods-Bellatrix-Tiny-1B-R1.jinja | Llama 3.x | - | prithivMLmods-Bellatrix-Tiny-1B-v3.jinja | Generic | - | prithivMLmods-Bellatrix-Tiny-3B-R1.jinja | Llama 3.x | - | prithivMLmods-Blaze-14B-xElite.jinja | Generic | - | prithivMLmods-Calcium-Opus-14B-Elite2-R1.jinja | Hermes 2 Pro | - | prithivMLmods-Calme-Ties-78B.jinja | Generic | - | prithivMLmods-Calme-Ties2-78B.jinja | Generic | - | prithivMLmods-Calme-Ties3-78B.jinja | Generic | - | prithivMLmods-ChemQwen2-vL.jinja | Generic | - | prithivMLmods-GWQ2b.jinja | Generic | - | prithivMLmods-LatexMind-2B-Codec.jinja | Generic | - | prithivMLmods-Llama-3.2-6B-AlgoCode.jinja | Llama 3.x | - | prithivMLmods-Megatron-Opus-14B-Exp.jinja | Hermes 2 Pro | - | prithivMLmods-Megatron-Opus-14B-Stock.jinja | Hermes 2 Pro | - | prithivMLmods-Megatron-Opus-7B-Exp.jinja | Hermes 2 Pro | - | prithivMLmods-Omni-Reasoner-Merged.jinja | Hermes 2 Pro | - | prithivMLmods-Omni-Reasoner4-Merged.jinja | Hermes 2 Pro | - | prithivMLmods-Primal-Opus-14B-Optimus-v1.jinja | Hermes 2 Pro | - | prithivMLmods-QwQ-Math-IO-500M.jinja | Hermes 2 Pro | - | prithivMLmods-Qwen-7B-Distill-Reasoner.jinja | DeepSeek R1 (extract reasoning) | - | prithivMLmods-Qwen2.5-1.5B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro | - | prithivMLmods-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | - | prithivMLmods-Qwen2.5-32B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro | - | prithivMLmods-Qwen2.5-7B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | - | prithivMLmods-Triangulum-v2-10B.jinja | Hermes 2 Pro | - | qingy2024-Falcon3-2x10B-MoE-Instruct.jinja | Hermes 2 Pro | - | rubenroy-Zurich-14B-GCv2-5m.jinja | Hermes 2 Pro | - | rubenroy-Zurich-7B-GCv2-5m.jinja | Hermes 2 Pro | - | silma-ai-SILMA-Kashif-2B-Instruct-v1.0.jinja | Generic | - | simplescaling-s1-32B.jinja | Hermes 2 Pro | - | sometimesanotion-Lamarck-14B-v0.7.jinja | Hermes 2 Pro | - | sonthenguyen-zephyr-sft-bnb-4bit-DPO-mtbr-180steps.jinja | Generic | - | sthenno-tempesthenno-icy-0130.jinja | Generic | - | sumink-qwft.jinja | Hermes 2 Pro | - | teknium-OpenHermes-2.5-Mistral-7B.jinja | Generic | - | thirdeyeai-elevate360m.jinja | Generic | - | tiiuae-Falcon3-10B-Instruct.jinja | Hermes 2 Pro | - | unsloth-DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit.jinja | DeepSeek R1 (extract reasoning) | - | unsloth-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) | - | unsloth-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) | - | unsloth-Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit.jinja | Generic | - | upstage-solar-pro-preview-instruct.jinja | Generic | - | whyhow-ai-PatientSeek.jinja | Generic | - | xwen-team-Xwen-72B-Chat.jinja | Hermes 2 Pro | - | xwen-team-Xwen-7B-Chat.jinja | Hermes 2 Pro | - - This table can be generated with: +[OpenAI-style function calling](https://platform.openai.com/docs/guides/function-calling) is supported with the `--jinja` flag (and may require a `--chat-template-file` override to get the right tool-use compatible Jinja template; worst case, `--chat-template chatml` may also work). - ```bash - ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null - ``` - -
- -- Generic tool call is supported when the template isn't recognized by native format handlers (you'll see `Chat format: Generic` in the logs). - - Use `--chat-template-file` to override the template when appropriate (see examples below) - - Generic support may consume more tokens and be less efficient than a model's native format. - -- Run with: - - ```shell - # Native support: - - llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M - llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L - llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M - llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M - - # Native support for DeepSeek R1 works best w/ our own template (official template buggy) - - llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \ - --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja - - llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M \ - --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja - - # Native support requires the right template for these GGUFs: - - llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \ - --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use ) - - llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \ - --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use ) - - llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \ - --chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use ) - - llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \ - --chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use ) - - # Generic format support - llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0 - llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0 - llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K - ``` - -- Test in CLI: - - ```bash - curl http://localhost:8080/v1/chat/completions -d '{ - "model": "gpt-3.5-turbo", - "tools": [ - { - "type":"function", - "function":{ - "name":"python", - "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", - "parameters":{ - "type":"object", - "properties":{ - "code":{ - "type":"string", - "description":"The code to run in the ipython interpreter." - } - }, - "required":["code"] - } - } - } - ], - "messages": [ - { - "role": "user", - "content": "Print a hello world message with python." - } - ] - }' - ``` - -
- Show output - - ```json - { - "choices": [ - { - "finish_reason": "tool", - "index": 0, - "message": { - "content": null, - "tool_calls": [ - { - "name": "python", - "arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}" - } - ], - "role": "assistant" - } - } - ], - "created": 1727287211, - "model": "gpt-3.5-turbo", - "object": "chat.completion", - "usage": { - "completion_tokens": 16, - "prompt_tokens": 44, - "total_tokens": 60 - }, - "id": "chatcmpl-Htbgh9feMmGM0LEH2hmQvwsCxq3c6Ni8" - } - ``` - -
+**See our [Function calling](../../docs/function-calling.md) docs** for more details, supported native tool call styles (generic tool call style is used as fallback) / examples of use. ### POST `/v1/embeddings`: OpenAI-compatible embeddings API diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 4db4c783afd87..6830c2e1a6fd0 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -521,8 +521,13 @@ static json oaicompat_completion_params_parse(const json & body) { throw std::runtime_error("Only one completion choice is allowed"); } + // Handle "echo" field + if (json_value(body, "echo", false)) { + throw std::runtime_error("Only no echo is supported"); + } + // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "best_of", "echo", "suffix" }; + static const std::vector unsupported_params { "best_of", "suffix" }; for (const auto & param : unsupported_params) { if (body.contains(param)) { throw std::runtime_error("Unsupported param: " + param); @@ -598,7 +603,7 @@ static json oaicompat_completion_params_parse( inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); inputs.grammar = grammar; - inputs.add_generation_prompt = true; + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); inputs.use_jinja = use_jinja; inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index b606edcaf5de9..e05b398a83d8a 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -212,6 +212,8 @@ set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) +include(GNUInstallDirs) + # # build the library # @@ -235,7 +237,6 @@ endif () # install # -include(GNUInstallDirs) include(CMakePackageConfigHelpers) # all public headers diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in index bf39f9c007bec..823eb797b7007 100644 --- a/ggml/cmake/ggml-config.cmake.in +++ b/ggml/cmake/ggml-config.cmake.in @@ -112,7 +112,7 @@ foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS}) string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}") if(is_cpu_variant) - list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base") + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base") set_target_properties(ggml::${_ggml_backend} PROPERTIES INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}") @@ -124,7 +124,7 @@ foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS}) endif() else() - list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base") + list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base") set_target_properties(ggml::${_ggml_backend} PROPERTIES INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}") @@ -139,6 +139,11 @@ foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS}) list(APPEND _ggml_all_targets ggml::${_ggml_backend}) endforeach() +list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}") +set_target_properties(ggml::ggml + PROPERTIES + INTERFACE_LINK_LIBRARIES "${GGML_INTERFACE_LINK_LIBRARIES}") + add_library(ggml::all INTERFACE IMPORTED) set_target_properties(ggml::all PROPERTIES diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h index 23600eea99cb8..2cb150fd2a313 100644 --- a/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h @@ -19,7 +19,7 @@ struct ggml_tallocr { }; GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer); -GGML_API void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor); +GGML_API enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor); // Graph allocator /* diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index a86b1a3180305..9002406d96e72 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -56,7 +56,7 @@ extern "C" { GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); - GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API enum ggml_status ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer); GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); @@ -343,8 +343,8 @@ extern "C" { GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data); // Tensor initialization - GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); - GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor); + GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); + GGML_API enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor); // CPU buffer types are always available GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 7244a9cbb0605..a3d3f690133b0 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -89,7 +89,7 @@ struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer) { return talloc; } -void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) { +enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) { size_t size = ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor); size = GGML_PAD(size, talloc->alignment); @@ -104,7 +104,7 @@ void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tenso assert(((uintptr_t)addr % talloc->alignment) == 0); - ggml_backend_tensor_alloc(talloc->buffer, tensor, addr); + return ggml_backend_tensor_alloc(talloc->buffer, tensor, addr); } // dynamic tensor allocator @@ -933,42 +933,51 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { // utils +static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) { + for (size_t i = 0; i < *n_buffers; i++) { + ggml_backend_buffer_free((*buffers)[i]); + } + free(*buffers); +} + static bool alloc_tensor_range(struct ggml_context * ctx, struct ggml_tensor * first, struct ggml_tensor * last, ggml_backend_buffer_type_t buft, size_t size, ggml_backend_buffer_t ** buffers, size_t * n_buffers) { + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size); if (buffer == NULL) { -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size); -#endif - for (size_t i = 0; i < *n_buffers; i++) { - ggml_backend_buffer_free((*buffers)[i]); - } - free(*buffers); + GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size); + free_buffers(buffers, n_buffers); return false; } + *buffers = realloc(*buffers, sizeof(ggml_backend_buffer_t) * (*n_buffers + 1)); + (*buffers)[(*n_buffers)++] = buffer; + struct ggml_tallocr tallocr = ggml_tallocr_new(buffer); for (struct ggml_tensor * t = first; t != last; t = ggml_get_next_tensor(ctx, t)) { + enum ggml_status status = GGML_STATUS_SUCCESS; if (t->data == NULL) { if (t->view_src == NULL) { - ggml_tallocr_alloc(&tallocr, t); + status = ggml_tallocr_alloc(&tallocr, t); } else if (t->buffer == NULL) { - ggml_backend_view_init(t); + status = ggml_backend_view_init(t); } } else { if (t->view_src != NULL && t->buffer == NULL) { // view of a pre-allocated tensor - ggml_backend_view_init(t); + status = ggml_backend_view_init(t); } } + if (status != GGML_STATUS_SUCCESS) { + GGML_LOG_ERROR("%s: failed to initialize tensor %s\n", __func__, t->name); + free_buffers(buffers, n_buffers); + return false; + } } - *buffers = realloc(*buffers, sizeof(ggml_backend_buffer_t) * (*n_buffers + 1)); - (*buffers)[(*n_buffers)++] = buffer; - return true; } diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index d1c2d76d8975e..c36c12d6579ac 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -44,7 +44,7 @@ extern "C" { // base address of the buffer void * (*get_base) (ggml_backend_buffer_t buffer); // (optional) initialize a tensor in the buffer (eg. add tensor extras) - void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + enum ggml_status (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // tensor data access void (*memset_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index dba7be33b88c0..184f99af5fce4 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -126,11 +126,12 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { return base; } -void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { +enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { // init_tensor is optional if (buffer->iface.init_tensor) { - buffer->iface.init_tensor(buffer, tensor); + return buffer->iface.init_tensor(buffer, tensor); } + return GGML_STATUS_SUCCESS; } void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { @@ -1641,7 +1642,7 @@ ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, // utils -void ggml_backend_view_init(struct ggml_tensor * tensor) { +enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) { GGML_ASSERT(tensor->buffer == NULL); GGML_ASSERT(tensor->view_src != NULL); GGML_ASSERT(tensor->view_src->buffer != NULL); @@ -1649,10 +1650,10 @@ void ggml_backend_view_init(struct ggml_tensor * tensor) { tensor->buffer = tensor->view_src->buffer; tensor->data = (char *)tensor->view_src->data + tensor->view_offs; - ggml_backend_buffer_init_tensor(tensor->buffer, tensor); + return ggml_backend_buffer_init_tensor(tensor->buffer, tensor); } -void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) { +enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) { GGML_ASSERT(tensor->buffer == NULL); GGML_ASSERT(tensor->data == NULL); GGML_ASSERT(tensor->view_src == NULL); @@ -1662,7 +1663,7 @@ void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor tensor->buffer = buffer; tensor->data = addr; - ggml_backend_buffer_init_tensor(buffer, tensor); + return ggml_backend_buffer_init_tensor(buffer, tensor); } static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, @@ -1708,7 +1709,8 @@ static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_ struct ggml_tensor * dst = node_copies[id]; if (dst->view_src != NULL) { graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src); - ggml_backend_view_init(dst); + enum ggml_status status = ggml_backend_view_init(dst); + GGML_ASSERT(status == GGML_STATUS_SUCCESS); } else { ggml_backend_tensor_copy(src, dst); @@ -1823,7 +1825,6 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t assert(g1->n_nodes == g2->n_nodes); for (int i = 0; i < g1->n_nodes; i++) { - //printf("eval %d/%d\n", i, g1->n_nodes); struct ggml_tensor * t1 = g1->nodes[i]; struct ggml_tensor * t2 = g2->nodes[i]; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index d410c02445c27..b8d272cda600c 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -796,11 +796,11 @@ static bool need_transform(ggml_type type) { * @param buffer The CANN buffer from which to initialize the tensor. * @param tensor Pointer to the tensor to be initialized. */ -static void ggml_backend_cann_buffer_init_tensor( +static enum ggml_status ggml_backend_cann_buffer_init_tensor( ggml_backend_buffer_t buffer, ggml_tensor* tensor) { if (tensor->view_src != NULL && tensor->view_offs == 0) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); - return; + return GGML_STATUS_SUCCESS; } // TODO: can backend doesn't support quantized yet. Just leave the code @@ -817,6 +817,7 @@ static void ggml_backend_cann_buffer_init_tensor( memset_size, 0, memset_size)); } } + return GGML_STATUS_SUCCESS; } // TODO: need handle tensor which has paddings. diff --git a/ggml/src/ggml-cann/kernels/dup.cpp b/ggml/src/ggml-cann/kernels/dup.cpp index c7ba38d10a0b2..d9b9574494b72 100644 --- a/ggml/src/ggml-cann/kernels/dup.cpp +++ b/ggml/src/ggml-cann/kernels/dup.cpp @@ -1,7 +1,5 @@ #include "kernel_operator.h" -#include - using namespace AscendC; #define BUFFER_NUM 2 @@ -183,7 +181,7 @@ extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32( copy_to_ub(output_ne_gm, output_ne_ub, 32); copy_to_ub(output_nb_gm, output_nb_ub, 32); - DupByRows op; + DupByRows op; op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub); op.dup(); } @@ -206,7 +204,7 @@ extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32_to_fp16( copy_to_ub(output_ne_gm, output_ne_ub, 32); copy_to_ub(output_nb_gm, output_nb_ub, 32); - DupByRows op; + DupByRows op; op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub); op.dup_with_cast(); } @@ -230,7 +228,7 @@ extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16_to_fp32( copy_to_ub(output_ne_gm, output_ne_ub, 32); copy_to_ub(output_nb_gm, output_nb_ub, 32); - DupByRows op; + DupByRows op; op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub); op.dup_with_cast(); } diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 5ec5263ceb4ba..0f067137df006 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -50,10 +50,11 @@ static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) { return (void *) (buffer->context); } -static void ggml_backend_amx_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { +static enum ggml_status ggml_backend_amx_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor); GGML_UNUSED(buffer); + return GGML_STATUS_SUCCESS; } static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp index 7b9baa54475ce..371b898c21fb2 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp @@ -4135,10 +4135,11 @@ static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(con return nullptr; } -static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { +static enum ggml_status ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { tensor->extra = (void *) const_cast(ggml_aarch64_get_optimal_repack_type(tensor)); GGML_UNUSED(buffer); + return GGML_STATUS_SUCCESS; } static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index d0c407bd6ebdc..2679b71ffda2d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -4587,7 +4587,252 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r const int nb = n / QK_K; -#ifdef __ARM_NEON +#ifdef __ARM_FEATURE_SVE + const int vector_length = svcntb()*8; + const svuint8_t m3s = svdup_n_u8(0x3); + const svuint32_t m4s = svdup_n_u32(0xF); + const svint32_t vzero_sv = svdup_n_s32(0); + svfloat32_t acc_sum = svdup_n_f32(0); + svbool_t pred_s32 = svptrue_pat_b32(SV_VL4); + + switch (vector_length) { + case 128: + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + svfloat32_t d_broad = svdup_n_f32((float32_t)d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8_sv = y[i].qs; + const uint8_t * restrict sc = x[i].scales; + + svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc); + const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); + + mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4); + const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); + + svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums); + svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4); + + const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2)); + + mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8); + const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); + + mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12); + const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); + + q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8); + q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12); + + svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2)); + + svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1)); + + acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad); + + svint32_t sumi1 = svdup_n_s32(0); + + { + const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2); + svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s)); + svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s)); + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0)); + + const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16); + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3)); + + + const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3)); + + //------------------------------- + + q2 += 32; + const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s)); + const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0)); + + const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16); + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1)); + + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3)); + + + const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1)); + + + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3)); + } + acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad); + } + *s = svaddv_f32(svptrue_b32(), acc_sum); + break; + + case 256: + case 512: + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + svfloat32_t d_broad = svdup_n_f32((float32_t)d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8_sv = y[i].qs; + const uint8_t * restrict sc = x[i].scales; + + const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8; + const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s)); + const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4)); + svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums); + + const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); + const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s)); + const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4)); + + svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8); + + svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2))); + + acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad); + + svint32_t sumi1 = svdup_n_s32(0); + + { + const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2); + svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s)); + svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2); + + q2 += 32; + + const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2); + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2); + } + acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad); + } + *s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum); + break; + + default: + assert(false && "Unsupported vector length"); + break; + } + +#elif __ARM_NEON const uint8x16_t m3 = vdupq_n_u8(0x3); const uint8x16_t m4 = vdupq_n_u8(0xF); @@ -5265,6 +5510,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r #if defined(__ARM_FEATURE_SVE) + uint32_t aux[3]; uint32_t utmp[4]; const int8_t m32 = 32; @@ -5276,7 +5522,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1); const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2); const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3); - svbool_t pred_s32 = svnot_b_z (svptrue_b32(), svptrue_pat_b32(SV_VL4)); float sum = 0; @@ -5289,7 +5534,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r const int8_t * restrict q8_sv = y[i].qs; // Set up scales - uint32_t * aux = &x[i].scales; + memcpy(aux, x[i].scales, 12); utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ebb2ccae04065..d23686d161773 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -540,12 +540,12 @@ static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) { return ctx->dev_ptr; } -static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; if (tensor->view_src != NULL) { assert(tensor->view_src->buffer->buft == buffer->buft); - return; + return GGML_STATUS_SUCCESS; } if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) { @@ -558,6 +558,7 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size)); } } + return GGML_STATUS_SUCCESS; } static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { @@ -792,7 +793,7 @@ static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buff GGML_UNUSED(buffer); } -static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +static enum ggml_status ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context; @@ -838,6 +839,7 @@ static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buf } } tensor->extra = extra; + return GGML_STATUS_SUCCESS; } static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 0451c65f30277..f2aca1f2014e1 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -109,9 +109,9 @@ static constexpr __device__ int get_mmq_x_max_device() { #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #ifdef GGML_CUDA_FORCE_MMQ - return MMQ_DP4A_MAX_BATCH_SIZE; -#else // GGML_CUDA_FORCE_MMQ return 128; +#else // GGML_CUDA_FORCE_MMQ + return MMQ_DP4A_MAX_BATCH_SIZE; #endif // GGML_CUDA_FORCE_MMQ #else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 087e7f58149f6..c550142a7d07c 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -407,6 +407,16 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, + GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_SQR, GGML_METAL_KERNEL_TYPE_SQRT, @@ -1012,6 +1022,16 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); @@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex default: return false; } + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } default: return false; }; @@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node( case GGML_OP_CPY: case GGML_OP_CONT: { - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - - int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); - id pipeline = nil; switch (src0t) { @@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node( switch (dstt) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break; case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break; - default: GGML_ASSERT(false && "not implemented"); + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q4_0: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q4_1: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q5_0: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q5_1: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q8_0: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break; + default: GGML_ABORT("not implemented"); }; } break; default: GGML_ABORT("not implemented"); @@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_SET: { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 83e7ac9f411ef..d092a16906155 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl( } } +template +kernel void kernel_cpy_q_f32( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + + device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { + T4x4 temp; + dequantize_func(src_data + i00/nl, i00%nl, temp); + dst_data[i00] = temp; + } +} + +typedef decltype(kernel_cpy_q_f32) cpy_q_f_t; + +template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; + +template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; + kernel void kernel_concat( constant ggml_metal_kargs_concat & args, device const char * src0, diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 3891eca72890c..d78ff1e721fe7 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -467,24 +467,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { backend_ctx->gpu_family = GPU_FAMILY::ADRENO; backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->name); - // we check the device version buffer if we cannot detect it from name - if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { - backend_ctx->adreno_gen = get_adreno_gpu_gen(device_ver_buffer); - } + // Use wave size of 64 for all Adreno GPUs. + backend_ctx->adreno_wave_size = 64; - // Default wave size is 128, A8x uses 64. - if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::A8X) { - backend_ctx->adreno_wave_size = 64; - } else if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::A7X || - backend_ctx->adreno_gen == ADRENO_GPU_GEN::X1E) { - backend_ctx->adreno_wave_size = 128; - } else { - backend_ctx->adreno_wave_size = 128; - GGML_LOG_WARN("ggml_opencl: Unsupported Adreno GPU: %s, " - "using wave size %d, " - "may not work as expected\n", - backend_ctx->device_name.c_str(), backend_ctx->adreno_wave_size); - } } else if (strstr(default_device->name, "Intel")) { backend_ctx->gpu_family = GPU_FAMILY::INTEL; } else { @@ -1233,7 +1218,7 @@ static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) GGML_UNUSED(buffer); } -static void ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; ggml_cl2_init(buffer->buft->device); @@ -1273,6 +1258,7 @@ static void ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; } } + return GGML_STATUS_SUCCESS; } // The optimized gemm and gemv kernels are used for large matrices without batch. @@ -1387,6 +1373,11 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, int M = tensor->ne[1]; // ne01 int K = tensor->ne[0]; // ne00 + //For matrix-vector multiplication kernel, we assume K is a multiple of 32 + GGML_ASSERT(K % 32 == 0); + //For transpose kernels, we assume K is a multiple of 4 (satisfied by prior assert), and M is a multiple of 4 + GGML_ASSERT(M % 4 == 0); + // transpose is out of place, so we need to allocate transposed buffers // <----------------------------------------------------------------------------------> // // use sub_buffer of max buffer size instead @@ -1427,36 +1418,36 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_mem qT_d_image1D; cl_mem dT_d_image1D; - cl_image_format img_fmt_1d = { CL_RGBA, CL_FLOAT }; + cl_image_format img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; cl_image_desc img_desc_1d; memset(&img_desc_1d, 0, sizeof(img_desc_1d)); img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 8 / 4; + img_desc_1d.image_width = M * K / 4 / 4; img_desc_1d.buffer = extra->q; q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); CL_CHECK(err); - img_fmt_1d = { CL_RGBA, CL_FLOAT }; + img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; memset(&img_desc_1d, 0, sizeof(img_desc_1d)); img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 8 / 4; + img_desc_1d.image_width = M * K / 4 / 4; img_desc_1d.buffer = qT_d; qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); CL_CHECK(err); - img_fmt_1d = { CL_RGBA, CL_FLOAT }; + img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; memset(&img_desc_1d, 0, sizeof(img_desc_1d)); img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32 / 4 / 2; + img_desc_1d.image_width = M * K / 32 / 4; img_desc_1d.buffer = extra->d; d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); CL_CHECK(err); - img_fmt_1d = { CL_RGBA, CL_FLOAT }; + img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; memset(&img_desc_1d, 0, sizeof(img_desc_1d)); img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32 / 4 / 2; + img_desc_1d.image_width = M * K / 32 / 4; img_desc_1d.buffer = dT_d; dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); CL_CHECK(err); @@ -1465,8 +1456,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, // set up and call the transpose kernels // <----------------------------------------------------------------------------------> // // weights - int height_q = M / 8; - int width_q = K / 8 / 4; + int height_q = M / 4; + int width_q = K / 4 / 4; kernel = backend_ctx->kernel_transpose_16; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); @@ -1480,8 +1471,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clWaitForEvents(1, &evt)); // scales - int height_s = M / 8; - int width_s = K / 32 / 8; + int height_s = M / 4; + int width_s = K / 32 / 4; kernel = backend_ctx->kernel_transpose_16; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); @@ -1875,7 +1866,6 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso void * buf_d; #endif -#ifdef GGML_USE_OPENCL // Make sure everything is done. CL_CHECK(clFinish(queue)); @@ -1911,7 +1901,6 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL)); CL_CHECK(clFinish(queue)); #endif // GGML_OPENCL_SOA_Q -#endif // GGML_USE_OPENCL // Open file and dump. char fname[512]; @@ -2876,6 +2865,9 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(status); int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } int width_B = K/4; int padded_height_B = (N + padding)/4; @@ -3024,11 +3016,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } if (N == 1) { - local_work_size[0] = backend_ctx->adreno_wave_size; // localsize + size_t wavesize = backend_ctx->adreno_wave_size; + local_work_size[0] = wavesize; // localsize local_work_size[1] = 4; // reduce factor local_work_size[2] = 1; - global_work_size[0] = M / 2; + global_work_size[0] = (((M / 2) + wavesize - 1) / wavesize) * wavesize; global_work_size[1] = 4; // reduce factor global_work_size[2] = 1; } diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl index d3cfb2f91e130..8882a8c9c6225 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl @@ -1797,6 +1797,9 @@ kernel void kernel_mul_mat_f16_f16( //------------------------------------------------------------------------------ // mul_mat_f16_f32_1row //------------------------------------------------------------------------------ +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif kernel void kernel_mul_mat_f16_f32_1row( global char * src0, ulong offset0, diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl index 5e195411d690e..ee5c79f000d69 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl @@ -1,9 +1,11 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_subgroups : enable -#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable -#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable -#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#ifdef cl_qcom_reqd_sub_group_size #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif // assume #define QK4_0 32 @@ -186,8 +188,9 @@ total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ - -__attribute__((qcom_reqd_sub_group_size("full"))) +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif __kernel void kernel_gemv_noshuffle( __read_only image1d_buffer_t src0_q, // quantized A global half2 * src0_d, // A scales diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl index 5bdd4d067639a..469d3edef00cc 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl @@ -1,9 +1,11 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_subgroups : enable -#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable -#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable -#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#ifdef cl_qcom_reqd_sub_group_size #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif // assume #define QK4_0 32 @@ -186,8 +188,9 @@ total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ - -__attribute__((qcom_reqd_sub_group_size("full"))) +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif __kernel void kernel_gemv_noshuffle( __read_only image1d_buffer_t src0_q, // quantized A global half2 * src0_d, // A scales diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl index 57768c80334eb..ecb577b993339 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl @@ -7,7 +7,16 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -__attribute__((qcom_reqd_sub_group_size("full"))) +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + kernel void kernel_mul_mat_Ab_Bi_8x4( global const ushort * src0_q, // quantized A global const half * src0_d, // A scales diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl index d59a0c05ddfd0..cd4e0afbad279 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl @@ -1,4 +1,6 @@ -// 16-bit transpose, loading/storing an 8x8 tile of elements +// 16-bit transpose, loading/storing a 4x4 tile of elements + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable kernel void kernel_transpose_16( __read_only image1d_buffer_t input, @@ -9,24 +11,16 @@ kernel void kernel_transpose_16( const int i = get_global_id(0); const int j = get_global_id(1); - const int i_3 = i<<3; - const int j_3 = j<<3; + const int i_2 = i<<2; + const int j_2 = j<<2; - ushort8 temp0 = as_ushort8(read_imagef(input, (j_3+0)*cols+i)); - ushort8 temp1 = as_ushort8(read_imagef(input, (j_3+1)*cols+i)); - ushort8 temp2 = as_ushort8(read_imagef(input, (j_3+2)*cols+i)); - ushort8 temp3 = as_ushort8(read_imagef(input, (j_3+3)*cols+i)); - ushort8 temp4 = as_ushort8(read_imagef(input, (j_3+4)*cols+i)); - ushort8 temp5 = as_ushort8(read_imagef(input, (j_3+5)*cols+i)); - ushort8 temp6 = as_ushort8(read_imagef(input, (j_3+6)*cols+i)); - ushort8 temp7 = as_ushort8(read_imagef(input, (j_3+7)*cols+i)); + half4 temp0 = read_imageh(input, (j_2+0)*cols+i); + half4 temp1 = read_imageh(input, (j_2+1)*cols+i); + half4 temp2 = read_imageh(input, (j_2+2)*cols+i); + half4 temp3 = read_imageh(input, (j_2+3)*cols+i); - write_imagef(output, (i_3+0)*rows+j, as_float4((ushort8)(temp0.s0, temp1.s0, temp2.s0, temp3.s0, temp4.s0, temp5.s0, temp6.s0, temp7.s0))); - write_imagef(output, (i_3+1)*rows+j, as_float4((ushort8)(temp0.s1, temp1.s1, temp2.s1, temp3.s1, temp4.s1, temp5.s1, temp6.s1, temp7.s1))); - write_imagef(output, (i_3+2)*rows+j, as_float4((ushort8)(temp0.s2, temp1.s2, temp2.s2, temp3.s2, temp4.s2, temp5.s2, temp6.s2, temp7.s2))); - write_imagef(output, (i_3+3)*rows+j, as_float4((ushort8)(temp0.s3, temp1.s3, temp2.s3, temp3.s3, temp4.s3, temp5.s3, temp6.s3, temp7.s3))); - write_imagef(output, (i_3+4)*rows+j, as_float4((ushort8)(temp0.s4, temp1.s4, temp2.s4, temp3.s4, temp4.s4, temp5.s4, temp6.s4, temp7.s4))); - write_imagef(output, (i_3+5)*rows+j, as_float4((ushort8)(temp0.s5, temp1.s5, temp2.s5, temp3.s5, temp4.s5, temp5.s5, temp6.s5, temp7.s5))); - write_imagef(output, (i_3+6)*rows+j, as_float4((ushort8)(temp0.s6, temp1.s6, temp2.s6, temp3.s6, temp4.s6, temp5.s6, temp6.s6, temp7.s6))); - write_imagef(output, (i_3+7)*rows+j, as_float4((ushort8)(temp0.s7, temp1.s7, temp2.s7, temp3.s7, temp4.s7, temp5.s7, temp6.s7, temp7.s7))); + write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); } diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 97873acc77dee..6c3b80b0883c9 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -464,7 +464,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { return result; } -static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; // CUDA backend on the server pads everything to 512 due to CUDA limitations. @@ -478,6 +478,7 @@ static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, gg bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0); GGML_ASSERT(status); } + return GGML_STATUS_SUCCESS; } static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 792e0569ca6cf..d804e66061721 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -323,14 +323,14 @@ static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) { return ctx->dev_ptr; } -static void +static enum ggml_status ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor *tensor) try { ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context; if (tensor->view_src != NULL) { assert(tensor->view_src->buffer->buft == buffer->buft); - return; + return GGML_STATUS_SUCCESS; } ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; @@ -348,6 +348,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, padded_size - original_size).wait())); } } + return GGML_STATUS_SUCCESS; } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -729,7 +730,7 @@ static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buff GGML_UNUSED(buffer); } -static void +static enum ggml_status ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor *tensor) try { GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported @@ -804,6 +805,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, } } tensor->extra = extra; + return GGML_STATUS_SUCCESS; } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8f50b48c50915..673be8d322db0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -241,15 +241,19 @@ struct vk_device_struct { vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; + vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_gelu_f32; vk_pipeline pipeline_gelu_quick_f32; vk_pipeline pipeline_silu_f32; + vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_relu_f32; vk_pipeline pipeline_leaky_relu_f32; vk_pipeline pipeline_tanh_f32; + vk_pipeline pipeline_sigmoid_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; + vk_pipeline pipeline_soft_max_back_f32; vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; @@ -504,6 +508,7 @@ struct vk_op_rope_push_constants { uint32_t s1; uint32_t s2; int32_t sections[4]; + uint32_t is_back; }; struct vk_op_soft_max_push_constants { @@ -1987,6 +1992,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } } else if (device->vendor_id == VK_VENDOR_ID_INTEL) rm_stdq = 2; + uint32_t rm_iq = 2 * rm_kq; for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); @@ -2001,15 +2007,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); @@ -2023,15 +2029,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); @@ -2046,15 +2052,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -2121,6 +2127,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -2180,9 +2187,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); @@ -2190,6 +2199,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -4183,7 +4193,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } if (qy_needs_dequant) { d_Y = ctx->prealloc_y; - GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); + GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; @@ -4760,7 +4770,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& } if (qy_needs_dequant) { d_Y = ctx->prealloc_y; - GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); + GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; @@ -5283,6 +5293,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_CONT: case GGML_OP_DUP: return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); + case GGML_OP_SILU_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_silu_back_f32; + } + return nullptr; case GGML_OP_NORM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_norm_f32; @@ -5298,6 +5313,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rms_norm_f32; } return nullptr; + case GGML_OP_RMS_NORM_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rms_norm_back_f32; + } + return nullptr; case GGML_OP_UNARY: switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_SILU: @@ -5325,6 +5345,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_tanh_f32; } break; + case GGML_UNARY_OP_SIGMOID: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sigmoid_f32; + } + break; default: break; } @@ -5344,7 +5369,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; } return nullptr; + case GGML_OP_SOFT_MAX_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_soft_max_back_f32; + } + return nullptr; case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: { const int mode = ((const int32_t *) dst->op_params)[2]; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; @@ -5672,7 +5703,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co switch (op) { case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: case GGML_OP_ARGMAX: { @@ -5696,6 +5729,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } break; case GGML_OP_DIAG_MASK_INF: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; break; case GGML_OP_GET_ROWS: @@ -5791,7 +5825,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); - } else if (op == GGML_OP_ROPE) { + } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { // Empty src2 is possible in rope, but the shader needs a buffer vk_subbuffer subbuf_z; if (use_src2) { @@ -6313,6 +6347,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } +static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; @@ -6335,6 +6373,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); } +static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } @@ -6370,7 +6413,12 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, }, dryrun); } -static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { +static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun); +} + +static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; @@ -6398,7 +6446,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, - sections[0], sections[1], sections[2], sections[3], + sections[0], sections[1], sections[2], sections[3], backprop }, dryrun); } @@ -7295,6 +7343,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: break; default: return false; @@ -7319,12 +7368,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: + case GGML_OP_SILU_BACK: case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: case GGML_OP_ARGSORT: @@ -7377,13 +7430,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: + case GGML_OP_SILU_BACK: case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_UNARY: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: @@ -7475,6 +7532,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_DUP: ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_SILU_BACK: + ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_NORM: ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun); @@ -7487,6 +7548,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RMS_NORM: ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_RMS_NORM_BACK: + ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { @@ -7495,6 +7560,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); break; default: @@ -7508,9 +7574,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_SOFT_MAX: ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_SOFT_MAX_BACK: + ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_ROPE: - ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun); + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun); + + break; + case GGML_OP_ROPE_BACK: + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun); break; case GGML_OP_ARGSORT: @@ -7636,12 +7710,16 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: + case GGML_OP_SILU_BACK: case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: @@ -7670,6 +7748,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: buf = tensor->buffer; break; default: @@ -7844,11 +7923,12 @@ static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { UNUSED(buffer); } -static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")"); if (tensor->view_src != nullptr) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); } + return GGML_STATUS_SUCCESS; } static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { @@ -8371,6 +8451,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: return ggml_is_contiguous(op->src[0]); default: return false; @@ -8560,6 +8641,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_REPEAT_BACK: return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -8576,6 +8658,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: + case GGML_OP_SILU_BACK: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: @@ -8585,6 +8669,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_PAD: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: @@ -8976,15 +9061,22 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_RMS_NORM) { tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); + } else if (tensor->op == GGML_OP_SILU_BACK) { + tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_SOFT_MAX) { if (src1 != nullptr) { tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else { tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); } + } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { + tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params); - } else if (tensor->op == GGML_OP_ROPE) { + } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; @@ -8997,9 +9089,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const float beta_slow = ((float *) tensor->op_params)[10]; if (mode & GGML_ROPE_TYPE_MROPE) { int32_t *sections = ((int32_t *) tensor->op_params) + 11; - tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } } else { - tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } } } else if (tensor->op == GGML_OP_UNARY) { switch (ggml_get_unary_op(tensor)) { @@ -9018,6 +9118,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { case GGML_UNARY_OP_TANH: tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_SIGMOID: + tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]); + break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 10318e8766074..8835c442ecfd8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -82,9 +82,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2]; - uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1]; - return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8)); + const i8vec2 v0 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2]); + const i8vec2 v1 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2 + 1]); + return vec4(v0.x, v0.y, v1.x, v1.y); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 4770469eddcab..4ccbe613af2ce 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -92,7 +92,7 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2 const uint iqs = idx; // Load 16b and select the byte for this element - int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1]; + int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1]; float16_t ret = float16_t(qs) * d; return ret; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index c9f855687dcb5..cfd645a38a8ba 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -1,5 +1,7 @@ #version 450 +#extension GL_EXT_control_flow_attributes : enable + #include "types.comp" #include "generic_binary_head.comp" #include "dequant_funcs.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 122b1e93fb496..09aa849e8815c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -40,6 +40,20 @@ void main() { const uint batch = gl_GlobalInvocationID.z / p.IC; const uint ic = gl_GlobalInvocationID.z % p.IC; + const uint src_base = ic * p.offset_delta + batch * p.batch_offset; + const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH); + const int oh_s1 = int(oh) * p.s1; + const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); + + const uint base_linear_idx = gidx * NUM_ITER; + + const uint max_ky = ksize / p.OW; + + uint current_kx = base_linear_idx / ksize; + const uint rem = base_linear_idx - (current_kx * ksize); + uint current_ky = rem / p.OW; + uint current_ix = rem % p.OW; + A_TYPE values[NUM_ITER]; uint offset_dst[NUM_ITER]; [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { @@ -48,36 +62,35 @@ void main() { [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - const uint i = gidx * NUM_ITER + idx; + const uint linear_idx = base_linear_idx + idx; - const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); - const uint kx = i / ksize; - const uint kd = kx * ksize; - const uint ky = (i - kd) / p.OW; - const uint ix = i % p.OW; + if (linear_idx >= p.pelements) { + continue; + } - const uint iiw = ix * p.s0 + kx * p.d0 - p.p0; - const uint iih = oh * p.s1 + ky * p.d1 - p.p1; + const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; + const uint iih = oh_s1 + current_ky * p.d1 - p.p1; - offset_dst[idx] = - ((batch * p.OH + oh) * p.OW + ix) * p.CHW + - (ic * (p.KW * p.KH) + ky * p.KW + kx); + offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx; - if (i >= p.pelements) { - continue; + if ((iih < p.IH) && (iiw < p.IW)) { + values[idx] = data_a[src_base + iih * p.IW + iiw]; } - if (iih < p.IH && iiw < p.IW) { - const uint offset_src = ic * p.offset_delta + batch * p.batch_offset; - values[idx] = data_a[offset_src + iih * p.IW + iiw]; + if (++current_ix == p.OW) { + current_ix = 0; + if (++current_ky == max_ky) { + current_ky = 0; + current_kx++; + } } } [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - const uint i = gidx * NUM_ITER + idx; + const uint linear_idx = base_linear_idx + idx; - if (i >= p.pelements) { + if (linear_idx >= p.pelements) { continue; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp new file mode 100644 index 0000000000000..9718a05e5adb2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -0,0 +1,90 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint nibble_shift = 4 * (itid & 1); + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; + const float db = d * (0.5 + scale) * 0.25; + + const uint qh = data_a[ibi].qh[ib32]; + const u8vec2 qs16 = unpack8(data_a_packed16[ibi].qs[itid]); + const u8vec2 sign16 = unpack8(data_a_packed16[ibi].qs[QUANT_K / 16 + itid]); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint8_t sign = sign16[l]; + const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300); + const uvec2 grid = iq2s_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp new file mode 100644 index 0000000000000..c496043241072 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp @@ -0,0 +1,87 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint nibble_shift = 4 * (itid & 1); + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; + const float db = d * (0.5 + scale) * 0.25; + + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs = data_a[ibi].qs[2 * itid + l]; + const uint sign = qs >> 9; + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x)); + const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp new file mode 100644 index 0000000000000..94d4b92e1ee69 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp @@ -0,0 +1,87 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint signscale = pack32(u16vec2( + data_a_packed16[ibi].qs[4 * ib32 + 2], + data_a_packed16[ibi].qs[4 * ib32 + 3])); + const float db = d * 0.25 * (0.5 + (signscale >> 28)); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs = data_a[ibi].qs[8 * ib32 + 2 * (itid & 1) + l]; + const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7); + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq2xxs_grid[qs].x)); + const vec4 grid1 = vec4(unpack8(iq2xxs_grid[qs].y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp new file mode 100644 index 0000000000000..af48f32902fe2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -0,0 +1,90 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const float dscale = d * (1 + 2 * scale); + const uint qh = data_a[ibi].qh[ib32]; + FLOAT_TYPE sum[NUM_COLS]; + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + sum[j] = 0.0; + } + [[unroll]] for (uint l = 0; l < 4; ++l) { + const u8vec2 qs = unpack8(data_a_packed16[ibi].qs[4 * ib32 + l]); + const uint sign = data_a[ibi].signs[4 * ib32 + l]; + const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)])); + const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)])); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + sum[j] = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w), + sum[j])))))))); + } + } + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + temp[j][n] = fma(dscale, sum[j], temp[j][n]); + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp new file mode 100644 index 0000000000000..3fe9dc3a4113a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp @@ -0,0 +1,88 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint signscale = pack32(u16vec2( + data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32], + data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32 + 1])); + const float db = d * 0.5 * (0.5 + (signscale >> 28)); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs0 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l]; + const uint qs1 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l + 1]; + const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7); + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq3xxs_grid[qs0])); + const vec4 grid1 = vec4(unpack8(iq3xxs_grid[qs1])); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 39657195cfc8d..a8fd93fdeadee 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -32,6 +32,13 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; @@ -243,74 +250,100 @@ void main() { #endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; - - const uint ib = idx / 16; - const uint iqs = idx & 0xF; - - const float d = float(data_a[ib].d); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; + const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE(v0.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); + buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); + buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); + buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); + buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); + buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); + buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; - - const uint ib = idx / 16; - const uint iqs = idx & 0xF; - - const float d = float(data_a[ib].d); - const float m = float(data_a[ib].m); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE(v0.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); + buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); + buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); + buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); + buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); + buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); + buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; + const float d = float(data_a_packed16[ib].d); + const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const float m = float(data_a[ib].m); - const uint uint_qh = data_a[ib].qh; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint uint_qh = data_a_packed16[ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 16; - const uint iqs = (idx & 0xF) * 2; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d; + const float d = float(data_a_packed16[ib].d); + const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]); + const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]); + const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE(v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE(v.w); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; @@ -623,17 +656,18 @@ void main() { buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; + const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[iqs]); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d; + buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; + buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; + buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; #endif } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp new file mode 100644 index 0000000000000..76009f3df6783 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp @@ -0,0 +1,55 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer G {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer X {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum_xx[BLOCK_SIZE]; +shared FLOAT_TYPE sum_xg[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + // Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5 + + // partial sums for thread in warp + sum_xx[tid] = FLOAT_TYPE(0.0f); + sum_xg[tid] = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]); + const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]); + sum_xx[tid] += xi * xi; + sum_xg[tid] += xi * gi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum_xx[tid] += sum_xx[tid + s]; + sum_xg[tid] += sum_xg[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE eps = FLOAT_TYPE(p.param1); + const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX); + const FLOAT_TYPE scale_g = inversesqrt(mean + eps); + const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE( + scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) + + scale_x * FLOAT_TYPE(data_b[row*p.KX + col])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp index 38075b75557f9..96c9c4cbd307c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp @@ -29,6 +29,7 @@ layout (push_constant) uniform parameter { uint s1; uint s2; int sections[4]; + uint is_back; } p; float rope_yarn_ramp(const float low, const float high, const uint i0) { @@ -48,6 +49,10 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); } + // Backprogagation uses inverted rotation + if (p.is_back != 0) { + theta = -theta; + } cos_theta = cos(theta) * mscale; sin_theta = sin(theta) * mscale; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp new file mode 100644 index 0000000000000..776581e2c4e26 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i]))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp new file mode 100644 index 0000000000000..f9afa9b13c1f2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp @@ -0,0 +1,26 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer G {A_TYPE data_g[];}; +layout (binding = 1) readonly buffer X {B_TYPE data_x[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + // Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2 + + const float xi = float(data_x[i]); + const float s = 1.0f / (1.0f + exp(-xi)); + data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp new file mode 100644 index 0000000000000..29bd77d7e1c88 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp @@ -0,0 +1,50 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "generic_head.comp" +#include "types.comp" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// In this shader Y = softmax(X) and X is not provided as input. + +layout (binding = 0) readonly buffer G {A_TYPE data_g[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_y[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum_yg[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + FLOAT_TYPE scale = p.param1; + + // partial sums for thread in warp + sum_yg[tid] = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]); + const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]); + sum_yg[tid] += yi * gi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum_yg[tid] += sum_yg[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE dot_yg = sum_yg[0]; + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale + * (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg) + * FLOAT_TYPE(data_y[row*p.KX + col])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index dfa16cda516bd..f01179326e7fc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -139,7 +139,7 @@ struct block_q8_0 struct block_q8_0_packed16 { float16_t d; - uint16_t qs[32/2]; + int16_t qs[32/2]; }; #if defined(DATA_A_Q8_0) @@ -466,10 +466,13 @@ shared uint16_t iq1s_grid[2048]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq1s_grid_const.length(); i += wgsize.x) { - u16vec2 g = unpack16(iq1s_grid_const[i]); - iq1s_grid[2*i+0] = g.x; - iq1s_grid[2*i+1] = g.y; + [[unroll]] for (uint i = 0; i < iq1s_grid_const.length(); i += wgsize.x) { + uint idx = i + gl_LocalInvocationIndex.x; + if (iq1s_grid_const.length() % wgsize.x == 0 || idx < iq1s_grid_const.length()) { + u16vec2 g = unpack16(iq1s_grid_const[idx]); + iq1s_grid[2*idx+0] = g.x; + iq1s_grid[2*idx+1] = g.y; + } } barrier(); } @@ -565,8 +568,10 @@ shared uvec2 iq2xxs_grid[256]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += wgsize.x) { - iq2xxs_grid[i] = iq2xxs_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq2xxs_grid.length(); i += wgsize.x) { + if (iq2xxs_grid_const.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xxs_grid_const.length()) { + iq2xxs_grid[i + gl_LocalInvocationIndex.x] = iq2xxs_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -733,8 +738,10 @@ shared uvec2 iq2xs_grid[512]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += wgsize.x) { - iq2xs_grid[i] = iq2xs_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq2xs_grid.length(); i += wgsize.x) { + if (iq2xs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xs_grid_const.length()) { + iq2xs_grid[i + gl_LocalInvocationIndex.x] = iq2xs_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -756,6 +763,14 @@ struct block_iq2_s uint8_t scales[QUANT_K_IQ2_S/32]; }; +struct block_iq2_s_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_S/8]; + uint16_t qh[QUANT_K_IQ2_S/64]; + uint16_t scales[QUANT_K_IQ2_S/64]; +}; + #if defined(DATA_A_IQ2_S) const uvec2 iq2s_grid_const[1024] = { @@ -1023,8 +1038,10 @@ shared uvec2 iq2s_grid[1024]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += wgsize.x) { - iq2s_grid[i] = iq2s_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq2s_grid.length(); i += wgsize.x) { + if (iq2s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2s_grid_const.length()) { + iq2s_grid[i + gl_LocalInvocationIndex.x] = iq2s_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -1032,6 +1049,7 @@ void init_iq_shmem(uvec3 wgsize) #define QUANT_K QUANT_K_IQ2_S #define QUANT_R QUANT_R_IQ2_S #define A_TYPE block_iq2_s +#define A_TYPE_PACKED16 block_iq2_s_packed16 #endif #define QUANT_K_IQ3_XXS 256 @@ -1092,8 +1110,10 @@ shared uint32_t iq3xxs_grid[256]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += wgsize.x) { - iq3xxs_grid[i] = iq3xxs_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq3xxs_grid.length(); i += wgsize.x) { + if (iq3xxs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3xxs_grid.length()) { + iq3xxs_grid[i + gl_LocalInvocationIndex.x] = iq3xxs_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -1200,8 +1220,10 @@ shared uint32_t iq3s_grid[512]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += wgsize.x) { - iq3s_grid[i] = iq3s_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq3s_grid.length(); i += wgsize.x) { + if (iq3s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3s_grid.length()) { + iq3s_grid[i + gl_LocalInvocationIndex.x] = iq3s_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 82c96975428cf..4b6b3b265741b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -326,11 +326,17 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); for (const auto& tname : type_names) { + std::string load_vec_quant = "2"; + if ((tname == "q4_0") || (tname == "q4_1")) + load_vec_quant = "8"; + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) + load_vec_quant = "4"; + std::string data_a_key = "DATA_A_" + to_uppercase(tname); // For unaligned, load one at a time for f32/f16, or two at a time for quants - std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2"; + std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant; // For aligned matmul loads - std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; + std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant; // don't generate f32 variants for coopmat2 if (!coopmat2) { @@ -397,7 +403,7 @@ void process_shaders() { for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); - std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); @@ -428,6 +434,7 @@ void process_shaders() { string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); @@ -478,14 +485,17 @@ void process_shaders() { string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); diff --git a/gguf-py/examples/reader.py b/gguf-py/examples/reader.py index d841048c6d701..703b782b5fa66 100644 --- a/gguf-py/examples/reader.py +++ b/gguf-py/examples/reader.py @@ -2,12 +2,14 @@ import logging import sys from pathlib import Path -from gguf.gguf_reader import GGUFReader logger = logging.getLogger("reader") +# Necessary to load the local gguf package sys.path.insert(0, str(Path(__file__).parent.parent)) +from gguf.gguf_reader import GGUFReader + def read_gguf_file(gguf_file_path): """ diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py index e17a4e83147d5..5991cdb76beac 100644 --- a/gguf-py/gguf/gguf_reader.py +++ b/gguf-py/gguf/gguf_reader.py @@ -6,6 +6,7 @@ import logging import os +import sys from collections import OrderedDict from typing import Any, Literal, NamedTuple, TypeVar, Union @@ -15,7 +16,6 @@ from .quants import quant_shape_to_byte_shape if __name__ == "__main__": - import sys from pathlib import Path # Allow running file in package as a script. @@ -28,6 +28,7 @@ GGUF_VERSION, GGMLQuantizationType, GGUFValueType, + GGUFEndian, ) logger = logging.getLogger(__name__) @@ -53,6 +54,48 @@ class ReaderField(NamedTuple): types: list[GGUFValueType] = [] + def contents(self, index_or_slice: int | slice = slice(None)) -> Any: + if self.types: + to_string = lambda x: str(x.tobytes(), encoding='utf-8') # noqa: E731 + main_type = self.types[0] + + if main_type == GGUFValueType.ARRAY: + sub_type = self.types[-1] + + if sub_type == GGUFValueType.STRING: + indices = self.data[index_or_slice] + + if isinstance(index_or_slice, int): + return to_string(self.parts[indices]) # type: ignore + else: + return [to_string(self.parts[idx]) for idx in indices] # type: ignore + else: + # FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too + + # Check if it's unsafe to perform slice optimization on data + # if any(True for idx in self.data if len(self.parts[idx]) != 1): + # optim_slice = slice(None) + # else: + # optim_slice = index_or_slice + # index_or_slice = slice(None) + + # if isinstance(optim_slice, int): + # return self.parts[self.data[optim_slice]].tolist()[0] + # else: + # return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice] + + if isinstance(index_or_slice, int): + return self.parts[self.data[index_or_slice]].tolist()[0] + else: + return [pv for idx in self.data[index_or_slice] for pv in self.parts[idx].tolist()] + + if main_type == GGUFValueType.STRING: + return to_string(self.parts[-1]) + else: + return self.parts[-1].tolist()[0] + + return None + class ReaderTensor(NamedTuple): name: str @@ -101,10 +144,19 @@ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = # If we get 0 here that means it's (probably) a GGUF file created for # the opposite byte order of the machine this script is running on. self.byte_order = 'S' - temp_version = temp_version.newbyteorder(self.byte_order) + temp_version = temp_version.view(temp_version.dtype.newbyteorder(self.byte_order)) version = temp_version[0] if version not in READER_SUPPORTED_VERSIONS: raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle') + if sys.byteorder == "little": + # Host is little endian + host_endian = GGUFEndian.LITTLE + swapped_endian = GGUFEndian.BIG + else: + # Sorry PDP or other weird systems that don't use BE or LE. + host_endian = GGUFEndian.BIG + swapped_endian = GGUFEndian.LITTLE + self.endianess = swapped_endian if self.byte_order == "S" else host_endian self.fields: OrderedDict[str, ReaderField] = OrderedDict() self.tensors: list[ReaderTensor] = [] offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32])) @@ -146,9 +198,7 @@ def _get( itemsize = int(np.empty([], dtype = dtype).itemsize) end_offs = offset + itemsize * count arr = self.data[offset:end_offs].view(dtype=dtype)[:count] - if override_order is None: - return arr - return arr.view(arr.dtype.newbyteorder(override_order)) + return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order)) def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int: if field.name in self.fields: @@ -190,6 +240,7 @@ def _get_field_parts( offs += int(alen.nbytes) aparts: list[npt.NDArray[Any]] = [raw_itype, alen] data_idxs: list[int] = [] + # FIXME: Handle multi-dimensional arrays properly instead of flattening for idx in range(alen[0]): curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0]) if idx == 0: diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index 962c27b204464..e807f434689de 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -121,19 +121,39 @@ def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]: if not model_card_path.is_file(): return {} - # The model card metadata is assumed to always be in YAML + # The model card metadata is assumed to always be in YAML (frontmatter) # ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473 + yaml_content: str = "" with open(model_card_path, "r", encoding="utf-8") as f: - if f.readline() == "---\n": - raw = f.read().partition("---\n")[0] - data = yaml.safe_load(raw) - if isinstance(data, dict): - return data + content = f.read() + lines = content.splitlines() + lines_yaml = [] + if len(lines) == 0: + # Empty file + return {} + if len(lines) > 0 and lines[0] != "---": + # No frontmatter + return {} + for line in lines[1:]: + if line == "---": + break # End of frontmatter else: - logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict") - return {} + lines_yaml.append(line) + yaml_content = "\n".join(lines_yaml) + "\n" + + # Quick hack to fix the Norway problem + # https://hitchdev.com/strictyaml/why/implicit-typing-removed/ + yaml_content = yaml_content.replace("- no\n", "- \"no\"\n") + + if yaml_content: + data = yaml.safe_load(yaml_content) + if isinstance(data, dict): + return data else: + logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict") return {} + else: + return {} @staticmethod def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]: diff --git a/gguf-py/gguf/scripts/gguf_convert_endian.py b/gguf-py/gguf/scripts/gguf_convert_endian.py index 837831799b31c..0e0febaa79178 100755 --- a/gguf-py/gguf/scripts/gguf_convert_endian.py +++ b/gguf-py/gguf/scripts/gguf_convert_endian.py @@ -20,22 +20,15 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None: - if np.uint32(1) == np.uint32(1).newbyteorder("<"): - # Host is little endian - host_endian = "little" - swapped_endian = "big" + file_endian = reader.endianess.name + if reader.byte_order == 'S': + host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE' else: - # Sorry PDP or other weird systems that don't use BE or LE. - host_endian = "big" - swapped_endian = "little" - if reader.byte_order == "S": - file_endian = swapped_endian - else: - file_endian = host_endian - order = host_endian if args.order == "native" else args.order - logger.info(f"* Host is {host_endian.upper()} endian, GGUF file seems to be {file_endian.upper()} endian") + host_endian = file_endian + order = host_endian if args.order == "native" else args.order.upper() + logger.info(f"* Host is {host_endian} endian, GGUF file seems to be {file_endian} endian") if file_endian == order: - logger.info(f"* File is already {order.upper()} endian. Nothing to do.") + logger.info(f"* File is already {order} endian. Nothing to do.") sys.exit(0) logger.info("* Checking tensors for conversion compatibility") for tensor in reader.tensors: @@ -47,7 +40,7 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None gguf.GGMLQuantizationType.Q6_K, ): raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}") - logger.info(f"* Preparing to convert from {file_endian.upper()} to {order.upper()}") + logger.info(f"* Preparing to convert from {file_endian} to {order}") if args.dry_run: return logger.warning("*** Warning *** Warning *** Warning **") diff --git a/gguf-py/gguf/scripts/gguf_dump.py b/gguf-py/gguf/scripts/gguf_dump.py index 20f23d729f4b9..e282892d645c7 100755 --- a/gguf-py/gguf/scripts/gguf_dump.py +++ b/gguf-py/gguf/scripts/gguf_dump.py @@ -9,8 +9,6 @@ from pathlib import Path from typing import Any -import numpy as np - # Necessary to load the local gguf package if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): sys.path.insert(0, str(Path(__file__).parent.parent.parent)) @@ -21,11 +19,11 @@ def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]: - host_endian = 'LITTLE' if np.uint32(1) == np.uint32(1).newbyteorder("<") else 'BIG' + file_endian = reader.endianess.name if reader.byte_order == 'S': - file_endian = 'BIG' if host_endian == 'LITTLE' else 'LITTLE' + host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE' else: - file_endian = host_endian + host_endian = file_endian return (host_endian, file_endian) @@ -45,12 +43,20 @@ def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: pretty_type = str(field.types[-1].name) log_message = f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}' - if len(field.types) == 1: + if field.types: curr_type = field.types[0] if curr_type == GGUFValueType.STRING: - log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60])) - elif field.types[0] in reader.gguf_scalar_to_np: - log_message += ' = {0}'.format(field.parts[-1][0]) + content = field.contents() + if len(content) > 60: + content = content[:57] + '...' + log_message += ' = {0}'.format(repr(content)) + elif curr_type in reader.gguf_scalar_to_np: + log_message += ' = {0}'.format(field.contents()) + else: + content = repr(field.contents(slice(6))) + if len(field.data) > 6: + content = content[:-1] + ', ...]' + log_message += ' = {0}'.format(content) print(log_message) # noqa: NP100 if args.no_tensors: return @@ -82,15 +88,9 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None: curr["array_types"] = [t.name for t in field.types][1:] if not args.json_array: continue - itype = field.types[-1] - if itype == GGUFValueType.STRING: - curr["value"] = [str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data] - else: - curr["value"] = [pv for idx in field.data for pv in field.parts[idx].tolist()] - elif field.types[0] == GGUFValueType.STRING: - curr["value"] = str(bytes(field.parts[-1]), encoding="utf-8") + curr["value"] = field.contents() else: - curr["value"] = field.parts[-1].tolist()[0] + curr["value"] = field.contents() if not args.no_tensors: for idx, tensor in enumerate(reader.tensors): tensors[tensor.name] = { diff --git a/gguf-py/gguf/scripts/gguf_new_metadata.py b/gguf-py/gguf/scripts/gguf_new_metadata.py index a8cfc9d581b3d..7aff6c9259a1f 100755 --- a/gguf-py/gguf/scripts/gguf_new_metadata.py +++ b/gguf-py/gguf/scripts/gguf_new_metadata.py @@ -8,7 +8,6 @@ import json from pathlib import Path -import numpy as np from tqdm import tqdm from typing import Any, Sequence, NamedTuple @@ -27,45 +26,10 @@ class MetadataDetails(NamedTuple): description: str = '' -def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian: - if np.uint32(1) == np.uint32(1).newbyteorder("<"): - # Host is little endian - host_endian = gguf.GGUFEndian.LITTLE - swapped_endian = gguf.GGUFEndian.BIG - else: - # Sorry PDP or other weird systems that don't use BE or LE. - host_endian = gguf.GGUFEndian.BIG - swapped_endian = gguf.GGUFEndian.LITTLE - - if reader.byte_order == "S": - return swapped_endian - else: - return host_endian - - -def decode_field(field: gguf.ReaderField | None) -> Any: - if field and field.types: - main_type = field.types[0] - - if main_type == gguf.GGUFValueType.ARRAY: - sub_type = field.types[-1] - - if sub_type == gguf.GGUFValueType.STRING: - return [str(bytes(field.parts[idx]), encoding='utf-8') for idx in field.data] - else: - return [pv for idx in field.data for pv in field.parts[idx].tolist()] - if main_type == gguf.GGUFValueType.STRING: - return str(bytes(field.parts[-1]), encoding='utf-8') - else: - return field.parts[-1][0] - - return None - - def get_field_data(reader: gguf.GGUFReader, key: str) -> Any: field = reader.get_field(key) - return decode_field(field) + return field.contents() if field else None def find_token(token_list: Sequence[int], token: str) -> Sequence[int]: @@ -93,7 +57,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new logger.debug(f'Removing {field.name}') continue - old_val = MetadataDetails(field.types[0], decode_field(field)) + old_val = MetadataDetails(field.types[0], field.contents()) val = new_metadata.get(field.name, old_val) if field.name in new_metadata: @@ -192,7 +156,6 @@ def main() -> None: reader = gguf.GGUFReader(args.input, 'r') arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE) - endianess = get_byteorder(reader) token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or [] @@ -230,7 +193,7 @@ def main() -> None: sys.exit(0) logger.info(f'* Writing: {args.output}') - writer = gguf.GGUFWriter(args.output, arch=arch, endianess=endianess) + writer = gguf.GGUFWriter(args.output, arch=arch, endianess=reader.endianess) alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT) if alignment is not None: diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index b4a47333dd15f..d214e8720122b 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gguf" -version = "0.15.0" +version = "0.16.0" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ diff --git a/include/llama.h b/include/llama.h index f014173875d80..d1b781dd9d59d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -105,6 +105,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, }; enum llama_rope_type { @@ -477,6 +478,7 @@ extern "C" { LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); diff --git a/models/ggml-vocab-gpt-4o.gguf.inp b/models/ggml-vocab-gpt-4o.gguf.inp new file mode 100644 index 0000000000000..9baf7d77ae6b5 --- /dev/null +++ b/models/ggml-vocab-gpt-4o.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Führer +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-gpt-4o.gguf.out b/models/ggml-vocab-gpt-4o.gguf.out new file mode 100644 index 0000000000000..478df726fa9ba --- /dev/null +++ b/models/ggml-vocab-gpt-4o.gguf.out @@ -0,0 +1,46 @@ + 1165 220 19 220 27124 5503 + 37 19194 259 + + 220 + 256 + 271 + 197 + 198 + 279 + 2499 + 2775 + 13225 2375 + 32949 2375 + 13225 5922 + 32949 5922 + 32949 5922 0 + 13225 11 2375 0 + 32949 11 2375 0 + 495 382 9552 99 247 13 17159 + 86 45404 220 22 10191 2852 22924 4750 6916 + 3907 53641 1235 185386 8118 + 11400 107516 15867 20804 22851 134178 77431 32010 104312 37984 16329 27751 89335 + 112927 222 350 14559 8 22861 114 2524 64364 104 15148 350 76466 166700 121942 780 8 91349 350 7393 74471 484 853 1617 2316 6602 8 + 13225 + 32949 + 220 32949 + 256 32949 + 271 32949 + 271 32949 198 271 32949 + 350 + 198 314 + 6 6837 + 13225 11 342 70653 0 3253 553 481 22861 223 1423 7522 18165 2178 34058 22369 16412 32999 16 867 8208 + 147475 + 18 + 2546 + 15517 + 15517 18 + 15517 2546 + 15517 15517 + 15517 15517 18 + 15517 15517 2546 + 15517 15517 15517 + 34 60213 53904 + 2960 3098 + 126470 25980 160432 16609 2775 4066 172261 19432 112927 222 350 14559 8 22861 114 2524 64364 104 15148 350 76466 166700 121942 780 8 91349 9552 99 247 4103 99 247 220 18 220 2546 220 15517 220 15517 18 220 15517 2546 220 15517 15517 220 15517 15517 18 220 15517 15517 2546 220 18 13 18 220 18 485 18 220 18 1008 18 44735 107516 15867 20804 22851 134178 77431 32010 104312 156437 1423 7522 18165 2178 34058 22369 16412 32999 16 867 8208 105024 106657 1967 53641 1235 185386 8118 22434 39336 26178 26178 168394 194663 27271 147475 25883 6961 9790 1339 461 83 1280 19016 1354 11 461 1099 481 3239 30 461 44 625 3239 17291 1520 480 11 461 35 481 1299 1236 17966 30 1416 6 27493 261 54602 43 diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f64c3afa02983..1da4eae7e63e2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2202,13 +2202,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PHI3: { - const int64_t n_embd_head = n_embd / n_head; - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -2223,8 +2226,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0); - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } } break; case LLM_ARCH_PHIMOE: @@ -3838,6 +3841,10 @@ int32_t llama_model_n_head(const struct llama_model * model) { return model->hparams.n_head(); } +int32_t llama_model_n_head_kv(const struct llama_model * model) { + return model->hparams.n_head_kv(); +} + // deprecated int32_t llama_n_ctx_train(const struct llama_model * model) { return llama_model_n_ctx_train(model); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index ad9ffe66aa749..163ff64f77973 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -392,6 +392,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; break; + case LLAMA_VOCAB_PRE_TYPE_GPT4O: + regex_exprs = { + // original regex from tokenizer.json + // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1592,6 +1599,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "megrez") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; + } else if ( + tokenizer_pre == "gpt-4o") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e1f7e6758b62e..1dc2cdda30fc2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -467,6 +468,7 @@ struct test_case { // allocate ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1); + if (buf == NULL) { printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1)); ggml_free(ctx); @@ -588,14 +590,13 @@ struct test_case { /* .mem_base = */ NULL, /* .no_alloc = */ true, }; - ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx(ggml_init(params)); // smart ptr GGML_ASSERT(ctx); - ggml_tensor * out = build_graph(ctx); + ggml_tensor * out = build_graph(ctx.get()); if (op_name != nullptr && op_desc(out) != op_name) { //printf(" %s: skipping\n", op_desc(out).c_str()); - ggml_free(ctx); return true; } @@ -605,7 +606,6 @@ struct test_case { // check if backends support op if (!ggml_backend_supports_op(backend, out)) { printf("not supported\n"); - ggml_free(ctx); return true; } @@ -618,22 +618,26 @@ struct test_case { printf("%*s", last - len, ""); // allocate - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend); + ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr + if (buf == NULL) { printf("failed to allocate tensors\n"); - ggml_free(ctx); return false; } // randomize tensors - initialize_tensors(ctx); + initialize_tensors(ctx.get()); // build graph - ggml_cgraph * gf = ggml_new_graph_custom(ctx, graph_nodes, false); + ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false); ggml_build_forward_expand(gf, out); // warmup run - ggml_backend_graph_compute(backend, gf); + ggml_status status = ggml_backend_graph_compute(backend, gf); + if (status != GGML_STATUS_SUCCESS) { + fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status)); + return false; + } // determine number of runs int n_runs; @@ -684,7 +688,11 @@ struct test_case { int total_runs = 0; do { int64_t start_time = ggml_time_us(); - ggml_backend_graph_compute(backend, gf); + ggml_status status = ggml_backend_graph_compute(backend, gf); + if (status != GGML_STATUS_SUCCESS) { + fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status)); + return false; + } int64_t end_time = ggml_time_us(); total_time_us += end_time - start_time; @@ -722,10 +730,6 @@ struct test_case { } printf("\n"); - ggml_backend_buffer_free(buf); - - ggml_free(ctx); - return true; } @@ -738,17 +742,16 @@ struct test_case { /* .mem_base = */ NULL, /* .no_alloc = */ true, }; - ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx(ggml_init(params)); // smart ptr GGML_ASSERT(ctx); - gf = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true); - gb = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true); + gf = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true); + gb = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true); - ggml_tensor * out = build_graph(ctx); + ggml_tensor * out = build_graph(ctx.get()); if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) { //printf(" %s: skipping\n", op_desc(out).c_str()); - ggml_free(ctx); return true; } @@ -756,7 +759,6 @@ struct test_case { fflush(stdout); if (out->type != GGML_TYPE_F32) { - ggml_free(ctx); printf("not supported [%s->type != FP32]\n", out->name); return true; } @@ -764,7 +766,7 @@ struct test_case { // check if the backend supports the ops bool supported = true; bool any_params = false; - for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) { if (!ggml_backend_supports_op(backend, t)) { printf("not supported [%s] ", ggml_backend_name(backend)); supported = false; @@ -785,40 +787,38 @@ struct test_case { } if (!supported) { printf("\n"); - ggml_free(ctx); return true; } int64_t ngrads = 0; - for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) { if (t->flags & GGML_TENSOR_FLAG_PARAM) { ngrads += ggml_nelements(t); } } if (ngrads > grad_nmax()) { printf("skipping large tensors for speed \n"); - ggml_free(ctx); return true; } if (!ggml_is_scalar(out)) { - out = ggml_sum(ctx, out); + out = ggml_sum(ctx.get(), out); ggml_set_name(out, "sum_of_out"); } ggml_set_loss(out); ggml_build_forward_expand(gf, out); ggml_graph_cpy(gf, gb); - ggml_build_backward_expand(ctx, ctx, gb, false); + ggml_build_backward_expand(ctx.get(), ctx.get(), gb, false); if (expect.size() != 1 || expect[0] != 0.0f) { GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf)); - for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) { GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE); } } - for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) { if (!ggml_backend_supports_op(backend, t)) { printf("not supported [%s] ", ggml_backend_name(backend)); supported = false; @@ -832,27 +832,32 @@ struct test_case { } if (!supported) { printf("\n"); - ggml_free(ctx); return true; } // allocate - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend); + ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr if (buf == NULL) { printf("failed to allocate tensors [%s] ", ggml_backend_name(backend)); - ggml_free(ctx); return false; } - - initialize_tensors(ctx); // Randomizes all tensors (including gradients). + initialize_tensors(ctx.get()); // Randomizes all tensors (including gradients). ggml_graph_reset(gb); // Sets gradients to 1 if loss, 0 otherwise. - ggml_backend_graph_compute(backend, gf); - ggml_backend_graph_compute(backend, gb); + ggml_status status = ggml_backend_graph_compute(backend, gf); + if (status != GGML_STATUS_SUCCESS) { + fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status)); + return false; + } + status = ggml_backend_graph_compute(backend, gb); + if (status != GGML_STATUS_SUCCESS) { + fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status)); + return false; + } bool ok = true; - for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + for (struct ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) { if (!(t->flags & GGML_TENSOR_FLAG_PARAM)) { continue; } @@ -897,20 +902,36 @@ struct test_case { float fu, fuh, fdh, fd; // output values for xiu, xiuh, xid, xidh ggml_backend_tensor_set(t, &xiu, i*sizeof(float), sizeof(float)); - ggml_backend_graph_compute(backend, gf); + status = ggml_backend_graph_compute(backend, gf); + if (status != GGML_STATUS_SUCCESS) { + fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status)); + return false; + } ggml_backend_tensor_get(out, &fu, 0, ggml_nbytes(out)); ggml_backend_tensor_set(t, &xid, i*sizeof(float), sizeof(float)); - ggml_backend_graph_compute(backend, gf); + status = ggml_backend_graph_compute(backend, gf); + if (status != GGML_STATUS_SUCCESS) { + fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status)); + return false; + } ggml_backend_tensor_get(out, &fd, 0, ggml_nbytes(out)); if (grad_precise()) { ggml_backend_tensor_set(t, &xiuh, i*sizeof(float), sizeof(float)); - ggml_backend_graph_compute(backend, gf); + status = ggml_backend_graph_compute(backend, gf); + if (status != GGML_STATUS_SUCCESS) { + fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status)); + return false; + } ggml_backend_tensor_get(out, &fuh, 0, ggml_nbytes(out)); ggml_backend_tensor_set(t, &xidh, i*sizeof(float), sizeof(float)); - ggml_backend_graph_compute(backend, gf); + status = ggml_backend_graph_compute(backend, gf); + if (status != GGML_STATUS_SUCCESS) { + fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status)); + return false; + } ggml_backend_tensor_get(out, &fdh, 0, ggml_nbytes(out)); gn[i] = (8.0*(double)fuh + (double)fd - (8.0*(double)fdh + (double)fu)) / (6.0*(double)eps); @@ -936,10 +957,6 @@ struct test_case { printf("compare failed "); } - ggml_backend_buffer_free(buf); - - ggml_free(ctx); - if (ok) { printf("\033[1;32mOK\033[0m\n"); return true;