Architecture / Sequence Modeling
注意力机制:Q/K/V、softmax 权重与上下文向量
把“看哪里”写成一个可微分的相似度矩阵:query 提问,key 被检索,value 被加权汇总。
Mechanism Lab
动画:Q/K/V 如何把 token 变成上下文向量
动画先把 token 投影成 Q/K/V,再展示一个 query 行如何和所有 key 形成分数矩阵,经过缩放、mask、softmax 后,对 value 做加权求和。
Step 1 / 5
Tokens
输入序列形成隐藏矩阵 H,每一行是一个 token 表示。
H in R^{n x d_model}Animation Control
Reduced-motion users receive the same step states without continuous motion.
01 / 直觉
核心直觉
注意力不是一个神秘模块,而是一次可学习的检索:当前位置的 query 与所有 token 的 key 比较,得到一行权重。
权重经过 softmax 后非负且和为 1,因此输出是 value 向量的凸组合,像一个按相关性加权的上下文摘要。
缩放因子 sqrt(d_k) 的作用是控制点积方差,避免维度变大时 softmax 过早饱和、梯度变小。
mask 决定哪些位置可见:encoder 可看全句,decoder 的因果 mask 只能看过去,padding mask 排除无意义 token。
02 / 数学
从 token 表示推导 scaled dot-product attention
01 / 线性投影
令 H in R^{n x d_model} 表示 n 个 token 的隐藏向量。三组可学习矩阵把同一批 token 投影到 query、key、value 空间。
Q=H W_Q, K=H W_K, V=H W_V02 / 相似度分数
第 i 个 query 与第 j 个 key 的点积给出“i 应该关注 j 多少”的未归一化分数。矩阵形式一次得到全部 pairwise 分数。
S_ij = q_i^T k_j, S = QK^T03 / 为什么除以 sqrt(d_k)
若 q_i 和 k_j 的各维近似独立、均值为 0、方差为 1,则点积的方差随 d_k 线性增长。除以 sqrt(d_k) 后分数方差约为 1。
Var(q_i^T k_j)=d_k -> Var(S_ij/sqrt(d_k))=104 / softmax 归一化
对每个 query 的一行做 softmax。得到的 alpha_ij 非负且每行求和为 1,因此注意力矩阵 A 是 row-stochastic。
A_ij = exp(S_ij)/sum_l exp(S_il)05 / 加权求和
输出 z_i 是所有 value 的加权平均。加上 mask M 后,不可见位置的分数设为 -infinity,softmax 权重变为 0。
Z = softmax(QK^T/sqrt(d_k)+M)V06 / 多头分解
多头注意力并行学习多个相似度空间,再把各头输出拼接。每个头可以捕捉不同关系,如语义对应、位置依赖或表格列关系。
MHA(H)=Concat(head_1,...,head_m)W_O03 / 代码
NumPy 演示:手写 scaled dot-product attention
下面的代码显式计算 Q/K/V、缩放分数、因果 mask、softmax 权重和最终上下文向量。
import numpy as np
import pandas as pd
def softmax(a, axis=-1):
a = a - np.max(a, axis=axis, keepdims=True)
exp = np.exp(a)
return exp / exp.sum(axis=axis, keepdims=True)
def scaled_dot_attention(X, Wq, Wk, Wv, mask=None):
Q = X @ Wq
K = X @ Wk
V = X @ Wv
d_k = Q.shape[-1]
scores = (Q @ K.T) / np.sqrt(d_k)
if mask is not None:
scores = np.where(mask, scores, -1e9)
weights = softmax(scores, axis=-1)
context = weights @ V
return context, weights, scores
tokens = ["policy", "raises", "wages", "jobs"]
rng = np.random.default_rng(7)
X = rng.normal(size=(len(tokens), 6))
Wq = rng.normal(size=(6, 4)) / np.sqrt(6)
Wk = rng.normal(size=(6, 4)) / np.sqrt(6)
Wv = rng.normal(size=(6, 4)) / np.sqrt(6)
# Decoder-style causal mask: token i cannot attend to future tokens j > i.
causal_mask = np.tril(np.ones((len(tokens), len(tokens)), dtype=bool))
context, weights, scores = scaled_dot_attention(X, Wq, Wk, Wv, causal_mask)
print(pd.DataFrame(weights, index=tokens, columns=tokens).round(3))
print("row sums:", weights.sum(axis=1).round(6))
print("context shape:", context.shape)04 / 案例
案例:研究助理如何用注意力连接问题、证据和结论
- 设用户问题是“最低工资政策会怎样影响就业?”模型内部的 query 来自当前位置的任务表示,key/value 来自问题、论文段落、变量定义、回归表和引用上下文。
- 当生成“就业效应取决于识别设计”时,query 会同时比较“minimum wage”“employment”“DID table”“control group”等 key,并把相关 value 加权汇总进当前表示。
- 在表格理解中,一个单元格的 query 可以关注列名、行标签、单位、标准误说明和脚注,从而避免把标准误、系数和样本量混在一起。
- 在实证 Agent 里,注意力本身不是因果证据;它只是表示层的可微检索。真正的可信性仍来自数据来源、识别假设、代码复跑和诊断检验。
05 / 风险
常见误区
参考资料
- Vaswani et al. (2017), Attention Is All You Needhttps://arxiv.org/abs/1706.03762
- NeurIPS Proceedings: Attention Is All You Needhttps://proceedings.neurips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html
- Bahdanau, Cho, and Bengio (2014), Neural Machine Translation by Jointly Learning to Align and Translatehttps://arxiv.org/abs/1409.0473