見出し画像

StableDiffusionのメモリ消費量を減らしたい

僕が普段利用しているマシンは RTX3070 mobile搭載のノートPCで、VRAMが 8GBと機械学習には厳しいので、メモリ的な効率も無視できない要素です。

ここ最近になってようやくソースコード読むようになって、モデル構造もある程度理解できてきましたので、メモリ消費量削減を目指してみようかと考えています。


そう簡単には削減できなそう

機械学習では、学習後にモデルサイズを削減するのはわりと一般的で、手法についても色々と提案されているみたいです。
ただ、StableDiffusionの場合は一般的な手法がちょっと使いづらいんですよね……

AI民主化とかそんな標語でしたっけ?

まあ、標語は何でも良いんです。
ソースコードにアクセスもできるし、HuggingfaceのDiffusersライブラリで簡単にプログラム組むこともできる。
WebUIとかのツールならソースコード書かなくても使えるし、追加学習もスクリプト類が充実してる。

便利なツールを作る人、生成ロジックに手を加える人、独自データセットでモデルをゼロから学習する人、追加学習する人、複数モデルをマージする人、それらを使って画像生成する人。

多くの人が関わっていながら、大きな分断なしでコミュニティが成立しているのは、ちょっと奇跡的かなぁとか思ってます。
(まあ、SD1.5とSDXLで分断気味かもしれませんが)

互換性を無視できない

で、StableDiffusionに関しても 派生モデルやLoRAが活発に開発されてますので、このあたりが使えなくなるような改造は、ちょっと使いづらいです。
(研究、というか実験としては意義があるんですが)

そんなわけで

  • 既存のLoRAが適用できる

    • LoRA適用時点では、モデル構造・サイズはそのまま

    • LoRA適用後にサイズ変更は多分大丈夫

  • ControlNet等のメジャーな拡張が適用できる

    • ブロックの入出力チャンネル数は変更しない

    • CrossAttentionには触らない
      (CrossAttentionに作用する拡張が多そうなので)

を目標としつつ、モデルサイズ削減できるか模索していきます。
あと、できたら重たい学習とかは避けたい

モデルサイズの内訳

SD1.5とSDXLにおける UNet パラメータ数をざっくり計算してみました。
(GroupNorm, Biasなどは無視した近似ですが)

SD1.5			params	
Resnet	 		537,670,400 	62.7%
Transformer		267,239,360 	31.2%
Downsample		 19,355,840 	 2.3%
Upsample		 33,180,800 	 3.9%
		
Total	 		857,446,400 	
(FeedForward)	148,684,800 	17.3%
SDXL			params	
Resnet			   327,782,400 	12.8%
Transformer		 2,197,913,600 	85.7%
Downsample		    19,353,600 	 0.8%
Upsample		    18,432,000 	 0.7%
		
Total			 2,563,481,600 	
(FeedForward)	 1,228,800,000 	47.9%

SDXLは Transformer関連が 85% 超えています。
またTransformerブロックはFeedForward部分が大きい(半分くらい)ので、FeedForward部のサイズ削減ができれば有用かもしれません。

SD1.5の方では、ResNetの比率が大きいです。
これは Midブロック周り(IN11からOUT2)に ResNetが7ブロックもあるのでまあ仕方が無いですね。
ここもサイズ削減できたら良いなぁという感じです。

FeedForwardのサイズ削減計画

FeedForwardは、CrossAttentionの後などに良い具合にチャンネルを混ぜ合わせるような働きをしている?ようなのですが、隠れ層のチャンネル数が4倍とかなり大きいため、Linearのサイズがかなり大きくなっています。
(Self-attentionの3倍くらい)

幸いなことに FeedForwardに介入するタイプの拡張は少ないので、LoRAだけ考えれば多分問題無いです。

問題は途中で活性化関数(gelu)が入るので、単純にPCAなどで次元圧縮すると誤差が大きくて駄目そうかなぁという感じです。
PCAで次元圧縮するような行列を初期値として学習させるのが良いかもしれません。

こんな感じにチャンネルを減らす/戻す Linearを挿入すれば、内部のチャンネル数を減らせそうです。
下側ルートにgeluがなければ、右側の Linese(dim→4c) の(疑似)逆行列を
左側2つのLinear(4c→dim) に設定すれば行けそうなんですが、geluが入っているため逆行列では駄目そうです。
元モデルを教師として学習させてなんとかします。

学習後は画像生成時(LoRA適用後)に、元モデルのLinearと結合させれば、
c→4c、2c→c のLinearになるので、VRAMに配置する重みは半分程度になる見込みです。

……実験してみないと何とも言えないですね。

ResNetのサイズ削減計画

ResNetもちょいちょい活性化関数(silu)が入るので、PCAなどで次元圧縮するだけでは駄目そうですね。

こんな感じにチャンネルを減らす/戻す Linearを挿入すれば、内部のチャンネル数を減らせそうです。まあ、FeedForwardと時とほぼ同じ発想ですね。
画像生成時には Linearとconvを結合したいので、右側のLinearは siluとconvの間に配置します。

これについてもsiluが挟まってるので、疑似逆行列では駄目かもしれません。
左側2つのLinear(c→dim)はその後加算するだけなので、パラメータは共有で良いかもしれません。
まあ、共有でも独立でもストレージサイズが変化するだけでVRAM消費量は変わらないので、ここはどうでもよいかもしれません。

結局学習は避けられなそうです

PCAだとかで次元圧縮すれば楽できるんじゃないかとか考えてたんですが、
モデル構造見直したらやっぱり学習させないと駄目そうです。

手法としては BERT-of-Theseus に近い感じかもしれません。
教師モデルの1ブロックを学習先モデルとしてコピー(とパラメータ削減)
元モデルのブロックの入出力で学習先ブロックを学習。
程々に進んだら、学習先モデルで元モデルを置き換えるという感じでしょうか。

いきなりSDXLで適用するのはサイズ的に厳しそうなので、まずはSD1.5で実験してみてから。という感じになりそうです。

いいなと思ったら応援しよう!