Skip to content

Commit

Permalink
adds ignore_index to sp cross entropy
Browse files Browse the repository at this point in the history
  • Loading branch information
Ronald Rogers committed Dec 17, 2024
1 parent da771ed commit 8d865bd
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions deepspeed/sequence/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,21 @@


class _VocabSequenceParallelCrossEntropy(torch.autograd.Function):

@staticmethod
def forward(ctx, vocab_seq_parallel_logits, target, sp_group):
def forward(ctx, vocab_seq_parallel_logits, target, sp_group, ignore_index=-100):
# vocab_seq_parallel_logits: [S/P, B, V]
# target: [S/P, B]
# return: [S, B]

# Need softmax for backward
softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1)
ctx.vocab_size = vocab_seq_parallel_logits.size(2)
loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction='none')
loss = torch.nn.functional.nll_loss(
softmax.log().view(-1, ctx.vocab_size),
target.view(-1),
ignore_index=ignore_index,
reduction="none",
)

sp_world_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
Expand Down Expand Up @@ -56,5 +60,9 @@ def backward(ctx, grad_output):
return grad_input, None, None, None


def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, sp_group):
return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, sp_group)
def vocab_sequence_parallel_cross_entropy(
vocab_parallel_logits, target, sp_group, ignore_index=-100
):
return _VocabSequenceParallelCrossEntropy.apply(
vocab_parallel_logits, target, sp_group, ignore_index
)

0 comments on commit 8d865bd

Please sign in to comment.