Stable Diffusion XLで生成画像の破綻が起きる問題の解析と対策
Animagine XL V3が出始めた辺りから、特定のプロンプトを入力すするとアーティファクトが出てくると言う話がポツポツと聞こえてきました。アーティファクトって言われると古代兵器を思い浮かべてしまうわけですが、ここでいうアーティファクトというのはノイズではないけれどなんか変な画像という意味ですね。
細けぇ事はいいんだよ解決方法だけ教えろという方は最後の方にある説明を読んで下さい。
入力するプロンプトの順番や強度を入れ替えると出てきたりするが、発生条件が解らないという話で、どうにも雲をつかむような話でしたが、原因を突き止めたという話が回ってきました。
Web-UI系列では発生し、Comfyなどでは発生しない。強度を強くしたときに似たようなアーティファクトが現れると言うことから、Web-UIの強度計算にバグがあるという推測です。
なるほど、と思う反面、妙だなとも思いました。というのも、Web-UIの強度計算にバグがあるのならば、これまでにも指摘されていたはずですが、なぜだかAnimagineが流行しだしたあたりから報告されるようになったからです。強度計算のバグというと、例えば(mikan:1.3)としたときにはmikanのテンソルが1.3倍に強められるわけですが、この計算式を積み上げたところで強度が10みたいになるのはどうにもおかしな話です。(kotatu,(mikan:1.3), cat:1.5)のような複雑なプロンプトであるとしても、掛け合わせて強度が2以上になるようにならなければ発生しないはずです。
そもそも強度に関係がなく、プロンプトの順番を変えただけでも出るなどの報告もありました。ということで、様々な報告を加味すると強度だけでは説明できない事象が多いように思えます。
ではなにが原因だろうかとコードを巡る冒険に出かけることにします。プロンプトの処理をするWeb-UIのコードはそこかしこに行ったり来たりして結構難解なんですよね。
まず入力されたプロンプトは強度ごとに分割されて処理されます。
a girl in (kotatu:1.3) eating mikan
のようなプロンプトだと
a girl in : 1.0
kotatu : 1.3
eating mikan : 1.0
のようになるわけです。強度計算はここで行われるわけですが、バグが入り込む余地はないように思えました。
その後、プロンプトはtokenの形に変換され、transformersに渡されテンソルになります。次にテンソルに対してそれぞれの強度を掛け合わせることで生成に使われるテンソルとなるわけです。怪しいのは強度を掛け合わせる周辺の処理だと感じたのでそこを見てみましょう。
以下はそのコードであって、GPTちゃんのわかりやすい解説が付いています。modulesフォルダのsd_hijack_clip.pyの中です。
def process_tokens(self, remade_batch_tokens, batch_multipliers):
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
# remade_batch_tokensをPyTorchの配列に変換し、指定されたデバイスに移動する
if self.id_end != self.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(self.id_end)
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
# SD1とSD2の扱いを区別するための条件分岐。SD1ではパディングとテキスト終了のためのトークンが同じだが、SD2では異なる。
# このブロックでは、各バッチ内でテキスト終了トークンの後の位置をパディングトークンで埋める。
z = self.encode_with_transformers(tokens)
# トークンをTransformerモデルによってエンコードする
pooled = getattr(z, 'pooled', None)
# エンコードされた出力からプールされた表現を取得する。プールされた表現が存在しない場合はNoneを返す。
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
# エンコードされた出力の平均値を計算し、オリジナルの平均値として保持する
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
# バッチの各要素に乗算するための倍数を適用する。これにより、エンコードされた表現が変更される。
new_mean = z.mean()
# 変更後のエンコードされた出力の平均値を計算する
z = z * (original_mean / new_mean)
# オリジナルの平均値を復元するために、新しい平均値で調整する。これにより、アーティファクトを防ぐことが目的。
if pooled is not None:
z.pooled = pooled
# プールされた表現が存在する場合、変更後のエンコードされた出力にそれを再割り当てする。
return z
# 最終的に調整されたエンコードされた出力を返す
怪しい処理がありますね。
z = z * (original_mean / new_mean)
プログラミングにおいて割り算は常に想定外の出力を生み出すわけです。この式を見ると、original_meanに比べてnew_meanの値が小さければ、テンソルの強度が想定外に大きくなってしまう可能性があります。元のコードによると、むしろアーティファクトを防ぐためだと書かれていますが、これはXLのモデルが出る前のコードなので、XLと相性が良くない可能性もありますね。
ここでoriginal_meanとnew_meanは強度計算が行われる前のテンソルの平均値と、強度計算が行われた後のテンソルの平均値です。通常、強度は1以上の値が使われるので、(original_mean / new_mean)の値は1から大きく外れないはずです。おそらくはanimagineにおいて、テンソルがマイナスに大きく振れてしまうようなtokenがあり、そのtokenから計算されたテンソルが混ざっていると、強度を強くしてしまったときに(original_mean / new_mean)の計算結果が大きくなってしまい、結果として破綻が起きてしまうのではないでしょうか。あるいは、強度計算によってoriginal_mean と new_meanの符号が反転してしまう可能性も考えられます。
と言うわけで、この仮説が正しいかどうか問題の再現に挑戦してみます。闇雲にあやしいtokenを探すのもあれなので、すべてのtokenから生成されるテンソルのmeanを計算します。すると、プラスになるテンソルとマイナスになるテンソルがあることが解りました(これはSD1.Xでも同様のようです)。
「eni」という言葉がマイナスのテンソルの中で一番強く、1がプラスのテンソルの中で一番強いと言うことが解りました。eniってなに?
そこでここらを組み合わせて生成すると、問題を再現することができました。1とeniだけでは発生しませんでしたが、(1 2 am:1.3) eniというプロンプトを入力したところ破綻が発生しました。
ここで、計算されるoriginal_mean と new_meanの値を見てみると、original_mean -0.0003/、 new_meanが-7.8293e-05となりました。そうすると、倍率は3.83になって、破綻に繋がるわけですね。さらに強度を1.3から上げていくと、正負の反転も起こるようでした。このプロンプトの組み合わせがちょうど正負の境目になるようです。ここから少しでもプロンプトを変えると破綻は起きません。これはこれまでの報告とも一致しますね。よって、以下のコードを無効化すればいいわけですね。
z = z * (original_mean / new_mean)
と言うわけで無効化してみます。
直りましたね。予想通りです。無効化した事による問題もなさそうです。さて、A1111とforgeその他もろもろにissueを出してって言うのを待つのもあれなので、これを解決するscriptを作りました。いつも通りインストールして、settingsから設定すれば無効化できるようにしておきました。
https://github.com/hako-mikan/sd-webui-prevent-artifact