# XLNet: Generalized Autoregressive Pretraining for Language Understanding
XLNet发表在NeurIPS-19上,是BERT之后相当有名的一个工作,这里简单总结一下它的要点。
# 要解决的痛点
XLNet,如果做一个总结的话,可以理解为是结合了BERT、GPT、Transformer-XL这些代表性工作各自优点的一个综合体。那么自然,GPT-2、BERT都是各自有个字的缺点。
# BERT的优缺点
BERT在pre-training中主要的技术就是Masked-language-modeling(MLM),MLM为了实现双向语言模型,在训练过程中,会随机挑选15%的token用[MASK]
来替换掉,然后用带[MASK]
的上下文来预测这个token。这就会造成预训练和微调时的不匹配(pretrain-finetune discrepancy)问题。虽然BERT采用了一些trick来缓解这个问题(15%的选中词里面,有80%的概率会被MASK,剩下的要么替换成随机词,要么不变),但这是杯水车薪的。
这就是BERT存在的一个问题,但是MLM毕竟实现了双向的语言模型,所以比传统的单向语言模型还是更好的。
# GPT模型的优缺点
GPT模型,跟BERT是不一样的路子,是单向语言模型,采用auto-regressive的方式,用前面的词去预测下一个词,这么做自然就不需要对token进行MASK,所以不存在pretrain-finetune discrepancy问题。但是这种单向语言模型,在语义表示的方面,自然没有双向的好。
XLNet,就是为了解决BERT和GPT各自的问题,想把他们各自的优点结合在一起的一个工作。
# Permutation Language Modeling(PLM)
这个PLM,就是XLNet最主要的贡献和亮点。
对于BERT的使用MASK token导致的pretrain-finetune discrepancy问题,如果还是用BERT采用的denoising auto-encoding的方式的话,那是没办法的。想不用MASK,就只能用auto-regressive方式,那如何在auto-regressive方式下还能进行双向建模呢?作者们就提出了这个PLM的想法。
一图胜千言,原文理论部分不是很容易看明白,但他们在附录里花了一张图,就很容易懂:
如上图所示,一个序列正常的顺序假设是1-2-3-4,然后我们想对位置3的token进行双向语言模型预测。
PLM的做法就是,先把1-2-3-4的顺序打乱(假设序列长度为L
,那么就有 L!
种token的不同排列组合):
- ① 3-2-4-1
- ② 2-4-3-1
- ③ 1-4-2-3
- ④ 4-3-1-2
- ...
然后,对于每一种排列,我们使用经典的auto-regressive的方式对3进行预测,那么:
- 对于①,3的上文是空的,所以在通过self-attention的时候,没有一个被attend,相当于凭空预测3;
- 对于②,3的上文是2、4,所以2、4的位置被attend,相当于用2、4来预测3;
- 对于③,3的上文是1、4、2,所以1、2、4的位置被attend,相当于用1、2、4来预测3;
- 对于④,3的上文是4,所以只有4的位置被attend,相当于用4来预测3。
这就是PLM的思想,通过这种方式,某个token的上下文,实际上都有可能参与预测该token,也就实现了双向的语言模型,这种设计还是挺精妙的。
**Notice!**虽然PLM把顺序都打乱了,但实际上输入模型的,都是原始的顺序,只是在进行language model预测的时候,对所谓的“上文”进行了各种采样,因此实现了实际上的上下文建模。所以我们不用担心这里的permutation对语义的影响,模型学习的还是正常的句子。
其实BERT还存在一个问题,那就是BERT预训练中随机MASK掉的那些词,在训练的时候没有考虑到彼此之间的关系,比方作者举的这个例子:
[New, York, is, a, city]这个句子,如果把New, York都给mask掉了,那么BERT在预测New的时候,使用的上下文就只有[is, a, city],预测York的时候使用的上下文也只有[is, a, city],而New和York之间的依赖关系就被忽略了。
而XLNet不同,假设对句子随机采样的一个排序是[is, a, city, New, York], 那么预测New的时候,使用的上文是[is, a, city],预测York的时候使用的上文则是[New,is, a, city],比BERT对了一个对New的考虑。
其实你细想一下为啥XLNet可以做到?因为XLNet不会同时去预测New和York,一次只预测一个token,而BERT则是同时预测所有被mask掉的词,那自然就没法考虑彼此之间的关系了。
# 目标函数
目标函数,就是岁所有可能permutation序列的联合概率分布的期望,比方sequence length是3,那么就有6种排列,那目标函数怎么计算呢?对于每种排列,依次计算每个位置token的概率,然后连乘再取log,最后把6种排列的结果进行平均。
实际上,这么算的话计算开销是很大的,所以作者实际使用的方式,是名为Partial Prediction的方式,即对于一个排列,我们只对最后的几个position进行预测,前面的不管了。比方1-2-3-4的一个排列是1-3-2-4,我们设置一个截断长度=2,那么我们只做1,3->2和1,2,3->4这两个LM预测。
# 使用PLM要解决的问题
上面描述PLM感觉很美好,但直接使用会存在问题。
首先我们看看auto-regressive是如何预测下一个词的概率的:
假设一个sequence的排列是a-b-c-d,那我们预测c的时候使用的上文就是a-b,但是套用上面的公式的话,
原始的sequence可能是c-a-b-d,也可能是a-b-c-d,还可能是a-c-d-b,所有的可能,只要在a-b-c-d的排列下,计算出来的概率就是一模一样的。也就是说,还用传统的计算方法的话,会忽略要预测词的位置信息,这就肯定很影响学习的效果。
所以,我们要做的改进,就是把要预测词的位置信息加进去:
但是,我们加入的也只能是位置信息,不能把内容信息给加进去了,不然预测就没有意义了。
上面这个公式实际上是Transformer的最后一层的处理,如何把要预测的词的位置,通过层层的self-attention,把位置信息给传上来的同时不传递内容信息呢?另外,我一个sequence输进来,也不能每次只预测一个token吧,那样效率就太低了,那对于要预测的token1如果只有位置信息,而要预测的token2如果要使用token1的信息的话,那岂不是矛盾了。
说得再具体一点,对于一个序列x1-x2-x3-x4,我们在预测x3的时候,不希望使用x3的内容,只使用它的位置;然后我们还想预测x4,这个时候,我们有需要同时获取x3的内容和位置。要想做到这一点,传统的Transformer结构是无法做到的。那么,我们能怎么办呢:
- 首先得把token representation分成embedding和position两部分,不能混在一起
- 得有两套self-attention机制来帮忙传递信息
这就是作者提出的Two-Stream Self-Attention方法:
# Two-Stream Self-Attention
content stream attention和query stream attention:
上图中,我们可以理解蓝色的部分就是content,绿色的部分就是position。
- (a)代表Content stream attention,它就是传统Transformer中一模一样的self-attention。它的作用是保证每个token的内容信息的传递。
- (b)代表Query stream attention,它对于当前的token,只访问position,对于其他的token,则访问content。当前token的position作为query,所有其他token的content作为K和V,然后使用position跟K计算attention权重,再对V进行加权求和,就得到下一层的position。总之,这个stream的作用就是保证每个token只有位置传递上去。
- 另外注意,content stream只使用了content,但是query stream同事使用了position和content,这种交互使得在参数不断更新之后,content里面也会包含position的信息,而position中本来也包含了content的信息,只是不包含当前位置的content。
综合起来,我们的decoder就是这样的:
图中最上面红色虚线,代表预测,比如要预测x3的话,我们使用的context hidden representation就是图中的g3,g3只包含x3的位置信息,以及x3的上文的所有信息;而要预测x4的时候,如果x3在x4的上文中,那么x3的所有信息都可以被g4获取。这样,上面提到的矛盾就可以解决了。
# 借鉴Transformer-XL
实际上XLNet的名字就是沿用了Transformer-XL来取的,因为这里在backbone上主要就采用了Transformer-XL的设计,即采用相对位置编码(relative positional encoding)和片段重用机制( segment recurrence mechanism),这样可以让模型接收更长的序列,从而对长文本的表示更好。具体这里不展开了。
# XLNet的效果
XLNet的其他细节,就暂不列出了,下面贴一下跟bert的公平对比,反正棒就完事儿了: