理解できていないけど、作ったからには説明してみる
というわけで、作ったカスタムノードを解説していきたいなぁと思ってたんですが……
解説できるほど理解できていない
作っておいてこういうのもアレなのですが、どうしてこんな挙動になるのかよくわかってません。
なので、なんとなくこんな感じかなぁという想像での備考になります。
パラメータとその影響
Depth decay
SlothfulAttention(以下SA)は Self-attentionの K, V を削減しているため、attentionの精度が落ちていると考えています。
この影響か、基本的にはどこかしら破綻しやすくなります。
特にt2iの場合は構図・ポーズにもがっつり影響してしまうのですが、depth_decayを上げる(深い層への適用強度を減らす)ことで、ある程度は軽減できるようです。
逆にi2iの場合は構図が決まっているので depth_decayを下げても破綻しづらいみたいです。
Time decay
time_decay=0(全step同じ強度)で適用するとわかりやすいんですが、SA適用すると細部が壊れた状態の絵が出力されやすいです
で、序盤のStepで描画に介入して後のStepで素のu-netに修正してもらうという目的で time_decayを導入してます。
出力結果がぼやけていたりディテールが破綻している場合は、time_decayを上げることで改善することがあります。
(SA未適用の絵に近づくということでもありますが)
i2iや 2段目以降のサンプラーなど TimeStepが途中から始まるようなケースでは、SAの効果が弱いため、strength を上げるか、time_decayを下げるなどの調整が必要な感じでした。
In/Outブロック
Inブロックについては、その領域(vaeで圧縮するので最低でも8x8の領域)が何であるのか判別する役割があるのか、SAの強度を上げると髪の毛とアクセサリが融合してしまうなどの影響が確認できました。
また、Inブロックの深い部分では構図/ポーズへの影響が強いようです。
Outブロックはどのように描くかを決めているようで、SAの影響が強めに出るようです。
このあたりは階層マージとかやっている人の方が詳しそうですね。
ModeとBlend
言葉で説明するのが辛かったので疑似コードで書きますが、K, V の削減は以下のような処理になります(Avgの場合)
stride = strength * time_ratio * depth_ratio
one = k.avg_pool(size=1, stride=stride)
avg = k.avg_pool(size=stride, stride=stride)
k = one * (1 - k_blend) + avg * k_blend
v = one * (1 - v_blend) + avg * v_blend
Q, K, V はデータとしてはテンソルですが、 このあとSelf-attentionの処理で類似度計算としてベクトルの内積取るので、ベクトルの配列と考えてみます。
(Vも後続のCross-attentionのQに影響するのでベクトルと考えてみます)
oneはsizeが1ですので、avg_pool でも max_pool でも stride ごとに1サンプル取るという操作になります。
これは元のベクトルの一部を抜き取る処理なので、(捨てられるベクトルはあるけど)ベクトル自体は変化させていません。
なので、blend=0で one だけを利用する場合は、出力画像への影響は控えめになるみたいです。
(まあ、Spatial-Reducion Attention が成立しているので、この操作なら影響は少ないんだと思います)
Avg
AvgPooling は 範囲内の値の平均を取っているわけですが、ベクトルとしては size本のベクトルの(線形)平均を取っているという事になるかと思います。
(ベクトルの正規化していないので、内積取る際に問題出るんじゃないかとも思いましたが、元々 to_q, to_k 後に正規化してなさそうなので、ここはアバウトでもなんとかなるんじゃないかと想像)
Vとして使うときは、平均値取っているだけにぼやっとした感じの出力になる感じです
Max
MaxPoolingについては、ベクトルの要素ごとに最大値を選択して新しいベクトルを作るという操作なので、正直うまくいかないだろうなぁと思って試したところ、思いのほか絵が壊れなかったので採用してみました。
Vとして使うと、コントラストや彩度が高くなりやすい感じですが、strength/blend を上げると絵が壊れます
k_blend
Kについては、低いBlend率では AvgでもMaxでも大きな差はなさそうです。 元々のベクトル(one)からどのくらい離れているか程度しか差はないのかもしれません。
Maxで高いk_blend値にすると、絵が壊れやすいです。 さすがに元ベクトルから遠すぎるとダメなのかもしれません
v_blend
Vについては、modeの違いが顕著に出るみたいです。
また、k_blendとv_blendの値が大きく違うと絵が破綻しやすいようですが、値によっては良い具合になったりするのでなかなか難しいところです。
その他の備考
K ≠ V で大丈夫なのか
HyperNetworkでは K, V を別のネットワーク通して変化させているみたいですので、K ≠ Vとなるのは問題なさそうです。
K は AvgPooling、 V は MaxPooling といった組み合わせで実験してみたのですが、かなり制御しづらかった(絵が安定しない)ので 異なるプーリング方法を混ぜるのは廃案にしました。
元モデルが K = V でトレーニングされているので、 K, V の差が大きくなると破綻しやすくなるんじゃないかと考えてます。
処理速度
SAを弱めにかけたときにむしろ処理速度が落ちてしまうことがあります。
TokenMergingも画像サイズが大きくないと高速化の恩恵が受けられないとか報告があるので、 思っている以上にSelf-attentionの処理時間は支配的ではないのかもしれません。
単に実装がイマイチなので速度出ていないだけかもしれませんが。