WSL2でNitroFusionを試してみる
「高品質な単一ステップの画像生成を実現するための動的敵対的トレーニング手法」らしいNitorFusionを試してみます。
使用するPCはドスパラさんの「GALLERIA UL9C-R49」。スペックは
・CPU: Intel® Core™ i9-13900HX Processor
・Mem: 64 GB
・GPU: NVIDIA® GeForce RTX™ 4090 Laptop GPU(16GB)・GPU: NVIDIA® GeForce RTX™ 4090 (24GB)
・OS: Ubuntu22.04 on WSL2(Windows 11)
です。
1. 準備
環境構築
venv環境構築。
python3 -m venv nitrofusion
cd $_
source bin/activate
リポジトリをクローンしま…せん。Hugginf Face Spacesにある requirements.txt を参考に必要なパッケージのインストールです。
pip install diffusers torch transformers accelerate
2. 流し込むコード
GitHubのREADMEにあるコードをベースにして、以下のようなコードにしました。変更点は、
プロンプトを指定できるように関数を定義
RealismモデルとVibrantモデルのそれぞれを、関数qr、関数qvで呼び出せるように
カレントディレクトリにファイル保存
これをファイル名 infer.py として保存します。
from diffusers import LCMScheduler
class TimestepShiftLCMScheduler(LCMScheduler):
def __init__(self, *args, shifted_timestep=250, **kwargs):
super().__init__(*args, **kwargs)
self.register_to_config(shifted_timestep=shifted_timestep)
def set_timesteps(self, *args, **kwargs):
super().set_timesteps(*args, **kwargs)
self.origin_timesteps = self.timesteps.clone()
self.shifted_timesteps = (self.timesteps * self.config.shifted_timestep / self.config.num_train_timesteps).long()
self.timesteps = self.shifted_timesteps
def step(self, model_output, timestep, sample, generator=None, return_dict=True):
if self.step_index is None:
self._init_step_index(timestep)
self.timesteps = self.origin_timesteps
output = super().step(model_output, timestep, sample, generator, return_dict)
self.timesteps = self.shifted_timesteps
return output
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import re
# Load base model ID and repository
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ChenDY/NitroFusion"
# Pre-load models and schedulers for Realism and Vibrant
models = {}
for mode, ckpt, shifted_timestep in [
("Realism", "nitrosd-realism_unet.safetensors", 250),
("Vibrant", "nitrosd-vibrant_unet.safetensors", 500),
]:
# Load UNet model
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
# Configure scheduler
scheduler = TimestepShiftLCMScheduler.from_pretrained(
base_model_id,
subfolder="scheduler",
shifted_timestep=shifted_timestep
)
scheduler.config.original_inference_steps = 4
# Initialize pipeline
pipe = DiffusionPipeline.from_pretrained(
base_model_id,
unet=unet,
scheduler=scheduler,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
# Store in dictionary
models[mode] = pipe
# Define functions
def qr(prompt):
"""
Generate an image using the Realism model.
Args:
prompt (str): The text prompt for the image generation.
"""
pipe_realism = models["Realism"]
image = pipe_realism(
prompt=prompt,
num_inference_steps=1, # Supports 1 - 4 inference steps.
guidance_scale=0,
).images[0]
# Display and save the image
image.show()
safe_filename = re.sub(r'[^\w\s-]', '', prompt).strip().replace(' ', '_')
image.save(f"./{safe_filename}_Realism.png")
def qv(prompt):
"""
Generate an image using the Vibrant model.
Args:
prompt (str): The text prompt for the image generation.
"""
pipe_vibrant = models["Vibrant"]
image = pipe_vibrant(
prompt=prompt,
num_inference_steps=1, # Supports 1 - 4 inference steps.
guidance_scale=0,
).images[0]
# Display and save the image
image.show()
safe_filename = re.sub(r'[^\w\s-]', '', prompt).strip().replace(' ', '_')
image.save(f"./{safe_filename}_Vibrant.png")
# Example usage
prompt = "a photo of a cat"
qr(prompt) # Generates an image using the Realism model
qv(prompt) # Generates an image using the Vibrant model
カスタムクラス TimestepShiftLCMScheduler がポイント。タイムステップを「シフト」させることで、タイムステップの進行を再スケーリングさせています。Realismは 250、Vibrantは 500と、モデル毎にシフトするタイムステップを変更できるよう設計されてます。
3. 試します
では試しましょう。
CUDA_VISIBLE_DEVICES=0 python -i ./infer.py
RealismモデルとVibrantモデルの両方を読み込んでもぎりぎり 23.5GB と溢れず。VRAMが溢れてしまっている方は、いずれかのモデルだけをロードするようにしてください。
実際の生成は、
100%|██████ (略) █████| 1/1 [00:00<00:00, 12.07it/s]
00:00 ・・・ ゼロ秒!?
生成された画像がこちら。
1秒未満だと 0秒となってしまうので、きちんと計測しましょう。
image.show()とimage.save()をコメントアウトして、適当なプロンプトを与えるプログラムを作成。
import random
import time
# サンプル単語リスト(プロンプト生成用)
words = ["cat", "dog", "mountain", "river", "car", "sky", "flower", "bird", "sunset", "beach",
"forest", "city", "person", "portrait", "animal", "abstract", "fantasy", "robot", "spaceship", "alien"]
# プロンプトをランダムに生成
def generate_random_prompt():
return "a " + " ".join(random.choices(words, k=3)) # 単語を3つ組み合わせてプロンプトを生成
# qr関数を100回呼び出し、時間を測定
total_start_time = time.time() # 全体の開始時刻
for i in range(100):
prompt = generate_random_prompt()
print(f"Generating image {i+1}: {prompt}") # 現在のプロンプトを表示
start_time = time.time() # 1回の開始時刻
qr(prompt) # qr関数を呼び出し
end_time = time.time() # 1回の終了時刻
elapsed_time = end_time - start_time # 1回の処理時間
print(f"Image {i+1} generated in {elapsed_time:.2f} seconds.\n")
total_end_time = time.time() # 全体の終了時刻
total_elapsed_time = total_end_time - total_start_time # 全体の処理時間
print(f"All 100 images generated in {total_elapsed_time:.2f} seconds.")
で、計測したところ、
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5.94it/s]
Image 1 generated in 1.02 seconds.
Generating image 2: a mountain portrait flower
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 11.91it/s]
Image 2 generated in 0.88 seconds.
Generating image 3: a abstract forest abstract
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.10it/s]
Image 3 generated in 0.88 seconds.
Generating image 4: a river cat river
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.04it/s]
Image 4 generated in 0.88 seconds.
Generating image 5: a person city portrait
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 11.84it/s]
Image 5 generated in 0.88 seconds.
(snip)
Generating image 99: a cat beach sky
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 11.83it/s]
Image 99 generated in 0.90 seconds.
Generating image 100: a car spaceship mountain
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 11.60it/s]
Image 100 generated in 0.90 seconds.
All 100 images generated in 99.93 seconds.
100枚あたり99.93秒。生成にかかる時間は1枚あたり 0.9秒 ぐらいですね。
4. まとめ
1秒かからない、というのが正しい表現ですね。