对比学习损失使用

对于自监督学习,一般分为两种。一种是AutoEncoder这种通过一个表征向量,从自己到自己的还原过程,这类称为生成式自监督学习。一种是以学习区分两种不同类事物的关键特征为目标,通过构建正负例子,学习表征向量的方法,这类称为判别式自监督学习,也叫做对比学习。

对比学习的通过互信息,衡量一个表征的好坏,与正例相似而远离负例。这里记录两个常用的对比学习损失。

NTXentLoss

NTXentLoss也就是InfoNCE使用的损失: \[ L = -log \frac{exp(q \cdot k_+ / \tau)}{\sum^{K}_{i=0}exp(q \cdot k_i / \tau)} \] 分子在最小化损失函数时,会使表征 \(q\) 与正例 \(k_+\) 的相似度增加。

可以直接使用 PyTorch Metric Learning 包调用损失函数类。

1
2
3
4
5
from pytorch_metric_learning.losses import NTXentLoss

...
loss_func = NTXentLoss(temperature=temperature)
...

其基本流程如下:

1
2
3
4
5
6
7
8
9
for anchor, positive in pos_pairs:
numerator = torch.exp(torch.matmul(anchor, positive) / (temperature * torch.norm(anchor) * torch.norm(positive)))
denominator = numerator.clone()

for (candidate, negetive) in neg_pairs:
tmp = torch.exp(torch.matmul(anchor, negative) / (temperature * torch.norm(anchor) * torch.norm(negative)))
denominator += tmp

total_loss += -torch.log(numerator / denominator)

SupConLoss

SupConLoss(Supervised Contrastive)是在监督数据中使用对比学习的损失函数。由于有监督数据的支撑,正例不再来源于样本自身,而且可以来自监督标签中属于同一类的样本。其计算公式的区别也在于多了监督标签的部分。 \[ L^{sup}= \sum_{i \in I} \frac{-1}{|P(i)|} \sum_{p \in P(i)} log \frac{exp(z_p \cdot z_i / \tau)}{ \sum_{a \in A(i)} exp(z_a \cdot z_i / \tau)} \] 对每个正例除以包含该正例的positive pairs的数量。具体看代码比较直接。这个公式有两种形式,还有一种是将对 \(|P(i)|\) 求平均的操作放置于 log 之内。

可以直接使用 PyTorch Metric Learning 包调用损失函数类。

1
2
3
from pytorch_metric_learning.losses import SupConLoss

loss_func = SupConLoss(temperature=temperature)

其基本流程如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
losses = torch.zeros(num_of_classes, dtype=torch.float64)

for anchor, positive in pos_pairs:
numerator = torch.exp(torch.matmul(anchor, positive) / (temperature * torch.norm(anchor) * torch.norm(positive)))
denominator = numerator.clone()

for (candidate, negetive) in neg_pairs:
tmp = torch.exp(torch.matmul(anchor, negative) / (temperature * torch.norm(anchor) * torch.norm(negative)))
denominator += tmp

losses[anchor_idx] += -torch.log(numerator / denominator)


total_loss = torch.mean(losses / num_of_positive_pairs_per_anchor)

Gather操作

和之前主题无关,只是记在一起。

1
output = tensor.gather(dim, index)

tensor与index是两个维度相同的张量。

output中的下标为 (i, j) 的值来自:

dim = 0,从tensor中取值时,0维的坐标值来自index张量的 (i, j) 位置的值,1维的坐标值就是 j (output中本来的坐标值)。

dim = 1,从tensor中取值时,1维的坐标值来自index张量的 (i, j) 位置的值,0维的坐标值就是 i (output中本来的坐标值)。


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!