语言模型Sampling方法

在text generation模型预测时,如果我们总是抽取最有可能的单词,标准语言模型训练目标会容易陷入“I don’t know. I don’t know. I don’t know.” 这种循环中。所以有了sample based generation方法。但是,它有一个潜在问题:

假如依照logit softmax生成的分布进行sample,假设有60%的词的概率极低以至于基本不会被选择,但是这60%的词的总的CDF占了30%,这意味着模型预测方向可能有30%的概率偏离了“正确”的方向。

而如果是在预测前期发生偏离,那么由于错误向后预测的累积,直接导致了预测的效果变差。

已有论文研究发现,经常被使用的Beam search方法,其生成效果和人类的表达有着一定的gap。

image-20210102130824216

[^]: Humans often choose words that surprise language models (Holtzman et al 2019) https://arxiv.org/abs/1904.09751

解决方法:temperature sampling和top k sampling.

Temperature sampling

借鉴热力学中现象,温度越高,则低energy的状态出现的概率会增加。

以logits作为“energy”,在进行softmax之前,除以temperature。

1
2
3
4
5
6
7
8
9
10
11
>>> import torch
>>> import torch.nn.functional as F
>>> a = torch.tensor([1,2,3,4.])
>>> F.softmax(a, dim=0)
tensor([0.0321, 0.0871, 0.2369, 0.6439])
>>> F.softmax(a/.5, dim=0)
tensor([0.0021, 0.0158, 0.1171, 0.8650])
>>> F.softmax(a/1.5, dim=0)
tensor([0.0708, 0.1378, 0.2685, 0.5229])
>>> F.softmax(a/1e-6, dim=0)
tensor([0., 0., 0., 1.])

NOTE:temperature越大,分布越趋向均匀

Top k sampling

Top k sampling是指根据概率进行排序将第k个token以下的概率都归零。

但是存在一个问题:某些时候,分布较均匀,可以选择的token大于k;某些时候,分布较集中,可以选择的token小于k。

直接导致了预测错误的概率增大。

Top p samplingnucleus sampling

1,按概率sort预测分布;

2,计算CDF;

3,将CDF大于某个设定p值之后的logit值,设为一个很大的负值;

4,softmax,之后进行采样。

这样就能动态的改变可选择的token数量,且错误概率相对降低。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn.functional as F
import numpy as np


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (..., vocabulary size)
top_k >0: keep only top k tokens with highest probability (top-k filtering).
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
"""
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value

if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs >= top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0

indices_to_remove = torch.zeros_like(logits, dtype=torch.uint8).scatter_(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove )
logits[indices_to_remove] = filter_value
return logits

End

虽然有这些方法来改进模型生成的效果,但是这些仅仅是模型的“补丁”。如何提高模型本身的性能,如何让模型能够直接生成多样性的、更“人类”的语句?emmm...


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!