Attention
Published:
注意力机制通过注意力汇聚将查询(自主性提示,Q)和键(非自主性提示,K)结合在一起,实现对值(感官输入,V)的选择倾向,其中键和值是成对的。
注意力机制与全连接层的区别在于其加入了自主性提示
注意力汇聚
非参数注意力汇聚
Nadaraya-Watson核回归(Nadaraya-Watson kernel regression)是一个典型的非参数注意力汇聚模型,如式$\eqref{eq:NW_kernel}$所示:
\[\begin{align} f(q) &= \sum_{i=1}^n \frac{K(q, k_i)}{\sum_{j=1}^n K(q, k_j)} v_i \label{eq:NW_kernel}\\ &= \sum_{i = 1}^n \alpha(q, k_i)v_i \label{eq:non_para} \end{align}\]其中$q$为查询,$(k_i, v_i)$为键值对,$K(k_i, v_i)$为核函数,用来衡量查询和键之间的距离。非参数注意力汇聚可以简化式$\eqref{eq:non_para}$所示,即将查询和键之间映射为注意力权重,查询结果是对值的加权平均,这里需要保证$\alpha(q, k_i) \ge 0$且$\sum_{i = 1}^n \alpha(q, k_i) = 1$,因此,可以使用注意力得分函数$s(q, k_i)$后使用$Softmax$函数映射为$\alpha(q, k_i) $。
如果键越接近给定的查询,那么相应的值的权重越大,对该值的倾向性(注意力)也就越大。同时,如果有足够的数据,非参数注意力汇聚会收敛到最优结果。
参数注意力汇聚
参数注意力汇聚即在注意力汇聚中的权重是可以学习的,可以简单的表述为式$\eqref{eq:para_atten}$所示:
\[f(q) = softmax\left(\boldsymbol{s}(q, \boldsymbol{k})\cdot\boldsymbol{\omega}\right)\cdot\boldsymbol{v}^{T} \label{eq:para_atten}\]其中$\omega$为可学习的参数。
注意力得分数
加性注意力
当查询和键是不同长度的矢量时,可以使用加性注意力作为得分函数,如式$\eqref{eq:add_atten}$所示:
\[\begin{equation} s(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R}, \label{eq:add_atten} \end{equation}\]其中,$\mathbf{q} \in \mathbb{R}^q$、$\mathbf{k} \in \mathbb{R}^k$,参数$\mathbf W_q\in\mathbb R^{h\times q}$、$\mathbf W_k\in\mathbb R^{h\times k}$、$\mathbf w_v\in\mathbb R^{h}$可通过网络进行学习。
缩放点积注意力
点积注意力要求查询和键具有相同的长度$d$,其计算如式$\eqref{eq:dot_atten}$所示:
\[\begin{equation} f(\mathbf Q) = softmax\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \label{eq:dot_atten} \end{equation}\]其中$d$为向量的长度。
多头注意力
多头注意力简单来说就是给定相同的查询、键和值的集合,使用$h$组不同的线性映射来变换查询、键和值,然后基于相同的注意力汇聚学习到不同的行为, 最后将得到的不同行为线性加权,产生最终输出。其中每一个注意力汇聚都被称为一个头。
在多头注意力中,通常使用缩放点积的注意力汇聚方式,其计算方式如下所示:
\[\begin{gather} \mathbf Q_i = f_{Q_i}(\mathbf Q), \mathbf K_i = f_{K_i}(\mathbf K), \mathbf V_i = f_{V_i}(\mathbf V)\\ \mathbf h_{i} = softmax\left(\frac{\mathbf Q_i \mathbf K_i^\top }{\sqrt{d}}\right) \mathbf V_i \\ f(\mathbf Q) = \sum\omega_i \mathbf h_i \end{gather}\]其中,$f$为线性映射函数。
自注意力
自注意力的查询、键和值来自同一组输入,只关注序列内信息。
交叉注意力
交叉注意力中的查询来自第一组输入,而键和值来自第二组输入,结合了编码器输出的上下文信息。