Skip to content

Commit

Permalink
phi-4 (#1904)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysjprojects authored Jan 9, 2025
1 parent 8db3ef5 commit a5021be
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ Every model is written from scratch to maximize performance and remove layers of
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Phi 3 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) |
| Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905) |
| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) |
Expand Down
19 changes: 19 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,25 @@ def norm_class(self) -> Type:
mlp_class_name="LLaMAMLP",
parallel_residual=False,
),
# https://huggingface.co/microsoft/phi-4/blob/main/config.json
dict(
name="phi-4",
hf_config=dict(org="microsoft", name="phi-4"),
vocab_size=100352,
padded_vocab_size=100352,
block_size=16384,
n_embd=5120,
n_layer=40,
n_head=40,
n_query_groups=10,
rotary_percentage=1.0,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=17920,
rope_base=250000,
mlp_class_name="LLaMAMLP",
parallel_residual=False,
),
]
configs.extend(phi)

Expand Down
6 changes: 6 additions & 0 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ class Phi3(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return f'<|system|>\nYou are a helpful assistant.<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n'

class Phi4(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return f'<|im_start|>user<|im_sep|>{prompt}<|im_end|><|im_start|>assistant<|im_sep|>'

class TinyLlama(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
Expand Down Expand Up @@ -337,6 +340,7 @@ def __init__(self):
"phi-1": Phi1,
"phi-2": Phi2,
"phi-3": Phi3,
"phi-4": Phi4,
"tinyllama": TinyLlama,
"gemma": Gemma,
"llama3": Llama3,
Expand Down Expand Up @@ -380,6 +384,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Phi2()
if re.search("Phi-3", model_name):
return Phi3()
if re.search("phi-4", model_name):
return Phi4()
if re.search(r"tiny-llama.*chat", model_name):
return TinyLlama()
if re.search(r"(Code)?Gemma.*-it", model_name):
Expand Down
2 changes: 1 addition & 1 deletion litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def copy_weights_phi(
"lm_head.bias": "lm_head.bias",
}

if config.name.startswith("Phi-3"):
if config.name.startswith(("Phi-3", "phi-4")):
weight_map.update(
{
"model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.qkv.weight",
Expand Down
2 changes: 1 addition & 1 deletion litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def copy_weights_phi(
"lm_head.weight": "lm_head.weight",
"lm_head.bias": "lm_head.bias",
}
if config.name.startswith("Phi-3"):
if config.name.startswith(("Phi-3", "phi-4")):
weight_map.update(
{
"transformer.h.{}.attn.qkv.weight": "model.layers.{}.self_attn.qkv_proj.weight",
Expand Down
3 changes: 2 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def test_against_hf_phi(model_name, device, dtype):


@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("Phi-3-mini-4k-instruct", "Phi-3-mini-128k-instruct", "Phi-3.5-mini-instruct"))
@pytest.mark.parametrize("model_name", ("Phi-3-mini-4k-instruct", "Phi-3-mini-128k-instruct", "Phi-3.5-mini-instruct", "phi-4"))
@pytest.mark.parametrize(
("device", "dtype"),
[
Expand All @@ -352,6 +352,7 @@ def test_against_hf_phi_3(model_name, device, dtype):
padded_vocab_size=10000,
n_layer=2,
n_head=4,
n_query_groups=4,
n_embd=256,
)
T = 5
Expand Down
2 changes: 2 additions & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Phi 3 & 3.5 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219)
| Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905) |
| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) |
Expand Down Expand Up @@ -165,6 +166,7 @@ microsoft/phi-2
microsoft/Phi-3-mini-128k-instruct
microsoft/Phi-3-mini-4k-instruct
microsoft/Phi-3.5-mini-instruct
microsoft/phi-4
mistralai/mathstral-7B-v0.1
mistralai/Mistral-7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
Expand Down

0 comments on commit a5021be

Please sign in to comment.