見出し画像

DRLX による Stable Diffusion の RLHF を試す

「DRLX」による「Stable Diffusion」の「RLHF」を試したので、まとめました。

【注意】 Google Colab Pro/Pro+ の A100 で動作確認しています。


1. DRLX

DRLX」は、「Cyper」が開発したStable Diffusion用のRLHFライブラリです。

「Carper」では、LLMをRLHFで学習するために「TRLX」を開発しました。そして今度は、diffusersモデルをRLHFで学習するために「DRLX」を開発しました。初期リリースでは、「DDPOトレーナー」とさまざまな「報酬モデル」が提供されています。

現在の制限事項は、次のとおりです。

・「StableDiffusion 1.4」のみでテストされているが、他のパイプラインも使用できる予定。
・サポートされているアルゴリズムは「DDPO」のみ。
・マルチGPUおよびマルチノード用のAccelerateも使用できるが未テスト。

2. 報酬モデル

報酬モデル」は、画像生成モデルの強化学習で報酬信号を生成するために使用されます。 通常、画像を取得し、何らかの報酬を返します。報酬を生成する際にプロンプトを使用するものもありますが、必須ではありません。

2-1. Toy Rewards

テスト用の報酬モデルです。

・AverageBlueReward : 画像の「青さ」に報酬を与える。
JPEGCompressability : 画像の JPEG 圧縮の可能性に報酬を与える。

2-2. Aesthetics

美的スコアが高い画像に報酬を与えるモデルです。

・Aesthetics : 美的スコアが高い画像に報酬を与える。
 CLIP と MLP を使用 (デフォルトではデバイスに配置されない)

2-3. Pickscore (WIP)

「PickAPic」の「PickScore」モデルを使用した報酬モデルです。

PickScoreModel : PickAPic の PickScore モデルで報酬を与える。

3. 学習の実行

Colabでの学習手順は、次のとおりです。

(1) パッケージのインストール。

# パッケージのインストール
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install git+https://github.com/CarperAI/DRLX.git
!pip install datasets

(2) 「ddpo_sd.yml」のダウンロード。
「DRLX」の設定ファイルのサンプルです。

# ddpo_sd.ymlのダウンロード
!wget https://raw.githubusercontent.com/CarperAI/DRLX/main/configs/ddpo_sd.yml

(3) 「ddpo_sd.yml」の編集。
学習時間がかかるため、total_samplesを100分の1 (5.0e+2) に変更します。

train:
  total_samples: 5.0e+4 # 50k

train:
  total_samples: 5.0e+2

(4) 学習の実行。
A100で1時間半かかりました (Accelerateでもっと高速に学習する方法ありそう)。
最後にモデルを保存してます。

from drlx.reward_modelling.aesthetics import Aesthetics
from drlx.pipeline.pickapic_prompts import PickAPicPrompts
from drlx.trainer.ddpo_trainer import DDPOTrainer
from drlx.configs import DRLXConfig

# 学習の実行
pipe = PickAPicPrompts()
config = DRLXConfig.load_yaml("ddpo_sd.yml")
trainer = DDPOTrainer(config)
trainer.train(pipe, Aesthetics())
trainer.save_pretrained("./drlx_model")

「DRLX」で保存されたモデルは、元のパイプラインと互換性があり、他のオブジェクトと同様にロードできます。

・drlx_model

・feature_extractor
・safety_checker
・scheduler
・text_encoder
・tokenizer
・unet
・vae
・module_index.json

4. 推論の実行

Colabでの推論の実行手順は、次のとおりです。

(1) 推論の実行。

from diffusers import StableDiffusionPipeline

# 画像生成の実行
pipe = StableDiffusionPipeline.from_pretrained("./drlx_model")
prompt = "A mad panda scientist"
image = pipe(prompt).images[0]
image.save("test.jpeg")

美的スコアは効果がわかりにくいので、まずはAverageBlueRewardでテストすればよかった気がする。



いいなと思ったら応援しよう!