可以参考:

基础八股

多头注意力机制?

多头注意力机制保证了 Transformer 注意到不同子空间的信息, 捕捉到更丰富的特征信息. 解决了注意力过于集中于自身位置的问题 (权重矩阵中, 对角线注意力值最大)

多头注意力的主要作用

  1. 表达不同的特征子空间
  2. 增强模型的泛化能力
  3. 减少每个头的计算负担
  4. 在不显著增加模型参数和计算负担的前提下, 提高参数效率

QK 为什么使用不同的权重矩阵生成, 为什么不能进行自身的点乘?

使用 qkv 不相同可以保证在不同空间进行投影, 增强了表达能力, 提高了泛化能力

为什么 Q 和 K 不能用相同的映射矩阵?

点乘的物理意义是计算两个向量之间的相似度. Q 和 K 的点乘是为了计算 attention matrix, QK 被映射到不同空间上, 从而增加了表达能力, 这样得到的 attention matrix 的泛化能力更高

如果不用 Q, K 直接自己点乘, 得到的 attention matrix 是一个对称矩阵. 同一个矩阵被投影到同一个空间, 所以泛化能力很差

Attention 为什么选择点乘而不是加法? 在计算复杂度和效果上的区别?

点乘可以用用来衡量相似性, 点乘的计算复杂度是 O(n2d)O(n^2 \cdot d), 这种计算方式可以有效利用现代硬件的并行计算能力. 点乘和加法的计算量相似但是在并行计算的效率上可能略低于点乘.

为什么要在 softmax 之前对 Attention 进行 scaled? 为什么是 dk\sqrt{d_k}

极大的点积值将 softmax 推向梯度平缓区, 使得收敛困难, 可能出现梯度消失; 由于进行矩阵乘积, attention matrix 的方差扩大了 dkd_k 倍, 所以需要除以 dk\sqrt{d_k} 进行缩放

具体的解释

dkd_k 较小的时候, Add 和 Mul 两种注意力计算方法的效果接近, 在 dkd_k 增加的时候, 加法的效果开始显著超越点乘.

Add:score(h,s)=<v,tanh(W2h+W2s)>Mul:score(h,s)=<W1h,W2s>Add:score(h, s) = <v, tanh(W_2h + W_2s)>\\ Mul: score(h, s)=<W_1h, W_2s>

  1. 为什么 Add 天然不需要 scaled, Mul 在 dkd_k 较大时必须做 scaled 呢?

Add 中的矩阵乘法只有随机变量 XX 和参数矩阵 WW 相乘, Mul 中包含随机变量 XX 与随机变量 XX 间的乘法

  1. 为什么一定是 dkd_k 呢?

对于 Mul, 如果 sshh 都分布在 [0,1][0, 1], 相乘时引入一次对所有位置的求和, 整体的分布就会扩大到 [0,dk][0, d_k], 为什么扩大到 dkd_k 参见 3. . 整体除以 dk\sqrt{d_k}, 此时的分布将恢复为 [0,1][0, 1]

  1. dkd_k 变大, qkTq\cdot k^T 的方差会变大 dkd_k

假设 q 和 k 的向量长度均为 dkd_k, 均值为 0, 方差为 1, 则 qkTq\cdot k^T 的点积的方差为

var[qkT]=var[i=1dkqi×ki]=i=1dkvar[qi]×var[ki]=i=1dk1=dk var[q\cdot k^T] = var[\sum^{d_k}_{i=1}q_i\times k_i] = \sum_{i=1}^{d_k} var[q_i]\times var[k_i] = \sum_{i=1}^{d_k} 1 = d_k

  1. 不 scaled 时, softmax 将退化为 argmax

计算 attention score 时如何对 padding 做 mask 操作?

mask 矩阵对应的 padding 位置的值为负无穷 (一般设置为一个很大的数, 如-1000)

为什么使用 LayerNorm 而不是 BatchNorm

LayerNorm 能够保留特征的语义信息, 并且不受 batch 大小和序列长度的限制

  1. 序列长度变化: NLP 中输入序列长度可变, BN 需要对每个 batch 中每个特征维度进行归一化, 处理可变长度序列时有困难
  2. 小批次问题: batch 比较小的时候, BatchNorm 性能下降
  3. 位置不变性: LayerNorm对每个样本独立归一化, 输入序列中的位置信息保持不变
  4. 训练和推理一致性: BN 在训练和推理时使用不同的统计量 (训练时使用 batch, 推理时使用全局), 这可能导致训练和推理间的不一致
  5. LayerNorm 计算复杂度与序列长度成线性关系, BN 的计算复杂度与批次大小和序列长度成线性关系. 序列长度较长时, LN 的计算效率更高

Encoder 模块和 Decoder 模块有什么关系