diff --git a/README.md b/README.md index 06a1d81903..e12368fb7d 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,7 @@ Every model is written from scratch to maximize performance and remove layers of | Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | | Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | +| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | diff --git a/litgpt/config.py b/litgpt/config.py index 9c64d2ae48..475f017e50 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2033,6 +2033,74 @@ def norm_class(self) -> Type: qwen_2_5.extend(qwen_2_5_coder) +qwen_2_5_math = [ + # https://huggingface.co/Qwen/Qwen2.5-Math-1.5B/blob/main/config.json + dict( + name="Qwen2.5-Math-1.5B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Math-1.5B{}"), + block_size=4096, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=28, + n_head=12, + n_embd=1536, + n_query_groups=2, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=8960, + norm_eps=1e-6, + rope_base=10000 + ), + # https://huggingface.co/Qwen/Qwen2.5-Math-7B/blob/main/config.json + dict( + name="Qwen2.5-Math-7B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Math-7B{}"), + block_size=4096, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=28, + n_head=28, + n_embd=3584, + n_query_groups=4, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=18944, + norm_eps=1e-6, + rope_base=10000 + ), + # https://huggingface.co/Qwen/Qwen2.5-Math-72B/blob/main/config.json + dict( + name="Qwen2.5-Math-72B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Math-72B{}"), + block_size=4096, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=29568, + norm_eps=1e-5, + rope_base=10000 + ), +] + +qwen_2_5.extend(qwen_2_5_math) + for c in qwen_2_5: for kind in ("", "-Instruct"): copy = deepcopy(c) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index f5b59e4e90..ab9cfee594 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -284,6 +284,10 @@ def apply(self, prompt: str, **kwargs: str) -> str: system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" +class Qwen2_5_Math(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + system_message = "Please reason step by step, and put your final answer within \\boxed{}." + return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" class QwQ(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: @@ -326,6 +330,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "llama3": Llama3, "olmo": OLMo, "qwen2.5": Qwen2_5, + "qwen2.5-math": Qwen2_5_Math, "qwq": QwQ, "smollm2": SmolLM2, # SmolLM uses a different template "salamandra": Salamandra, @@ -367,6 +372,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Gemma() if re.search(r"OLMo.*-hf", model_name): return OLMo() + if re.search(r"Qwen2\.5-Math-.*", model_name): + return Qwen2_5_Math() if re.search(r"Qwen2\.5-.*", model_name): return Qwen2_5() if re.search(r"QwQ-.*", model_name): diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index 5e24827cef..5809f0063d 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -524,7 +524,7 @@ def test_check_conversion_supported_lora(): check_conversion_supported(lit_weights=lit_weights) @torch.inference_mode() -@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "QwQ-32B-Preview")) +@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "Qwen2.5-Math-1.5B", "QwQ-32B-Preview")) @pytest.mark.parametrize( ("device", "dtype"), [ diff --git a/tests/test_model.py b/tests/test_model.py index 1a997f3134..e3aad3bb0e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -792,7 +792,7 @@ def test_against_original_gemma_2(model_name, device, dtype): @torch.inference_mode() -@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "QwQ-32B-Preview")) +@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "Qwen2.5-Math-1.5B", "QwQ-32B-Preview")) @pytest.mark.parametrize( ("device", "dtype"), [ diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index fe276d3eac..876db1916a 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -37,6 +37,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | | Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | +| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | | SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | @@ -207,6 +208,12 @@ Qwen/Qwen2.5-Coder-14B Qwen/Qwen2.5-Coder-14B-Instruct Qwen/Qwen2.5-Coder-32B Qwen/Qwen2.5-Coder-32B-Instruct +Qwen/Qwen2.5-Math-1.5B +Qwen/Qwen2.5-Math-1.5B-Instruct +Qwen/Qwen2.5-Math-7B +Qwen/Qwen2.5-Math-7B-Instruct +Qwen/Qwen2.5-Math-72B +Qwen/Qwen2.5-Math-72B-Instruct Qwen/QwQ-32B-Preview stabilityai/FreeWilly2 stabilityai/stable-code-3b