TODO: 这里写 PagedAttention 的核心抽象:block/page、block table、逻辑 token 到物理 KV block 的映射。

基础:tensor 级拆请求的形状(大量细节)

定义符号:$B$ 是 batch size,$T$ 是 seq_len,$D$ 是 token_dim,$d_q$ 是把 embedding token 投影到 $Q$ 后的维度。

推理框架拿到的请求是:\(R \in \mathbb{R}^{B \times T}\)。

$R_{b,t}$ 是一个最最基本的 token id 标量。

raw 请求经过 embedding lookup,做的操作是把这个 token 标量映射成一个高维向量。假设原先 token 是 1234 这个标量,现在就把 token 映射成 [0.1, 0.2, 0.3, 0.4] 这样的向量。

所以 $R$ 经过 embedding lookup 之后,得到:\(X \in \mathbb{R}^{B \times T \times D}\)。

因为我们目前只考虑推理场景,所以把 $W_Q$、$W_K$、$W_V$ 之类的矩阵当成固定的模型参数。

然后很多博客会直接写:\(Q = XW_Q\)。

这样写很容易产生误解,因为这好像是把整个原始请求丢给了 $Q$ 去投影。从感觉上这也没问题,但是对深入理解帮助比较小。

实际上注意力投影的最小粒度是每个 token embedding vector,也就是 \(X_{b,t,:}\)。

所以在最小粒度的视图下,$x$ 的形状是 \(x \in \mathbb{R}^{D}\)。

如果没有多头注意力,$W_Q$ 的形状就是 \(W_Q \in \mathbb{R}^{D \times d_q}\)。

在最常见的单头简化设定下,可以取 $d_q = D$。为了让 $q$ 和 $k$ 能做点积,通常需要 $d_q = d_k$。

那么投影后,对于每个 token 的 $q$ 就变成了 $[d_q]$ 形状。$K/V$ 也是同理。单头下最小的逐 token 操作粒度如下:

公式视角是:\(q = xW_Q\),\(k = xW_K\),\(v = xW_V\)。

q = x @ W_Q   [d_q]
k = x @ W_K   [d_k]
v = x @ W_V   [d_v]

每个 embedding token 被投影成三个向量 $q$、$k$、$v$。

将其扩展到整个 batch + 整个 seq:

公式视角是:\(Q = XW_Q\),\(K = XW_K\),\(V = XW_V\)。

Q = X @ W_Q   [B, T, d_q]
K = X @ W_K   [B, T, d_k]
V = X @ W_V   [B, T, d_v]

这里“扩展”的意思是:对于每一个 batch index $b$ 和每一个位置 $t$,都执行一次同样的逐 token 投影。固定某个 $b,t$ 后,投影结果是:

公式视角是:\(Q_{b,t,:} = X_{b,t,:}W_Q\),\(K_{b,t,:} = X_{b,t,:}W_K\),\(V_{b,t,:} = X_{b,t,:}W_V\)。

Q[b, t, :]   [d_q]
K[b, t, :]   [d_k]
V[b, t, :]   [d_v]

写成代码就是:

for b in range(B):
    for t in range(T):
        Q[b, t, :] = X[b, t, :] @ W_Q

这里说明了一件事情:对于不同 batch 且不同位置的 token,都是用同一个 $W_Q$ 去投影的。

这是合理的,因为 $W_Q$ 表示的是这一层学到的通用投影规则。一个 embedding token 如何被投影,本质上取决于当前层这个位置的 hidden vector $X_{b,t,:}$ 和该层共享的 $W_Q$。

batch index $b$ 本身不携带语义,它只是并行计算时的编号。真正决定“这是哪个请求”的,是 $X_{b,:,:}$ 里的 token 内容。

position index $t$ 也不是通过换一套 $W_Q$ 来发挥作用的。位置信息通常已经通过 position embedding、RoPE、causal mask,或者前面层的上下文化 hidden state 体现在 $X_{b,t,:}$ 或 attention 计算里。

所以我们不需要给不同 batch 或不同位置各自准备一套 $W_Q$。否则参数量会随着 batch size 或 sequence length 增长,而且会破坏 Transformer 对不同位置共享同一套 token 处理规则的设计。

在计算注意力分数的时候,依然从逐 token 去看:如定义所述,就是让一个 batch 内的第 $i$ 个 token 的 $q$ 和第 $j$ 个 token 的 $k$ 做一次乘法。所以注意力分数矩阵应该是如下形状:

[B, T, T]

现代的模型有很多都用了 GQA。GQA 的改动简单来说就是让一个 token 被多个 $Q$ 和多个 $KV$ 去投影。$Q$ 的数量是 $KV$ 的整数倍。既然有整数倍就有组的对应,这里不做赘述。

定义符号:$H_Q$ 是 $Q$ 的组数,$H_{KV}$ 是 $KV$ 的组数。

相应也得修改 $W_Q/W_K/W_V$ 的形状。这里有一个非常重要的工程细节:线性层投影出来的结果,通常不是直接得到 [head_num, head_dim] 的二维结构,而是先得到一个扁平的一维向量,然后再 reshape 成 head 形式。

从逐 token 视角看:

公式视角是:\(W_Q \in \mathbb{R}^{D \times (H_Q d_h)}\),\(W_K \in \mathbb{R}^{D \times (H_{KV} d_h)}\),\(W_V \in \mathbb{R}^{D \times (H_{KV} d_h)}\)。

x: [D]

W_Q: [D, H_Q  * d_h]
W_K: [D, H_KV * d_h]
W_V: [D, H_KV * d_h]

所以先得到 raw 投影结果:

公式视角是:\(q_{\text{raw}} = xW_Q\),\(k_{\text{raw}} = xW_K\),\(v_{\text{raw}} = xW_V\)。

q_raw = x @ W_Q   [H_Q  * d_h]
k_raw = x @ W_K   [H_KV * d_h]
v_raw = x @ W_V   [H_KV * d_h]

这里 [H_Q * d_h] 中间是乘法,表示这是一个长度为 $H_Q \times d_h$ 的一维向量。它还没有显式拆成 head 维度。

然后再 reshape:

公式视角是:\(q = \operatorname{reshape}(q_{\text{raw}}, [H_Q, d_h])\),\(k = \operatorname{reshape}(k_{\text{raw}}, [H_{KV}, d_h])\),\(v = \operatorname{reshape}(v_{\text{raw}}, [H_{KV}, d_h])\)。

q = reshape(q_raw, [H_Q,  d_h])
k = reshape(k_raw, [H_KV, d_h])
v = reshape(v_raw, [H_KV, d_h])

也就是:

q: [H_Q,  d_h]
k: [H_KV, d_h]
v: [H_KV, d_h]

所以 [H_Q * d_h][H_Q, d_h] 的元素总数相同,但含义不同:

[H_Q * d_h] 是一个扁平向量;
[H_Q, d_h] 是已经拆成 H_Q 个 head、每个 head 是 d_h 维的二维结构。

$q$、$k$、$v$ 也都变成了一组向量。但这个时候 $q$ 和 $kv$ 的向量形状不一定相等了。

计算注意力分数的时候,注意力矩阵从原来的 $[B,T,T]$ 变成:

[B, T, T, H_Q]

注意,这里计算完 $KV$ 之后,缓存的 KVCache 其实就是:

k: [H_KV, d_h]
v: [H_KV, d_h]

如果是 MLA,还有一些变化。这里不纠结具体的数学原理,只看关键的形状部分。

MLA 不直接缓存 $k$ 和 $v$。

还是从逐 token 的视角去看。对于一个 token:

公式视角是:\(x = X_{b,t,:}\)。

x = X[b, t, :]
x: [D]

先得到 KV latent:\(c_{kv} = xW_{DKV}\)。

其中:

W_DKV: [D, d_c]
c_kv:  [d_c]

所以 MLA 推理时核心缓存的是:

C_KV_cache per token: [d_c]

扩展到 batch + seq:

C_KV_cache: [B, T, d_c]

如果为了理解,把 latent 显式还原成 $K/V$,可以写成:

公式视角是:\(k_{\text{raw}} = c_{kv}W_{UK}\),\(v_{\text{raw}} = c_{kv}W_{UV}\)。

k_raw = c_kv @ W_UK   [H_Q * d_h]
v_raw = c_kv @ W_UV   [H_Q * d_h]

W_UK: [d_c, H_Q * d_h]
W_UV: [d_c, H_Q * d_h]

注意这里同样是先得到扁平向量 [H_Q * d_h],再 reshape:

公式视角是:\(k = \operatorname{reshape}(k_{\text{raw}}, [H_Q, d_h])\),\(v = \operatorname{reshape}(v_{\text{raw}}, [H_Q, d_h])\)。

k = reshape(k_raw, [H_Q, d_h])
v = reshape(v_raw, [H_Q, d_h])

所以:

k: [H_Q, d_h]
v: [H_Q, d_h]

$Q$ 可以先简化理解为仍然从 $x$ 直接投影出来:

公式视角是:\(q_{\text{raw}} = xW_Q\),\(q = \operatorname{reshape}(q_{\text{raw}}, [H_Q, d_h])\)。

q_raw = x @ W_Q   [H_Q * d_h]
q = reshape(q_raw, [H_Q, d_h])

于是注意力分数仍然是:

[B, T, T, H_Q]

其实这里可以把 GQA 纳入到 MLA 的框架中去,因为 MLA 是通过 $W_{UK}$ 和 $W_{UV}$ 两个矩阵去从 $C_{KV}$ 里面还原出 $KV$,这是基于一个假设:$C_{KV}$ 有能力保存 $K$ 和 $V$ 的低秩压缩。

那么假设我们这里将 $C_{KV}$ 变得足够大,使其刚好等于 $K$ 和 $V$,并且让 $W_{UK}$ 和 $W_{UV}$ 矩阵仅做选择功能,这个时候,只要我们对应好 $Q$ 和 $KV$ 的 group 关系,其实也能在 MLA 的一套框架下实现出一个 GQA。我希望后续的工程设计上能统一这部分的代码设计。

基础:自回归时的数据流形态

用语言描述这个过程很简单。笔者这里以某一层为例:prefill 阶段计算注意力和缓存 KV,得到一个 [B, T, T, H_Q] 的注意力分数矩阵和两个 [B, T, H_KV, d_h] 的 KV 矩阵,并且取计算最后一个 token 时的输出 logits 作为新的 token。

decode 阶段则是重复做如下操作:

计算新 token qkv
-> 写 kvcache
-> 计算注意力分数
-> 写注意力分数
-> 取模型输出的 logits

直到输出 EOF。

如果要考虑多层的话,attention score 和 KVCache 的存储 tensor 形状还要加一个 layer 的维度。

这里还有一个需要强调的点,就是工程实现上不会直接保存 [B, T, T, H_Q] 的注意力分数矩阵,因为有 FA、PA 这种优化在。

所以每一步 decode,数据和流向就是这样:

input token:
X_new: [B, 1, D]

output qkv:
Q_new: [B, 1, H_Q,  d_h]  -- 用于计算 logits
K_new: [B, 1, H_KV, d_h]  -- 写入 kvcache
V_new: [B, 1, H_KV, d_h]  -- 写入 kvcache

公式视角是:\(Q_{\text{new}} = X_{\text{new}}W_Q\),\(K_{\text{new}} = X_{\text{new}}W_K\),\(V_{\text{new}} = X_{\text{new}}W_V\)。

如果是 MLA,$k$ 和 $v$ 会被合并成一个东西。所以工程上可以考虑把 abstract kv cache info 作为一个抽象的实现基类,而多态体现在如何得到 abstract kv cache info 上。

但是笔者目前除了 GQA 系列和 MLA,还没看过其他的 QKV 映射方案。所以这部分的设计感觉还是不够 general,等待后续增加见识后确定如何设计再动手优化工程设计。

此外 MLA还涉及很多复杂的实现细节,比如QK的hiddendim需要加上rope,而v不需要。此外还有MLA的各种变体,例如TPA,MFA等。这些都要等笔者了解之后,思考一下MLA类注意力的kv是在什么地方做创新的,再去给出最好的工程实现。这个部分会单开一部分讲。

paged attention原理

前面铺垫了很久,到这里解释起来就非常容易。 前面说了,生成出来的KV Tensor形状是[L,B,T,HKV,DH],但是这只是理想情况,考虑一下实际场景会发生什么。我们可以重新审视这个形状,KV Tensor的形状过于规整了,这是因为理想中这个KV Tensor的每个batch的每个请求都有着相同的T。但实际上并不是。如果静态按照这个形状去分配,显然会很浪费空间。那么很自然就会想到动态扩容。也就是给每个request一个很小的T_init,如果不够长再去分配。但这个代码写出来效率不高。

伪代码大概要做如下几步: if len == capacity new_k = allocate(capacityfactor) new_v = allocate(capacityfactor)

    copy old kv to new kv
    free old kv

    crud处理其他引用

连续的allocate在大容量的时候会很费时间。而且factor也不好选择。 pa的办法是不以一个KV为分配单位,而是以KV_block_size大小分配。如果一个请求的token长度是68,block_size的大小是32,则消耗3个block。这样最多浪费一个block的空间。如果只看某一层的 K cache 或 V cache,一个 block 的形状可以理解为 [block_size, HKV, DH],元素数是 block_size * HKV * DH。K 和 V 两份合起来就是 2 * block_size * HKV * DH。(所以为什么不叫blockattention呢)pa只有划分block还不够,pa还允许了这些block之间在物理上可以不连续,而只需要一个table保存他们之间的顺序即可。

核心思路说完了,但是还是有一些实现的细节,首先table在哪里构建?table放在哪个结构体?

未完待续。