diff --git a/docs/dev/Port-Models.md b/docs/dev/Port-Models.md index f76d0a6d8..a0242fb47 100644 --- a/docs/dev/Port-Models.md +++ b/docs/dev/Port-Models.md @@ -217,10 +217,12 @@ For modules like Attention, Mlp, and Embeddings, you can read the weight from Le ```python # initialize the module in Levanter +import haliax + attention = LlamaAttention.init(config=config, key=random.PRNGKey(0)) # read the weights from Levanter -state = attention.to_state_dict() +state = haliax.state_dict.to_torch_compatible_state_dict(attention.state_dict()) state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} # load the weights into HuggingFace diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index a19ac00e3..a4b3a9516 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -101,7 +101,7 @@ def compute_loss(model: LmHeadModel, example: LmExample): if config.hf_checkpoint is not None: # load the huggingface model model_config = config.model - if not hasattr(model_config, "default_hf_checkpoint_converter"): + if not hasattr(model_config, "hf_checkpoint_converter"): raise ValueError("Model config does not have an HF checkpoint converter. Can't load HF checkpoint.") converter: HFCheckpointConverter = model_config.hf_checkpoint_converter() converter = converter.replaced(reference_checkpoint=config.hf_checkpoint, tokenizer=tokenizer) diff --git a/src/levanter/models/qwen.py b/src/levanter/models/qwen.py index 7f8afa951..f41911faf 100644 --- a/src/levanter/models/qwen.py +++ b/src/levanter/models/qwen.py @@ -1,3 +1,4 @@ +import dataclasses from dataclasses import dataclass from typing import Dict, Optional, Type @@ -269,6 +270,50 @@ class QwenLMHeadModel(LmHeadModel[QwenConfig], ModuleWithStateDictSerialization) embeddings: LlamaEmbedding # Can reuse Llama embeddings lm_head: Optional[hnn.Linear] + @property + def config(self) -> QwenConfig: + return self.transformer.config + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + def activations( + self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None + ) -> NamedArray: + """ + Compute the activations for the next token in a sequence. + Args: + input_ids: token IDs with shape {Pos} + attn_mask: attention mask with shape {Pos, KeyPos} + key: PRNGKey for random number generation + + Returns: + NamedArray: activations with shape {Pos, Embed} + + """ + x = self.embeddings.embed(input_ids) + x = self.transformer(x, attn_mask=attn_mask, key=key) + + return x + + def get_lm_head(self) -> hax.NamedArray: + if self.lm_head is None: + return self.embeddings.token_embeddings.weight + else: + return self.lm_head.weight + + def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]": + new_Vocab = self.Vocab.resize(new_size) + k1, k2 = maybe_rng_split(key, 2) + new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1) + if self.lm_head is not None: + new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) + new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) + return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) + else: + return dataclasses.replace(self, embeddings=new_embeddings) + @classmethod def init(cls, Vocab: Axis, config: QwenConfig, *, key) -> "QwenLMHeadModel": k_t, k_emb = jrandom.split(key, 2) @@ -280,3 +325,6 @@ def init(cls, Vocab: Axis, config: QwenConfig, *, key) -> "QwenLMHeadModel": lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) return QwenLMHeadModel(transformer, embeddings, lm_head) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"transformer": "model", "embeddings": None} diff --git a/tests/test_llama.py b/tests/test_llama.py index 9c08043d8..6f682ec5d 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -154,12 +154,13 @@ def test_llama_attention(use_flash, num_kv_heads, test_seq_len): hf_attention = HFLlamaAttention(config.to_hf_config(32000)) hf_attention.load_state_dict(state, strict=True) - x, mask = _get_random_inputs(config) - x_torch = torch.from_numpy(np.array(x.array)) - batch_size = x_torch.shape[0] test_Pos = config.Pos.resize(test_seq_len) test_KeyPos = config.KeyPos.resize(test_seq_len) + x, mask = _get_random_inputs(config, test_Pos) + x_torch = torch.from_numpy(np.array(x.array)) + batch_size = x_torch.shape[0] + explicit_mask = torch.from_numpy(np.array(mask.materialize(test_Pos, test_KeyPos).array)) mask_torch = explicit_mask.broadcast_to((batch_size, 1, -1, -1)) @@ -167,12 +168,9 @@ def test_llama_attention(use_flash, num_kv_heads, test_seq_len): mask_torch = (mask_torch == 0).float() * -1e9 out = attention(x, mask) - position_ids = torch.arange(config.Pos.size).reshape(1, -1) + position_ids = torch.arange(test_Pos.size).reshape(1, -1) hf_out = hf_attention(x_torch, position_ids=position_ids, attention_mask=mask_torch) - # assert np.isclose( - # hf_out[0].detach().cpu().numpy(), np.array(out.array), rtol=1e-4, atol=1e-4 - # ).all(), f"{hf_out[0]} != {out}" chex.assert_trees_all_close(hf_out[0].detach().cpu().numpy(), out.array, rtol=1e-4, atol=1e-4) @@ -343,9 +341,12 @@ def _get_llama_config(use_flash=False, num_kv_heads=4, seq_len=128) -> LlamaConf ) -def _get_random_inputs(config: LlamaConfig): +def _get_random_inputs(config: LlamaConfig, override_Pos=None): Embed = config.Embed - Pos = config.Pos + if override_Pos is not None: + Pos = override_Pos + else: + Pos = config.Pos Batch = hax.Axis("batch", 2) x = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Embed)) mask = AttentionMask.causal()