diff --git a/README.md b/README.md index 6092db7ad8..c58a586fdc 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ Every model is written from scratch to maximize performance and remove layers of |----|----|----|----| | Llama 3, 3.1, 3.2 | 1B, 3B, 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | -| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) | +| Mixtral MoE | 8x7B, 8x22B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) | | Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) | | CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | | Gemma 2 | 2B, 9B, 27B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf) | @@ -117,6 +117,7 @@ Every model is written from scratch to maximize performance and remove layers of | CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | +| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) | | FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | | Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | | Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | @@ -124,16 +125,24 @@ Every model is written from scratch to maximize performance and remove layers of | Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | | Llama 3.1 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/) | +| Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) | | Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) | | MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) | | Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) | | Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) | +| Mixtral MoE | 8x22B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mixtral-8x22b/) | | OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) | | OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | | Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | | Phi 3 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) | | Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | | Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | +| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | +| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | +| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | +| QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | +| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | +| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | diff --git a/extensions/thunder/unsloth/executor.py b/extensions/thunder/unsloth/executor.py index a0ed54598a..1779daf8ee 100644 --- a/extensions/thunder/unsloth/executor.py +++ b/extensions/thunder/unsloth/executor.py @@ -240,7 +240,7 @@ def unsloth_apply_rope_meta( Q: TensorProxy, cos: TensorProxy, sin: TensorProxy ) -> Tuple[TensorProxy, TensorProxy, TensorProxy, int, int, int]: batch, n_heads, seq_len, head_dim = Q.shape - assert seq_len <= cos.shape[0] + assert seq_len <= cos.shape[-2] BLOCK_SIZE, num_warps = kernels.calculate_settings(head_dim // 2) div, mod = divmod(n_heads, kernels.rope_embedding.ROPE_GROUP_SIZE) n_groups = div + (mod != 0) diff --git a/litgpt/adapter.py b/litgpt/adapter.py index f8e4ac51e4..bef77ece1b 100644 --- a/litgpt/adapter.py +++ b/litgpt/adapter.py @@ -132,8 +132,8 @@ def __init__(self, config: Config, block_idx: int) -> None: self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None self.block_idx = block_idx self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_placing == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index 3ce60c9471..6885f628aa 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -164,7 +164,7 @@ def __init__(self, config: Config, block_idx: int) -> None: nn.Module.__init__(self) shape = (config.n_head + 2 * config.n_query_groups) * config.head_size # key, query, value projections for all heads, but in a batch - self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) + self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) # output projection # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) @@ -180,8 +180,8 @@ def __init__(self, config: Config, block_idx: int) -> None: self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None self.block_idx = block_idx self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_placing == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config diff --git a/litgpt/api.py b/litgpt/api.py index a114fdd512..ea156ce600 100644 --- a/litgpt/api.py +++ b/litgpt/api.py @@ -386,7 +386,7 @@ def distribute( model.eval() if generate_strategy == "sequential": - state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu") + state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False) model.load_state_dict(state_dict, assign=True) model = fabric.setup_module(model, move_to_device=False) @@ -405,7 +405,7 @@ def distribute( pbar = tqdm(total=fabric.world_size, desc="Loading model weights") for rank in range(fabric.world_size): if fabric.global_rank == rank: - state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu") + state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False) model.load_state_dict(state_dict, assign=True) # cannot use `.setup_module` because it will wrap with DDP diff --git a/litgpt/config.py b/litgpt/config.py index b218df849c..a4a70c8238 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -15,22 +15,23 @@ class Config: name: str = "" hf_config: dict = field(default_factory=dict) - scale_embeddings: bool = False - attention_scores_scalar: Optional[int] = None + # General size parameters block_size: int = 4096 - sliding_window_size: Optional[int] = None - sliding_window_layer_placing: Optional[Literal["all", "interleaved"]] = None + n_layer: int = 16 + n_embd: int = 4096 vocab_size: int = 50254 padding_multiple: int = 512 padded_vocab_size: Optional[int] = None - n_layer: int = 16 + # Transformer block (structure, normalizations) + norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" + norm_eps: float = 1e-5 + post_attention_norm: bool = False + post_mlp_norm: bool = False + parallel_residual: bool = True + shared_attention_norm: bool = False + # Transformer block (self-attention) n_head: int = 32 head_size: Optional[int] = None - n_embd: int = 4096 - rotary_percentage: float = 0.25 - parallel_residual: bool = True - bias: bool = True - lm_head_bias: bool = False # to use multi-head attention (MHA), set this to `n_head` (default) # to use multi-query attention (MQA), set this to 1 # to use grouped-query attention (GQA), set this to a value in between @@ -52,20 +53,29 @@ class Config: # # credit https://arxiv.org/pdf/2305.13245.pdf n_query_groups: Optional[int] = None - shared_attention_norm: bool = False - norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" - post_attention_norm: bool = False - post_mlp_norm: bool = False - norm_eps: float = 1e-5 - mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" - gelu_approximate: str = "none" - intermediate_size: Optional[int] = None - rope_condense_ratio: int = 1 + attn_bias: bool = False + attention_scores_scalar: Optional[int] = None + sliding_window_size: Optional[int] = None + sliding_window_layer_placing: Optional[Literal["all", "interleaved"]] = None + # if `attention_logit_softcapping` is used, cannot use optimized + # `torch.nn.functional.scaled_dot_product_attention` (which implements + # Flash attention), may result in higher memory and runtime footprint. + attention_logit_softcapping: Optional[float] = None + # Rotary position embedding (RoPE) rope_base: int = 10000 + rotary_percentage: float = 0.25 + rope_condense_ratio: int = 1 rope_adjustments: Optional[dict] = None + # Transformer block (MLP) + intermediate_size: Optional[int] = None + bias: bool = True + mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" + gelu_approximate: str = "none" n_expert: int = 0 n_expert_per_token: int = 0 - attention_logit_softcapping: Optional[float] = None + # GPT before/after blocks + scale_embeddings: bool = False + lm_head_bias: bool = False final_logit_softcapping: Optional[float] = None def __post_init__(self): @@ -98,7 +108,7 @@ def __post_init__(self): self.rope_n_elem = int(self.rotary_percentage * self.head_size) if self.sliding_window_size is not None: - self.sliding_window_layer_placing = ( + self.sliding_window_layer_stride = ( 1 if (self.sliding_window_layer_placing is None or self.sliding_window_layer_placing == "all") else 2 ) @@ -440,6 +450,95 @@ def norm_class(self) -> Type: copy["hf_config"]["name"] = falcon180b["hf_config"]["name"].format(kind) configs.append(copy) +falcon3 = [ + # https://huggingface.co/tiiuae/Falcon3-1B-Base/blob/main/config.json + dict( + name="Falcon3-1B{}", + hf_config=dict(org="tiiuae", name="Falcon3-1B{}"), + block_size=4096, + vocab_size=131072, + padded_vocab_size=131072, + n_layer=18, + n_head=8, + n_query_groups=4, + n_embd=2048, + rotary_percentage=1.0, + parallel_residual=False, + rope_base=1000042, + norm_eps=1e-6, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=8192, + ), + # https://huggingface.co/tiiuae/Falcon3-3B-Base/blob/main/config.json + dict( + name="Falcon3-3B{}", + hf_config=dict(org="tiiuae", name="Falcon3-3B{}"), + block_size=32768, + vocab_size=131072, + padded_vocab_size=131072, + n_layer=22, + n_head=12, + n_query_groups=4, + n_embd=3072, + rotary_percentage=1.0, + parallel_residual=False, + rope_base=1000042, + norm_eps=1e-6, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=9216, + ), + # https://huggingface.co/tiiuae/Falcon3-7B-Base/blob/main/config.json + dict( + name="Falcon3-7B{}", + hf_config=dict(org="tiiuae", name="Falcon3-7B{}"), + block_size=32768, + vocab_size=131072, + padded_vocab_size=131072, + n_layer=28, + n_head=12, + n_query_groups=4, + n_embd=3072, + rotary_percentage=1.0, + parallel_residual=False, + rope_base=1000042, + norm_eps=1e-6, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=23040, + ), + # https://huggingface.co/tiiuae/Falcon3-10B-Base/blob/main/config.json + dict( + name="Falcon3-10B{}", + hf_config=dict(org="tiiuae", name="Falcon3-10B{}"), + block_size=32768, + vocab_size=131072, + padded_vocab_size=131072, + n_layer=40, + n_head=12, + n_query_groups=4, + n_embd=3072, + rotary_percentage=1.0, + parallel_residual=False, + rope_base=1000042, + norm_eps=1e-6, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=23040, + ), +] +for c in falcon3: + for kind in ("-Base", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + ############################# # OpenLM Research Open LLaMA @@ -699,8 +798,31 @@ def norm_class(self) -> Type: rope_base=500000, rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) ), + # https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct/blob/main/config.json + dict( + name="Llama-3.3-70B-Instruct", + hf_config=dict(org="meta-llama", name="Llama-3.3-70B-Instruct"), + block_size=131072, + vocab_size=128000, + padded_vocab_size=128256, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=28672, + rope_base=500000, + rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) + ), ] for c in llama_3: + if c["name"] == "Llama-3.3-70B-Instruct": + configs.append(c) + continue for kind in ("", "-Instruct"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) @@ -1519,6 +1641,27 @@ def norm_class(self) -> Type: n_expert=8, n_expert_per_token=2, ), + # https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1/blob/main/config.json + dict( + name="Mixtral-8x22B-{}v0.1", + hf_config=dict(org="mistralai", name="Mixtral-8x22B-{}v0.1"), + padded_vocab_size=32768, + block_size=65536, + n_layer=56, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + norm_eps=1e-05, + mlp_class_name="LLaMAMoE", + intermediate_size=16384, + n_head=48, + n_embd=6144, + rope_base=1000000, + n_expert=8, + n_expert_per_token=2, + ), ] for c in mistral: for kind in ("", "Instruct-"): @@ -1618,6 +1761,26 @@ def norm_class(self) -> Type: intermediate_size=28672, ) ) +configs.append( + # https://huggingface.co/mistralai/Mistral-Large-Instruct-2411/blob/main/config.json + dict( + name="Mistral-Large-Instruct-2411", + hf_config=dict(org="mistralai", name="Mistral-Large-Instruct-2411"), + padded_vocab_size=32768, + block_size=32768, + n_layer=88, + n_head=96, + n_embd=12288, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + norm_eps=1e-05, + mlp_class_name="LLaMAMLP", + intermediate_size=28672, + ) +) ############ @@ -1704,4 +1867,518 @@ def norm_class(self) -> Type: configs.extend(llama_2_function_calling) +########## +# Qwen2.5 +########## +qwen_2_5 = [ + # https://huggingface.co/Qwen/Qwen2.5-0.5B/blob/main/config.json + dict( + name="Qwen2.5-0.5B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-0.5B{}"), + block_size=32768, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=24, + n_head=14, + n_embd=896, + n_query_groups=2, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=4864, + norm_eps=1e-6, + rope_base=1000000 + ), + # https://huggingface.co/Qwen/Qwen2.5-1.5B/blob/main/config.json + dict( + name="Qwen2.5-1.5B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-1.5B{}"), + block_size=131072, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=28, + n_head=12, + n_embd=1536, + n_query_groups=2, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=8960, + norm_eps=1e-6, + rope_base=1000000 + ), + # https://huggingface.co/Qwen/Qwen2.5-3B/blob/main/config.json + dict( + name="Qwen2.5-3B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-3B{}"), + block_size=32768, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=36, + n_head=16, + n_embd=2048, + n_query_groups=2, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=11008, + norm_eps=1e-6, + rope_base=1000000 + ), + # https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/config.json + dict( + name="Qwen2.5-7B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-7B{}"), + block_size=131072, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=28, + n_head=28, + n_embd=3584, + n_query_groups=4, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=18944, + norm_eps=1e-6, + rope_base=1000000 + ), + # https://huggingface.co/Qwen/Qwen2.5-14B/blob/main/config.json + dict( + name="Qwen2.5-14B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-14B{}"), + block_size=131072, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=48, + n_head=40, + n_embd=5120, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=13824, + norm_eps=1e-5, + rope_base=1000000 + ), + # https://huggingface.co/Qwen/Qwen2.5-32B/blob/main/config.json + dict( + name="Qwen2.5-32B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-32B{}"), + block_size=131072, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=64, + n_head=40, + n_embd=5120, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=27648, + norm_eps=1e-5, + rope_base=1000000 + ), + # https://huggingface.co/Qwen/Qwen2.5-72B/blob/main/config.json + dict( + name="Qwen2.5-72B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-72B{}"), + block_size=131072, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=29568, + norm_eps=1e-5, + rope_base=1000000 + ), +] + +qwen_2_5_coder = [ + # https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B/blob/main/config.json + dict( + name="Qwen2.5-Coder-0.5B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Coder-0.5B{}"), + block_size=32768, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=24, + n_head=14, + n_embd=896, + n_query_groups=2, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=4864, + norm_eps=1e-6, + rope_base=1000000 + ), + # https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B/blob/main/config.json + dict( + name="Qwen2.5-Coder-1.5B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Coder-1.5B{}"), + block_size=32768, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=28, + n_head=12, + n_embd=1536, + n_query_groups=2, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=8960, + norm_eps=1e-6, + rope_base=1000000 + ), + # https://huggingface.co/Qwen/Qwen2.5-Coder-3B/blob/main/config.json + dict( + name="Qwen2.5-Coder-3B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Coder-3B{}"), + block_size=32768, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=36, + n_head=16, + n_embd=2048, + n_query_groups=2, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=11008, + norm_eps=1e-6, + rope_base=1000000 + ), + # https://huggingface.co/Qwen/Qwen2.5-Coder-7B/blob/main/config.json + dict( + name="Qwen2.5-Coder-7B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Coder-7B{}"), + block_size=32768, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=28, + n_head=28, + n_embd=3584, + n_query_groups=4, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=18944, + norm_eps=1e-6, + rope_base=1000000 + ), + # https://huggingface.co/Qwen/Qwen2.5-Coder-14B/blob/main/config.json + dict( + name="Qwen2.5-Coder-14B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Coder-14B{}"), + block_size=32768, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=48, + n_head=40, + n_embd=5120, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=13824, + norm_eps=1e-5, + rope_base=1000000 + ), + # https://huggingface.co/Qwen/Qwen2.5-Coder-32B/blob/main/config.json + dict( + name="Qwen2.5-Coder-32B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Coder-32B{}"), + block_size=32768, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=64, + n_head=40, + n_embd=5120, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=27648, + norm_eps=1e-5, + rope_base=1000000 + ), +] + +qwen_2_5.extend(qwen_2_5_coder) + +qwen_2_5_math = [ + # https://huggingface.co/Qwen/Qwen2.5-Math-1.5B/blob/main/config.json + dict( + name="Qwen2.5-Math-1.5B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Math-1.5B{}"), + block_size=4096, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=28, + n_head=12, + n_embd=1536, + n_query_groups=2, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=8960, + norm_eps=1e-6, + rope_base=10000 + ), + # https://huggingface.co/Qwen/Qwen2.5-Math-7B/blob/main/config.json + dict( + name="Qwen2.5-Math-7B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Math-7B{}"), + block_size=4096, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=28, + n_head=28, + n_embd=3584, + n_query_groups=4, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=18944, + norm_eps=1e-6, + rope_base=10000 + ), + # https://huggingface.co/Qwen/Qwen2.5-Math-72B/blob/main/config.json + dict( + name="Qwen2.5-Math-72B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Math-72B{}"), + block_size=4096, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=29568, + norm_eps=1e-5, + rope_base=10000 + ), +] + +qwen_2_5.extend(qwen_2_5_math) + +for c in qwen_2_5: + for kind in ("", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + +qwq = [ + # https://huggingface.co/Qwen/QwQ-32B-Preview/blob/main/config.json + dict( + name="QwQ-32B-Preview", + hf_config=dict(org="Qwen", name="QwQ-32B-Preview"), + block_size=131072, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=64, + n_head=40, + n_embd=5120, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=27648, + norm_eps=1e-5, + rope_base=1000000 + ), +] + +configs.extend(qwq) + + +############# +# Salamandra +############# +salamandra = [ + # https://huggingface.co/BSC-LT/salamandra-2b-instruct/blob/main/config.json + dict( + name="salamandra-2b{}", + hf_config=dict(org="BSC-LT", name="salamandra-2b{}"), + block_size=8192, + vocab_size=256000, + padded_vocab_size=256000, + n_layer=24, + n_head=16, + n_embd=2048, + n_query_groups=16, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=5440, + norm_eps=1e-5, + rope_base=10000 + ), + # https://huggingface.co/BSC-LT/salamandra-7b-instruct/blob/main/config.json + dict( + name="salamandra-7b{}", + hf_config=dict(org="BSC-LT", name="salamandra-7b{}"), + block_size=8192, + vocab_size=256000, + padded_vocab_size=256000, + n_layer=32, + n_head=32, + n_embd=4096, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=11008, + norm_eps=1e-6, + rope_base=10000 + ), +] + +for c in salamandra: + for kind in ("", "-instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + + +############### +# SmolLM2 +############### +smollm2 = [ + # https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config.json + dict( + name="SmolLM2-135M{}", + hf_config=dict(org="HuggingFaceTB", name="SmolLM2-135M{}"), + block_size=8192, + vocab_size=49152, + padded_vocab_size=49152, + n_layer=30, + n_head=9, + n_embd=576, + n_query_groups=3, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=1536, + rope_base=100000, + norm_eps=1e-5, + ), + # https://huggingface.co/HuggingFaceTB/SmolLM2-360M/blob/main/config.json + dict( + name="SmolLM2-360M{}", + hf_config=dict(org="HuggingFaceTB", name="SmolLM2-360M{}"), + block_size=8192, + vocab_size=49152, + padded_vocab_size=49152, + n_layer=32, + n_head=15, + n_embd=960, + n_query_groups=5, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=2560, + rope_base=100000, + norm_eps=1e-5, + ), + # https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B/blob/main/config.json + dict( + name="SmolLM2-1.7B{}", + hf_config=dict(org="HuggingFaceTB", name="SmolLM2-1.7B{}"), + block_size=8192, + vocab_size=49152, + padded_vocab_size=49152, + n_layer=24, + n_head=32, + n_embd=2048, + n_query_groups=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=8192, + rope_base=130000, + norm_eps=1e-5, + ), +] + +for c in smollm2: + for kind in ("", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + + name_to_config = {config["name"]: config for config in configs} diff --git a/litgpt/data/__init__.py b/litgpt/data/__init__.py index 0f0f3d6cf9..b6d7275e5e 100644 --- a/litgpt/data/__init__.py +++ b/litgpt/data/__init__.py @@ -33,6 +33,6 @@ "TextFiles", "TinyLlama", "TinyStories", - "MicroLlama" + "MicroLlama", "get_sft_collate_fn", ] diff --git a/litgpt/data/prepare_slimpajama.py b/litgpt/data/prepare_slimpajama.py new file mode 100644 index 0000000000..70d4d9e2aa --- /dev/null +++ b/litgpt/data/prepare_slimpajama.py @@ -0,0 +1,63 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +import os +import time +from pathlib import Path + +from litgpt.tokenizer import Tokenizer +from litgpt.data.prepare_starcoder import DataChunkRecipe +from litgpt.utils import CLI, extend_checkpoint_dir + + +class SlimPajamaDataRecipe(DataChunkRecipe): + is_generator = True + + def __init__(self, tokenizer: Tokenizer, chunk_size: int): + super().__init__(chunk_size) + self.tokenizer = tokenizer + + def prepare_structure(self, input_dir): + files = Path(input_dir).rglob("*.zst") + return [str(file) for file in files] + + def prepare_item(self, filepath): + import zstandard as zstd + + with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: + for row in f: + text = json.loads(row)["text"] + if json.loads(row)["meta"]["redpajama_set_name"] == "RedPajamaGithub": + continue # exclude the GitHub data since it overlaps with starcoder + text_ids = self.tokenizer.encode(string=text, bos=False, eos=True) + yield text_ids + + +def prepare( + input_dir: Path = Path("data/SlimPajama-627B/train"), + output_dir: Path = Path("data/slimpajama/train"), + tokenizer_path: Path = Path("checkpoints/Llama-2-7b-hf/"), + chunk_size: int = (2049 * 16384), + fast_dev_run: bool = False, +) -> None: + from litdata.processing.data_processor import DataProcessor + + tokenizer_path = extend_checkpoint_dir(tokenizer_path) + tokenizer = Tokenizer(tokenizer_path) + data_recipe = SlimPajamaDataRecipe(tokenizer=tokenizer, chunk_size=chunk_size) + data_processor = DataProcessor( + input_dir=str(input_dir), + output_dir=str(output_dir), + fast_dev_run=fast_dev_run, + num_workers=os.cpu_count(), + num_downloaders=1, + ) + + start_time = time.time() + data_processor.run(data_recipe) + elapsed_time = time.time() - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + CLI(prepare) diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index d349502489..866947beea 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -230,7 +230,7 @@ def batched_generate_fn( Args: model: The model to use. prompts: A 2D tensor of shape [batch_size, prompt_length]. - max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens. + max_returned_tokens: The maximum number of tokens to return, including the prompt tokens. sample_args: The dictionary of kwargs to pass to sample() for each each token for each index in the batch. stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens. include_prompt: Whether to output the prompt tokens. diff --git a/litgpt/lora.py b/litgpt/lora.py index e519d5445d..beca761c48 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -589,7 +589,7 @@ def __init__(self, config: Config, block_idx: int) -> None: lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, enable_lora=(config.lora_query, config.lora_key, config.lora_value), - bias=config.bias, + bias=config.bias or config.attn_bias, # for MQA/GQA support head_size=config.head_size, n_head=config.n_head, @@ -608,7 +608,8 @@ def __init__(self, config: Config, block_idx: int) -> None: # disabled by default self.kv_cache: Optional[KVCache] = None self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and block_idx % config.sliding_window_layer_placing == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config diff --git a/litgpt/model.py b/litgpt/model.py index f3c426192d..5fbd0a8c24 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -75,11 +75,30 @@ def _init_weights(self, module: nn.Module) -> None: torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Args: + idx (torch.Tensor): Input token indices, shape `(B, T)` + input_pos (torch.Tensor, optional): Contains input positions, + either with shape `(T,)` or `(B, T)`, if provided. This is used + for generative inference, where a KV cache is required. By + default, this assumes `input_dim == arange(T)` with all inputs + up to `T` provided upfront. + + Returns: + torch.Tensor: Output (logits), shape `(B, T, config.padded_vocab_size)` + """ + if idx.dim() != 2: + raise ValueError(f"idx must have 2 dimensions, idx.shape = {idx.shape}") T = idx.size(1) if self.max_seq_length < T: raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") if input_pos is not None: # use the kv cache + if input_pos.dim() > 2: + # otherwise, things go wrong in `apply_rope` + raise ValueError(f"input_pos must have 1 or 2 dimensions, input_pos.shape = {input_pos.shape}") + if input_pos.shape[-1] != T: + raise ValueError(f"input_pos.shape[-1] = {input_pos.shape[-1]} != {T} = idx.shape[1], must be the same") cos = batched_index_select(self.cos, 0, input_pos) sin = batched_index_select(self.sin, 0, input_pos) if self.mask_cache is None: @@ -90,20 +109,22 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) - # we get if input_pos has a batch dimension mask = mask.squeeze(1) else: - cos = self.cos[:T] - sin = self.sin[:T] - mask = None + # unsqueeze to have a batch dimension + cos = self.cos[:T].unsqueeze(0) + sin = self.sin[:T].unsqueeze(0) + # `cos`, `sin` have shape (1, T, config.rope_n_elem) + mask = None # defaults to causal mask - x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd) if self.config.scale_embeddings: x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype) for block in self.transformer.h: x = block(x, cos, sin, mask, input_pos) x = self.transformer.ln_f(x) - x = self.lm_head(x) # (b, t, vocab_size) + x = self.lm_head(x) # (B, T, padded_vocab_size) if self.config.final_logit_softcapping is not None: - x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping + x = do_softcapping(x, self.config.final_logit_softcapping) return x @classmethod @@ -125,10 +146,8 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso elif num_params_present == 4: # These parameters should always be used together so that we don't interfere with standard rope extra_config = { - "original_max_seq_len": self.config.rope_adjustments["original_max_seq_len"], - "factor": self.config.rope_adjustments["factor"], - "low_freq_factor": self.config.rope_adjustments["low_freq_factor"], - "high_freq_factor": self.config.rope_adjustments["high_freq_factor"], + name: self.config.rope_adjustments[name] + for name in adjusted_params_required } else: # Some but not all parameters are specified; raise an error @@ -240,12 +259,13 @@ def forward( attention_output = self.post_attention_norm(attention_output) if self.config.parallel_residual: - x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x) - x = self.mlp(x_normed) + attention_output + x + if not self.config.shared_attention_norm: + x_normed = self.norm_2(x) + x = attention_output + x else: x = attention_output + x - x = self.post_mlp_norm(self.mlp(self.norm_2(x))) + x - return x + x_normed = self.norm_2(x) + return self.post_mlp_norm(self.mlp(x_normed)) + x class CausalSelfAttention(nn.Module): @@ -255,7 +275,7 @@ def __init__(self, config: Config, block_idx: int) -> None: self.qkv = nn.Linear( config.n_embd, (config.n_head + 2 * config.n_query_groups) * config.head_size, # support for grouped/multi queries - bias=config.bias, + bias=config.bias or config.attn_bias, ) # output projection # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` @@ -263,7 +283,8 @@ def __init__(self, config: Config, block_idx: int) -> None: # disabled by default self.kv_cache: Optional[KVCache] = None self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and block_idx % config.sliding_window_layer_placing == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config @@ -326,9 +347,8 @@ def forward( # NOTE: flash attention requires it in training mode. # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): - q_per_kv = self.config.n_head // self.config.n_query_groups - k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) - v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) + k = k.expand(*q.shape) # (B, nh_q, T, hs) + v = v.expand(*q.shape) # (B, nh_q, T, hs) if self.apply_sliding_window_attention: """ @@ -366,11 +386,8 @@ def scaled_dot_product_attention( # with softcapping we cannot use SDPA if self.config.attention_logit_softcapping is not None: - scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size) scores = q @ k.mT * scale - scores = ( - torch.tanh(scores / self.config.attention_logit_softcapping) * self.config.attention_logit_softcapping - ) + scores = do_softcapping(scores, self.config.attention_logit_softcapping) if mask is None: mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1) mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min) @@ -541,10 +558,11 @@ def batched_index_select(t, dim, idx): res = torch.index_select(t, dim, idx.reshape(-1)) # flat index # split out single batch idx res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :]) - # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors - dims = [dim] + list(range(res.dim())) - del dims[dim + 1] - res = res.permute(dims) + if dim > 0: + # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors + dims = [dim] + list(range(res.dim())) + del dims[dim + 1] + res = res.permute(dims) # unflatten batch dims res = res.view(*batch_shape, *res.shape[1:]) return res @@ -601,6 +619,8 @@ def batched_index_copy_(t, dim, idx, val): def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + # x: (B, nh, T, hs) + # sin, cos: (B, T, hs) or (1, T, hs) head_size = x.size(-1) x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) @@ -616,6 +636,10 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T return roped.to(dtype=x.dtype) +def do_softcapping(x: torch.Tensor, thresh: float) -> torch.Tensor: + return torch.tanh(x / thresh) * thresh + + class KVCache(nn.Module): def __init__( self, diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 09fb86676c..48850efd51 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -112,6 +112,17 @@ def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: ) +class Falcon3(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f"<|user|>\n{prompt}<|endoftext|>\n<|assistant|>\n" + + def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: + return ( + [tokenizer.eos_id], + [tokenizer.token_to_id("<|endoftext|>")], + ) + + class Llama2FunctionCalling(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: # Has to be before the llama config @@ -274,11 +285,37 @@ def apply(self, prompt: str, **kwargs: str) -> str: return f"user\n{prompt}\nmodel\n" - - class OLMo(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: return f"<|endoftext|><|user|>\n{prompt}\n<|assistant|>\n" + + +class ChatML(PromptStyle): + def __init__(self, system_message: str): + self.system_message = system_message + + def apply(self, prompt: str, **kwargs: str) -> str: + return f"<|im_start|>system\n{self.system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + +class Qwen2_5(ChatML): + def __init__(self): + super().__init__("You are Qwen, created by Alibaba Cloud. You are a helpful assistant.") + +class Qwen2_5_Math(ChatML): + def __init__(self): + super().__init__("Please reason step by step, and put your final answer within \\boxed{}.") + +class QwQ(ChatML): + def __init__(self): + super().__init__("You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.") + +class SmolLM2(ChatML): + def __init__(self): + super().__init__("You are a helpful AI assistant named SmolLM, trained by Hugging Face") + +class Salamandra(ChatML): + def __init__(self): + super().__init__("I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit.") # Maps prompt style names to PromptStyle classes @@ -304,6 +341,11 @@ def apply(self, prompt: str, **kwargs: str) -> str: "gemma": Gemma, "llama3": Llama3, "olmo": OLMo, + "qwen2.5": Qwen2_5, + "qwen2.5-math": Qwen2_5_Math, + "qwq": QwQ, + "smollm2": SmolLM2, + "salamandra": Salamandra, } @@ -314,6 +356,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return StableLMZephyr() if re.search("stablecode-instruct", model_name): return StableCode() + if re.search(r"Falcon3.*-Instruct", model_name): + return Falcon3() if re.search(r"falcon.*-instruct", model_name): return Falcon() if re.search("Llama-2-7b-chat-hf-function-calling-v2", model_name): @@ -342,6 +386,16 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Gemma() if re.search(r"OLMo.*-hf", model_name): return OLMo() + if re.search(r"Qwen2\.5-Math-.*", model_name): + return Qwen2_5_Math() + if re.search(r"Qwen2\.5-.*", model_name): + return Qwen2_5() + if re.search(r"QwQ-.*", model_name): + return QwQ() + if re.search(r"SmolLM2.*-Instruct", model_name): + return SmolLM2() + if re.search(r"salamandra-.*-instruct", model_name): + return Salamandra() return Default() diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index b24f874680..2c0dbb6aad 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -376,6 +376,81 @@ def copy_weights_phi( if progress_per_file is not None: pbar.update(progress_per_file) +def copy_weights_qwen_2_5( + config: Config, + qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, + pbar: Optional[tqdm] = None, + progress_per_file: Optional[float] = None, + debug_mode: Optional[bool] = False +) -> None: + weight_map = { + "model.embed_tokens.weight": "transformer.wte.weight", + "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "model.layers.{}.self_attn.q_proj.weight": None, + "model.layers.{}.self_attn.k_proj.weight": None, + "model.layers.{}.self_attn.v_proj.weight": None, + "model.layers.{}.self_attn.q_proj.bias": None, + "model.layers.{}.self_attn.k_proj.bias": None, + "model.layers.{}.self_attn.v_proj.bias": None, + "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", + "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", + "model.norm.weight": "transformer.ln_f.weight", + "lm_head.weight": "lm_head.weight", + } + + if progress_per_file is not None: + progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) + + for name, param in hf_weights.items(): + if "model.layers" in name: + from_name, l = layer_template(name, 2) + qkv = qkv_weights.setdefault(l, defaultdict(dict)) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + to_name = weight_map[from_name] + if to_name is None: + continue + to_name = to_name.format(l) + else: + to_name = weight_map[name] + param = load_param(param, name, dtype, verbose=debug_mode) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + if progress_per_file is not None: + pbar.update(progress_per_file) + + if "lm_head.weight" not in state_dict: + state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] + + for i in list(qkv_weights): + for weight_type in list(qkv_weights[i]): + qkv = qkv_weights[i][weight_type] + if len(qkv) != 3: + # split across different .bin files + continue + q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) + k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) + v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) + q_per_kv = config.n_head // config.n_query_groups + qs = torch.split(q, config.head_size * q_per_kv) + ks = torch.split(k, config.head_size) + vs = torch.split(v, config.head_size) + cycled = [t for group in zip(qs, ks, vs) for t in group] + qkv = torch.cat(cycled) + state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + del qkv_weights[i][weight_type] + if progress_per_file is not None: + pbar.update(progress_per_file) def qkv_reassemble( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config @@ -463,6 +538,10 @@ def convert_hf_checkpoint( # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_phi, config, qkv_weights) + elif model_name.lower().startswith(("qwen2.5","qwq")): + # holder to reconstitute the split q, k, v + qkv_weights = {} + copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): # holder to reconstitute the split q, k, v qkv_weights = {} diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index bab1ab57d2..5bb08ea4f6 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -291,6 +291,52 @@ def copy_weights_phi( state_dict[layer_name] = weight del gate_up_proj_weights[layer_idx] +def copy_weights_qwen_2_5( + config: Config, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + untie_weights: bool = False, + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "model.embed_tokens.weight", + "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", + "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", + "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", + "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", + "transformer.ln_f.weight": "model.norm.weight", + "lm_head.weight": "lm_head.weight", + } + + for name, param in lit_weights.items(): + if name == "lm_head.weight" and untie_weights: + continue + if name.endswith((".attn.attn.weight", ".attn.attn.bias")): + from_name, l_idx = layer_template(name, 2) + qkv = load_param(param, name, None) + qp, kp, vp = qkv_split(qkv, config) + + weight_type = name.split(".")[-1] # weight or bias + q = f"model.layers.{l_idx}.self_attn.q_proj.{weight_type}" + k = f"model.layers.{l_idx}.self_attn.k_proj.{weight_type}" + v = f"model.layers.{l_idx}.self_attn.v_proj.{weight_type}" + for to_name, param in zip((q, k, v), (qp, kp, vp)): + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + else: + if "transformer.h" in name: + from_name, l_idx = layer_template(name, 2) + to_name = weight_map[from_name] + to_name = to_name.format(l_idx) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: """Reassemble from a normal to an interleaved placement in a QKV matrix. @@ -334,6 +380,8 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: copy_fn = partial(copy_weights_gemma_2, config) elif config.name.lower().startswith("phi"): copy_fn = partial(copy_weights_phi, config) + elif config.name.lower().startswith(("qwen2.5","qwq")): + copy_fn = partial(copy_weights_qwen_2_5, config) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): untie_weights = "Gemma" in config.name copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights) diff --git a/litgpt/scripts/download.py b/litgpt/scripts/download.py index c1af2af133..fc6c153fad 100644 --- a/litgpt/scripts/download.py +++ b/litgpt/scripts/download.py @@ -131,7 +131,7 @@ def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[s with gated_repo_catcher(repo_id, access_token): info = repo_info(repo_id, token=access_token) filenames = [f.rfilename for f in info.siblings] - bins = list(filter_repo_objects(items=filenames, allow_patterns=["*.bin*"])) + bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"])) safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"])) return bins, safetensors diff --git a/litgpt/tokenizer.py b/litgpt/tokenizer.py index a81c59aa2d..10f7d031f6 100644 --- a/litgpt/tokenizer.py +++ b/litgpt/tokenizer.py @@ -87,7 +87,7 @@ def token_to_id(self, token: str) -> int: raise ValueError(f"token {token!r} not found in the collection.") return id_ - def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: + def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file(): return False with open(tokenizer_config_path, encoding="utf-8") as fp: @@ -96,6 +96,8 @@ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: # `PreTrainedTokenizerFast` if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")): return True + if checkpoint_dir.stem.startswith("SmolLM2") and checkpoint_dir.name.endswith("Instruct"): + return True if "add_bos_token" in config: return config["add_bos_token"] # if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True. @@ -143,6 +145,9 @@ def decode(self, tensor: torch.Tensor) -> str: if len(tokens) == 1 and self.apply_decoding_fix: dummy_token_id = 33 # \x1e dummy_token = self.processor.decode([dummy_token_id]) + if dummy_token != "\x1e": + dummy_token_id = 165 # \x1e is different in salamandra tokenizers + dummy_token = self.processor.decode([dummy_token_id]) return self.processor.decode([dummy_token_id] + tokens)[len(dummy_token) :] return self.processor.decode(tokens) diff --git a/pyproject.toml b/pyproject.toml index 406f5f2e1a..1dd1e53743 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "litgpt" -version = "0.5.4.dev1" +version = "0.5.5.dev1" description = "Hackable implementation of state-of-the-art open-source LLMs" authors = [ { name = "Lightning AI", email = "contact@lightning.ai" }, @@ -42,7 +42,7 @@ all = [ "sentencepiece>=0.2.0", # llama-based models "requests>=2.31.0", # litgpt.data "litdata==0.2.17", # litgpt.data - "litserve>=0.1.5", # litgpt.deploy + "litserve<=0.2.4", # litgpt.deploy "zstandard>=0.22.0", # litgpt.data.prepare_slimpajama.py "pandas>=1.9.0", # litgpt.data.prepare_starcoder.py "pyarrow>=15.0.2", # litgpt.data.prepare_starcoder.py diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index 2168454f11..9e0cd93c35 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -19,6 +19,7 @@ from transformers.models.phi.modeling_phi import PhiForCausalLM from transformers.models.phi3.configuration_phi3 import Phi3Config from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM +from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from litgpt import GPT, Config from litgpt.scripts.convert_lit_checkpoint import ( @@ -29,6 +30,7 @@ copy_weights_gpt_neox, copy_weights_llama, copy_weights_phi, + copy_weights_qwen_2_5, qkv_reassemble, ) from tests.conftest import RunIf @@ -159,9 +161,10 @@ def test_against_hf_llama2(ours_kwargs): @torch.inference_mode() -def test_against_mixtral(): +@pytest.mark.parametrize("model_name", ("Mixtral-8x7B-Instruct-v0.1", "Mixtral-8x22B-Instruct-v0.1")) +def test_against_mixtral(model_name): ours_config = Config.from_name( - "Mixtral-8x7B-Instruct-v0.1", + model_name, padded_vocab_size=10000, n_layer=2, n_embd=32, @@ -524,6 +527,69 @@ def test_check_conversion_supported_lora(): with pytest.raises(ValueError, match=r"LoRA.*cannot be converted"): check_conversion_supported(lit_weights=lit_weights) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "Qwen2.5-Math-1.5B", "QwQ-32B-Preview")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_qwen_2_5(model_name, device, dtype): + torch.set_default_dtype(dtype) + + T = 20 + ours_config = Config.from_name( + model_name, + block_size=T, + n_layer=2, + n_head=16, + n_embd=32, + intermediate_size=86, + ) + theirs_config = Qwen2Config( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + head_dim=ours_config.head_size, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=ours_config.block_size, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.attn_bias, + tie_word_embeddings=True, + ) + + assert ours_config.intermediate_size == theirs_config.intermediate_size + + ours_model = GPT(ours_config).to(device) + # tie weights + ours_model.lm_head.weight = ours_model.transformer.wte.weight + ours_state_dict = ours_model.state_dict() + theirs_state_dict = {} + copy_weights_qwen_2_5(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True) + theirs_model = Qwen2ForCausalLM(theirs_config).to(device) + keys = theirs_model.load_state_dict(theirs_state_dict, strict=False) + assert not keys.unexpected_keys + + # test end to end + x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) def test_qkv_reassemble(): # MHA diff --git a/tests/test_generate.py b/tests/test_generate.py index 6fc561b945..592f2c3acf 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -93,7 +93,13 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): pattern = rf".*^{re.escape(expected_output.strip())}$.*" assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE) - assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue() + err_value = err.getvalue() + expected_parts = [ + "'padded_vocab_size': 512", + "'n_layer': 2", + "'n_head': 4", + ] + assert all(part in err_value for part in expected_parts) def test_cli(): diff --git a/tests/test_generate_adapter.py b/tests/test_generate_adapter.py index 6e57ff0c5e..a40672d03e 100644 --- a/tests/test_generate_adapter.py +++ b/tests/test_generate_adapter.py @@ -55,7 +55,15 @@ def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like): pattern = rf".*^{re.escape(expected_output.strip())}$.*" assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE) - assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4, 'head_size': 2, 'n_embd': 8" in err.getvalue() + err_value = err.getvalue() + expected_parts = [ + "'padded_vocab_size': 512", + "'n_layer': 2", + "'n_head': 4", + "'head_size': 2", + "'n_embd': 8", + ] + assert all(part in err_value for part in expected_parts) @pytest.mark.parametrize("version", ("", "_v2")) diff --git a/tests/test_lora.py b/tests/test_lora.py index 0db9ea5285..c417d588a4 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -684,7 +684,7 @@ def test_against_original_gemma_2(model_name): assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float - torch.testing.assert_close(ours_y, theirs_y) + torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5) @RunIf(min_cuda_gpus=1) diff --git a/tests/test_model.py b/tests/test_model.py index 9d696c9397..abd1a767bf 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -28,6 +28,7 @@ from transformers.models.mistral import MistralConfig, MistralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.olmo import OlmoConfig, OlmoForCausalLM +from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM import litgpt.config as config_module from litgpt import GPT, Config @@ -38,6 +39,7 @@ copy_weights_gpt_neox, copy_weights_hf_llama, copy_weights_phi, + copy_weights_qwen_2_5, ) from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -222,6 +224,7 @@ def test_against_original_open_llama_3b(device, dtype): {"name": "Llama-3.1-8B-Instruct"}, {"name": "Llama-3.2-1B"}, {"name": "Llama-3.2-3B"}, + {"name": "Llama-3.3-70B-Instruct"}, ], ) @pytest.mark.parametrize( @@ -511,11 +514,12 @@ def test_against_mathstral_hf_models(device, dtype): @torch.inference_mode() -def test_against_hf_mixtral(): +@pytest.mark.parametrize("model_name", ("Mixtral-8x7B-Instruct-v0.1", "Mixtral-8x22B-Instruct-v0.1")) +def test_against_hf_mixtral(model_name): device = torch.device("cpu") dtype = torch.float32 ours_config = Config.from_name( - "Mixtral-8x7B-Instruct-v0.1", + model_name, padded_vocab_size=10000, n_layer=2, n_embd=32, @@ -790,6 +794,247 @@ def test_against_original_gemma_2(model_name, device, dtype): torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "Qwen2.5-Math-1.5B", "QwQ-32B-Preview")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_qwen_2_5(model_name, device, dtype): + torch.set_default_dtype(dtype) + + T = 20 + ours_config = Config.from_name( + model_name, + block_size=T, + n_layer=2, + n_head=16, + n_embd=32, + intermediate_size=86, + ) + theirs_config = Qwen2Config( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + head_dim=ours_config.head_size, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=ours_config.block_size, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.attn_bias, + tie_word_embeddings=True, + ) + + theirs_model = Qwen2ForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + # Gemma weights are shipped without `lm_head.weight` + theirs_state_dict.pop("lm_head.weight") + state_dict = {} + copy_weights_qwen_2_5(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_salamandra(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + n_query_groups=2, + intermediate_size=86, + ) + T = 5 + theirs_config = LlamaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = LlamaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-360M", "SmolLM2-1.7B")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_smollm2(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + n_query_groups=2, + intermediate_size=86, + ) + T = 5 + theirs_config = LlamaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = LlamaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("Falcon3-1B-Base", "Falcon3-7B-Base")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_hf_falcon3(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + n_query_groups=2, + intermediate_size=86, + ) + T = 5 + theirs_config = LlamaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = LlamaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @RunIf(dynamo=True) @torch.inference_mode() def test_model_compile(): diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index a823eb71cd..d5c7d12699 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -56,10 +56,15 @@ def test_tokenizer_against_hf(config): else: assert ours.vocab_size == config.vocab_size - if config.name.startswith("falcon") or config.name.startswith("stablecode"): + if config.name.startswith(("falcon", "stablecode", "Qwen2.5", "QwQ")): # even though their config defines it, it's set as None in HF assert isinstance(ours.bos_id, int) assert theirs.bos_token_id is None + elif config.name.startswith("Falcon3"): + if isinstance(ours.bos_id, int): + assert theirs.bos_token_id is None + else: + assert ours.bos_id == theirs.bos_token_id == None else: assert ours.bos_id == theirs.bos_token_id diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 9ab0041357..a170506c3d 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -12,6 +12,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) | | Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | | Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | +| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) | | FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | | Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | | Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | @@ -20,12 +21,14 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.1 | 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md) | +| Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) | | Llama 3.1 Nemotron | 70B | NVIDIA | [NVIDIA AI 2024](https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct/modelcard) | | LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | | Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) | | MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) | Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) | | Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) | +| Mixtral MoE | 8x22B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mixtral-8x22b/) | | Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) | | OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) | | OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | @@ -33,8 +36,14 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Phi 3 & 3.5 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) | Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | | Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | +| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | +| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | +| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | +| QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | +| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | @@ -58,6 +67,10 @@ The output is shown below: allenai/OLMo-1B-hf allenai/OLMo-7B-hf allenai/OLMo-7B-Instruct-hf +bsc-lt/salamandra-2b +bsc-lt/salamandra-2b-instruct +bsc-lt/salamandra-7b +bsc-lt/salamandra-7b-instruct codellama/CodeLlama-13b-hf codellama/CodeLlama-13b-Instruct-hf codellama/CodeLlama-13b-Python-hf @@ -111,6 +124,12 @@ google/gemma-2b-it google/gemma-7b google/gemma-7b-it h2oai/h2o-danube2-1.8b-chat +HuggingFaceTB/SmolLM2-135M +HuggingFaceTB/SmolLM2-135M-Instruct +HuggingFaceTB/SmolLM2-360M +HuggingFaceTB/SmolLM2-360M-Instruct +HuggingFaceTB/SmolLM2-1.7B +HuggingFaceTB/SmolLM2-1.7B-Instruct lmsys/longchat-13b-16k lmsys/longchat-7b-16k lmsys/vicuna-13b-v1.3 @@ -130,6 +149,7 @@ meta-llama/Llama-3.2-1B meta-llama/Llama-3.2-1B-Instruct meta-llama/Llama-3.2-3B meta-llama/Llama-3.2-3B-Instruct +meta-llama/Llama-3.3-70B-Instruct meta-llama/Meta-Llama-3-70B meta-llama/Meta-Llama-3-70B-Instruct meta-llama/Meta-Llama-3-8B @@ -152,8 +172,11 @@ mistralai/Mistral-7B-Instruct-v0.3 mistralai/Mistral-7B-v0.1 mistralai/Mistral-7B-v0.3 mistralai/Mistral-Large-Instruct-2407 +mistralai/Mistral-Large-Instruct-2411 mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mixtral-8x7B-v0.1 +mistralai/Mixtral-8x22B-Instruct-v0.1 +mistralai/Mixtral-8x22B-v0.1 NousResearch/Nous-Hermes-13b NousResearch/Nous-Hermes-llama-2-7b NousResearch/Nous-Hermes-Llama2-13b @@ -161,6 +184,39 @@ nvidia/Llama-3.1-Nemotron-70B-Instruct-HF openlm-research/open_llama_13b openlm-research/open_llama_3b openlm-research/open_llama_7b +Qwen/Qwen2.5-0.5B +Qwen/Qwen2.5-0.5B-Instruct +Qwen/Qwen2.5-1.5B +Qwen/Qwen2.5-1.5B-Instruct +Qwen/Qwen2.5-3B +Qwen/Qwen2.5-3B-Instruct +Qwen/Qwen2.5-7B +Qwen/Qwen2.5-7B-Instruct +Qwen/Qwen2.5-14B +Qwen/Qwen2.5-14B-Instruct +Qwen/Qwen2.5-32B +Qwen/Qwen2.5-32B-Instruct +Qwen/Qwen2.5-72B +Qwen/Qwen2.5-72B-Instruct +Qwen/Qwen2.5-Coder-0.5B +Qwen/Qwen2.5-Coder-0.5B-Instruct +Qwen/Qwen2.5-Coder-1.5B +Qwen/Qwen2.5-Coder-1.5B-Instruct +Qwen/Qwen2.5-Coder-3B +Qwen/Qwen2.5-Coder-3B-Instruct +Qwen/Qwen2.5-Coder-7B +Qwen/Qwen2.5-Coder-7B-Instruct +Qwen/Qwen2.5-Coder-14B +Qwen/Qwen2.5-Coder-14B-Instruct +Qwen/Qwen2.5-Coder-32B +Qwen/Qwen2.5-Coder-32B-Instruct +Qwen/Qwen2.5-Math-1.5B +Qwen/Qwen2.5-Math-1.5B-Instruct +Qwen/Qwen2.5-Math-7B +Qwen/Qwen2.5-Math-7B-Instruct +Qwen/Qwen2.5-Math-72B +Qwen/Qwen2.5-Math-72B-Instruct +Qwen/QwQ-32B-Preview stabilityai/FreeWilly2 stabilityai/stable-code-3b stabilityai/stablecode-completion-alpha-3b @@ -178,6 +234,14 @@ tiiuae/falcon-40b tiiuae/falcon-40b-instruct tiiuae/falcon-7b tiiuae/falcon-7b-instruct +tiiuae/Falcon3-1B-Base +tiiuae/Falcon3-1B-Instruct +tiiuae/Falcon3-3B-Base +tiiuae/Falcon3-3B-Instruct +tiiuae/Falcon3-7B-Base +tiiuae/Falcon3-7B-Instruct +tiiuae/Falcon3-10B-Base +tiiuae/Falcon3-10B-Instruct TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T togethercomputer/LLaMA-2-7B-32K