学習不要のUNet改変法まとめ
最近になって、UNetの構造をいじって生成速度をあげたり、品質を高める方法が色々提案されたのでまとめてみます。
UNetについて
過去に記事をあげてたと思います。UNetはIN層・MID層・OUT層に分かれます。downとかupとか呼ばれることもありますが、今回はIN/MID/OUTで統一します。IN層では入力がどんどん縮小されていって、OUT層では元のサイズに戻っていきます。UNetは一直線ではなくIN層からOUT層へのSkip connectionがあることが特徴です。ディープラーニングでは、一般的に入力に近い層を浅い層、出力に近い層を深い層と言いますが、UNetはその構造上入出力に近い層を浅い層、真ん中を深い層と表現します。
生成速度の向上法
Token Merging
これは前に記事を作った気がします。元々は画像のクラス分類に利用するViTモデルに対する手法ですが、それをStable Diffusionに応用したものです。簡単に言うとピクセル間の似た要素をマージすることで計算量を減らします。UNetのSelf Attentionは画像サイズの2乗の計算量になるので、ピクセル数を減らすと大きく計算量を削減できます。一番計算量の大きいUNetの浅い層(画像が縮小されていない層)に適用するのが最も効率がいいっぽいです。そのため深い層ほどAttention層が大きくなっていくSDXLでは、あまり恩恵を受けられないです。また効果が画像サイズに依存するため、高解像度の生成をするときに効果が大きいです。
HyperTile
https://github.com/tfernd/HyperTile
HyperTileはSelf Attentionをタイル分けして複数回に分けて計算します。Self Attentionの計算量オーダーは画像サイズの2乗になるので、たとえば4分割すれば、計算量は4分の1になります($${(4n)^2\to 4n^2}$$)。Token Mergingはトークンの類似度を計算する必要がありますが、こちらはそういうのは必要ないです。ただしこのままだと各タイルを個別に生成するような形になってしまうため、分割サイズをレイヤーごとにランダムにしてるようですね。Token Mergingと同じくUNetの浅い層に適用するのが基本的な使い方のようです。
DeepCache
Deep Cacheは、近い時刻で中間出力を使いまわすことで、計算時間を削減します。入力側から直接出力側につながるskip connectionがあるUNetならではの手法ですね。生成時、近いステップ同士では中間の出力は結構類似しているようなので、こんなことができるそうです。上の二つの手法はUNetの浅い層の計算時間を削減する方法でしたが、こちらは深い層を削減するのでSDXLとの相性もいいかもしれません。
ステップ数が大きくないと使えない方法なので、LCM-LoRA等の低ステップ生成方法と一緒に使えないのが欠点ですかね。
高解像度生成法
学習データに比べて著しく大きい画像を生成しようとすると、ぐちゃぐちゃな画像が生成されることは有名ですが、それを何とかしようとする方法です。Hires fixじゃだめなのとかいってはいけない。
ScaleCrafter
この手法は、高解像度生成がうまくいかない原因は畳み込み層の受容野の違いが原因と考えて改良を加えています。畳み込み層のカーネルが3×3のままなのに、入力が大きくなるとおかしなことになるというわけですね。
一つ目の方法として、膨張畳み込みを使うことがあげられています。膨張畳み込みは、まあ下みたいな感じです。
カーネルの各マスを1マスとか2マスずつとか開けることで、受容野を2倍か3倍にできます。見ての通りマスの数は変わらないので、重みは変更無しです。しかしそのままでは整数倍にしか対応できません。有理数倍に対応するためには、入力を一度拡大して膨張畳み込みを適用した後、元のサイズに縮小します。
たとえば1.5倍にしたい場合、2倍の膨張畳み込みを2/1.5倍に拡大した画像に対して適用します。
もう一つの方法として、カーネルをいい感じの線形変換で大きくする方法が紹介されています。たとえば3×3カーネルを5×5カーネルに拡張するためには、(25,9)行列が必要です。どんな線形変換がいいかというと、
元の畳み込みを適用して、出力を拡大したもの
入力を拡大したあと、変換後の畳み込みを適用した出力
この二つが同じような結果になっていればよさそうなので、そんな線形変換を求めます。この辺の話よく分かってないので、よく分からないですがこのままだと解は不定になるので、以下の最小二乗問題を解くそうです。
k'=Rkとします。解は入力や元のカーネルの具体的な値によらない(サイズのみに依存する)らしいので、一度Rを計算すればあとは全層のカーネルにRを適用して、パディングとかを適切に設定するだけみたいです。コードがMATLABなんでお手上げです。
Deep Shrink Hires fix
https://gist.github.com/kohya-ss/3f774da220df102548093a7abc8538ed
不世出の天才、人民の太陽、常勝無敗の大元帥Kohyaさんが考えた方法です。高解像度生成による構図の崩壊を防ぐためには、構図に影響する部分だけ低解像度にすればいいじゃないかというアイデアです。構図に影響する場所は、経験的にノイズの大きい時刻(生成の初期段階)かつUNetの深い層(真ん中の層)ということが分かっています。そのためノイズの大きい時刻で、Down側のある層で縮小、Up側の対応する層で元のサイズに戻します。
ここからは特に根拠のない妄想ですが、この手法を使うと吹雪が降ったり火花が舞ったりします。この原因として、UNetのある区間を縮小することで、ノイズ予測がぼやけてノイズを除去しきれないからなのかなと思いました。ancestralやSDEが付くサンプラーは初期ノイズのみに依存しないため、初期ノイズの特定パターンが生成画像に現れにくくなり改善される気がします。
生成速度向上+高解像度生成
HiDiffusion
これはほとんどDeep shrink+HyperTileです。
Deep shrinkと異なる点は拡大や縮小をUNetのUpsample層やDownsample層の倍率を変えることで実装することです。こっちの方が正確になりそうな気もしますが、適用位置が特定の深さに限定されたり、倍率もDownsample層のdilationやstrideの調整によって行うため任意の倍率で使えなかったりと設定の自由度は低そうです。
HyperTileとの違いはSwin Transformerから着想を得て時刻ごとにタイルをシフトする戦略っぽいです。詳しく書いてなかったんであんまり分かりませんでした。
品質向上法
CD-Tuner
https://github.com/hako-mikan/sd-webui-cd-tuner
ColorやDetailを変更する方法です。全部は理解してないんですが、Detailに関してはUNetの重みを調整する方法なので紹介してみます。簡単に言うとDetailの1はUNetの最初の畳み込み層のweightを小さくし、biasを大きくします。Detail2は最後のGroupNorm層のweightを小さくし、biasを大きくします。これでディテールが上がるらしい。なんでかよく分かりませんが、入力や出力の分散を小さくすることで、予測ノイズの分散が小さくなり、生成画像がノイジーになる(=書き込みが増える)ということなのかな・・・?
FreeU
backbone(前層からの入力)とskip connectionを人為的にいじくることで生成画像の品質を高める方法です。backboneは低周波成分(構図など画像の大まかな要素)を、skip connectionは高周波成分(細部など)を担当するらしいです。backboneは画像が縮小されている深い層からの入力で、skip connectionは画像が縮小されていない浅い層からの入力なので、そうなるだろうという感じはしますね。backboneを強めて、skip connectionの低周波成分を弱めることでいい感じになるらしい。といってもbackbone側の強めるチャンネルを限定するなど根拠の分からない操作も多いです。品質が向上するかどうかはともかく、生成画像は確かに変化するので試してみると面白いです。