一、什么是三元组损失?

三元组损失(Triplet Loss)是深度学习中用于学习特征表示的重要损失函数,最初在FaceNet论文中提出,后被广泛应用于人脸识别、行人重识别(ReID)等任务。其核心思想是通过锚点样本(Anchor)正样本(Positive)和负样本(Negative)的三元组,让同类样本的特征距离更近,不同类样本的特征距离更远。

二、代码结构解析

完整示例代码:


class TripletLoss(nn.Module):
"""Triplet loss with hard positive/negative mining.

Reference:
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.

Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_.

Args:
margin (float, optional): margin for triplet. Default is 0.3.
"""

def __init__(self, margin=0.3):
super(TripletLoss, self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)

def forward(self, inputs, targets):
"""
Args:
inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).
targets (torch.LongTensor): ground truth labels with shape (num_classes).
"""
n = inputs.size(0)

  #步骤1:计算特征距离矩阵
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t()
dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2)
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability

# For each anchor, find the hardest positive and negative
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
dist_ap = torch.cat(dist_ap)
dist_an = torch.cat(dist_an)

# Compute ranking hinge loss
y = torch.ones_like(dist_an)
return self.ranking_loss(dist_an, dist_ap, y)

步骤1:计算特征距离矩阵

\[D_{i j}\,=\,{\sqrt{\vert\vert x_{i}\,-\,x_{j}\vert\vert^{2}}}\]