StableDiffusionのモデルサイズ削減を試してみる(2)
この記事の続きです。
先は長そうですので、とりあえず現時点での進捗です。
実装について
いくつか実装を変更・拡張しました
ResNetも対応してみる
に書いた通り、ResNetのサイズ削減も実装してみました。
Attentionブロックの IN/OUT LinearにLoRA適用
FeedForwardでの誤差を吸収してくれることを期待して、Attentionブロックの出口にあるLinearをLoRAで学習させるようにしました。
また、ResNetの誤差についても Attentionブロックの入り口にあるLinearをLoRA学習させるようにしました。
これらのLinearに対して作用する拡張機能は少ない(LoRAでは学習対象ですが)ため、多分学習させても影響は少ないだろうと判断しました。
timestepによって学習の寄与率を調整する
SDXL Turbo に関する論文 Adversarial Diffusion Distillation で言及採用されている手法で、初期(ノイズ成分が多い)ステップの寄与率を下げます。
「exponential weighting」との名称でしたので、exp(-timestep/1000) をウェイトとして採用しました。
(ウェイトの合計が1となるようにスケーリングしてロスに乗算)
UNet全体で学習させず、入れ替えたブロックを含む区間で学習
今回は1280チャンネルとなっているブロックを対象としていますので、
教師モデルでの推論時に IN7の入力とOUT5の出力をキャッシュしておき、
生徒モデルの学習時は IN7からOUT5までだけ通します。
これにより、学習時に浅い層のブロックを通す必要がないため、学習を高速化できました。
※IN1, IN2, OUT9, OUT10, OUT11はサンプル数が多い状態なので、ここでの Self-Attention処理は結構重いということですね。
※SD1.5のUNet構造については https://note.com/gcem156/n/nf2672cd16a9d や https://hoshikat.hatenablog.com/entry/2023/03/31/022605 がわかりやすいです
圧縮後の次元とロスを調査してみる
前回、1/4以下と大幅に削減してみたわけですが、削減後の次元によってどの程度ロスが生じるかを確認してみます。
以下、ブロックごとに200ステップ学習させたときのロスです。
・dim640 : 1/2 の次元に圧縮
・dim320: 1/4 の次元に圧縮
・dim160: 1/8 の次元に圧縮
FeedForward
圧縮後の次元数での差は少ないようです。
MID以外のブロックは ロスが1未満に落ち着く感じです。
MIDブロックだけは次元にかかわらずロスが暴れており、次元圧縮は厳しそうです。
ResNet
ResNetについては、 1/8 まで削減するとロスが大きい感じですが、1/4と1/2ではそこまで差はないように見えます。
また、ブロックによっては 200Step時点でもロスが減少傾向にあるため、さらに学習させてもよさそうです。
MID, OUT0, OUT1, OUT3, OUT4 などロスが 1以上と比較的高いブロックもあるため、MIDブロック付近のResNetは次元圧縮が厳しそうです。
対象ブロックを絞る
前述の調査で、ロスが高い状態で収束しそうなブロック(MID層全て、OUT0 ResNet, OUT1 ResNet, OUT3 ResNet, OUT4 ResNet)を次元圧縮の対象から外すようにしてみます。
また、IN11, IN12のResNet についてはロスはそこまで悪くないですが、
MID層に近く誤差が大きな影響を与えそうだったため、これも対象から外してみます。
また、次元(チャンネル数)は 1/4 程度までは削減できそうな感じ
(というか、1/2にしても精度は上がらなそう)でしたので、1/4を目安に削減してみます
なお、これらを適用すると20%程度のモデルサイズ削減……
SD1.5については実用性は薄そうですね。