Skip to content

Commit

Permalink
Add triplet margin loss (#912)
Browse files Browse the repository at this point in the history
Add support for triplet margin loss

Closes #911
  • Loading branch information
pglusacTT authored Feb 25, 2025
1 parent 389821d commit d3e67bd
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 2 deletions.
82 changes: 80 additions & 2 deletions forge/forge/op/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from forge.op.tm import Broadcast, Unsqueeze
from ..module import ForgeModule
from .constant import Constant
from .eltwise_unary import Log, Abs, Sigmoid
from .eltwise_binary import Add, GreaterEqual, Less, Subtract, Multiply
from .eltwise_unary import Clip, Log, Abs, Sigmoid, Sqrt
from .eltwise_binary import Add, GreaterEqual, Less, Max, Min, Subtract, Multiply
from .nn import Softmax
from .reduce import ReduceSum, ReduceAvg

Expand Down Expand Up @@ -291,3 +291,81 @@ def forward(self, prediction, labels):
sigmoid_prediction = Sigmoid("sigmoid", prediction)
loss = self.bce_loss(sigmoid_prediction, labels)
return loss


class TripletMarginLoss(ForgeModule):
"""
Triplet Margin Loss
Used to measure the relative similarity between anchor, positive and negative samples.
loss = reduce(max(dist(anchor, positive) - dist(anchor, negative) + margin, 0), dim=0)
when swap is True:
neg_dist = min(dist(anchor, negative), dist(positive, negative))
loss = reduce(max(dist(anchor, positive) - neg_dist + margin, 0), dim=0)
"""

def __init__(
self,
name: str,
margin: float = 1.0,
p: float = 2.0,
reduction: str = "mean",
swap: bool = False,
eps: float = 1e-8,
):
super().__init__(name)
self.margin = margin
# pow is fixed to 2.0
# https://github.com/tenstorrent/tt-mlir/issues/1033
self.p = 2.0
self.reduction = reduction
self.swap = swap
self.eps = eps
self.is_loss = True

@validate_shapes(min_dim=2, max_dim=2)
def forward(self, anchor, positive, negative):
eps_const = Constant("eps_pos", constant=self.eps)

# Squared distance for positive pair (anchor-positive)
sub_pos = Subtract("sub_pos", anchor, positive)
# Add eps for numerical stability
sub_pos = Add("add_eps_pos", sub_pos, eps_const)
square_pos = Multiply("square_pos", sub_pos, sub_pos)
square_pos = ReduceSum("reduce_pos_sum_dim_1", square_pos, 1)
pos_dist = Sqrt("sqrt_pos", square_pos)

# Squared distance for negative pair (anchor-negative)
sub_neg = Subtract("sub_neg", anchor, negative)
# Add eps for numerical stability
sub_neg = Add("add_eps_neg", sub_neg, eps_const)
square_neg = Multiply("square_neg", sub_neg, sub_neg)
square_neg = ReduceSum("reduce_neg_sum_dim_1", square_neg, 1)
neg_dist = Sqrt("sqrt_neg", square_neg)

if self.swap:
# Squared distance for positive-negative pair
sub_pos_neg = Subtract("sub_pos_neg", positive, negative)
# Add eps for numerical stability
sub_pos_neg = Add("add_eps_pos_neg", sub_pos_neg, eps_const)
square_pos_neg = Multiply("square_pos_neg", sub_pos_neg, sub_pos_neg)
square_pos_neg = ReduceSum("reduce_pos_neg_sum_dim_1", square_pos_neg, 1)
pos_neg_dist = Sqrt("sqrt_pos_neg", square_pos_neg)

# Take minimum of dist(positive, negative) and dist(anchor, negative)
neg_dist = Min("min_diff", pos_neg_dist, neg_dist)

# Compute original triplet loss
dist_diff = Subtract("dist_diff", pos_dist, neg_dist)
margin = Constant("margin", constant=self.margin)
margin = Unsqueeze("expand_margin", margin, -1)
margin = Broadcast("broadcast_margin", margin, 0, dist_diff.shape[0])
margin = Broadcast("broadcast_margin_dim_1", margin, 1, dist_diff.shape[1])

dist_with_margin = Add("dist_with_margin", dist_diff, margin)

# Clip to (0, inf)
clip_dist = Clip("clip_dist", dist_with_margin, 0.0, float("inf"))
loss = reduce_loss(self.reduction, clip_dist)
return loss
35 changes: 35 additions & 0 deletions forge/test/mlir/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,38 @@ def test_bce_with_logits_loss(prediction_shape, reduction):
torch_loss_out = torch_loss(prediction, target)

assert torch.allclose(torch_loss_out, forge_loss_out[0], rtol=5e-2, atol=5e-3)


@pytest.mark.parametrize(
"prediction_shape",
[
(2, 2),
(3, 5),
(32, 32),
(33, 127),
(128, 20),
(128, 128),
],
)
@pytest.mark.parametrize("reduction", ["mean", "sum"])
@pytest.mark.parametrize("margin", [0.5, 2.0])
@pytest.mark.parametrize("eps", [1e-6, 1e-2])
@pytest.mark.parametrize("swap", [True, False])
def test_triplet_margin_loss(prediction_shape, reduction, margin, eps, swap):
forge_loss = forge.op.loss.TripletMarginLoss(
"triplet_margin_loss", margin=margin, reduction=reduction, eps=eps, swap=swap
)
torch_loss = torch.nn.TripletMarginLoss(margin=margin, p=2.0, reduction=reduction, eps=eps, swap=swap)

anchor = torch.randn(prediction_shape, requires_grad=True)
anchor_forge = forge.tensor.Tensor.create_from_torch(anchor)
positive = torch.randn(prediction_shape, requires_grad=True)
positive_forge = forge.tensor.Tensor.create_from_torch(positive)
negative = torch.randn(prediction_shape, requires_grad=True)
negative_forge = forge.tensor.Tensor.create_from_torch(negative)

forge_loss = forge.compile(forge_loss, sample_inputs=[anchor_forge, positive_forge, negative_forge])
forge_loss_out = forge_loss(anchor_forge, positive_forge, negative_forge)
torch_loss_out = torch_loss(anchor, positive, negative)

assert torch.allclose(torch_loss_out, forge_loss_out[0], rtol=5e-2)

0 comments on commit d3e67bd

Please sign in to comment.