見出し画像

DreamBooth Stable Diffusion を試す

愛犬の合成画像を生成できる画像生成AI「DreamBooth」の「Stable Diffusion」版を作ってる人がいたので、愛猫の合成画像の生成に挑戦してみました。

・DreamBooth Stable Diffusion

GPUのメモリは32GB以上必要です。

1. DreamBooth

「DreamBooth」は、数枚の被写体画像 (例 : 特定の犬) と対応するクラス名 (例 : 犬) を与えてファインチューニングすることで、Text-to-Imageモデルに新たな被写体を学習させる手法です。愛犬の合成画像を生成できる画像生成AIとして話題になりました。

オリジナルの「DreamBooth」は「Imagen」をベースにしていますが、この実装は「Textual Inversion」をベースにしています。

2. DreamBooth Stable Diffusion


「DreamBooth Stable Diffusion」その名の通り、「DreamBooth」の「Stable Diffusion」の実装です。オリジナルの「DreamBooth」のText-to-Imageモデルは「Imagen」をベースにしていますが、この実装は「Textual Inversion」をベースにしています。

3. 入力画像の準備

今回はうちの猫を学習させるため、写真を7枚ほど用意しました。

4. 正則化画像の準備

「DreamBooth」は、正則化のための一連の画像を必要とするため、クラスとなる猫の画像を200枚ほど用意しました。

5. DreamBooth Stable Diffusion のインストール

「Stable Diffusion」のインストール方法と同じです。

(1) PyTorchのインストール。

(2) パッケージのインストール。

# パッケージのインストール
pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
pip install pytorch_lightning tensorboard==2.8 omegaconf einops taming-transformers==0.0.1 clip transformers kornia test-tube
pip install diffusers invisible-watermark

(3) Dreambooth Stable Diffusion のインストール。

git clone https://github.com/XavierXiao/Dreambooth-Stable-Diffusion.git
cd Dreambooth-Stable-Diffusion
pip install -e .
mkdir inputs reg_inputs outputs checkpoint

(4) inputsに入力画像、reg_inputsに正則化画像、checkpointにStable Diffusionのチェックポイント(sd-v1-4-full-ema.ckpt)を配置。
Stable Diffusionのチェックポイントは、以下のサイトで入手できます。

(5) 「Dreambooth-Stable-Diffusion/configs/stable-diffusion/v1-finetune_unfrozen.yaml」の編集。
ファインチューニングの各種パラメータを調整します。GPUのメモリが足りないのでその対策になります。

・num_workers:2 → 1
・batch_frequency
: 500 → 99999
・max_images : 8 → 1 

学習するステップ数を増やしたい場合は、max_stepsを増やします。

・max_steps : 800 → 1000

学習するほど特徴を捉えますが、プロンプト編集が効きにくくなります。

6. 学習の実行

以下のコマンドで学習を実行します。

python main.py \
    --no-test \
    --base configs/stable-diffusion/v1-finetune_unfrozen.yaml \
    -t \
    --actual_resume ./checkpoint/sd-v1-4-full-ema.ckpt \
    -n cat \
    --gpus 0, \
    --data_root ./inputs \
    --reg_data_root ./reg_inputs \
    --class_word cat

学習には15分ほどかかります。「./logs/<job_name>/checkpoints」に2つのチェックポイントが保存されます。1つは500ステップで、もう1つは最終ステップです。通常、500ステップで十分に機能します。

学習率」のデフォルトは「1.0e-6」です。これは、Dreamboothの論文で「1.0e-5」を使用すると編集性が低下することがわかったためです。

「Dreambooth」には、「プレースホルダーワード[V]」と呼ばれる識別子が必要になります。Dreamboothの論文では、T5-XXLトークナイザーで珍しい単語を使用しています。この実装では、ランダムな単語「sks」をハードコードしています。「dog」で一般的な犬、「sks dog」で特定の犬を、プロンプトで指示できます。

7. 推論の実行

以下のコマンドで推論を実行します。

python scripts/stable_txt2img.py \
    --n_samples 8 \
    --n_iter 1 \
    --scale 7.0 \
    --ddim_eta 0.0 \
    --ddim_steps 50 \
    --ckpt ./logs/inputs2022-09-11T21-20-34_cat/checkpoints/last.ckpt \
    --prompt "photo of sks cat on the moon " \
    --seed 13167

「outputs」に月面に行ったうちの猫の画像が生成されます。

・photo of sks cat on the moon (左は本物)

実際には、猫のブチの模様が安定しないので、まだまだ学習データやパラメータに調整が必要そうです。

・photo of sks cat

8. 参考

・DreamBoothとTextual Inversionの比較

・22/09/18 Stable Diffusionの追加トレーニングについてのサーベイ

・22/09/26 Stable Diffusion追加学習の記録

9. 関連



いいなと思ったら応援しよう!