Skip to content

Commit

Permalink
Fix some issues due to not running torch tests in CI (#835)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Dec 8, 2024
1 parent 9f7421b commit 8b15485
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 11 deletions.
4 changes: 3 additions & 1 deletion docs/dev/Port-Models.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,12 @@ For modules like Attention, Mlp, and Embeddings, you can read the weight from Le

```python
# initialize the module in Levanter
import haliax

attention = LlamaAttention.init(config=config, key=random.PRNGKey(0))

# read the weights from Levanter
state = attention.to_state_dict()
state = haliax.state_dict.to_torch_compatible_state_dict(attention.state_dict())
state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()}

# load the weights into HuggingFace
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/main/eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def compute_loss(model: LmHeadModel, example: LmExample):
if config.hf_checkpoint is not None:
# load the huggingface model
model_config = config.model
if not hasattr(model_config, "default_hf_checkpoint_converter"):
if not hasattr(model_config, "hf_checkpoint_converter"):
raise ValueError("Model config does not have an HF checkpoint converter. Can't load HF checkpoint.")
converter: HFCheckpointConverter = model_config.hf_checkpoint_converter()
converter = converter.replaced(reference_checkpoint=config.hf_checkpoint, tokenizer=tokenizer)
Expand Down
48 changes: 48 additions & 0 deletions src/levanter/models/qwen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
from dataclasses import dataclass
from typing import Dict, Optional, Type

Expand Down Expand Up @@ -269,6 +270,50 @@ class QwenLMHeadModel(LmHeadModel[QwenConfig], ModuleWithStateDictSerialization)
embeddings: LlamaEmbedding # Can reuse Llama embeddings
lm_head: Optional[hnn.Linear]

@property
def config(self) -> QwenConfig:
return self.transformer.config

@property
def Vocab(self) -> Axis:
return self.embeddings.Vocab

def activations(
self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
) -> NamedArray:
"""
Compute the activations for the next token in a sequence.
Args:
input_ids: token IDs with shape {Pos}
attn_mask: attention mask with shape {Pos, KeyPos}
key: PRNGKey for random number generation
Returns:
NamedArray: activations with shape {Pos, Embed}
"""
x = self.embeddings.embed(input_ids)
x = self.transformer(x, attn_mask=attn_mask, key=key)

return x

def get_lm_head(self) -> hax.NamedArray:
if self.lm_head is None:
return self.embeddings.token_embeddings.weight
else:
return self.lm_head.weight

def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]":
new_Vocab = self.Vocab.resize(new_size)
k1, k2 = maybe_rng_split(key, 2)
new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1)
if self.lm_head is not None:
new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2)
new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix)
return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head)
else:
return dataclasses.replace(self, embeddings=new_embeddings)

@classmethod
def init(cls, Vocab: Axis, config: QwenConfig, *, key) -> "QwenLMHeadModel":
k_t, k_emb = jrandom.split(key, 2)
Expand All @@ -280,3 +325,6 @@ def init(cls, Vocab: Axis, config: QwenConfig, *, key) -> "QwenLMHeadModel":
lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True)

return QwenLMHeadModel(transformer, embeddings, lm_head)

def _state_dict_key_map(self) -> Dict[str, Optional[str]]:
return {"transformer": "model", "embeddings": None}
19 changes: 10 additions & 9 deletions tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,23 @@ def test_llama_attention(use_flash, num_kv_heads, test_seq_len):
hf_attention = HFLlamaAttention(config.to_hf_config(32000))
hf_attention.load_state_dict(state, strict=True)

x, mask = _get_random_inputs(config)
x_torch = torch.from_numpy(np.array(x.array))
batch_size = x_torch.shape[0]
test_Pos = config.Pos.resize(test_seq_len)
test_KeyPos = config.KeyPos.resize(test_seq_len)

x, mask = _get_random_inputs(config, test_Pos)
x_torch = torch.from_numpy(np.array(x.array))
batch_size = x_torch.shape[0]

explicit_mask = torch.from_numpy(np.array(mask.materialize(test_Pos, test_KeyPos).array))
mask_torch = explicit_mask.broadcast_to((batch_size, 1, -1, -1))

# the torch mask is really a bias, so we need to invert it and make it a big negative number
mask_torch = (mask_torch == 0).float() * -1e9

out = attention(x, mask)
position_ids = torch.arange(config.Pos.size).reshape(1, -1)
position_ids = torch.arange(test_Pos.size).reshape(1, -1)
hf_out = hf_attention(x_torch, position_ids=position_ids, attention_mask=mask_torch)

# assert np.isclose(
# hf_out[0].detach().cpu().numpy(), np.array(out.array), rtol=1e-4, atol=1e-4
# ).all(), f"{hf_out[0]} != {out}"
chex.assert_trees_all_close(hf_out[0].detach().cpu().numpy(), out.array, rtol=1e-4, atol=1e-4)


Expand Down Expand Up @@ -343,9 +341,12 @@ def _get_llama_config(use_flash=False, num_kv_heads=4, seq_len=128) -> LlamaConf
)


def _get_random_inputs(config: LlamaConfig):
def _get_random_inputs(config: LlamaConfig, override_Pos=None):
Embed = config.Embed
Pos = config.Pos
if override_Pos is not None:
Pos = override_Pos
else:
Pos = config.Pos
Batch = hax.Axis("batch", 2)
x = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Embed))
mask = AttentionMask.causal()
Expand Down

0 comments on commit 8b15485

Please sign in to comment.