Skip to content

Commit

Permalink
Fix edge cases in expert parallelism with dynamic MoE
Browse files Browse the repository at this point in the history
  • Loading branch information
skavulya committed Feb 14, 2025
1 parent f924d31 commit b3c604e
Showing 1 changed file with 9 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def __init__(self, config):
self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size)

self.expert_slice = math.ceil(self.experts_per_rank / SLICE_MAX_EXPERT)
self.expert_chunk = self.experts_per_rank // self.expert_slice
self.expert_chunk = math.ceil(self.experts_per_rank / self.expert_slice)

def forward(self, hidden_states):
identity = hidden_states
Expand Down Expand Up @@ -669,20 +669,12 @@ def forward(self, hidden_states):
(batch * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
for idx in range(self.expert_slice):
expert_offset = self.ep_rank * self.experts_per_rank
experts_range = range(self.expert_chunk)
gate_proj_list = [
self.experts[idx * self.expert_chunk + i + expert_offset].gate_proj.weight.squeeze()
for i in experts_range
]
down_proj_list = [
self.experts[idx * self.expert_chunk + i + expert_offset].down_proj.weight.squeeze()
for i in experts_range
]
up_proj_list = [
self.experts[idx * self.expert_chunk + i + expert_offset].up_proj.weight.squeeze()
for i in experts_range
]
experts_min = (self.ep_rank * self.experts_per_rank) + (self.expert_chunk * idx)
experts_max = min((experts_min + self.expert_chunk), (self.ep_rank + 1) * self.experts_per_rank)
experts_range = range(experts_min, experts_max)
gate_proj_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range]
down_proj_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range]
up_proj_list = [self.experts[i].up_proj.weight.squeeze() for i in experts_range]
hidden_states_slice = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=topk_idx,
Expand All @@ -692,8 +684,8 @@ def forward(self, hidden_states):
w3=down_proj_list,
permuted_weights=True,
activation="silu",
experts_min=(self.expert_chunk * idx + expert_offset),
experts_max=(self.expert_chunk * (idx + 1) - 1 + expert_offset),
experts_min=experts_min,
experts_max=experts_max - 1,
)
final_hidden_states = final_hidden_states + hidden_states_slice
htcore.mark_step()
Expand Down

0 comments on commit b3c604e

Please sign in to comment.