Skip to content

Commit

Permalink
Refactor into multiple files
Browse files Browse the repository at this point in the history
  • Loading branch information
AAnoosheh committed Jan 21, 2025
1 parent 1cf6f96 commit 1538557
Show file tree
Hide file tree
Showing 5 changed files with 613 additions and 526 deletions.
18 changes: 18 additions & 0 deletions nemo/collections/llm/distillation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .loss import LogitsKLLoss
from .model import DistillationGPTModel

__all__ = ["LogitsKLLoss", "DistillationGPTModel"]
182 changes: 182 additions & 0 deletions nemo/collections/llm/distillation/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABCMeta
from typing import TYPE_CHECKING, Tuple

import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from torch import Tensor
from torch.nn.modules.loss import _Loss

if TYPE_CHECKING:
from megatron.core.transformer.transformer_config import TransformerConfig


class BaseLoss(_Loss, metaclass=ABCMeta):
"""Abstract base class for Megatron distillation losses."""

def __init__(self, model_config: "TransformerConfig"):
"""
Constructor.
Args:
model_config: MCore transformer config.
"""
super().__init__()
self._config = model_config

def pre_forward(self, predictions: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""Prepares inputs safely for loss computation."""
if isinstance(predictions, tuple):
# `ColumnParallelLinear` returns bias too
predictions, targets = predictions[0], targets[0]
targets = targets.detach()

return predictions, targets

def post_forward(self, loss: Tensor, tp_reduce: bool = False) -> Tensor:
"""Reshapes tensor from [s, b] to [b, s] for upcoming loss masking."""
loss = loss.transpose(0, 1).contiguous()
return loss, tp_reduce


class LogitsKLLoss(BaseLoss):
"""Calculates KL-Divergence loss between two logits tensors without reducing the sequence dim."""

def __init__(self, model_config: "TransformerConfig", temperature: float = 1.0, reverse: bool = False):
"""
Constructor.
Args:
model_config: MCore transformer config.
temperature: Divide tensors by this value prior to calculating loss.
reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher)
"""
super().__init__(model_config)
self._temperature = temperature
self._reverse = reverse

def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
"""
Forward function.
Args:
predictions: Student model tensors (size [s, b, h])
targets: Teacher model tensors (size [s, b, h])
Returns:
KLD loss of tensors (size [b, s])
"""
predictions, targets = self.pre_forward(predictions, targets)

# Division by temp should happen prior to finding max for both student and teacher.
# Currently we don't use temperature in any of ours runs (temp=1.0)
output_teacher = targets.float() / self._temperature
output_student = predictions.float() / self._temperature

# Compute local softmax, and the reweight to compute global softmax.
if self._config.tensor_model_parallel_size > 1:

# Maximum value along vocab dimension across all GPUs.
teacher_logits_max, _ = torch.max(output_teacher, dim=-1)
torch.distributed.all_reduce(
teacher_logits_max,
op=torch.distributed.ReduceOp.MAX,
group=parallel_state.get_tensor_model_parallel_group(),
)
output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1)

denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1)
# We can't use standard reduction function here since the computation
# that follows it isn't identical across TP ranks.
denom_teacher = all_reduce_autograd(denom_teacher, group=parallel_state.get_tensor_model_parallel_group())

# Maximum value along vocab dimension across all GPUs.
student_logits_max, _ = torch.max(output_student, dim=-1)
torch.distributed.all_reduce(
student_logits_max,
op=torch.distributed.ReduceOp.MAX,
group=parallel_state.get_tensor_model_parallel_group(),
)
output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach()

denom_student = torch.sum(torch.exp(output_student), dim=-1)
denom_student = all_reduce_autograd(denom_student, group=parallel_state.get_tensor_model_parallel_group())

slen, bsz, sharded_vocab_size = output_student.shape
student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand(
slen, bsz, sharded_vocab_size
)
teacher_log_prob = output_teacher - torch.log(denom_teacher).view(slen, bsz, 1).expand(
slen, bsz, sharded_vocab_size
)

if self._reverse:
loss = torch.sum(
F.kl_div(teacher_log_prob, student_log_prob, reduction="none", log_target=True),
dim=-1,
)
else:
loss = torch.sum(
F.kl_div(student_log_prob, teacher_log_prob, reduction="none", log_target=True),
dim=-1,
)

else:
if self._reverse:
loss = torch.sum(
F.kl_div(
F.log_softmax(output_teacher, dim=-1),
F.softmax(output_student, dim=-1),
reduction="none",
),
dim=-1,
)
else:
loss = torch.sum(
F.kl_div(
F.log_softmax(output_student, dim=-1),
F.softmax(output_teacher, dim=-1),
reduction="none",
),
dim=-1,
)

return self.post_forward(loss, tp_reduce=True)


class _AllReduce(torch.autograd.Function):
"""Implementation from old PyTorch `torch.distributed.nn.parallel`."""

@staticmethod
def forward(ctx, op, group, tensor):
ctx.group, ctx.op = group, op
tensor = tensor.clone()
torch.distributed.all_reduce(tensor, op=op, group=group)
return tensor

@staticmethod
def backward(ctx, grad_output):
return (None, None, _AllReduce.apply(ctx.op, ctx.group, grad_output))


def all_reduce_autograd(tensor, op=torch.distributed.ReduceOp.SUM, group=torch.distributed.group.WORLD):
"""Custom all-reduce function.
Needed instead of other all-reduce functions available when the computation following
the all-reduce call differs per rank. In KL loss, this corresponds to the different numerators.
"""
return _AllReduce.apply(op, group, tensor)
Loading

0 comments on commit 1538557

Please sign in to comment.