見出し画像

StableDiffusionのスケジューラを切り替える

前回はStableDiffusionの学習モデルを切り替える方法を紹介しました。

しばらく使っていたところ、画像のノイズを除去するスケジューラも切り替えたくなったこと、画像生成に使ったモデルやスケジューラの情報もデータベースに残したいということでその対応を追加しました。


1. スケジューラの切り替え

細かい部分は把握していませんが、StableDiffusionを扱うときにPythonでインポートするdiffuserというライブラリには、画像を生成するときにノイズ除去するやり方をいくつも実装しています。diffusers v0.19では以下の14種類のスケジューラを使うことができます。

DPMSolverSinglestepScheduler
DDIMScheduler
KDPM2DiscreteScheduler
DEISMultistepScheduler
EulerAncestralDiscreteScheduler
PNDMScheduler
UniPCMultistepScheduler
DDPMScheduler
EulerDiscreteScheduler
LMSDiscreteScheduler
DPMSolverMultistepScheduler
HeunDiscreteScheduler
DPMSolverSDEScheduler
KDPM2AncestralDiscreteScheduler

pipeline.scheduler.compatiblesで出力

スケジューラには処理自体の早いもの遅いもの、ステップ数を上げないとノイズが取り切れないもの、結果が似ているものがあるため、このなかから以下の4つを選べるようにしました。

  • DDIM

  • DDPM

  • EulerAncestralDiscrete

  • UniPCMultistep

2. データベースの拡張

従来のテーブルに使用したモデルを記録するmodel、使用したスケジューラを記録するschedulerを追加しました。空のデータベースを下記からダウンロードできるようにしました。

3. Pythonコードの修正

変更点は以下の部分です。

3-1. diffusersのスケジューラをインポートに追加

DDIM、DDPM、EulerAncestralDiscrete、UniPCMultistepを切り替えできるようにするため、この4つをインポートします。もともとStableDiffusionPipelineとEulerAncestralDiscreteはインポートしていましたので追加はDDIM、DDPM、UniPCMultistepです。

from diffusers import (
    StableDiffusionPipeline,
    DDIMScheduler,
    DDPMScheduler,
    EulerAncestralDiscreteScheduler,
    UniPCMultistepScheduler,
)

3-2. スケジューラー用のコンボボックスを作成

やり方は学習モデルの切り替えの時と同じです。コンボボックスに表示するテキストと選択前に表示する初期設定を行います。

# スケジューラ選択用のコンボボックス
combobox_scheduler = customtkinter.CTkComboBox(root,values=["DDIM","DDPMS","EulerAncestralDiscrete","UniPCMultistep"])
combobox_scheduler.set("EulerAncestralDiscrete")  # set initial value
combobox_scheduler.place(relx=0.5, rely=0.03, relheight=0.05, relwidth=0.4)

データベースから画像を読み込んだ際に使用モデルやスケジューラをコンボボックスに表示します。

 combobox_model.set(pictinfo.model[0])
 combobox_scheduler.set(pictinfo.scheduler[0])

3-3. スケジューラー用のリストを作成

コンボボックスで選択したスケジューラをStableDiffusionのPipelineに与えるためリストを作成しました。

    scheduler = combobox_scheduler.get()
    scheduler_list = {
        'DDIM' : DDIMScheduler,
        'DDPMS' : DDPMScheduler,
        'EulerAncestralDiscrete' : EulerAncestralDiscreteScheduler,
        'UniPCMultistep' : UniPCMultistepScheduler,
    }

#-------------------------------

    model_id = combobox_model.get()
    pipe = StableDiffusionPipeline.from_ckpt(
        model_id,
        load_safety_checker = False,
        extract_ema = True,
        torch_dtype = torch.float16)
    pipe.scheduler = scheduler_list[scheduler].from_config(pipe.scheduler.config)

3-4.データベースへの拡張

今回は新たに使用した学習モデルとスケジューラをデータベースに書き込むため書き込み項目を追加しました。

sql = 'insert into history (filename, model, scheduler, prompt, negative_prompt, guidance_scale, Inference_steps, seeds ) values (?,?,?,?,?,?,?,?)'
        data = (pict_name, model_id, scheduler, prompt, negative_prompt, scale, steps, seed)

4.完成したPythonコード

以上の修正を行ったコードは以下のとおりです。こちらのコードと同じディレクトリに学習モデル(BracingEvoMix_v1.safetensors)やデータベースファイル(generate_history.db)を配置して実行してください。

import os
import customtkinter
from tkinter import filedialog
from tkinter import messagebox
import random
import math
import datetime
import sqlite3
import pandas as pd
import torch
from torch import autocast
from diffusers import (
    StableDiffusionPipeline,
    DDIMScheduler,
    DDPMScheduler,
    EulerAncestralDiscreteScheduler,
    UniPCMultistepScheduler,
)

def btn_click_generate():
    scheduler = combobox_scheduler.get()
    scheduler_list = {
        'DDIM' : DDIMScheduler,
        'DDPMS' : DDPMScheduler,
        'EulerAncestralDiscrete' : EulerAncestralDiscreteScheduler,
        'UniPCMultistep' : UniPCMultistepScheduler,
    }

    model_id = combobox_model.get()
    try:
        pipe = StableDiffusionPipeline.from_ckpt(
            model_id,
            load_safety_checker = False,
            extract_ema = True,
            torch_dtype = torch.float16)
        pipe.scheduler = scheduler_list[scheduler].from_config(pipe.scheduler.config)
        pipe.load_textual_inversion("sayakpaul/EasyNegative-test",weight_name="EasyNegative.safetensors", token="EasyNegative")
        pipe.to("cuda")

        prompt = prompt_data.get()
        negative_prompt = negative_prompt_data.get()
        if prompt == "":
            prompt = "((masterpiece:1.4, best quality)), ((masterpiece, best quality)),  (photo realistic:1.4), woman, female, Beautiful face, bright eyes,"
        if negative_prompt == "":
            negative_prompt = "Easy Negative (worst quality:2) (low quality:2) (normal quality:2) lowers normal quality ((monochrome)) ((grayscale)),skin spots,acnes,skin blemishes,age spot,ugly face,fat,missing fingers, extra fingers, extra arms,open chest,thick eyebrows, huge breasts, open chest"

        generate = generate_data.get()
        if generate == "":
            generate = 1
        else:
            generate = int(generate)

        conn = sqlite3.connect('generate_history.db')

        for i in range(generate):
            if inference_data.get() == "":
                steps = random.randrange(20, 40, 1)
            else:
                steps = int(inference_data.get())
            if guidance_data.get() == "":
                scale = math.floor((random.uniform(7, 10))*100)/100
                #scale = math.floor((random.uniform(steps/3, steps/4))*100)/100
            else:
                scale = float(guidance_data.get())

            if seeds_data.get() == "":
                seed = random.randrange(0, 4294967295, 1)
            else:
                seed = int(seeds_data.get())

            with autocast("cuda"):
                generator = torch.Generator("cuda").manual_seed(seed)
                image = pipe(
                    prompt = prompt,
                    negative_prompt = negative_prompt,
                    generator = generator,
                    num_inference_steps = steps,
                    guidance_scale = scale,
                    width = 768,height = 768,
                    ).images[0]

            # ファイル作成時間を求める
            t_delta = datetime.timedelta(hours=9)
            JST = datetime.timezone(t_delta, 'JST')
            now = datetime.datetime.now(JST)

            pict_name = f"{now.strftime('%Y%m%d%H%M%S')}_{scale}_{steps}_{seed}.png"
            image.save(pict_name)

            # 生成された画像情報をデータベースに書き込む
            cur = conn.cursor()
            sql = 'insert into history (filename, model, scheduler, prompt, negative_prompt, guidance_scale, Inference_steps, seeds ) values (?,?,?,?,?,?,?,?)'
            data = (pict_name, model_id, scheduler, prompt, negative_prompt, scale, steps, seed)
            cur.execute(sql,data)
            conn.commit()
        conn.close()
    except:
        messagebox.showerror('ファイルがありません', '選択した学習モデルがパスにありません')

def btn_click_open():
    selfile = []

    # エクスプローラを開いて画像ファイルを指定する
    filename = filedialog.askopenfilename(title = "画像ファイルを開く",
    filetypes = [("Image file", ".png"), ("PNG", ".png")],
    initialdir = "./")
    selfile.append(os.path.basename(filename))

    # データベースから選択した画像情報を検索し、その情報をGUIに書き込む
    conn = sqlite3.connect('generate_history.db')
    pictinfo = pd.read_sql_query('SELECT * FROM history WHERE filename == ?', conn, params=selfile)
    if len(pictinfo) == 0:
        messagebox.showerror('DBエラー', '画像を選択していないか、選択した画像情報がデータベースに存在しません')
    else:
        prompt_data.delete(0, customtkinter.END)
        negative_prompt_data.delete(0, customtkinter.END)
        guidance_data.delete(0, customtkinter.END)
        inference_data.delete(0, customtkinter.END)
        seeds_data.delete(0, customtkinter.END)

        combobox_model.set(pictinfo.model[0])
        combobox_scheduler.set(pictinfo.scheduler[0])
        prompt_data.insert(customtkinter.END,pictinfo.prompt[0])
        negative_prompt_data.insert(customtkinter.END,pictinfo.negative_prompt[0])
        guidance_data.insert(customtkinter.END,pictinfo.guidance_scale[0])
        inference_data.insert(customtkinter.END,pictinfo.Inference_steps[0])
        seeds_data.insert(customtkinter.END,pictinfo.seeds[0])
    conn.close()

customtkinter.set_appearance_mode("System")  # Modes: system (default), light, dark
customtkinter.set_default_color_theme("blue")  # Themes: blue (default), dark-blue, green

root = customtkinter.CTk()
root.geometry('700x500')
root.title('Stable Diffusion GUI')

# 学習モデル選択用のコンボボックス
combobox_model = customtkinter.CTkComboBox(root,values=["BracingEvoMix_v1.safetensors", "BracingEvoMix_Another_v1.safetensors", "Brav6.safetensors", "chilloutmix_NiPrunedFp32Fix.safetensors"])
combobox_model.set("BracingEvoMix_v1.safetensors")  # set initial value
combobox_model.place(relx=0.05, rely=0.03, relheight=0.05, relwidth=0.4)


# スケジューラ選択用のコンボボックス
combobox_scheduler = customtkinter.CTkComboBox(root,values=["DDIM","DDPMS","EulerAncestralDiscrete","UniPCMultistep"])
combobox_scheduler.set("EulerAncestralDiscrete")  # set initial value
combobox_scheduler.place(relx=0.5, rely=0.03, relheight=0.05, relwidth=0.4)

# プロンプトラベル
prompt_lbl = customtkinter.CTkLabel(root, text="プロンプト(未入力時はデフォルト値)", width=60, justify="left", anchor="w")
prompt_lbl.place(relx=0.05, rely=0.1, relheight=0.05, relwidth=0.5)

# プロンプト入力テキスト
prompt_data = customtkinter.CTkEntry(root,placeholder_text="", width=20, height=25, border_width=2, corner_radius=6)
prompt_data.place(relx=0.05, rely=0.15, relheight=0.07, relwidth=0.9)

# ネガティブプロンプト・ラベル
negative_prompt_lbl = customtkinter.CTkLabel(root, text='ネガティブプロンプト(未入力時はデフォルト値)', width=60, justify="left", anchor="w")
negative_prompt_lbl.place(relx=0.05, rely=0.23, relheight=0.05, relwidth=0.5)

# ネガティブ・プロンプト入力テキスト
negative_prompt_data = customtkinter.CTkEntry(root,placeholder_text="", width=25, height=25, border_width=2, corner_radius=6)
negative_prompt_data.place(relx=0.05, rely=0.28, relheight=0.07, relwidth=0.9)

# 生成枚数ラベル
generate_lbl = customtkinter.CTkLabel(root, text='生成枚数(未入力時は1枚作成)', width=60, justify="left", anchor="w")
generate_lbl.place(relx=0.05, rely=0.38, relheight=0.05, relwidth=0.5)

# 生成枚数入力テキスト
generate_data = customtkinter.CTkEntry(root,placeholder_text="", width=25, height=25, border_width=2, corner_radius=6)
generate_data.place(relx=0.05, rely=0.43, relheight=0.05, relwidth=0.2)

# Guidance scaleラベル
guidance_lbl = customtkinter.CTkLabel(root, text='プロンプトと出力画像の類似度(未入力時はランダム)', width=60, justify="left", anchor="w")
guidance_lbl.place(relx=0.05, rely=0.48, relheight=0.05, relwidth=0.5)

# Guidance scale(プロンプトと出力画像の類似度)入力テキスト
guidance_data = customtkinter.CTkEntry(root,placeholder_text="", width=25, height=25, border_width=2, corner_radius=6)
guidance_data.place(relx=0.05, rely=0.53, relheight=0.05, relwidth=0.2)

# Inference stepsラベル
inference_lbl = customtkinter.CTkLabel(root, text='画像生成に費やすステップ数(未入力時はランダム)', width=60, justify="left", anchor="w")
inference_lbl.place(relx=0.05, rely=0.58, relheight=0.05, relwidth=0.5)

# Inference steps(画像生成に費やすステップ数)入力テキスト
inference_data = customtkinter.CTkEntry(root,placeholder_text="", width=25, height=25, border_width=2, corner_radius=6)
inference_data.place(relx=0.05, rely=0.63, relheight=0.05, relwidth=0.2)

# Seeds値ラベル
seeds_lbl = customtkinter.CTkLabel(root, text='シード値(未入力時はランダム)', width=60, justify="left", anchor="w")
seeds_lbl.place(relx=0.05, rely=0.68, relheight=0.05, relwidth=0.5)

# Seeds値 入力テキスト
seeds_data = customtkinter.CTkEntry(root,placeholder_text="", width=25, height=25, border_width=2, corner_radius=6)
seeds_data.place(relx=0.05, rely=0.73, relheight=0.05, relwidth=0.2)

# 画像生成ボタン
btn_gen = customtkinter.CTkButton(root, text='画像生成', command=btn_click_generate)
btn_gen.place(relx=0.05, rely=0.85, relheight=0.07, relwidth=0.15)

# 画像選択ボタン
btn_open = customtkinter.CTkButton(root, text='画像を開いて情報を取得する', command=btn_click_open)
btn_open.place(relx=0.35, rely=0.85, relheight=0.07, relwidth=0.3)

root.mainloop()


いいなと思ったら応援しよう!