diff --git a/forge/forge/op/loss.py b/forge/forge/op/loss.py index 353c85aaa..5777b543b 100644 --- a/forge/forge/op/loss.py +++ b/forge/forge/op/loss.py @@ -5,8 +5,8 @@ from forge.op.tm import Broadcast, Unsqueeze from ..module import ForgeModule from .constant import Constant -from .eltwise_unary import Log, Abs, Sigmoid -from .eltwise_binary import Add, GreaterEqual, Less, Subtract, Multiply +from .eltwise_unary import Clip, Log, Abs, Sigmoid, Sqrt +from .eltwise_binary import Add, GreaterEqual, Less, Max, Min, Subtract, Multiply from .nn import Softmax from .reduce import ReduceSum, ReduceAvg @@ -291,3 +291,81 @@ def forward(self, prediction, labels): sigmoid_prediction = Sigmoid("sigmoid", prediction) loss = self.bce_loss(sigmoid_prediction, labels) return loss + + +class TripletMarginLoss(ForgeModule): + """ + Triplet Margin Loss + + Used to measure the relative similarity between anchor, positive and negative samples. + + loss = reduce(max(dist(anchor, positive) - dist(anchor, negative) + margin, 0), dim=0) + when swap is True: + neg_dist = min(dist(anchor, negative), dist(positive, negative)) + loss = reduce(max(dist(anchor, positive) - neg_dist + margin, 0), dim=0) + """ + + def __init__( + self, + name: str, + margin: float = 1.0, + p: float = 2.0, + reduction: str = "mean", + swap: bool = False, + eps: float = 1e-8, + ): + super().__init__(name) + self.margin = margin + # pow is fixed to 2.0 + # https://github.com/tenstorrent/tt-mlir/issues/1033 + self.p = 2.0 + self.reduction = reduction + self.swap = swap + self.eps = eps + self.is_loss = True + + @validate_shapes(min_dim=2, max_dim=2) + def forward(self, anchor, positive, negative): + eps_const = Constant("eps_pos", constant=self.eps) + + # Squared distance for positive pair (anchor-positive) + sub_pos = Subtract("sub_pos", anchor, positive) + # Add eps for numerical stability + sub_pos = Add("add_eps_pos", sub_pos, eps_const) + square_pos = Multiply("square_pos", sub_pos, sub_pos) + square_pos = ReduceSum("reduce_pos_sum_dim_1", square_pos, 1) + pos_dist = Sqrt("sqrt_pos", square_pos) + + # Squared distance for negative pair (anchor-negative) + sub_neg = Subtract("sub_neg", anchor, negative) + # Add eps for numerical stability + sub_neg = Add("add_eps_neg", sub_neg, eps_const) + square_neg = Multiply("square_neg", sub_neg, sub_neg) + square_neg = ReduceSum("reduce_neg_sum_dim_1", square_neg, 1) + neg_dist = Sqrt("sqrt_neg", square_neg) + + if self.swap: + # Squared distance for positive-negative pair + sub_pos_neg = Subtract("sub_pos_neg", positive, negative) + # Add eps for numerical stability + sub_pos_neg = Add("add_eps_pos_neg", sub_pos_neg, eps_const) + square_pos_neg = Multiply("square_pos_neg", sub_pos_neg, sub_pos_neg) + square_pos_neg = ReduceSum("reduce_pos_neg_sum_dim_1", square_pos_neg, 1) + pos_neg_dist = Sqrt("sqrt_pos_neg", square_pos_neg) + + # Take minimum of dist(positive, negative) and dist(anchor, negative) + neg_dist = Min("min_diff", pos_neg_dist, neg_dist) + + # Compute original triplet loss + dist_diff = Subtract("dist_diff", pos_dist, neg_dist) + margin = Constant("margin", constant=self.margin) + margin = Unsqueeze("expand_margin", margin, -1) + margin = Broadcast("broadcast_margin", margin, 0, dist_diff.shape[0]) + margin = Broadcast("broadcast_margin_dim_1", margin, 1, dist_diff.shape[1]) + + dist_with_margin = Add("dist_with_margin", dist_diff, margin) + + # Clip to (0, inf) + clip_dist = Clip("clip_dist", dist_with_margin, 0.0, float("inf")) + loss = reduce_loss(self.reduction, clip_dist) + return loss diff --git a/forge/test/mlir/test_loss.py b/forge/test/mlir/test_loss.py index b070bb19f..8de1e586e 100644 --- a/forge/test/mlir/test_loss.py +++ b/forge/test/mlir/test_loss.py @@ -262,3 +262,38 @@ def test_bce_with_logits_loss(prediction_shape, reduction): torch_loss_out = torch_loss(prediction, target) assert torch.allclose(torch_loss_out, forge_loss_out[0], rtol=5e-2, atol=5e-3) + + +@pytest.mark.parametrize( + "prediction_shape", + [ + (2, 2), + (3, 5), + (32, 32), + (33, 127), + (128, 20), + (128, 128), + ], +) +@pytest.mark.parametrize("reduction", ["mean", "sum"]) +@pytest.mark.parametrize("margin", [0.5, 2.0]) +@pytest.mark.parametrize("eps", [1e-6, 1e-2]) +@pytest.mark.parametrize("swap", [True, False]) +def test_triplet_margin_loss(prediction_shape, reduction, margin, eps, swap): + forge_loss = forge.op.loss.TripletMarginLoss( + "triplet_margin_loss", margin=margin, reduction=reduction, eps=eps, swap=swap + ) + torch_loss = torch.nn.TripletMarginLoss(margin=margin, p=2.0, reduction=reduction, eps=eps, swap=swap) + + anchor = torch.randn(prediction_shape, requires_grad=True) + anchor_forge = forge.tensor.Tensor.create_from_torch(anchor) + positive = torch.randn(prediction_shape, requires_grad=True) + positive_forge = forge.tensor.Tensor.create_from_torch(positive) + negative = torch.randn(prediction_shape, requires_grad=True) + negative_forge = forge.tensor.Tensor.create_from_torch(negative) + + forge_loss = forge.compile(forge_loss, sample_inputs=[anchor_forge, positive_forge, negative_forge]) + forge_loss_out = forge_loss(anchor_forge, positive_forge, negative_forge) + torch_loss_out = torch_loss(anchor, positive, negative) + + assert torch.allclose(torch_loss_out, forge_loss_out[0], rtol=5e-2)