结合底部的测试用例(batch_size=2, seq_len=10, model_dim=64, num_heads=8),逐步跟踪数据在每个阶段的维度变化和计算逻辑。
零、初始化阶段(__init__)
1
| self.head_dim = model_dim // num_heads
|
每个头的维度 dhead = 8。然后创建四个线性投影层:
1 2 3 4
| self.w_q = nn.Linear(64, 64) self.w_k = nn.Linear(64, 64) self.w_v = nn.Linear(64, 64) self.w_o = nn.Linear(64, 64)
|
注意这里每个投影层的维度是 (model_dim, model_dim) = (64, 64),而不是 (64, 8)。这是因为所有头的投影被合并在一个大矩阵里,后面通过 reshape 来拆分成多个头。这样做的好处是只需要一次矩阵乘法就能完成所有头的投影,效率远高于分别为每个头做投影。
一、线性投影
测试用例中调用方式是 mha(x, x),即 x_query = x_context = x,这是自注意力。
1 2 3
| q = self.w_q(x_query) k = self.w_k(x_query) v = self.w_v(x_query)
|
此时 Q、K、V 的形状都是 [2, 10, 64],包含了所有 8 个头的信息,还没有拆分。
如果是交叉注意力(比如 x_context 来自编码器输出,形状为 [2, 20, 64]),那么 K 和 V 的 seq_len 就是 20 而不是 10,这也完全兼容后续的计算。
二、分头处理(reshape + transpose)
这一步是多头注意力实现的关键技巧:
1
| q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
拆解来看分两步。第一步 view:
1 2
| [2, 10, 64] → [2, 10, 8, 8] batch seq model_dim batch seq heads head_dim
|
这里把最后一个维度 64 拆成了 8 个头 × 每头 8 维。本质上就是把一个 64 维的"大向量"理解为 8 个 8 维的"小向量"。
第二步 transpose(1, 2):
1 2
| [2, 10, 8, 8] → [2, 8, 10, 8] batch seq heads head_dim batch heads seq head_dim
|
把 heads 维移到 seq 前面。这样做的目的是:让 (seq, head_dim) 成为最内层的两个维度,方便后续对每个头独立做矩阵运算。transpose 之后,可以把 [2, 8, 10, 8] 理解为“2 个样本,每个样本有 8 个头,每个头看到 10 个位置,每个位置用 8 维表示”。
K 和 V 也做完全相同的变换,最终:
1 2 3
| q: [2, 8, 10, 8] k: [2, 8, 10, 8] v: [2, 8, 10, 8]
|
三、缩放点积注意力
3.1 计算注意力分数
1
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
先看 k.transpose(-2, -1),这是对最后两个维度做转置:
1 2 3
| k: [2, 8, 10, 8] → k.T: [2, 8, 8, 10] ↑ ↑ head_dim seq
|
然后 q 和 k.T 做矩阵乘法。batch 和 heads 这两个维度作为“批次维度”不参与乘法,实际运算发生在最后两个维度上:
1 2 3 4
| q: [2, 8, 10, 8] k.T: [2, 8, 8, 10] ↓ matmul scores: [2, 8, 10, 10]
|
scores 中每个元素 scores[b][h][i][j] 表示:第 b 个样本中,第 h 个头里,位置 i 的 query 对位置 j 的 key 的点积相似度。
然后除以 √dhead = √8 ≈ 2.83。这里除的是每个头的维度 √8,而不是整个模型维度 √64。
3.2 应用掩码(可选)
1 2
| if mask is not None: scores = scores.masked_fill(mask == 0, -1e9)
|
在自回归语言模型(如 GPT、Qwen)中,会传入一个因果掩码(causal mask),形状为 [1, 1, 10, 10] 的下三角矩阵。mask 为 0 的位置(即未来位置)会被填充为 -1e9(一个极大的负数),经过 softmax 后这些位置的注意力权重就趋近于 0,从而阻止模型“看到未来的 token”。
本测试用例中 mask=None,所以跳过这一步。
3.3 Softmax 归一化
1
| attn_weights = F.softmax(scores, dim=-1)
|
对 scores 的最后一个维度(seq_len_k)做 softmax:
1 2 3
| scores: [2, 8, 10, 10] ↑ 对这个维度 softmax attn_weights: [2, 8, 10, 10] (每行之和 = 1)
|
attn_weights[b][h][i] 是一个长度为 10 的概率分布,表示位置 i 对所有 10 个位置的注意力权重分配。
然后应用 Dropout:
1
| attn_weights = self.dropout(attn_weights)
|
训练时随机将部分注意力权重置零,起正则化作用。
3.4 加权求和
1
| context = torch.matmul(attn_weights, v)
|
用注意力权重对 V 做加权求和:
1 2 3 4
| attn_weights: [2, 8, 10, 10] v: [2, 8, 10, 8] ↓ matmul context: [2, 8, 10, 8]
|
context[b][h][i] 是一个 8 维向量,它是位置 i 根据注意力权重对所有位置的 value 向量做加权平均的结果。如果位置 i 对位置 j 的注意力权重很高,那么位置 j 的 value 向量就会对 context[b][h][i] 贡献更多。
四、合并多头
现在需要把 8 个头的结果拼接回去:
1 2 3
| context = context.transpose(1, 2) context = context.contiguous() output = context.view(batch_size, -1, self.model_dim)
|
transpose(1, 2) 把 heads 和 seq 换回来。contiguous() 是因为 transpose 后张量在内存中可能不连续,而 view 要求内存连续,所以需要先调用 contiguous 重新整理内存布局。最后 view 把 (8 heads, 8 dim) 合并回 64 维。
最后通过输出投影:
1
| output = self.w_o(output)
|
Wᴼ 让不同头学到的信息互相融合。最终输出形状 [2, 10, 64],和输入完全一致,可以无缝接入残差连接和后续层。
五、完整数据流总结
用测试用例(batch=2, seq=10, d_model=64, heads=8, d_head=8):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| 输入 x: [2, 10, 64] ↓ w_q / w_k / w_v 线性投影 Q, K, V: [2, 10, 64] ↓ view + transpose 分头 Q, K, V: [2, 8, 10, 8] (8个头各自独立) ↓ QKᵀ / √8 scores: [2, 8, 10, 10] (每个头的注意力分数) ↓ mask(可选)+ softmax + dropout attn_weights: [2, 8, 10, 10] (归一化的注意力权重) ↓ × V 加权求和 context: [2, 8, 10, 8] (每个头的输出) ↓ transpose + view 合并多头 merged: [2, 10, 64] (拼接回完整维度) ↓ w_o 输出投影 output: [2, 10, 64] (最终输出)
|
整个过程的核心洞察是:通过 view 和 transpose 这两个零计算开销的操作,巧妙地将"多头并行计算"转化为了标准的批量矩阵乘法,让 GPU 能高效并行执行。所有头的计算在同一次 matmul 中完成,完全没有显式的循环。
完整代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
| """ 多头注意力(Multi-Head Attention)
Transformer 的核心组件,通过并行运行多个注意力头来捕捉不同子空间的特征。 每个头独立计算注意力,最后将结果拼接并通过线性层融合。 """
import torch import torch.nn as nn import torch.nn.functional as F import math
class MultiHeadAttention(nn.Module): """ 多头注意力模块
支持自注意力(Self-Attention)和交叉注意力(Cross-Attention): - 自注意力:Q = K = V = x_query - 交叉注意力:Q = x_query, K = V = x_context
Args: model_dim: 模型隐藏维度 num_heads: 注意力头数 dropout_p: Dropout 概率,默认 0.0 """
def __init__(self, model_dim, num_heads, dropout_p=0.0): super().__init__()
assert model_dim % num_heads == 0, "model_dim must be divisible by num_heads"
self.model_dim = model_dim self.num_heads = num_heads self.head_dim = model_dim // num_heads
self.w_q = nn.Linear(model_dim, model_dim) self.w_k = nn.Linear(model_dim, model_dim) self.w_v = nn.Linear(model_dim, model_dim)
self.w_o = nn.Linear(model_dim, model_dim)
self.dropout = nn.Dropout(dropout_p)
def forward(self, x_query, x_context, mask=None): """ 前向传播
Args: x_query: 查询输入 [batch_size, seq_len_q, model_dim] x_context: 上下文输入(用于生成 K 和 V)[batch_size, seq_len_k, model_dim] 如果为 None,则使用 x_query(自注意力) mask: 注意力掩码 [batch_size, 1, seq_len_q, seq_len_k] 或 [1, 1, seq_len_q, seq_len_k]
Returns: output: 注意力输出 [batch_size, seq_len_q, model_dim] """ batch_size = x_query.size(0)
q = self.w_q(x_query)
if x_context is not None: k = self.w_k(x_context) v = self.w_v(x_context) else: k = self.w_k(x_query) v = self.w_v(x_query)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None: scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2) context = context.contiguous() output = context.view(batch_size, -1, self.model_dim)
output = self.w_o(output)
return output
if __name__ == "__main__": x = torch.randn(2, 10, 64) mha = MultiHeadAttention(model_dim=64, num_heads=8) out = mha(x, x) print(f"Input shape: {x.shape}") print(f"Output shape: {out.shape}")
|