一、什么是三元组损失?
三元组损失(Triplet
Loss)是深度学习中用于学习特征表示的重要损失函数,最初在FaceNet论文中提出,后被广泛应用于人脸识别、行人重识别(ReID)等任务。其核心思想是通过锚点样本(Anchor) 、正样本(Positive)和负样本(Negative) 的三元组,让同类样本的特征距离更近,不同类样本的特征距离更远。
二、代码结构解析
完整示例代码:
class TripletLoss (nn.Module): """Triplet loss with hard positive/negative mining. Reference: Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_. Args: margin (float, optional): margin for triplet. Default is 0.3. """ def __init__ (self, margin=0.3 ): super (TripletLoss, self ).__init__() self .margin = margin self .ranking_loss = nn.MarginRankingLoss(margin=margin) def forward (self, inputs, targets ): """ Args: inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim). targets (torch.LongTensor): ground truth labels with shape (num_classes). """ n = inputs.size(0 ) dist = torch.pow (inputs, 2 ).sum (dim=1 , keepdim=True ).expand(n, n) dist = dist + dist.t() dist.addmm_(inputs, inputs.t(), beta=1 , alpha=-2 ) dist = dist.clamp(min =1e-12 ).sqrt() mask = targets.expand(n, n).eq(targets.expand(n, n).t()) dist_ap, dist_an = [], [] for i in range (n): dist_ap.append(dist[i][mask[i]].max ().unsqueeze(0 )) dist_an.append(dist[i][mask[i] == 0 ].min ().unsqueeze(0 )) dist_ap = torch.cat(dist_ap) dist_an = torch.cat(dist_an) y = torch.ones_like(dist_an) return self .ranking_loss(dist_an, dist_ap, y)
步骤1:计算特征距离矩阵
\[D_{i j}\,=\,{\sqrt{\vert\vert
x_{i}\,-\,x_{j}\vert\vert^{2}}}\]