StableDiffusionのモデルサイズ削減を試してみる(3)
この記事の続きです。
色々と試してみたのですが、実用的な品質には程遠いかなぁというのが現状です。
暗黙の前提が
学習速度向上のためにUNetの途中でロス計算しているのですが、損失関数はMSEをそのまま使っています。
MSEのMは Mean。平均ですね……
各チャンネルについて単に平均を取っているということで、各チャンネルが等しく重要であるという前提での計算になってますね……
恐らく中間層では、チャンネルごとの重要度は違いますね。
チャンネルごとの重要度が異なるならば、ブロック単位での学習時は
・ロス計算箇所のチャンネルごとの重要度を調べる
・ロス計算時に重要度を重みとして乗算する
という手法が良さそうな感じです。
自動微分って便利ですよね
チャンネルごとの重要度については、PyTorchのAutoGradを使えば計算できそうです。
・勾配を計算用に全要素 1 のベクトルを用意
・計測箇所で、上記ベクトルと乗算
・UNet出力とゼロをMSEでロス計算してBackward
・勾配計算用ベクトルの勾配が更新される
・重みの更新は行わない
UNet出力をゼロに近づける方向で勾配が計算されるはずですので、出力への寄与度が高いチャンネルは勾配が大きくなる……はず。
まあ、未実装なんですが
とりあえず、この手法でチャンネルごとの重みを計測できるような実装進めてみようかと考えてます。