Skip to content

Commit

Permalink
fix embed_tokens for last layer in qwen models
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jan 28, 2025
1 parent af171f0 commit 9c1bea9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion exo/inference/mlx/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, args: ModelArgs):
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0

if self.args.shard.is_first_layer():
if self.args.shard.is_first_layer() or (self.args.shard.is_last_layer() and args.tie_word_embeddings):
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)

self.layers = []
Expand Down

0 comments on commit 9c1bea9

Please sign in to comment.