
RWKVの学習の安定性について
はじめに
RWKVに関して二つのブログを書いてきました(その1,その2)が,追加で学習(勾配)の安定性についても日本語で書いておこうと思います.というのも,RWKV-4の学習では以下の画像のように,LLMでありがちなLoss Spikeが確認されなかったという実験結果があります.

これがなぜそうなのか,という根拠を若干不十分な説明でありますが,Appendix Hに書いてあったので,文脈を補いながら解説します.
前準備:勾配と学習の安定性
RWKVにおいて,TransformerのAttentionに相当する機構であるwkvは以下の数式でした.
$$
wkv_{t} = \frac{\sum\limits_{i=1}^{t-1} e^{-(t-1-i)w+k_{i}}v_{i} + e^{u+k_{t}}v_{t}}{\sum\limits_{i=1}^{t-1} e^{-(t-1-i)w+k_{i}} + e^{u+k_{t}}}
(1)
$$
ここで,この式を簡略化するために,valueは$${v_t = W_v x_t}$$,指数関数部分は$${K_t^e = e^{W_k x_t + w_{T,t}}}$$と置き換えます.また,トークンは最終トークン$${T}$$を見ることとします.すると,wkvは次のようになります.Eは平均,Sは重み付きの総和を表す関数です.wkvは$${v_t}$$の平均とも見れ,その分母と分子を,重み$${K_t^e}$$の総和と,重みつき$${v_t}$$の総和とも見ることができます.
$$
wkv_T =\frac{\sum_{t=1}^{T} e^{W_k x_t + w_{T,t}} v_t}{\sum_{t=1}^{T} e^{W_k x_t + w_{T,t}}}= \frac{\sum_{t=1}^{T} K_t^e v_t}{\sum_{t=1}^{T} K_t^e} = E(v_t) = \frac{S(v_t)}{S(1)} (2)
$$
wkv層を出ると,以降の層$${f(wkv_t)}$$に入り,それが正解ラベル$${y_t}$$との誤差が評価されます.今回,最終トークン$${T}$$の損失を見るとすると,以下の式が得られます.
$$
L_T=l(f(wkv_T),y_T) (3)
$$
機械学習における学習は,勾配の逆向きの更新により行われます(勾配降下法).従って,RWKVにおいて損失を減らす向きに,ある隠れ層$${a}$$のパラメータ$${(W_a)_{i,j}}$$を更新をすると,以下のようになります.
$$
{ (W_a)_{i,j}} ← { (W_a)_{i,j}} - γ\frac{\partial L_T}{\partial (W_a)_{i,j}} (4)
$$
発散しない勾配
式(4)を元に,$${W_v}$$の更新がどのような上限が加わるかを見ます.
wkv層に対して合成関数の偏微分の連鎖律を考えると,以下のように損失は損失→wkv層→valueと,経由できます.
$$
\frac{\partial L_T}{\partial (W_v)_{i,j}} = \frac{\partial L_T}{\partial (wk v_T)_i} \cdot \frac{\partial (wk v_T)_i}{\partial (W_v)_{i,j}} (5)
$$
連鎖律はRWKVのアーキテクチャ図を見ると理解しやすいと思います.

勾配の発散は正負関係無しに発生するため,式(5)のwkv層から伝わる勾配に絶対値をとります.
$$
\frac{\partial L_T}{\partial (W_v)_{i,j}} = \frac{\partial L_T}{\partial (wk v_T)_i} \cdot \frac{\partial (wk v_T)_i}{\partial (W_v)_{i,j}}→\frac{\partial L_T}{\partial (wk v_T)_i} \cdot \left|\frac{\partial (wkv_T)_i}{\partial (W_v)_{i,j}}\right| (6)
$$
この絶対値をさらに見ていくと,式(2)を用いて
$$
\left|\frac{\partial (wkv_T)_i}{\partial (W_v)_{i,j}}\right|=\left|\frac{\partial \left( \frac{\sum_{t=1}^{T} K_t^e v_t}{\sum_{t=1}^{T} K_t^e} \right)_i}{\partial (W_v)_{i,j}}\right| (7)
$$
式(7)の$${wkv_T}$$に式(2)を代入し,最初に書いたように$${v_t=W_vx_t}$$であることを踏まえると
$$
\left|\frac{\partial (wkv_T)_i}{\partial (W_v)_{i,j}}\right|=\left| \frac{\partial E_i[(v_t)_i]}{\partial (W_v)_{i,j}} \right| = \left| \frac{\partial E_i[( W_v x_t)_i]}{\partial (W_v)_{i,j}} \right| = \left| \frac{ E_i[(\partial W_v x_t)_i]}{\partial (W_v)_{i,j}} \right|=\left| E[(x_t)_j]\right| (8)
$$
式(8)を見たらわかる通り,入力$${x_t}$$の平均の絶対値となっています.平均の絶対値は,その要素の最大値の絶対値を超えない ことから,以下の不等式が成り立ちます.
$$
\left|\frac{\partial (wkv_T)_i}{\partial (W_v)_{i,j}}\right|=\left| E[(x_t)_j]\right| = \left| \frac{x_{1j} + \ldots + x_{tj} + \ldots + x_{Tj}}{T} \right| \leq \max_t \left| (x_t)_j \right| (9)
$$
つまり,wkv層における$${W_v}$$を更新する勾配は,入力列の最大要素を超えない,という上限がつくわけです.
加えて,式(9)の上限に$${T}$$が出てこないため,系列の長さによって勾配の上限は制限されない,系列長により勾配が消失がしずらいとわかります.
消失しない勾配
式(4)から,同様に$${W_k}$$の更新にどのような下限が加わるかを見ます.
$$
\frac{\partial L_T}{\partial (W_k)_{i,j}} = \frac{\partial L_T}{\partial (wk v_T)_i} \cdot \frac{\partial (wk v_T)_i}{\partial (W_k)_{i,j}} (10)
$$
式(10)のwkvに式(2)を代入すると
$$
\frac{\partial (wk v_T)_i}{\partial (W_k)_{i,j}}= \frac{\partial S_i[(v_t)_i]}{\partial (W_k){i,j}} \cdot \frac{S_i(1)}{\partial (W_k){i,j}}=\frac{\partial}{\partial (W_k){i,j}} \left( \frac{\sum_{t=1}^T (K^t)_i(v_t)_i}{S_i(1)} \right) (11)
$$
商の偏微分公式$${\frac{∂}{∂x} \left( \frac{f(x)}{g(x)} \right) = \frac{f'(x) \cdot g(x) - f(x) \cdot g'(x)}{g(x)^2}}$$から
$$
= \frac{\left( \frac{\partial}{\partial (W_k)_{i,j}} \sum_{t=1}^T (K^t)_i(v_t)_i \right) \cdot S_i(1) - \left( \sum_{t=1}^T (K^t)_i(v_t)_i \right) \cdot \frac{\partial}{\partial (W_k)_{i,j}} S_i(1)}{S_i(1)^2} (12)
$$
部分ごとの偏微分を計算すると,最初に簡素化した式を参照して,
$${\frac{\partial (v_t)_i}{\partial (W_v)_{i,j}} = (x_t)_j}$$
$${\frac{\partial ({K}_t^e)_i}{\partial ({W}k){i,j}} = (x_t)_j ({K}_t^e)_i}$$.
これらを式(12)に戻すと
$${= \frac{(\sum_{t=1}^T (x_t)_j (K^t)_i(v_t)_i) \cdot (\sum_{t=1}^T (K^t)_i(v_t)_i) - (\sum_{t=1}^T (x_t)_j (K^t)_i)}{S_i(1)^2} \\= \frac{S_i[(x_t)_j(v_t)_i]}{S_i(1)} - \frac{S_i[(x_t)_j]S_i[(v_t)_i]}{S_i(1)^2} \\= E_i[(x_t)_j(v_t)_i] - E_i[(x_t)_j]E_i[(v_t)_i] }$$
これは共分散の式であるため,
$${= \text{cov}_i((x_t)_j, (v_t)_i) (13)}$$
(2)式よりwkv層の計算において,$${x_t}$$と$${v_t}$$の共分散は0にはならないことから,
$$
\frac{\partial (wk v_T)_i}{\partial (W_k)_{i,j}} = \text{cov}_i((x_t)_j, (v_t)_i) ≠0 (14)
$$
よって,wkv層における$${W_k}$$を更新する勾配は0にならない,消失しないとなります.
まとめ
以上,RWKVを構成するwkvのパラメータは勾配と発散,両者の対策がなされていることが数式からわかります.故に最初に示した画像のように,loss spikeが発生しなかったことが説明できる,とされています.
誤り等あれば,コメント or @gojiteji まで連絡お願いします🙇.
参考文献
Bo Peng, et al. 2023. RWKV: Reinventing RNNs for the transformer era.
https://arxiv.org/abs/2305.13048The RWKV Language Model (and my LM tricks) - BlinkDL/RWKV-LM GitHub
https://github.com/BlinkDL/RWKV-LM
関連ブログ
RWKVのRNN mode(推論時の計算方法)→GPT mode(学習時の計算方法)の計算:漸化式から一般項の導出
https://blog.gojiteji.com/2023/10/19/rwkv-gptmode-and-rnnmode/Transformer→Attention Free Transformer→RWKVへ式変形の気持ち
https://blog.gojiteji.com/2023/11/14/rwkv-and-aft/