Triplet Loss

Triplet Loss

【对比学习】| Triplet loss

MXNet/Gluon 中 Triplet Loss 算法

1.什么是triplet loss 损失函数?

triplet loss 是深度学习的一种损失函数,主要是用于训练差异性小的样本,比如人脸,细粒度分类等;其次在训练目标是得到样本的embedding任务中,triplet loss 也经常使用,比如文本、图片的embedding。本文主要讨论,对于训练样本差异小的问题。

2.tripletloss原理

损失函数公式:

https://pic1.zhimg.com/80/v2-90daccead9843f3bf5b1489e58dbd578_720w.webp

。输入是一个三元组,包括锚(Anchor)示例、正(Positive)示例、负(Negative)示例,通过优化锚示例与正示例的距离小于锚示例与负示例的距离,实现样本之间的相似性计算。a:anchor,锚示例;p:positive,与a是同一类别的样本;n:negative,与a是不同类别的样本;margin是一个大于0的常数。最终的优化目标是拉近a和p的距离,拉远a和n的距离。其中样本可以分为三类:

**easy triplets:**即

https://pic2.zhimg.com/80/v2-dd0386738a3604977a84ea8eb7487a21_720w.webp

,这种情况不需要优化,天然a和p的距离很近,a和n的距离很远,如下图:

https://pic1.zhimg.com/80/v2-ad02e580b107eae95d8feedcd1cfc50c_720w.webp

easy triplets示例

**hard triplets:**即d(a,n)<d(a,p)

https://pic4.zhimg.com/80/v2-c0f8f2ddcb833953d4746f5cab18b3db_720w.webp

,a和n的距离近,a和p的距离远,这种情况损失最大,需要优化,如下图:

https://pic1.zhimg.com/80/v2-0e26c64925ef2bcb9806eaf18c2aad0c_720w.webp

hard triplets示例

**semi-hard triplets:**即

https://pic3.zhimg.com/80/v2-e4dfc15a84b8f0c2edfcb0943bf3895e_720w.webp

,即a和p的距离比a和n的距离近,但是近的不够多,不满足margin,这种情况存在损失,但损失比hard triplets要小,也需要优化,如下图:

https://pic2.zhimg.com/80/v2-a84e1ad1b49496a5806c14613f662b31_720w.webp

semi-hard triplets示例

3.Margin的作用

  • 避免模型走捷径,将negative和positive的embedding训练成很相近,因为如果没margin,triplets loss公式就变成了,那么只要就可以满足上式,也就是锚点a和正例p与锚点a和负例n的距离一样即可,这样模型很难正确区分正例和负例。

    https://pic4.zhimg.com/80/v2-94655266dc5db1158df340927e4e8f3f_720w.webp

    https://pic2.zhimg.com/80/v2-656b1ac96526f9140a11ae0c021e36d1_720w.webp

  • 设定一个margin常量,可以迫使模型努力学习,能让锚点a和负例n的distance值更大,同时让锚点a和正例p的distance值更小。

  • 由于margin的存在,使得triplets loss多了一个参数,margin的大小需要调参。如果margin太大,则模型的损失会很大,而且学习到最后,loss也很难趋近于0,甚至导致网络不收敛,但是可以较有把握的区分较为相似的样本,即a和p更好区分;如果margin太小,loss很容易趋近于0,模型很好训练,但是较难区分a和p。

在训练的时候,一个重要的选择就是对于负样本进行挑选。称之为,负样本选择或者三元组采集(triplet mining)。一个原则时,easy triplet应该尽量避免被采集到,因为loss为0,所以对训练并没有贡献。

Triplet Loss使用案例


在此任务中使用Triplet Loss,可以通过以下方式实现模型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch.nn as nn
import torch.optim as optim

class MyModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MyModel, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x

model = MyModel(input_size=10, hidden_size=5, output_size=2)
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
optimizer = optim.SGD(model.parameters(), lr=0.001)

# Training loop
for epoch in range(num_epochs):
for i, (anchor, pos, neg) in enumerate(train_loader):
optimizer.zero_grad()
anchor_output = model(anchor)
pos_output = model(pos)
neg_output = model(neg)
loss = triplet_loss(anchor_output, pos_output, neg_output)
loss.backward()
optimizer.step()

# Validation loop
with torch.no_grad():
for anchor, pos, neg in val_loader:
anchor_output = model(anchor)
pos_output = model(pos)
neg_output = model(neg)
dist_pos = F.pairwise_distance(anchor_output, pos_output)
dist_neg = F.pairwise_distance(anchor_output, neg_output)
total += 1
if dist_pos < dist_neg:
correct += 1
accuracy = 100 * correct / total

# Testing loop
with torch.no_grad():
for anchor, pos, neg in test_loader:
anchor_output = model(anchor)
pos_output = model(pos)
neg_output = model(neg)
dist_pos = F.pairwise_distance(anchor_output, pos_output)
dist_neg = F.pairwise_distance(anchor_output, neg_output)
total += 1
if dist_pos < dist_neg:
correct += 1
accuracy = 100 * correct / total

在这个示例代码中,我们定义了一个自定义模型,并使用nn.TripletMarginLoss作为损失函数来最小化锚点、正样本和负样本之间的距离。在训练循环中,我们首先通过模型获取锚点、正样本和负样本的输出,然后将它们传递给nn.TripletMarginLoss来计算损失并更新模型参数。

在验证循环和测试循环中,我们首先通过模型获取锚点、正样本和负样本的输出,然后使用PyTorch内置函数F.pairwise_distance来计算锚点和正样本之间的欧几里得距离(或其他距离度量),以及锚点和负样本之间的距离。如果锚点和正样本之间的距离小于锚点和负样本之间的距离,则认为预测正确。

请注意,在上面的示例代码中,模型的forward函数只接受一个输入,并且我们假设锚点、正样本和负样本都是从train_loaderval_loadertest_loader中获取的。如果你的输入数据包含两个样本,你可以修改forward函数来接受两个输入:

1
2
3
4
5
6
def forward(self, x1, x2):
x = torch.cat((x1, x2), dim=1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x

然后在训练循环中,您需要分别传递两个输入分别计算输出。在验证和测试循环中,您需要在计算欧几里得距离之前将两个输入传递给模型以获得它们的特征向量。

with torch.no_grad()是在PyTorch中的上下文管理器,用于在执行代码期间禁用梯度计算和自动求导。这在测试模型时非常有用,因为在测试时我们不需要计算梯度,而且如果开启自动求导,会增加计算时间和内存占用。在with torch.no_grad()中执行的所有操作都不会计算梯度,所以可以加速代码执行速度和节省内存消耗。

在PyTorch中,model.train()model.eval()方法是用于控制模型训练和评估模式的。当调用model.train()时,模型将被设置为训练模式,这意味着它将启用dropout和batch normalization等训练特定的操作。当调用model.eval()时,模型将被设置为评估模式,这意味着它将禁用dropout和batch normalization等特定于训练的操作,并使用整个测试集对模型进行评估。在测试或验证期间,应该始终调用model.eval(),以确保模型不会受到dropout等操作的影响,从而获得准确的测试结果。