从几何视角手撕Self-Attention:用Python实现与三维可视化理解
在深度学习领域,Transformer架构已经成为自然语言处理、计算机视觉等领域的基石。而Self-Attention机制作为Transformer的核心组件,其重要性不言而喻。然而,很多学习者在初次接触Self-Attention时,往往会被Q(Query)、K(Key)、V(Value)这三个矩阵搞得晕头转向,陷入公式记忆的泥潭。本文将带你从几何角度重新审视Self-Attention,通过Python代码实现和三维可视化,让你真正理解其本质。
1. 向量投影:理解Self-Attention的几何基础
要理解Self-Attention,首先需要掌握向量内积的几何意义。两个向量的内积,本质上是一个向量在另一个向量方向上的投影长度与被投影向量长度的乘积。这个简单的几何概念,正是Self-Attention机制的核心。
让我们用NumPy来实现向量内积,并可视化这一过程:
import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # 定义两个三维向量 v1 = np.array([2, 1, 0]) v2 = np.array([1, 3, 0]) # 计算内积 dot_product = np.dot(v1, v2) print(f"向量v1和v2的内积为: {dot_product}") # 可视化 fig = plt.figure(figsize=(10, 8)) ax = fig.add_subplot(111, projection='3d') # 绘制向量 ax.quiver(0, 0, 0, v1[0], v1[1], v1[2], color='r', label='向量v1', arrow_length_ratio=0.1) ax.quiver(0, 0, 0, v2[0], v2[1], v2[2], color='b', label='向量v2', arrow_length_ratio=0.1) # 计算v1在v2上的投影向量 projection_length = dot_product / np.linalg.norm(v2) projection_vector = (dot_product / np.dot(v2, v2)) * v2 ax.quiver(0, 0, 0, projection_vector[0], projection_vector[1], projection_vector[2], color='g', linestyle='dotted', label='v1在v2上的投影', arrow_length_ratio=0.1) ax.set_xlim([0, 3]) ax.set_ylim([0, 3]) ax.set_zlim([0, 1]) ax.set_xlabel('X轴') ax.set_ylabel('Y轴') ax.set_zlabel('Z轴') ax.legend() plt.title("向量内积的几何意义:投影") plt.show()这段代码展示了两个三维向量的内积计算及其几何意义。从可视化结果中,我们可以直观地看到:
- 红色向量v1在蓝色向量v2上的绿色投影
- 内积值的大小反映了两个向量的相似程度
- 当两个向量夹角为90度时,内积为零,表示完全不相关
提示:在实际的Self-Attention中,这种投影关系决定了不同位置之间的注意力权重。相似的向量(夹角小)会获得更高的注意力分数。
2. 从单词嵌入到注意力矩阵:一步步构建Self-Attention
现在,让我们把这些几何概念扩展到多个向量组成的矩阵运算中。假设我们有一个简单的句子"深度学习有趣",经过嵌入层后,我们得到三个词向量:
# 三个词的嵌入向量 (维度为4) deep = np.array([1.0, 0.5, 0.2, -0.3]) learning = np.array([0.3, 1.2, -0.7, 0.8]) fun = np.array([0.8, -0.2, 1.1, 0.4]) # 组成输入矩阵X (3个词x4维) X = np.vstack([deep, learning, fun]) print("输入矩阵X:") print(X)Self-Attention的第一步是计算注意力分数,也就是词与词之间的相关性。这通过查询向量(Query)和键向量(Key)的点积来实现:
# 初始化随机权重矩阵 (在实际训练中这些是学习得到的) W_Q = np.random.randn(4, 3) # 查询变换矩阵 W_K = np.random.randn(4, 3) # 键变换矩阵 W_V = np.random.randn(4, 3) # 值变换矩阵 # 计算Q, K, V Q = X @ W_Q K = X @ W_K V = X @ W_V print("\n查询矩阵Q:") print(Q) print("\n键矩阵K:") print(K) print("\n值矩阵V:") print(V)接下来,我们计算注意力分数矩阵:
# 计算注意力分数 (未缩放) attention_scores = Q @ K.T print("\n未缩放的注意力分数矩阵:") print(attention_scores) # 缩放注意力分数 d_k = K.shape[1] # 键向量的维度 scaled_attention = attention_scores / np.sqrt(d_k) print("\n缩放后的注意力分数矩阵:") print(scaled_attention)缩放操作是为了防止点积结果过大导致softmax函数进入梯度饱和区。然后我们应用softmax归一化:
# 应用softmax得到注意力权重 attention_weights = np.exp(scaled_attention) / np.sum(np.exp(scaled_attention), axis=1, keepdims=True) print("\n注意力权重矩阵:") print(attention_attention_weights)最后,我们用注意力权重对值向量V进行加权求和:
# 计算加权和 output = attention_weights @ V print("\nSelf-Attention输出:") print(output)这个输出矩阵就是经过Self-Attention机制处理后的新表示,其中每个词向量都包含了句子中其他相关词的信息。
3. 三维可视化:注意力权重的几何解释
为了更直观地理解上述过程,我们可以将矩阵运算可视化。让我们创建一个三维图来展示查询、键和值向量的关系:
# 可视化Q, K, V的关系 fig = plt.figure(figsize=(18, 6)) # 查询向量可视化 ax1 = fig.add_subplot(131, projection='3d') for i, vec in enumerate(Q): ax1.quiver(0, 0, 0, vec[0], vec[1], vec[2], color=['r', 'g', 'b'][i], label=f'Query {i}', arrow_length_ratio=0.1) ax1.set_title('查询向量(Q)') ax1.legend() # 键向量可视化 ax2 = fig.add_subplot(132, projection='3d') for i, vec in enumerate(K): ax2.quiver(0, 0, 0, vec[0], vec[1], vec[2], color=['r', 'g', 'b'][i], label=f'Key {i}', arrow_length_ratio=0.1) ax2.set_title('键向量(K)') ax2.legend() # 注意力权重可视化 ax3 = fig.add_subplot(133) cax = ax3.matshow(attention_weights) fig.colorbar(cax) ax3.set_xticks([0, 1, 2]) ax3.set_yticks([0, 1, 2]) ax3.set_xticklabels(['deep', 'learning', 'fun']) ax3.set_yticklabels(['deep', 'learning', 'fun']) ax3.set_title('注意力权重热力图') plt.tight_layout() plt.show()从可视化中我们可以观察到:
- 查询向量和键向量的方向决定了注意力权重的大小
- 相似方向的查询和键向量会产生更高的注意力分数
- 热力图直观展示了词与词之间的关注程度
4. 完整Self-Attention层的Python实现
现在,我们将上述步骤整合成一个完整的Self-Attention类实现:
class SelfAttention: def __init__(self, input_dim, d_k, d_v): """ 初始化Self-Attention层 :param input_dim: 输入向量的维度 :param d_k: 查询和键的维度 :param d_v: 值的维度 """ self.d_k = d_k # 初始化权重矩阵 self.W_Q = np.random.randn(input_dim, d_k) self.W_K = np.random.randn(input_dim, d_k) self.W_V = np.random.randn(input_dim, d_v) def forward(self, X): """ 前向传播 :param X: 输入矩阵 (seq_len, input_dim) :return: 输出矩阵 (seq_len, d_v) """ # 计算Q, K, V Q = X @ self.W_Q K = X @ self.W_K V = X @ self.W_V # 计算注意力分数 attention_scores = Q @ K.T / np.sqrt(self.d_k) # 应用softmax attention_weights = np.exp(attention_scores) / np.sum(np.exp(attention_scores), axis=1, keepdims=True) # 计算输出 output = attention_weights @ V return output, attention_weights # 使用示例 input_dim = 4 # 输入维度 d_k = 3 # 查询和键的维度 d_v = 3 # 值的维度 attention = SelfAttention(input_dim, d_k, d_v) output, attn_weights = attention.forward(X) print("\n完整Self-Attention层输出:") print(output) print("\n对应的注意力权重:") print(attn_weights)这个实现包含了Self-Attention的所有关键步骤:
- 线性变换生成Q、K、V
- 计算缩放的点积注意力分数
- 应用softmax归一化
- 对值向量进行加权求和
5. 多头注意力:扩展Self-Attention的能力
单头注意力有时可能无法捕捉输入的不同方面的特征。多头注意力通过并行运行多个注意力头来解决这个问题:
class MultiHeadAttention: def __init__(self, input_dim, d_model, num_heads): """ 初始化多头注意力层 :param input_dim: 输入维度 :param d_model: 模型维度 :param num_heads: 注意力头数量 """ assert d_model % num_heads == 0, "d_model必须能被num_heads整除" self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads self.d_v = d_model // num_heads # 初始化权重矩阵 self.W_Q = np.random.randn(input_dim, d_model) self.W_K = np.random.randn(input_dim, d_model) self.W_V = np.random.randn(input_dim, d_model) self.W_O = np.random.randn(d_model, d_model) def split_heads(self, x, batch_size): """ 将输入分割为多个头 """ x = x.reshape(batch_size, -1, self.num_heads, self.d_k) return x.transpose(0, 2, 1, 3) # (batch_size, num_heads, seq_len, d_k) def forward(self, X): batch_size = X.shape[0] # 线性变换 Q = X @ self.W_Q # (batch_size, seq_len, d_model) K = X @ self.W_K V = X @ self.W_V # 分割多头 Q = self.split_heads(Q, batch_size) K = self.split_heads(K, batch_size) V = self.split_heads(V, batch_size) # 计算缩放点积注意力 attention_scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(self.d_k) attention_weights = np.exp(attention_scores) / np.sum(np.exp(attention_scores), axis=-1, keepdims=True) output = attention_weights @ V # (batch_size, num_heads, seq_len, d_v) # 合并多头 output = output.transpose(0, 2, 1, 3) # (batch_size, seq_len, num_heads, d_v) output = output.reshape(batch_size, -1, self.d_model) # 最终线性变换 output = output @ self.W_O return output, attention_weights # 使用示例 (假设批量大小为1) X_batch = np.expand_dims(X, axis=0) # 添加批次维度 multihead_attn = MultiHeadAttention(input_dim=4, d_model=6, num_heads=2) output, attn_weights = multihead_attn.forward(X_batch) print("\n多头注意力输出:") print(output[0]) # 去掉批次维度 print("\n其中一个注意力头的权重:") print(attn_weights[0, 0]) # 第一个批次的第一个头多头注意力的优势在于:
- 每个头可以学习关注输入的不同方面
- 并行计算提高了效率
- 增强了模型的表达能力
6. Self-Attention的实际应用与优化
理解了Self-Attention的基本原理后,让我们看看在实际应用中需要考虑的一些优化和变体:
6.1 掩码自注意力
在处理序列数据时,我们经常需要防止当前位置关注到未来的位置。这可以通过注意力掩码实现:
def masked_self_attention(X): Q = X @ W_Q K = X @ W_K V = X @ W_V # 计算注意力分数 attention_scores = Q @ K.T / np.sqrt(d_k) # 创建掩码 (下三角矩阵) seq_len = X.shape[0] mask = np.tril(np.ones((seq_len, seq_len))) masked_scores = attention_scores * mask - 1e10 * (1 - mask) # 应用softmax attention_weights = np.exp(masked_scores) / np.sum(np.exp(masked_scores), axis=1, keepdims=True) # 计算输出 output = attention_weights @ V return output, attention_weights masked_output, masked_weights = masked_self_attention(X) print("\n掩码自注意力输出:") print(masked_output) print("\n掩码注意力权重:") print(masked_weights)6.2 相对位置编码
原始的Self-Attention不包含位置信息,可以通过添加位置编码来注入序列顺序信息:
def positional_encoding(seq_len, d_model): position = np.arange(seq_len)[:, np.newaxis] div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model)) pe = np.zeros((seq_len, d_model)) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) return pe # 添加位置编码 pe = positional_encoding(X.shape[0], X.shape[1]) X_pe = X + pe # 使用带位置编码的输入 output_pe, weights_pe = attention.forward(X_pe) print("\n带位置编码的Self-Attention输出:") print(output_pe)6.3 计算效率优化
对于长序列,标准的Self-Attention计算复杂度为O(n²),可以采用以下优化策略:
| 优化方法 | 原理 | 复杂度 | 适用场景 |
|---|---|---|---|
| 稀疏注意力 | 只计算部分位置的注意力 | O(n√n) | 长序列处理 |
| 局部注意力 | 限制注意力窗口大小 | O(nk) | 局部相关性强的数据 |
| 低秩近似 | 使用低秩矩阵近似注意力 | O(n) | 对精度要求不高的场景 |
| 内存压缩 | 减少中间存储需求 | O(n) | 内存受限环境 |
在实际项目中,根据具体需求选择合适的优化策略可以显著提高模型效率。