Songqian Li's Blog

去历史上留点故事

本文希望从 Base-Attention 到 Transformer 逐层递进的解释其中的计算细节,从而更好的理解 Transformer 模型。本文主要对模型可能存在的盲区进行解释,可能思路有些跳跃,请谅解。参考资料见文章末尾Reference.

Base-Attention


如左图所示,为了实现Attention机制,我们需要一个 Z0Z0,与 h1h1 计算得到一个scaleα10α01,这个 α10α01 可以简单看做 hh 与 zz 的相似度.
将 Z0Z0 与每一个 hihi 进行相乘然后 SoftmaxSoftmax 计算得到distribution αi0α0i,就变成了右图。这个distribution就是Attention机制的Weight。得到Attention之后将αi0α0i 乘上 hihi 后加到一起形成VectorC0C0,C0C0 就是预测第一个单词的特征向量。
注:本例中的Value是与Key相同的,但是实际上Value可以与Key不同,因为Key是用来计算AttentionWeightαα 的,而Value是用来与 αα 相乘得到Vector的,因此KeyValue不需要一定相同。

Dot-Product Attention是 Base 的一种计算方式,图中表示了 Attention 的计算公式以及相对应的 Q,K,VQ,K,V 的维度。

Self-Attention


Self-Attention主要解决了训练时的并行问题,计算每个单词Attention时不用等前一个单词计算完成即可开始计算,以图中a为例。

如图以 e2e2 为例,把自身的vector作为Query,Dot-Product 的Key就是 e1e1。

而在 decoder 的时候需要把 e2e2 后面的词给 mask 掉,因为inference阶段是没有后面输入。

Multi-Head Attention


Multi-Head Attention的相关解释如图所示。

Scaled Dot-Product Attentionh指的是Multi-Head Attention的数量。

Transformer


Transformer 就是用 Self-Attention 替换了序列模型的 RNN 单元。

本图主要解释了加入Scaled过程的原因,如上图中间的图像所示,蓝色是 QKTQKT,橙色是 softmax(QKT)softmax(QKT),明显可以看出函数变的很尖,这样导致梯度变得很小,只有高数值的信息才会使模型更新,因此为了解决这个问题,在计算 softmax 之前除以 √dkdk。
这里可以理解为两个独立的矩阵,均值为 0,方差为 1。经过矩阵相乘之后,均值还是 0,方差会变成 dkdk,这样就会导致 Softmax 之后梯度变得很小。

residual connection 是指经过一个 F(x)F(x)后,输出结果是 F(x)+xF(x)+x,这样做的目的是为了解决模型在多次进行 F(x)F(x)计算之后没有更多的signore的时候,模型可以弹性的选择哪部分要经过 F(x)F(x),哪部分可以跳过 F(x)F(x),而且模型也不需要更多的参数来进行训练。

Decoder中,因为序列的预测过程是从左往右,所以右面单词的Attention是未知的,这就是绿色框中Masked与红色框之间有区别的原因。
在蓝色框中,Decoder提供 Attention 的QueryEncoder提供 Attenion 的KeyQuery

Training Trick

Byte-pair encodings

BPE 算法的基本思路是把经常出现的byte pair用一个新的byte来代替,例如假设('A', ’B‘)经常顺序出现,则用一个新的标志'AB'来代替它们。
例如,假设我们的文本库中出现的单词及其出现次数为{'l o w': 5, 'l o w e r': 2, 'n e w e s t': 6, 'w i d e s t': 3},我们的初始词汇库为{ 'l', 'o', 'w', 'e', 'r', 'n', 'w', 's', 't', 'i', 'd'},出现频率最高的 ngram pair 是('e','s') 9 次,所以我们将’es’作为新的词汇加入到词汇库中,由于’es’作为一个整体出现在词汇库中,这时文本库可表示为 {'l o w': 5, 'l o w e r': 2, 'n e w es t': 6, 'w i d es t': 3},这时出现频率最高的 ngram pair 是('es','t')9 次,将’est’加入到词汇库中,文本库更新为{'l o w': 5, 'l o w e r': 2, 'n e w est': 6, 'w i d est': 3},新的出现频率最高的 ngram pair 是(‘l’,’o’)7 次,将’lo’加入到词汇库中,文本库更新为{'lo w': 5, 'lo w e r': 2, 'n e w est': 6, 'w i d est': 3}。以此类推,直到词汇库大小达到我们所设定的目标。这个例子中词汇量较小,对于词汇量很大的实际情况,我们就可以通过 BPE 逐步建造一个较小的基于 subword unit 的词汇库来表示所有的词汇。

Checkpoint averaging

对训练过程中 checkpoint 保存的模型进行参数平均

  1. step 1:在预测第一个单词的时候选择前Beam width个概率最大的单词作为第一个单词;
  2. step 2:假设词汇表长10000,B=3按照 P(y<1>,y<2>|x)=P(y<1>|x)P(y<2>|x,y<1>)P(y<1>,y<2>|x)=P(y<1>|x)P(y<2>|x,y<1>)最大,就会在开头为(y<1>,y<2>)(y<1>,y<2>)的30000种选择中选取概率最大的3种,并将(y<1>,y<2>)(y<1>,y<2>)作为新的开头;
  3. step 3:将step 2中选取的新的开头作为条件,重复计算step 2,如果遇到终止符则停止。

改进集束搜索(改进其打分规则):
由于argmaxt=1TyP(yx,y1,,yt1)\displaystyle argmax \prod_{t=1}^{T_y}P(y^{}|x,y^{1},\dots,y^{t-1}),是多个小于 1 的概率相乘,最终得到的数值会很小很小,出现数值下溢,因此对上述公式取对数:
argmaxTy∑t=1log,P(y|x,y1,…,yt−1)argmax∑t=1Tylog,P(y|x,y1,…,yt−1)
但是 log 函数在(0,1)是负数,序列越长数值越低,所以模型会更倾向于短文本输出而不是简洁的文本。于是第二次改进:
1TαyTy∑t=1log,P(y|x,y1,…,yt−1)1Tyα∑t=1Tylog,P(y|x,y1,…,yt−1)
对和函数进行归一化,但是不是单纯的除以单词数量,而是除以 TyTy 的 αα 次幂,这里的 αα 是一个超参数。

Reference

  1. 台大《应用深度学习》国语课程(2020) by 陈蕴侬
  2. 序列模型 by Andrew Ng
  3. CS224N 笔记(十二):Subword 模型
相关文章
评论
分享
  • 使用PyTorch可视化必须知道的TensorBoard参数

    亲测可用的PyTorch和TensorflowBoard版本,不会出现绘制模型结构图片时空白的情况。 1234torch==1.2.0tensorboard==2.1.1tensorflow==2.1.0tensorboardX==2...

    使用PyTorch可视化必须知道的TensorBoard参数
  • 你看不懂的BERT解读

    本文主要是针对近年来序列模型的发展,例如 BERT、Transformer-XL、XLNet、RoBERTa 以及 XLM 等模型的思路整理。 BERT: Bidirectional Encoder Representations ...

    你看不懂的BERT解读
  • 正则化

    模型越复杂越容易出现过拟合状态,所以需要一种机制来保证我们模型的“简单”,这样我们的模型才能有较好的泛化能力,正则化是这类机制之一。 欧几里得范数: L2 范数: L1 范数: 推导过程 泰勒公式   为什么可以减少过拟...

    正则化
  • 机器学习常用算法原理

    Logistic Regression 逻辑回归的假设函数: 其中是输入,是要求解的参数。 函数图像: 一个机器学习模型实际上是把决策函数限定在某组条件下,这组限定条件决定了模型的假设空间,逻辑回归的假设空间: 它的意思是在给...

    机器学习常用算法原理
  • 《操作系统真象还原》:第十章 输入输出系统

    上一章中我们遇到的字符混乱和 GP 异常问题,根本原因是由于临界区代码的资源竞争,这需要一些互斥的方法来保证操作的原子性。 10.1 同步机制——锁 10.1.1 排查 GP 异常,理解原子操作 多线程执行刷屏时光标值越界导致...

    《操作系统真象还原》:第十章 输入输出系统
  • 《操作系统真象还原》:第九章 线程

    线程和进程将分两部分实现,本章先讲解线程。 9.1 实现内核线程 9.1.1 执行流 在处理器数量不变的情况下,多任务操作系统采用多道程序设计的方式,使处理器在所有任务之间来回切换,这称为“伪并行”,由操作系统中的任务调度器决定当...

    《操作系统真象还原》:第九章 线程
  • GPU虚拟化

    用户层虚拟化 本地 API 拦截和 API formwarding 在用户态实现一个函数库,假设叫 libwrapper, 它要实现底层库的所有 API; 让 APP 调用这个 libwrapper。如何做? libwrap...

    GPU虚拟化
  • 硬件虚拟化

    硬件虚拟化介绍 硬件虚拟化要做的事情 体系结构支持 体系结构 实现功能 作用 模式切换 Host CPU <-> Guest CPU 切换 CPU 资源隔离 二阶段地址转换 GVA-> GPA...

    硬件虚拟化
  • 《操作系统真象还原》:第八章 内存管理系统

    8.1 makefile 简介 这部分可参考阮一峰的讲解:https://www.ruanyifeng.com/blog/2015/02/make.html 8.1.1 makefile 是什么 makefile 是 Linu...

    《操作系统真象还原》:第八章 内存管理系统
  • 《操作系统真象还原》:第七章 中断

    7.1 中断是什么,为什么要有中断 运用中断能够显著提升并发,从而大幅提升效率。 7.2 操作系统是中断驱动的 略 7.3 中断分类 把中断按事件来源分类,来自 CPU 外部的中断就称为外部中断,来自 CPU 内部的中断称为内部...

    《操作系统真象还原》:第七章 中断