本文是 LLM Theory 下 SMD 专题的第一篇,是关于 Spherical Motion Dynamics(SMD)的学习笔记,主要参考 Wan et al. 的 Spherical Motion Dynamics: Learning Dynamics of Neural Network with Normalization, Weight Decay, and SGD 以及 NeurIPS 2021 版本 Spherical Motion Dynamics: Learning Dynamics of Normalized Neural Network using SGD and Weight Decay

朴素的描述模型更新

对于本科一二年级时候的笔者而言,如果想要描述模型的更新量,那么只会考虑这个非常直接的东西:$\lVert \boldsymbol{W}_{t+1} - \boldsymbol{W}_{t} \rVert$。 这确实是很直白的,相当于计算了模型权重的欧氏距离。是什么让我们必须放弃这种非常直观的欧氏距离呢?源于常用的归一化。 加上归一化之后,对于模型的输出,$y = \operatorname{BN}(x, k\boldsymbol{W}) = \operatorname{BN}(x, \boldsymbol{W})$,看不出模型尺度对 $y$ 的影响。但是却会从 $\lVert k\boldsymbol{W}_{t+1} - k\boldsymbol{W}_{t} \rVert$ 这一度量手段上产生 $k$ 倍的差距。 这就发现一个很显然的问题了:我们是希望通过观测类似于 $\lVert \boldsymbol{W}_{t+1} - \boldsymbol{W}_{t} \rVert$ 这样的东西来控制模型的训练,而不是直接看最后的 $y$。因为只通过看 $y$ 的变化,并基于这种变化的规律,去指导如何进行训练前的设置,这件事并不容易。我们是希望找到一个更可以写出明确表达式,更容易观测,意义更明确的指标去观测,而这个指标恰好还要有一些规律和 $y$ 的规律“趋同”,这样我们就得到了一个 $y$ 的近似物,而且这个近似物比 $y$ 更好分析。所以我们会很直接的想到 $\lVert \boldsymbol{W}_{t+1} - \boldsymbol{W}_{t} \rVert$。但是现在,当我们给模型加上BN(为什么加BN不赘述了)以后发现的问题是 $\lVert \boldsymbol{W}_{t+1} - \boldsymbol{W}_{t} \rVert$ 会随着模型尺度产生成倍的差距,而 $y$ 却不会被尺度影响。这说明 $\lVert \boldsymbol{W}_{t+1} - \boldsymbol{W}_{t} \rVert$ 不是一个好的指标,因为它错误地预测了 $y$ 随着尺度变化的行为规律。 当然,原则上,我们其实可以发现,上面的地方有一个逻辑漏洞:为什么权重变化一定能预测 $y$ 的变化,这是错误的呀,权重变了 $y$ 不一定变。所以 $\lVert \boldsymbol{W}_{t+1} - \boldsymbol{W}_{t} \rVert$ 这个东西失效是情理之中的,不能先入为主的认为 $\lVert \boldsymbol{W}_{t+1} - \boldsymbol{W}_{t} \rVert$ 的规律能和 $y$ 趋同,那么出现上面的问题十分合理。我们一个理想中的目标是通过权重变化去预测 $y$ 的变化,这样我们就可以通过设计一套权重变化的方案去精确得到想要的 $y$。理想是很远大的,现实会比较难做,但是可以在最朴素的方案上向前不断推进预测的细化程度。SMD就是这样一种东西。

更可解释的模型更新–SMD方法论

笔者决定先把基于SMD分析模型更新量的方案给出,然后说明这种方案的合理性。 如果某一层权重是尺度无关的,比如BN层。记作 $\boldsymbol{W}$。先把 $\boldsymbol{W}$ 写成 $\boldsymbol{W}_t = r_t \boldsymbol{U}_t$,$\boldsymbol{U}_t$ 是单位的权重。

\[ r_t = \lVert \boldsymbol{W}_t \rVert \]

设优化器给出的更新为:

\[ \boldsymbol{W}_{t+1} = \boldsymbol{W}_t + \boldsymbol{\delta}_t \]

这个时候要对 $\boldsymbol{\delta}_t$ 做一次关于 $\boldsymbol{W}_t$ 方向的分解:

\[ \boldsymbol{\delta}_t = \boldsymbol{\delta}_r + \boldsymbol{\delta}_u \]

其中 $\boldsymbol{\delta}_r$ 是 $\boldsymbol{W}_t$ 的径向方向更新,$\boldsymbol{\delta}_u$ 是 $\boldsymbol{W}_t$ 的切向方向更新:

\[ \boldsymbol{\delta}_r = \left\langle \boldsymbol{\delta}_t, \boldsymbol{U}_t \right\rangle \boldsymbol{U}_t \]

那么我们主要分析的其实是两点,一点是 $\boldsymbol{\delta}_r$ 要稳定,另一点是 $\boldsymbol{\delta}_u$ 是否真的有改变。 因为 $L(k\boldsymbol{W}) = L(\boldsymbol{W})$(源于尺度无关的假设),而 $k\boldsymbol{W}$ 相对 $\boldsymbol{W}$ 也只是在 $\boldsymbol{W}$ 的切向做了变换,这说明在尺度无关模型模型改变量只是 $\boldsymbol{W}$ 径向的改变量。相较于之前的欧氏距离,我们实际上对改变量这个事情进行了细化,细化到了 $\boldsymbol{W}$ 的径向方向。

用SMD分析朴素的SGD

注意,所有分析都有一个前提,就是这个层加了BN或者其他能让尺度无关现象发生的操作。 对于朴素 SGD:

\[ \boldsymbol{W}_{t+1} = \boldsymbol{W}_{t} - \eta \boldsymbol{g}_{t} \]

由于 normalization 导致:

\[ \boldsymbol{g}_{t} \perp \boldsymbol{W}_{t} \]

所以:

\[ \boldsymbol{\delta}_t = -\eta \boldsymbol{g}_{t} \]

是纯切向更新。

因此:

\[ \boldsymbol{\delta}_{r} = \boldsymbol{0}, \qquad \boldsymbol{\delta}_{u} = -\eta \boldsymbol{g}_{t} \]

于是:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 = \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \left\lVert \boldsymbol{\delta}_{u} \right\rVert^2 \]

也就是:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 = \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \eta^2 \left\lVert \boldsymbol{g}_{t} \right\rVert^2 \]

朴素 SGD 没有径向收缩项 $\boldsymbol{\delta}_{r}$,只有切向更新 $\boldsymbol{\delta}_{u}$。但切向更新虽然改变方向,却会二阶地增加半径。

角更新则是:

\[ \Delta_t \approx \frac{\left\lVert \boldsymbol{\delta}_{u} \right\rVert} {\left\lVert \boldsymbol{W}_{t} \right\rVert} \]

对于朴素 SGD:

\[ \Delta_t \approx \frac{\eta \left\lVert \boldsymbol{g}_{t} \right\rVert} {\left\lVert \boldsymbol{W}_{t} \right\rVert} \]

再用 scale-invariant gradient:

\[ \boldsymbol{g}_{t} = \frac{1}{\left\lVert \boldsymbol{W}_{t} \right\rVert} \tilde{\boldsymbol{g}}_{t} \]

得到:

\[ \Delta_t \approx \frac{\eta \left\lVert \tilde{\boldsymbol{g}}_{t} \right\rVert} {\left\lVert \boldsymbol{W}_{t} \right\rVert^2} \]

用SMD分析SGD with Weight Decay

有了朴素 SGD 的铺垫之后,SGD + Weight Decay 就很好拆了。

先写更新式:

\[ \boldsymbol{W}_{t+1} = \boldsymbol{W}_{t} - \eta\left( \boldsymbol{g}_{t} + \lambda \boldsymbol{W}_{t} \right) \]

这里的 $\boldsymbol{g}_{t}$ 是任务 loss 对 $\boldsymbol{W}_t$ 的梯度,先不把 WD 算进 loss 里面。展开之后就是:

\[ \boldsymbol{W}_{t+1} = \boldsymbol{W}_{t} - \eta\boldsymbol{g}_{t} - \eta\lambda\boldsymbol{W}_{t} \]

所以总更新量是:

\[ \boldsymbol{\delta}_t = \boldsymbol{W}_{t+1} - \boldsymbol{W}_{t} = - \eta\boldsymbol{g}_{t} - \eta\lambda\boldsymbol{W}_{t} \]

因为尺度无关带来的正交性质:

\[ \left\langle \boldsymbol{W}_{t}, \boldsymbol{g}_{t} \right\rangle = 0 \]

所以 $\boldsymbol{g}_t$ 还是切向的。这个时候 WD 项就很有意思了,$-\eta\lambda\boldsymbol{W}_{t}$ 和 $\boldsymbol{W}_{t}$ 平行,因此它不是切向更新,而是一个非常干净的径向收缩。

所以按照前面的记号,可以直接拆成:

\[ \boldsymbol{\delta}_{r} = - \eta\lambda\boldsymbol{W}_{t} \]
\[ \boldsymbol{\delta}_{u} = - \eta\boldsymbol{g}_{t} \]

也就是:

\[ \boldsymbol{\delta}_{t} = \boldsymbol{\delta}_{r} + \boldsymbol{\delta}_{u} \]

简单描述就是,$\boldsymbol{\delta}_r$ 是 WD 给出的径向收缩,把 $\boldsymbol{W}$ 往原点拉;$\boldsymbol{\delta}_u$ 是任务梯度给出的切向更新,推动 $\boldsymbol{W}$ 在球面上转方向。

接下来看半径怎么变化。由于:

\[ \boldsymbol{W}_t + \boldsymbol{\delta}_r = \left(1-\eta\lambda\right)\boldsymbol{W}_t \]

并且 $\boldsymbol{\delta}_u \perp \boldsymbol{W}_t$,所以交叉项为 0:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 = \left\lVert \left(1-\eta\lambda\right)\boldsymbol{W}_t + \boldsymbol{\delta}_u \right\rVert^2 \]
\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 = \left(1-\eta\lambda\right)^2 \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \left\lVert \boldsymbol{\delta}_{u} \right\rVert^2 \]

代入 $\boldsymbol{\delta}_{u} = -\eta\boldsymbol{g}_t$:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 = \left(1-\eta\lambda\right)^2 \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \eta^2 \left\lVert \boldsymbol{g}_{t} \right\rVert^2 \]

这里和朴素 SGD 的区别就很清楚了。朴素 SGD 没有 $\boldsymbol{\delta}_r$,所以半径只会因为切向更新的二阶项变大;SGD + WD 多了一个 $-\eta\lambda\boldsymbol{W}_t$,它会把半径往回压。于是训练过程里同时存在两个方向相反的东西:WD 带来的径向收缩,和切向运动带来的离心增长。

再代入 scale-invariant gradient:

\[ \boldsymbol{g}_{t} = \frac{1}{\left\lVert \boldsymbol{W}_{t} \right\rVert} \tilde{\boldsymbol{g}}_{t} \]

可以得到:

\[ \left\lVert \boldsymbol{g}_{t} \right\rVert^2 = \frac{ \left\lVert \tilde{\boldsymbol{g}}_{t} \right\rVert^2 }{ \left\lVert \boldsymbol{W}_{t} \right\rVert^2 } \]

所以半径动力学可以写成:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 = \left(1-\eta\lambda\right)^2 \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \frac{ \eta^2 \left\lVert \tilde{\boldsymbol{g}}_{t} \right\rVert^2 }{ \left\lVert \boldsymbol{W}_{t} \right\rVert^2 } \]

这个式子其实已经把 SMD 的图像写出来了。当 $\lVert \boldsymbol{W}_t \rVert$ 很大时,WD 的收缩很明显,而切向梯度项因为分母很大反而不明显;当 $\lVert \boldsymbol{W}_t \rVert$ 很小时,切向梯度带来的离心增长会变强。所以 SGD + WD 不再像朴素 SGD 那样让半径一直长大,而是有机会稳定到某个平衡值。

如果用小步长近似去看半径本身的变化,可以写成:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert - \left\lVert \boldsymbol{W}_{t} \right\rVert \approx - \eta\lambda \left\lVert \boldsymbol{W}_{t} \right\rVert + \frac{ \eta^2 \left\lVert \tilde{\boldsymbol{g}}_{t} \right\rVert^2 }{ 2 \left\lVert \boldsymbol{W}_{t} \right\rVert^3 } \]

前一项是 WD 的向心项,后一项是切向更新带来的离心项。平衡的时候可以认为:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert \approx \left\lVert \boldsymbol{W}_{t} \right\rVert \]

所以:

\[ - \eta\lambda \left\lVert \boldsymbol{W}_{t} \right\rVert + \frac{ \eta^2 \left\lVert \tilde{\boldsymbol{g}}_{t} \right\rVert^2 }{ 2 \left\lVert \boldsymbol{W}_{t} \right\rVert^3 } \approx 0 \]

移项之后得到:

\[ \left\lVert \boldsymbol{W}_{t} \right\rVert^4 \approx \frac{ \eta \left\lVert \tilde{\boldsymbol{g}}_{t} \right\rVert^2 }{ 2\lambda } \]

如果令:

\[ L = \mathbb{E} \left[ \left\lVert \tilde{\boldsymbol{g}}_{t} \right\rVert^2 \right] \]

那么平衡半径就是:

\[ \left\lVert \boldsymbol{W} \right\rVert^* = \left( \frac{L\eta}{2\lambda} \right)^{1/4} \]

这个结论也解释了为什么在有 normalization 的网络里,WD 不能只被理解成“普通正则项”。它实际上在控制 scale-invariant weight 的半径,并且和切向梯度共同决定了半径的 equilibrium。

最后看角更新。角更新主要由切向更新决定:

\[ \Delta_t \approx \frac{ \left\lVert \boldsymbol{\delta}_{u} \right\rVert }{ \left\lVert \boldsymbol{W}_{t} \right\rVert } \]

SGD + WD 里面 $\boldsymbol{\delta}_{u} = -\eta\boldsymbol{g}_t$,所以:

\[ \Delta_t \approx \frac{ \eta \left\lVert \boldsymbol{g}_{t} \right\rVert }{ \left\lVert \boldsymbol{W}_{t} \right\rVert } = \frac{ \eta \left\lVert \tilde{\boldsymbol{g}}_{t} \right\rVert }{ \left\lVert \boldsymbol{W}_{t} \right\rVert^2 } \]

把平衡半径代回去:

\[ \left\lVert \boldsymbol{W}_{t} \right\rVert^2 \approx \left\lVert \tilde{\boldsymbol{g}}_{t} \right\rVert \sqrt{ \frac{\eta}{2\lambda} } \]

因此平衡时:

\[ \Delta^* \approx \sqrt{2\eta\lambda} \]

这就很有意思了:进入 equilibrium 以后,角更新不再显式依赖 $\lVert \tilde{\boldsymbol{g}}_t \rVert$,而主要由学习率 $\eta$ 和 weight decay $\lambda$ 决定。也就是说,SGD + WD 的 SMD 意义不是“SGD 多加了一个惩罚项”,而是 WD 给 scale-invariant weight 加了一个径向控制器,让朴素 SGD 的半径增长变成稳定半径上的球面运动。

用SMD分析Adam

同样,这一层权重是 scale-invariant 的,也就是:

\[ L(\alpha\boldsymbol{W}) = L(\boldsymbol{W}) \]

因此任务梯度满足:

\[ \left\langle \boldsymbol{W}_{t}, \boldsymbol{g}_{t} \right\rangle = 0 \]

这里:

\[ \boldsymbol{g}_{t} = \nabla_{\boldsymbol{W}} L(\boldsymbol{W}_{t}) \]

也就是说,raw gradient $\boldsymbol{g}_t$ 是切向的。但是 更新权重的方向里面加了动量之类的东西,比较复杂。

先忽略 bias correction,Adam 的更新式是:

\[ \boldsymbol{m}_{t} = \beta_1 \boldsymbol{m}_{t-1} + \left(1-\beta_1\right)\boldsymbol{g}_{t} \]
\[ \boldsymbol{v}_{t} = \beta_2 \boldsymbol{v}_{t-1} + \left(1-\beta_2\right)\boldsymbol{g}_{t}^{2} \]
\[ \boldsymbol{q}_{t} = \frac{ \boldsymbol{m}_{t} }{ \sqrt{\boldsymbol{v}_{t}}+\epsilon } \]
\[ \boldsymbol{W}_{t+1} = \boldsymbol{W}_{t} - \eta\boldsymbol{q}_{t} \]

所以总更新量是:

\[ \boldsymbol{\delta}_{t} = \boldsymbol{W}_{t+1} - \boldsymbol{W}_{t} = - \eta\boldsymbol{q}_{t} \]

现在还是按前面的方式,把 $\boldsymbol{\delta}_t$ 拆成径向更新和切向更新:

\[ \boldsymbol{\delta}_{t} = \boldsymbol{\delta}_{r} + \boldsymbol{\delta}_{u} \]

径向分量就是 $\boldsymbol{\delta}_t$ 在 $\boldsymbol{W}_t$ 方向上的投影:

\[ \boldsymbol{\delta}_{r} = \frac{ \left\langle \boldsymbol{\delta}_{t}, \boldsymbol{W}_{t} \right\rangle }{ \left\lVert \boldsymbol{W}_{t} \right\rVert^2 } \boldsymbol{W}_{t} \]

代入 $\boldsymbol{\delta}_t = -\eta\boldsymbol{q}_t$:

\[ \boldsymbol{\delta}_{r} = - \eta \frac{ \left\langle \boldsymbol{q}_{t}, \boldsymbol{W}_{t} \right\rangle }{ \left\lVert \boldsymbol{W}_{t} \right\rVert^2 } \boldsymbol{W}_{t} \]

令:

\[ c_t = \frac{ \left\langle \boldsymbol{q}_{t}, \boldsymbol{W}_{t} \right\rangle }{ \left\lVert \boldsymbol{W}_{t} \right\rVert^2 } \]

那么:

\[ \boldsymbol{\delta}_{r} = - \eta c_t \boldsymbol{W}_{t} \]

切向分量就是剩下的部分:

\[ \boldsymbol{\delta}_{u} = \boldsymbol{\delta}_{t} - \boldsymbol{\delta}_{r} = - \eta \left( \boldsymbol{q}_{t} - c_t\boldsymbol{W}_{t} \right) \]

令:

\[ \boldsymbol{q}_{u,t} = \boldsymbol{q}_{t} - c_t\boldsymbol{W}_{t} \]

那么有:

\[ \boldsymbol{q}_{u,t} \perp \boldsymbol{W}_{t} \]

因此:

\[ \boldsymbol{\delta}_{u} = - \eta\boldsymbol{q}_{u,t} \]

这里就是 Adam 和朴素 SGD 最关键的差别。对于朴素 SGD,$\boldsymbol{\delta}_t = -\eta\boldsymbol{g}_t$,而 $\boldsymbol{g}_t \perp \boldsymbol{W}_t$,所以它是纯切向更新,$\boldsymbol{\delta}_r = \boldsymbol{0}$。

但是 Adam 不一样。虽然 raw gradient $\boldsymbol{g}_t$ 是切向的,Adam 实际使用的却是:

\[ \boldsymbol{q}_{t} = \frac{ \boldsymbol{m}_{t} }{ \sqrt{\boldsymbol{v}_{t}}+\epsilon } \]

动量项和逐元素缩放一般不会保持“和 $\boldsymbol{W}_t$ 正交”这个性质。也就是说:

\[ \left\langle \boldsymbol{g}_{t}, \boldsymbol{W}_{t} \right\rangle = 0 \]

并不能推出:

\[ \left\langle \boldsymbol{q}_{t}, \boldsymbol{W}_{t} \right\rangle = 0 \]

所以 Adam 的更新通常会有:

\[ c_t \neq 0 \]

Adam 会把原本切向的 normalized gradient,经过动量和逐元素预条件之后,变成一个既有切向分量、也可能有径向分量的更新。(真坏啊)

接着看径向动力学。因为:

\[ \boldsymbol{W}_t + \boldsymbol{\delta}_{r} = \left(1-\eta c_t\right)\boldsymbol{W}_t \]

同时 $\boldsymbol{\delta}_u \perp \boldsymbol{W}_t$,所以:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 = \left(1-\eta c_t\right)^2 \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \left\lVert \boldsymbol{\delta}_{u} \right\rVert^2 \]

代入 $\boldsymbol{\delta}_u = -\eta\boldsymbol{q}_{u,t}$:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 = \left(1-\eta c_t\right)^2 \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \eta^2 \left\lVert \boldsymbol{q}_{u,t} \right\rVert^2 \]

小步长近似下:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 - \left\lVert \boldsymbol{W}_{t} \right\rVert^2 \approx - 2\eta c_t \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \eta^2 \left\lVert \boldsymbol{q}_{u,t} \right\rVert^2 \]

这里依然可以看到两部分。第一项来自 $\boldsymbol{\delta}_r$,也就是 Adam 自己产生的径向分量;第二项来自 $\boldsymbol{\delta}_u$,也就是切向更新带来的离心增长。

这个时候 $c_t$ 的符号就很重要了。如果 $c_t > 0$,那么 $\boldsymbol{\delta}_r = -\eta c_t\boldsymbol{W}_t$ 是向内收缩,半径倾向于变小。如果 $c_t < 0$,那么 $\boldsymbol{\delta}_r$ 和 $\boldsymbol{W}_t$ 同向,是向外扩张,半径倾向于变大。如果 $c_t \approx 0$,Adam 近似没有径向力,半径主要由切向离心项撑大。

所以 Adam 的径向动力学被这个量控制:

\[ c_t = \frac{ \left\langle \boldsymbol{q}_{t}, \boldsymbol{W}_{t} \right\rangle }{ \left\lVert \boldsymbol{W}_{t} \right\rVert^2 } \]

它被 moment、二阶矩估计和element-wise缩放共同作用。因此 Adam 的径向力更像是adaptive preconditioner 带来的。

再看角更新。角更新主要由切向分量决定:

\[ \Delta_t \approx \frac{ \left\lVert \boldsymbol{\delta}_{u} \right\rVert }{ \left\lVert \boldsymbol{W}_{t} \right\rVert } \]

Adam 中:

\[ \boldsymbol{\delta}_{u} = - \eta\boldsymbol{q}_{u,t} \]

所以:

\[ \Delta_t \approx \frac{ \eta \left\lVert \boldsymbol{q}_{u,t} \right\rVert }{ \left\lVert \boldsymbol{W}_{t} \right\rVert } \]

这也和 SGD 不一样。SGD 里面因为 scale-invariant gradient 会带来 $\lVert \boldsymbol{g}_t \rVert \propto 1/\lVert \boldsymbol{W}_t \rVert$,所以角更新大概按 $1/\lVert \boldsymbol{W}_t \rVert^2$ 衰减。

Adam 的 $\boldsymbol{q}_t$ 做了类似梯度尺度归一化的操作。如果 $\boldsymbol{g}_t$ 因为权重尺度变大而按 $1/\alpha$ 缩小,那么 $\boldsymbol{m}_t$ 大致也按 $1/\alpha$ 缩小,$\boldsymbol{v}_t$ 大致按 $1/\alpha^2$ 缩小,于是:

\[ \boldsymbol{q}_t = \frac{ \boldsymbol{m}_{t} }{ \sqrt{\boldsymbol{v}_{t}}+\epsilon } \]

在 $\epsilon$ 不主导的时候,$\boldsymbol{q}_t$ 对这种整体尺度变化并不敏感。因此 Adam 的角更新更像是:

\[ \Delta_t^{\mathrm{Adam}} \approx \frac{ \eta \left\lVert \boldsymbol{q}_{u,t} \right\rVert }{ \left\lVert \boldsymbol{W}_{t} \right\rVert } \]

也就是更接近按 $1/\lVert \boldsymbol{W}_t \rVert$ 衰减,而不是 SGD 那种 $1/\lVert \boldsymbol{W}_t \rVert^2$ 衰减。当然这里有一个前提,就是 $\epsilon$ 没有主导分母;如果 $\epsilon$ 主导了,那这个尺度抵消就会变弱。

最后说 equilibrium。朴素 Adam 没有一个很干净的平衡半径公式,因为它的径向项不是固定超参数给出来的,而是:

\[ \boldsymbol{\delta}_{r} = - \eta c_t\boldsymbol{W}_{t} \]

其中 $c_t$ 可能为正,可能为负,也可能随着训练变化。因此 Adam 的平衡条件最多只能先写成形式上的:

\[ - 2\eta c_t \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \eta^2 \left\lVert \boldsymbol{q}_{u,t} \right\rVert^2 \approx 0 \]

也就是:

\[ 2c_t \left\lVert \boldsymbol{W}_{t} \right\rVert^2 \approx \eta \left\lVert \boldsymbol{q}_{u,t} \right\rVert^2 \]

如果额外假设 $c_t > 0$,并且 $c_t$ 和 $\lVert \boldsymbol{q}_{u,t} \rVert^2$ 在局部都比较稳定,那么可以形式上写成:

\[ \left\lVert \boldsymbol{W}_{t} \right\rVert^2 \approx \frac{ \eta \left\lVert \boldsymbol{q}_{u,t} \right\rVert^2 }{ 2c_t } \]

所以关键就是径向力c要和adam update的径向力差不多能抵消,这个没有太确切的解法,感觉应该需要调参。

用SMD分析AdamW

AdamW 可以看成是在 Adam 的自适应更新之外,又显式加了一个 decoupled weight decay。这里依然只分析带 normalization 的 scale-invariant weight:

\[ L(\alpha\boldsymbol{W}) = L(\boldsymbol{W}) \]

因此原始任务梯度还是满足:

\[ \left\langle \boldsymbol{W}_{t}, \boldsymbol{g}_{t} \right\rangle = 0 \]

AdamW 的 moment 部分和 Adam 一样:

\[ \boldsymbol{m}_{t} = \beta_1 \boldsymbol{m}_{t-1} + \left(1-\beta_1\right)\boldsymbol{g}_{t} \]
\[ \boldsymbol{v}_{t} = \beta_2 \boldsymbol{v}_{t-1} + \left(1-\beta_2\right)\boldsymbol{g}_{t}^{2} \]
\[ \boldsymbol{q}_{t} = \frac{ \boldsymbol{m}_{t} }{ \sqrt{\boldsymbol{v}_{t}}+\epsilon } \]

区别在更新式:

\[ \boldsymbol{W}_{t+1} = \boldsymbol{W}_{t} - \eta\boldsymbol{q}_{t} - \eta\lambda\boldsymbol{W}_{t} \]

其中 $-\eta\boldsymbol{q}_t$ 是自适应梯度更新,$-\eta\lambda\boldsymbol{W}_t$ 是 decoupled weight decay。总更新量为:

\[ \boldsymbol{\delta}_{t} = \boldsymbol{W}_{t+1} - \boldsymbol{W}_{t} = - \eta\boldsymbol{q}_{t} - \eta\lambda\boldsymbol{W}_{t} \]

和 Adam 一样,虽然 $\boldsymbol{g}_t \perp \boldsymbol{W}_t$,但是 $\boldsymbol{q}_t$ 不一定还和 $\boldsymbol{W}_t$ 垂直。所以先把 $\boldsymbol{q}_t$ 分解成径向和切向两部分。

定义:

\[ c_t = \frac{ \left\langle \boldsymbol{q}_{t}, \boldsymbol{W}_{t} \right\rangle }{ \left\lVert \boldsymbol{W}_{t} \right\rVert^2 } \]

于是:

\[ \boldsymbol{q}_{t} = c_t\boldsymbol{W}_{t} + \boldsymbol{q}_{u,t} \]

其中:

\[ \boldsymbol{q}_{u,t} = \boldsymbol{q}_{t} - c_t\boldsymbol{W}_{t} \]

并且:

\[ \boldsymbol{q}_{u,t} \perp \boldsymbol{W}_{t} \]

这里 $c_t\boldsymbol{W}_t$ 是 AdamW update 里的径向改变,$\boldsymbol{q}_{u,t}$ 是切向改变。

代回 AdamW 的总更新:

\[ \boldsymbol{\delta}_{t} = - \eta \left( c_t\boldsymbol{W}_{t} + \boldsymbol{q}_{u,t} \right) - \eta\lambda\boldsymbol{W}_{t} \]

整理一下:

\[ \boldsymbol{\delta}_{t} = - \eta \left( \lambda+c_t \right) \boldsymbol{W}_{t} - \eta\boldsymbol{q}_{u,t} \]

所以 AdamW 的 SMD 分解就是:

\[ \boldsymbol{\delta}_{r} = - \eta \left( \lambda+c_t \right) \boldsymbol{W}_{t} \]
\[ \boldsymbol{\delta}_{u} = - \eta\boldsymbol{q}_{u,t} \]

这里很清楚:$-\eta\lambda\boldsymbol{W}_t$ 是显式 weight decay 带来的径向收缩,$-\eta c_t\boldsymbol{W}_t$ 是 AdamW 自适应更新自己产生的径向分量,$-\eta\boldsymbol{q}_{u,t}$ 是切向更新,主要改变权重方向。

接下来看半径动力学。由:

\[ \boldsymbol{W}_{t} + \boldsymbol{\delta}_{r} = \left( 1-\eta\left(\lambda+c_t\right) \right) \boldsymbol{W}_{t} \]

又因为 $\boldsymbol{\delta}_{u} \perp \boldsymbol{W}_{t}$,所以:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 = \left( 1-\eta\left(\lambda+c_t\right) \right)^2 \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \left\lVert \boldsymbol{\delta}_{u} \right\rVert^2 \]

代入 $\boldsymbol{\delta}_{u} = -\eta\boldsymbol{q}_{u,t}$:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 = \left( 1-\eta\left(\lambda+c_t\right) \right)^2 \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \eta^2 \left\lVert \boldsymbol{q}_{u,t} \right\rVert^2 \]

小步长近似下:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 - \left\lVert \boldsymbol{W}_{t} \right\rVert^2 \approx - 2\eta \left( \lambda+c_t \right) \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \eta^2 \left\lVert \boldsymbol{q}_{u,t} \right\rVert^2 \]

这个式子就是 AdamW 的半径动力学。前一项是径向项,后一项是切向更新带来的离心增长。

所以 AdamW 里面真正控制径向行为的是:

\[ \lambda+c_t \]

$\lambda$ 是显式 weight decay,通常提供向内收缩;$c_t$ 来自自适应更新 $\boldsymbol{q}_t$ 对 $\boldsymbol{W}_t$ 的径向投影。如果 $\lambda+c_t > 0$,那么 $\boldsymbol{\delta}_r$ 是向内的,半径倾向于减小。如果 $\lambda+c_t < 0$,那么 $\boldsymbol{\delta}_r$ 是向外的,半径倾向于增大。如果 $\lambda+c_t \approx 0$,那么显式 weight decay 和自适应更新的径向改变大致抵消,半径主要由切向离心项影响。

所以不难看出AdamW 相对于Adam的优势是Adam 只有隐式的 $c_t$,而 AdamW 至少多了一个显式可控的 $\lambda$。但是要注意,AdamW 不是只有 $\lambda$ 这一个径向项,因为 $\boldsymbol{q}_t$ 仍然可能通过 $c_t$ 引入额外径向改变。这里提一个思考,SMD分解的理论完全对吗?AdamW已经被广泛的用于LLM训练,如果存在问题,应该不会一直沿用。

角更新还是主要由切向分量决定:

\[ \Delta_t \approx \frac{ \left\lVert \boldsymbol{\delta}_{u} \right\rVert }{ \left\lVert \boldsymbol{W}_{t} \right\rVert } \]

AdamW 中:

\[ \boldsymbol{\delta}_{u} = - \eta\boldsymbol{q}_{u,t} \]

所以:

\[ \Delta_t \approx \frac{ \eta \left\lVert \boldsymbol{q}_{u,t} \right\rVert }{ \left\lVert \boldsymbol{W}_{t} \right\rVert } \]

如果 $\boldsymbol{q}_t$ 主要是切向的,也就是 $c_t \approx 0$,那么 $\boldsymbol{q}_{u,t} \approx \boldsymbol{q}_t$,此时:

\[ \Delta_t \approx \frac{ \eta \left\lVert \boldsymbol{q}_{t} \right\rVert }{ \left\lVert \boldsymbol{W}_{t} \right\rVert } \]

最后写一下形式上的平衡条件。半径平衡要求:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert \approx \left\lVert \boldsymbol{W}_{t} \right\rVert \]

也就是:

\[ - 2\eta \left( \lambda+c_t \right) \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \eta^2 \left\lVert \boldsymbol{q}_{u,t} \right\rVert^2 \approx 0 \]

约掉一个 $\eta$:

\[ 2 \left( \lambda+c_t \right) \left\lVert \boldsymbol{W}_{t} \right\rVert^2 \approx \eta \left\lVert \boldsymbol{q}_{u,t} \right\rVert^2 \]

因此形式上的平衡半径满足:

\[ \left\lVert \boldsymbol{W}_{t} \right\rVert^2 \approx \frac{ \eta \left\lVert \boldsymbol{q}_{u,t} \right\rVert^2 }{ 2 \left( \lambda+c_t \right) } \]

也就是:

\[ \left\lVert \boldsymbol{W}_{t} \right\rVert \approx \sqrt{ \frac{ \eta \left\lVert \boldsymbol{q}_{u,t} \right\rVert^2 }{ 2 \left( \lambda+c_t \right) } } \]

这个式子成立的前提是 $\lambda+c_t > 0$,并且 $c_t$、$\lVert \boldsymbol{q}_{u,t} \rVert^2$ 在局部训练窗口内比较稳定。

如果考虑一个更理想的情况:AdamW 的自适应更新几乎没有径向改变,也就是 $c_t \approx 0$,那么:

\[ \boldsymbol{\delta}_{r} \approx - \eta\lambda\boldsymbol{W}_{t} \]
\[ \boldsymbol{\delta}_{u} \approx - \eta\boldsymbol{q}_{t} \]

半径动力学近似变成:

\[ \left\lVert \boldsymbol{W}_{t+1} \right\rVert^2 \approx \left( 1-\eta\lambda \right)^2 \left\lVert \boldsymbol{W}_{t} \right\rVert^2 + \eta^2 \left\lVert \boldsymbol{q}_{t} \right\rVert^2 \]

平衡时:

\[ \left\lVert \boldsymbol{W}_{t} \right\rVert^2 \approx \frac{ \eta \left\lVert \boldsymbol{q}_{t} \right\rVert^2 }{ 2\lambda } \]

所以:

\[ \left\lVert \boldsymbol{W}_{t} \right\rVert \approx \sqrt{ \frac{ \eta \left\lVert \boldsymbol{q}_{t} \right\rVert^2 }{ 2\lambda } } \]

此时角更新为:

\[ \Delta_t \approx \frac{ \eta \left\lVert \boldsymbol{q}_{t} \right\rVert }{ \left\lVert \boldsymbol{W}_{t} \right\rVert } \]

代入平衡半径,可以得到:

\[ \Delta^* \approx \sqrt{2\eta\lambda} \]

这个结果和前面的 SMD 直觉是相似的:如果 $c_t \approx 0$ 且 $\lVert \boldsymbol{q}_t \rVert$ 局部稳定,那么 AdamW 也会表现出稳定角更新。但是这个还是无法从理论上保证稳定。

用SMD分析Muon?

挖个坑以后再说。

参考文献

  1. Ruosi Wan, Zhanxing Zhu, Xiangyu Zhang, Jian Sun. Spherical Motion Dynamics: Learning Dynamics of Neural Network with Normalization, Weight Decay, and SGD. arXiv:2006.08419, 2020.
  2. Ruosi Wan, Zhanxing Zhu, Xiangyu Zhang, Jian Sun. Spherical Motion Dynamics: Learning Dynamics of Normalized Neural Network using SGD and Weight Decay. NeurIPS 2021.
  3. OpenReview: Spherical Motion Dynamics: Learning Dynamics of Normalized Neural Network using SGD and Weight Decay.