DRAFT 实现GPT:多头注意力机制
约 934 字大约 3 分钟
2025-10-30
现代LLM的核心是Transformer架构,而Transformer的核心是(多头)注意力机制。
下面我们来分析如何实现多头注意力机制的代码。
PyTorch 中一个典型的多头注意力(Multi-Head Self-Attention)实现代码如下:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x).view(B, T, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
mask = torch.tril(torch.ones(T, T, device=x.device))
att = att.masked_fill(mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)
out = att @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.proj(out)1. 概念说明
**多头注意力机制(Multi-Head Attention)**是 Transformer 模型的核心模块之一。 它的主要作用是:
- 将输入序列映射为 查询(Q)、键(K)、值(V)。
- 计算注意力分数,让每个位置的表示关注序列中不同位置的信息。
- 将每个注意力头的输出进行合并,并线性映射回原始维度。
参数说明:
d_model:输入 embedding 的维度。n_heads:注意力头的数量。- 每个头的维度:
head_dim = d_model // n_heads。
2. 代码分析
2.1 初始化(__init__ 方法)
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)qkv:一次性将输入映射为 Q、K、V,输出维度是3*d_model。proj:将多头注意力的输出重新投影回原始维度d_model。head_dim:每个头的维度,保证能均分d_model。
2.2 前向传播(forward 方法)
1) 输入形状
B, T, C = x.shapeB:batch size(一次同时处理多少个序列)T:Time steps / Sequence Length 时间步长 / 序列长度(一个序列中有多少个token)C:Channels / Embedding dimension 通道数 / 嵌入维度(每个token用多少维的向量表示)
2) 生成 Q、K、V
qkv = self.qkv(x).view(B, T, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)self.qkv(x)输出(B, T, 3*d_model)。.view(B, T, 3, n_heads, head_dim)将其维度重构,原来的最后一个维度分为三部分。unbind(dim=2)拆分为q, k, v三个张量,每个张量形状为(B, T, n_heads, head_dim)。
3) 计算注意力分数
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)- Q 和 K 的点积表示相关性。
- 除以
sqrt(head_dim)进行缩放,防止数值过大导致 softmax 梯度消失。 - 输出形状:
(B, n_heads, T, T)。
4) 施加因果掩码(Causal Mask)
mask = torch.tril(torch.ones(T, T, device=x.device))
att = att.masked_fill(mask == 0, float('-inf'))- 下三角掩码保证每个位置只能关注当前位置及之前的 token。
- 这是自回归生成模型(如 GPT)的必要步骤。
5) Softmax 归一化
att = F.softmax(att, dim=-1)- 将注意力分数转换为概率分布。
- 注意力矩阵每行的和为 1。
6) 应用注意力权重到 V
out = att @ v- 输出形状
(B, n_heads, T, head_dim)。 - 每个头的输出是对序列中值的加权求和。
7) 合并头并线性映射
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.proj(out)transpose调整维度为(B, T, n_heads, head_dim)。view(B, T, C)将多头拼接成原始维度。proj将拼接后的输出映射回d_model,作为最终输出。
3. 总结
多头注意力机制的核心流程可以概括为:
- 线性映射生成 Q、K、V
- 计算缩放点积注意力
- 施加掩码(可选)
- Softmax 获得注意力权重
- 将注意力权重应用于 V
- 合并多头输出并投影
通过多头注意力,模型可以在不同子空间捕捉序列中不同位置的依赖关系,增强表示能力。
