Latentの量子化
りょーしか?っていうのを理解したかったのと、学習時に保存するlatentの容量を減らせたらいーなーて思ったのでやってみます。
latent
VAEで画像をエンコードして、ちっちゃくしたやつです。今回はSD3.5のVAEでやってみよーと思います。特にSD3はチャンネルが16もあるのでサイズがおっきいです。
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3.5-large", subfolder="vae")
vae.to("cuda", dtype=torch.float32)
@torch.no_grad()
def decode_latents(latents):
images = []
for i in range(latents.shape[0]):
image = vae.decode(latents[i].unsqueeze(0)).sample
images.append(image)
images = torch.cat(images, dim=0)
images = (images / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
@torch.no_grad()
def encode_latents(images):
to_tensor_norm = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
images = torch.stack([to_tensor_norm(image) for image in images]).to("cuda", dtype=torch.float32)
latents = vae.encode(images).latent_dist.mode()
return latents
ランダム性を避けるためエンコード時はmode()にしておきます。
整数量子化
latentの値を整数で表すことを考えます。たとえばnビットに量子化したい場合、値が$${[-2^{n-1}, 2^{n-1}-1]}$$の範囲になるように変換・切り取りをして、整数に丸めることで量子化できます。変換の方法によっていろんな量子化がありますが、今回はだいたいなんらかのスケールやシフトを求めて、$${x * \mathrm{scale} + \mathrm{shift}}$$することを考えます。
実験はSD3.5のVAEをFP32で計算します。北斎ちゃんの絵からてきとーに4枚選んでエンコード、量子化、逆量子化して量子化前との平均二乗誤差を計算しました。
静的量子化
性的量子化とかえっちだなあ。
スケールやシフトを適当な定数を決めて行います。今回はなんとなくスケールを$${(2^{n-1}-1)/2}$$、シフトを0にしてみます。最大値が1.9いくつとかだったので、これならほとんど切り取られずにすみます。
# Noneにも対応するcast
def cast(x, dtype):
if isinstance(x, torch.Tensor):
return x.to(dtype)
return x
# intNbitの最小値
def nbit_min(bit):
return -(2**(bit-1))
# intNbitの最大値
def nbit_max(bit):
return 2**(bit-1)-1
# intNbitの範囲
def nbit_range(bit):
return (nbit_min(bit), nbit_max(bit))
def static_quantize(x, bit, scale, shift=0):
x_int = (x * scale + shift).round().to(dtype=torch.int64)
x_int = x_int.clamp(*nbit_range(bit))
return x_int
def static_dequantize(x_int, scale, shift=0):
x = (x_int.to(dtype=torch.float32) - shift) / scale
return x
def static_quantize_dequantize(x, bit, scale, shift=0):
x_int = static_quantize(x, bit, scale, shift)
x_dequant = static_dequantize(x_int, scale, shift)
bpw = bit # bit per weight
log_mse = ((x - x_dequant)**2).mean().log10().item()
return x_dequant, bpw , log_mse
bit = 4
latents_dequant, bpw, log_mse = static_quantize_dequantize(latents, bit, nbit_max(bit) / 2)
print(f"log_mse:{log_mse:.6f}, bpw:{bpw:.6f}")
8bitは結構きれいそうですが、4bitだとかなりざらざらーですね。
動的量子化
絶対値の最大値で正規化することを考えます。$${x_q=\mathrm{int}(x\times\frac{2^{n-1}-1}{\mathrm{max}(|x|)})}$$とすることで、すべての値が範囲に入り、切り取られることがなくなることで、精度上昇が期待できます。しかしもとに戻すためには最大値も保存しておく必要があります。
先ほどの方法と比較してみます。
横軸は値1つあたりのbit数で、縦軸は平均二乗誤差のlog10です。左下であればあるほどよいグラフになります。torchのfp8とも比較してみました。低bitではスケール値を定数にしたほうがよくなりましたが、定数によってもかわるのでよくわかりません。どちらにせよ整数量子化はfloatの低精度化よりも精度がよさそうです。
というのが一般的な方法なんですが、この方法だと$${[-(2^{n-1}-1), 2^{n-1}-1]}$$の範囲になってしまい、負の数が1個無駄になってしまいます。8bitくらいならいいですが、4bitだと問題がありそうです。そこで$${2^{n-1}-0.5}$$になるようスケーリングして、0.5を引くことで、全部使い切るようにします。たとえば4bitだったらまず$${[-15.5, 15.5]}$$の範囲にして、0.5を引いて$${[-16,15]}$$に変えます。
0.5シフトするかどうかで比較してみました。bit数が少ないと黄のシフトしたほうが精度が高くなりますが、13itとかになると逆転が起きたりしますが、実用的には8bit以下にしたいのでシフトしたほうがいいような気がします。
ブロック分け
最大値でスケーリングすると、一部の外れ値に弱くなります。極端な話1つの値だけが5000兆!!!だと、それで割ってしまったらそれ以外のすべての値が0になってしまいます。そこでいくつかのブロックに分けてそれぞれで最大値をとって量子化する方法が考えられます。ブロックのサイズを小さくすると精度は高くなりますが、保存すべき最大値の数も増えていきます。たとえばブロックサイズ64で、値をint8、最大値を32bitで保存すると、8+32 / 64 = 8.5bitの量子化になります。4bitとか8bitの量子化が、16bitの半分や4分の1より大きいサイズになるのはそういう理由です。
ブロックサイズやbitを変えながら比較してみます。
紫や黄といった今までの方法より、ブロック分けしたほうが左下に分布しており効率がいいことがわかります。int15でいっきにこわれちゃうのはなんでだろう?
量子化定数の型
上の表では最大値をfloat16にしましたが、何bitでもっておけばいいか実験しました。torchのfloat16, bfloat16, float32で、ブロックサイズはそれぞれ32, 32, 64とします。こうすると全体のデータ量は同じ大きさで比較できます。
bfloat16はだめそうですね。float16は10bitくらいまで優位ですが、12bitくらいからfloat32が逆転してます。bitが増えるほど最大値側の精度低下がより支配的になる感じですかね。これもやっぱり12bitとかは実用的にはいらないのでfloat16でよさそうですね。
最小値でシフト
今までは絶対値の最大値でスケーリングしていましたが、これだと平均が0から遠い値に対してはうまくいきません。たとえば値が全部正の数だったら、負の整数が使われなくなってしまいます。そこで最小値でシフトすることで、端から端までめいっぱい使えます。新たに最小値も保存する必要がでてきますが、その代わりに精度があがりそうです。
ブロックサイズ16, 32, 64で比較してみます。
最小値が必要な分、bpwが上がりますが、ブロックサイズを2倍にすれば同じ大きさで比較できます。10bitまではシフトありの方が優れていますね。また12bitで逆転現象が起きてます・・・。
今までの手法の実装
def dynamic_quantize(x, bit, do_shift=False, block_size=None, ctype=torch.float32, point_five_shift=False):
bsz = x.shape[0]
block_size = block_size or x.numel() // bsz
x = x.view(bsz, -1, block_size)
if do_shift:
x_max = x.max(dim=2, keepdim=True).values
x_min = x.min(dim=2, keepdim=True).values
c_scale = (2**bit - 1) / (x_max - x_min)
c_shift = - x_min * c_scale + nbit_min(bit)
else:
x_max = x.abs().max(dim=2, keepdim=True).values
c_shift = -0.5 if point_five_shift else 0
c_scale = (nbit_max(bit) -c_shift) / x_max
x_int = static_quantize(x, bit, c_scale, c_shift)
return x_int, cast(c_scale, ctype), cast(c_shift, ctype)
def dynamic_quantize_dequantize(x, bit, do_shift=False, block_size=None, ctype=torch.float32, point_five_shift=False):
shape = x.shape
block_size = block_size or x.numel() // shape[0]
x_int, c_scale, c_shift = dynamic_quantize(x, bit, do_shift, block_size, ctype, point_five_shift=point_five_shift)
x_dequant = static_dequantize(x_int, c_scale, c_shift).view(shape)
x_dequant = x_dequant.to(x.dtype)
ctype = c_scale.dtype
c_bit = torch.tensor([], dtype=ctype).element_size() * 8
c_bpw = c_bit / block_size * 2 if do_shift else c_bit / block_size
bpw = bit + c_bpw
log_mse = ((x - x_dequant)**2).mean().log10().item()
return x_dequant, bpw, log_mse
二重量子化
量子化定数を量子化することで、データサイズを削減します。さっきまでfloat16でしたが、これをint8(シフトありブロックサイズ45)にするとどうなるか実験してみました。
矢印で二重量子化の効果を表していますが、ビミョーな感じですね。QLoRAの論文では最大値をfloat32としていたので、8bitにすると削減効果が大きいですが、float16だと削減効果が限定的な感じがします。
def double_quantize_dequantize(x, bit, do_shift=False, block_size=None, cbit=None, cblock_size=None, ctype=torch.float32, point_five_shift=False):
shape = x.shape
x_int, c_scale, c_shift = dynamic_quantize(x, bit, do_shift, block_size, torch.float32, point_five_shift)
c_scale_dequant, c_scale_bpw, _ = dynamic_quantize_dequantize(c_scale, cbit, do_shift, cblock_size, ctype, point_five_shift)
if do_shift:
c_shift_dequant, c_shift_bpw, _ = dynamic_quantize_dequantize(c_shift, cbit, do_shift, cblock_size, ctype, point_five_shift)
else:
c_shift_dequant = c_shift
c_shift_bpw = 0
x_dequant = static_dequantize(x_int, c_scale_dequant, c_shift_dequant).view(shape)
bpw = bit + (c_scale_bpw + c_shift_bpw) / block_size
log_mse = ((x - x_dequant)**2).mean().log10().item()
return x_dequant, bpw, log_mse
Normal Float 4
QLoRAで提唱されたらしいやつで、いままでのようにアフィン変換でなく、正規分布のように0付近を重視した量子化になります。ニューラルネットワークのパラメータは正規分布でサンプリングされたものっぽくなるらしいのでこうするといいとか。
というわけでやってみましたが、今回はintに比べて精度向上できませんでした。latentには向いていないのか、それとも実装が違うのかな。
NF4_VALUES = [
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224,
0.44070982933044434,0.5626170039176941, 0.7229568362236023, 1.0,
]
def nf4_quantize_dequantize(x, block_size=None, ctype=torch.float32):
block_size = block_size or x.numel()
q = torch.tensor(NF4_VALUES).to(x)
x = x.reshape(-1, block_size)
max_x = x.abs().max(dim=1, keepdim=True).values.to(dtype=ctype)
x_nf4 = torch.argmin((q.view(-1,1,1) - (x / max_x).unsqueeze(0)).abs(), dim=0)
x_dequant = q[x_nf4] * max_x
x_dequant = x_dequant.to(x.dtype)
bpw = 4 + torch.tensor([], dtype=ctype).element_size() * 8 / block_size
log_mse = ((x - x_dequant)**2).mean().log10().item()
return x_dequant, bpw, log_mse
PSNR
いままではlatent上でmseを計算していましたが、元画像とデコードした画像のPSNRを計算してみました。シフトありで、二重量子化あるなしの比較をPSNRでしてみます。
二重量子化によってPSRNをあまり落とさずbpwを減らせているような感じがしますね。
def calc_psnr(images_1, images_2):
mse = ((np.array(images_1) - np.array(images_2)) ** 2).mean()
return 10 * np.log10(255 ** 2 / mse)
保存方法
これまでnbitとかやってきましたが、そもそもtorchとかnumpyはint8やらint16やら一部のbitにしか対応していません。中途半端なbit数のときは、ビットパッキングとやらをする必要があります。たとえば4bitなら二つの変数を一つにまとめて8bitにできます。
latents_int, c_scale, c_shift = dynamic_quantize(latents, 4, do_shift=True, block_size=32, point_five_shift=True, ctype=torch.float16)
latents_uint = latents_int + 8
latents_pack = latents_uint[:, :, :16] + latents_uint[:, :, 16:] * 16
latents_pack = latents_pack.to(torch.uint8)
latents_unpack = torch.cat([latents_pack % 16, latents_pack // 16], dim=2).to(torch.long) - 8
latents_dequant = static_dequantize(latents_unpack, c_scale, c_shift).view(latents.shape)
latents_dequant = latents_dequant.to(latents.dtype)
print("mse:", ((latents - latents_dequant)**2).mean().log10().item())
images_dequant = decode_latents(latents_dequant)
psnr = calc_psnr(images, images_dequant)
print("psnr:", psnr)
print("psnr:", psnr)
>mse: -3.4560818672180176
>psnr: 31.281642032888684
4bit+量子化定数が32bit/32block_sizeなので実質5bitですが、目を凝らしてもあんまり違いが判らないものがでてきました。16bitと比べても3分の1くらいにサイズを落とせそうですね。二重量子化でさらに削減できそうですが、めんどくさくなりそう。