Würstchenアーキテクチャを利用したStable Cascadeの仕組みを理解する
はじめに
久しぶりの知見を深めようシリーズです。(?)
2024/2/13に、Stability AI社がStable Cascadeを発表しました。
その新モデルに採用されているWürstchenアーキテクチャの論文を読んでみて、理解を深めようというのが今回の記事です。
Würstchenアーキテクチャの概要とStable Cascadeについて
Würstchenとは、ドイツ語でソーセージを意味します。
キンキンに冷えたビールに合いそうな名前で、まさにビールとソーセージが有名なドイツに相応しい名前です。
このアーキテクチャは比較的新しいモデルで、Stage A~Cの3段階から構成されています。
上図の通り、このモデルも拡散モデルを採用しておりますが、Stable Diffusionと大きく異なる点はこの3段階構成に加え、Stable Diffusion v1.xでは64x64であった潜在空間モデルに対して、1/2以下である24x24という極めて小さなサイズに変換されることが特徴です。
この3段階構造では、以下の処理を行います。
Stage A(Latent Decoder)
デコーダ部分であり、VAEを利用して復号処理を行います。Stage B(Latent Decoder)
拡散モデルを使用して、小さい潜在空間からアップスケールを行います。Stage C(Latent Generator)
従来でいうところのText Encoder部分になります。
主にユーザが取り扱うのはこの部分です。(後述)
学習過程はStage Aから逆復号で学習されます。
逆復号で学習するのは、Stable Diffusionモデルでも同様で、他の拡散モデルであれば一様に同じです。
中でも特筆すべき点はStage Cです。
各工程が切り離されたことで、ユーザはモデルへの追加学習やLoRA学習を行う際、Stage Cのみを扱えば良くなります。
その為、従来よりも学習コストが低減し、かつ簡略的に扱えるようになることで、Stability AI社では「一般消費者向けハードウェアでのトレーニングと微調整が簡単にできる」と発表しています。
これはあくまでもファインチューニング等の追加学習に関する点であり、事前学習などの大規模学習に関しては、一般家庭で行うには依然ハードルが高いことに変わりはありません。
また、Stage B・Cには2種類の異なるパラメータが用意されております。
Stage Bでは700M(7億)と1.5B(15億)、Stage Cでは1B(10億)と3.6B(36億)となっており、感覚的にも分かると思いますが、パラメータが多いほうが高品質になる反面、それ相応のマシンスペックが必要になってくると思われます。
Stable Cascadeをリリースにするにあたって、Stability AI社は付加機能であるControlNetに加え、追加学習やLoRAに関するトレーニング用のコードを公開予定とのことです。
Würstchenアーキテクチャの各Stageの仕組みについて
前述した通り、Würstchenアーキテクチャには3つの独立したStageがあり、それらをパイプラインで結合することで実装しているものになります。
Stage A・BはLatent Decoder部分であると前述しました。
ここでは、主にデコード処理を扱いますが、このアプローチは以下を参考にしているようです。
ただ、高い係数を利用するとデコード品質が劣化してしまってあまり良くないので、ここではf4を係数としてVQGAN(ベクトル量子化敵対的生成ネットワーク)を使用しているようです。
以下で量子化した潜在値を求めます。
$$
f_θ ^−1 (f_θ (X)) = f_θ ^−1
(Xq ) ≈ X
$$
上記までが、かなりざっくりとした説明ですが、Stage A部分でStage Bへ引き渡す際は、この量子化部分は取り除かれ、条件付き潜在拡散モデル(Conditioned LDM)として学習されるようです。
めっちゃ端的に説明すると、ベースはVQGANを利用したアーキテクチャのようです。
多分ここの部分(Stage A・B)は、普段利用する際に弄ることはあまりないと思います。
最後のStage Cでは、最終的に1/42に圧縮され、テキスト条件付きの潜在拡散モデルとして学習するようです。
Stage Cでは、16個のDown SamplingなしのConvNeXt Blockで構成され、各ブロックの後にテキストなどの条件付けがCross Attentionで処理されます。
U-Netモデルとはかなり構成が違いますが、同じ拡散モデルということもあり、Stable Diffusionモデルと似ている部分があるので、チューニングする際でもあまり迷いはしなさそうですね。
Paperでは、最終的な圧縮率が42倍を超えると品質に悪影響を及ぼすと書かれているので、これ以上の潜在空間サイズは厳しいのでしょうか?
かなりざっくりと足早に進めてきましたが、こう見るとStable Diffusionとは大きく異なるアーキテクチャモデルであるのが分かります。
Stable Diffusionモデルとの比較
Stability AI社では、プロンプトアライメントと美的クオリティに対して、従来のStable Diffusionモデルよりも上回ると発表しています。
AIモデルにおけるアライメント(AIアライメント)とは、AIに人間の価値観や倫理観、目標を埋め込み、可能な限り有用で安全かつ信頼できるものにするプロセスのことを指しますが、ここでいうプロンプトアライメントとは、推論結果と指定したプロンプトの合致性を指しています。
ほぼ全ての画像生成モデルより、合致性も品質も上回っていますね。
また、Stable Cascadeは推論速度も大きく改善しています。
ただこの結果については、詳しいことが書かれていないので一概に言えません。
一定の推論結果に達するまでのStep数と読み取ればいいのか、であれば推論に使用したSamplerは何なのかが分からないので、どう読み取ればいいのか分かりませんが、「同品質の推論結果に達するStep数での比較」という意味であれば、推論速度が非常に向上しています。
では、実際に推論してみましょう。
Stable Cascadeで推論してみる
もう簡単に導入できるように、有志の方がスクリプト組んでくれているので、そちらを利用してみます。
git cloneで任意の場所にクローンして、install.batを実行するだけで、実行環境を構築してくれます。
バッチ処理完了後は、generate_images.batを実行することで推論が可能になります。
以下条件で推論を実行してみました。
Size : 1024 x 1024
CFG Scale : 11
Prompt
4k wallpaper,highres,upper body,kawaii,cute,1 girl, solo,silver hair, long hair, bangs, flower_hair_ornament,orange eyes,bokeh
Negative Prompt
Size : 1024 x 1024
CFG Scale : 11
Prompt
4k wallpaper,highres,upper body,kawaii,cute,1 girl, solo,silver hair, long hair, bangs, flower_hair_ornament,orange eyes,bokeh
Negative Prompt
worst quality,low quality,normal quality,monochrome,grayscale,watermark,white letters,signature,username,text,error
Negative Promptを入れると、かなり推論が遅くなるようですが、実行環境が起因なのかは不明です。
おわりに
今回は、新しく発表されたStable Cascadeについてご紹介しました。
Stability AI社の方で、今後様々な学習コードなどを発表するそうなので、今後に期待ですね!
仕組みも学習コストがかなり低減されているとのことなので、チューニングやモデル製作をする方は是非動向をチェックしたい内容でした。
是非、皆さんも触れてみてください!