Stable Diffusionの学習コードを作る:5.ControlNet編
今回はControlNetの学習についてやっていきます。以下のような設定を増やすことで学習できるようにします。
controlnet:
train: true
resume: null # model file path
transformer_layers_per_block: false # default = false
global_average_pooling: false # default = false
DiffusionModel
create_controlnet
diffusersのControlNetModelを使います。最初から作る場合はfrom_unetで作れます。
真ん中らへんはTransformerを省略するためのものです。SDXLでモデルサイズを削減するためにhuggingfaceが提案したものです。空のTransformer省略ControlNetを作った後、ロード済みのControlNetのweightを適用します。
global_average_poolingはshuffle用ですが、学習時には使わないと思います。
def create_controlnet(self, config):
if config.resume is not None:
pre_controlnet = ControlNetModel.from_pretrained(config.resume)
else:
pre_controlnet = ControlNetModel.from_unet(self.unet)
if config.transformer_layers_per_block is not None:
down_block_types = tuple(["DownBlock2D" if l == 0 else "CrossAttnDownBlock2D" for l in config.transformer_layers_per_block])
transformer_layers_per_block = tuple([int(x) for x in config.transformer_layers_per_block])
self.controlnet = ControlNetModel.from_config(
pre_controlnet.config,
down_block_types=down_block_types,
transformer_layers_per_block=transformer_layers_per_block,
)
self.controlnet.load_state_dict(pre_controlnet.state_dict(), strict=False)
del pre_controlnet
else:
self.controlnet = pre_controlnet
self.controlnet.config.global_pool_conditions = config.global_average_pooling
forward
UNetの推論と統合します。diffusersの機能をそのまま使うだけなので、別に難しいことはないですね。新たにcontrolnet_hintという入力が増えています。これは値が[0, 1]でサイズが画像と同じ[b, 3, h, w]のてんさーです。
def forward(self, latents, timesteps, encoder_hidden_states, pooled_output, size_condition=None, controlnet_hint=None):
if self.sdxl:
if size_condition is None:
h, w = latents.shape[2] * 8, latents.shape[3] * 8
size_condition = torch.tensor([h, w, 0, 0, h, w]) # original_h/w. crop_top/left, target_h/w
size_condition = size_condition.repeat(latents.shape[0], 1).to(latents)
added_cond_kwargs = {"text_embeds": pooled_output, "time_ids": size_condition}
else:
added_cond_kwargs = None
if self.controlnet is not None:
assert controlnet_hint is not None, "controlnet_hint is required when controlnet is enabled"
down_block_additional_residuals, mid_block_additional_residual = self.controlnet(
latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=controlnet_hint,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)
else:
down_block_additional_residuals = None
mid_block_additional_residual = None
model_output = self.unet(
latents,
timesteps,
encoder_hidden_states,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
).sample
データセット
ControlNetのヒント画像を取り込むメソッドを増やします。上記にある通り[0, 1]のてんさーであり、ToTensor()をするだけです。引数も増やしてますが省略。
def get_control(self, samples, dir="control"):
images = []
transform = transforms.ToTensor()
for sample in samples:
image = Image.open(os.path.join(self.path, dir, sample + f".png")).convert("RGB")
images.append(transform(image))
images_tensor = torch.stack(images).to(memory_format=torch.contiguous_format).float()
return images_tensor
canny edgeの場合、計算が軽いのでわざわざ前処理済み画像を保存せずとも、学習中に計算して取得することもできます。そういった場合はカスタムデータセットを作りましょう。式の中身は全然わかっていません。
from modules.dataset import BaseDataset
import cv2
import os
import torch
import numpy as np
from torchvision import transforms
class CannyDataset(BaseDataset):
def get_control(self, samples, dir="control"):
images = []
transform = transforms.ToTensor()
for sample in samples:
# ref https://qiita.com/kotai2003/items/662c33c15915f2a8517e
image = cv2.imread(os.path.join(self.path, dir, sample + f".png"))
med_val = np.median(image)
sigma = 0.33 # 0.33
min_val = int(max(0, (1.0 - sigma) * med_val))
max_val = int(max(255, (1.0 + sigma) * med_val))
image = cv2.Canny(image, threshold1 = min_val, threshold2 = max_val)
image = image[:, :, None] # add channel
image = np.concatenate([image]*3, axis=2) # grayscale to rgb
images.append(transform(image))
images_tensor = torch.stack(images).to(memory_format=torch.contiguous_format).float()
return images_tensor
とれーなー
controlnetの準備はこんな感じ
def prepare_controlnet(self, config):
if config is None:
self.controlnet = None
self.controlnet_train = False
logger.info("コントロールネットはないみたい。")
return
self.diffusion.create_controlnet(config)
self.controlnet_train = config.train
self.diffusion.controlnet.to(self.device, self.train_dtype if self.controlnet_train else self.weight_dtype)
self.diffusion.controlnet.train(self.controlnet_train)
self.diffusion.controlnet.requires_grad_(self.controlnet_train)
logger.info("コントロールネットを作ったよ!")
他にgradient checkpointingをcontrolnetにも適用できるようにするとかそういうどうでもいい変更があります。
損失の計算ではヒント画像の読み込みを追加します。またdiffusionでcontrolnet_hintも追加で入力します。
if "controlnet_hint" in batch:
controlnet_hint = batch["controlnet_hint"].to(self.device)
else:
controlnet_hint = None
サンプル生成
sampleにもcontrolnet_hint関連を追加します。ファイルパスを直接指定することもできるようにします。損失とどうようdiffusionへの引数にも追加します。
if controlnet_hint is not None:
if isinstance(controlnet_hint, str):
controlnet_hint = Image.open(controlnet_hint).convert("RGB")
controlnet_hint = transforms.ToTensor()(controlnet_hint).unsqueeze(0)
controlnet_hint = controlnet_hint.to(self.device)
if guidance_scale != 1.0:
controlnet_hint = torch.cat([controlnet_hint] *2)
学習設定でヒント画像のファイルパスを指定すれば任意の画像でテストできるようになります。
validation_args:
prompt: "1girl, solo, sitting, blonde hair, red eyes , sailor collar, blue skirt, black thighhighs, room"
negative_prompt: "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name"
width: 832
height: 1216
controlnet_hint: "data/pose.png"
モデルファイルについて
diffusers形式になっています。sgm形式にするためには変換コードがありますがSD1/SD2用になります。SDXLはComfyUIだと中身のsafetensorsをそのまま使えるので変換コードはありません(SD1/SD2でも使えるのかな?)。
Control-LoRA
SDXLのControlNetはそのままだと2.5GBとかいうくそでかファイルになってしまいますが、学習差分を特異値分解によってLoRAに変換することができます。ControlNet固有のモジュール(input hintやzero conv)はLoRAにせず、またbiasやLayerNormなどの重みがベクトルのものもLoRAにはしません(できません)。
変換コード用意しておきました(どこか別のところにもありそうだけど。。。)
https://github.com/laksjdjf/sd-trainer/blob/main/tools/create_control_lora.py