DreamBooth Stable Diffusion を試す
愛犬の合成画像を生成できる画像生成AI「DreamBooth」の「Stable Diffusion」版を作ってる人がいたので、愛猫の合成画像の生成に挑戦してみました。
GPUのメモリは32GB以上必要です。
1. DreamBooth
「DreamBooth」は、数枚の被写体画像 (例 : 特定の犬) と対応するクラス名 (例 : 犬) を与えてファインチューニングすることで、Text-to-Imageモデルに新たな被写体を学習させる手法です。愛犬の合成画像を生成できる画像生成AIとして話題になりました。
オリジナルの「DreamBooth」は「Imagen」をベースにしていますが、この実装は「Textual Inversion」をベースにしています。
2. DreamBooth Stable Diffusion
「DreamBooth Stable Diffusion」その名の通り、「DreamBooth」の「Stable Diffusion」の実装です。オリジナルの「DreamBooth」のText-to-Imageモデルは「Imagen」をベースにしていますが、この実装は「Textual Inversion」をベースにしています。
3. 入力画像の準備
今回はうちの猫を学習させるため、写真を7枚ほど用意しました。
4. 正則化画像の準備
「DreamBooth」は、正則化のための一連の画像を必要とするため、クラスとなる猫の画像を200枚ほど用意しました。
5. DreamBooth Stable Diffusion のインストール
「Stable Diffusion」のインストール方法と同じです。
(1) PyTorchのインストール。
(2) パッケージのインストール。
# パッケージのインストール
pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
pip install pytorch_lightning tensorboard==2.8 omegaconf einops taming-transformers==0.0.1 clip transformers kornia test-tube
pip install diffusers invisible-watermark
(3) Dreambooth Stable Diffusion のインストール。
git clone https://github.com/XavierXiao/Dreambooth-Stable-Diffusion.git
cd Dreambooth-Stable-Diffusion
pip install -e .
mkdir inputs reg_inputs outputs checkpoint
(4) inputsに入力画像、reg_inputsに正則化画像、checkpointにStable Diffusionのチェックポイント(sd-v1-4-full-ema.ckpt)を配置。
Stable Diffusionのチェックポイントは、以下のサイトで入手できます。
(5) 「Dreambooth-Stable-Diffusion/configs/stable-diffusion/v1-finetune_unfrozen.yaml」の編集。
ファインチューニングの各種パラメータを調整します。GPUのメモリが足りないのでその対策になります。
学習するステップ数を増やしたい場合は、max_stepsを増やします。
学習するほど特徴を捉えますが、プロンプト編集が効きにくくなります。
6. 学習の実行
以下のコマンドで学習を実行します。
python main.py \
--no-test \
--base configs/stable-diffusion/v1-finetune_unfrozen.yaml \
-t \
--actual_resume ./checkpoint/sd-v1-4-full-ema.ckpt \
-n cat \
--gpus 0, \
--data_root ./inputs \
--reg_data_root ./reg_inputs \
--class_word cat
学習には15分ほどかかります。「./logs/<job_name>/checkpoints」に2つのチェックポイントが保存されます。1つは500ステップで、もう1つは最終ステップです。通常、500ステップで十分に機能します。
「学習率」のデフォルトは「1.0e-6」です。これは、Dreamboothの論文で「1.0e-5」を使用すると編集性が低下することがわかったためです。
「Dreambooth」には、「プレースホルダーワード[V]」と呼ばれる識別子が必要になります。Dreamboothの論文では、T5-XXLトークナイザーで珍しい単語を使用しています。この実装では、ランダムな単語「sks」をハードコードしています。「dog」で一般的な犬、「sks dog」で特定の犬を、プロンプトで指示できます。
7. 推論の実行
以下のコマンドで推論を実行します。
python scripts/stable_txt2img.py \
--n_samples 8 \
--n_iter 1 \
--scale 7.0 \
--ddim_eta 0.0 \
--ddim_steps 50 \
--ckpt ./logs/inputs2022-09-11T21-20-34_cat/checkpoints/last.ckpt \
--prompt "photo of sks cat on the moon " \
--seed 13167
「outputs」に月面に行ったうちの猫の画像が生成されます。
・photo of sks cat on the moon (左は本物)
実際には、猫のブチの模様が安定しないので、まだまだ学習データやパラメータに調整が必要そうです。
・photo of sks cat
8. 参考
・DreamBoothとTextual Inversionの比較
・22/09/18 Stable Diffusionの追加トレーニングについてのサーベイ
・22/09/26 Stable Diffusion追加学習の記録