BitNet&BitNet b158の実装③
はじめに
BitNetおよびBitNet b158の実装を続けていこうと思います。
ボリュームが大きくなってきたため、記事を分けることとしました。前回までの内容は以下をご参照ください。
2日連続での投稿となるので前後関係をお気をつけください。
3. BitNet b158
これまでに作成したBitLinearを修正していく形でBitNet b158用のBitLinear b158を作成していきます。
BitLinearとBitLinear b158の変更点は以下の2つです。
重みwの量子化手法の変更({-1, 0, 1}の3値化)
非線形関数の前の入力の[0, Qb]スケーリングの削除
それ以外はBitLinearと同等となるため、BitLinearを継承する形でBitLinear b158を作成します。
class BitLinear158b(BitLinear):
3-1. {-1, 0, 1}の3値化
BitLinearにおけるquantize_weights関数を修正していきます。
元々、BitLinearでは重みを中心化してからcustom_signで{-1, 1}に振り分けていました。BitLinearではabsmean quantizationという方法を採用しています。
ここでわかりづらいことに、式(3)のγ(gamma)はBitNetの論文中ではβ(beta)と名付けられていました。なので私のコードではBitNetに引き続きこれをβとして取り扱います。
まずはこのβを計算します。
# 式(3): betaの計算
beta = self.weight.abs().mean().clamp(min=self.epsilon)
後にこの値でweightを割ることになるので、オーバーフロー防止のため極小値ε(epsilon)でクリッピングします。式(1)ではεを足していますが、公式FAQを正としてクリッピングを採用します。
次に、式(1), (2)の計算です。これも式の通り、①重みWをβで割る、②roundをとる、③クリッピングの実施のみです。
これによって、重みの各値は{-1, 0, +1}の中で最も近い整数に丸められます。
# 式(1),(2): 重みの量子化(-1, 0, 1)とクリップ
# 各値は{-1, 0, +1}の中で最も近い整数に丸められます。
weight_scaled = self.weight / beta
weight_trinarized = torch.round(weight_scaled)
weight_trinarized = torch.clamp(weight_trinarized, -1, 1)
最後に、STEの実装を行なってquantize_weightsの変更は完了です。STEの詳細については1-7.に記載しました。また、BitLinear b158において複数パターンを調べた記事もあるのでよければそちらもご確認ください。
# STE
weight_trinarized = (weight_trinarized - self.weight).detach() + self.weight
これによって全体像は以下の様になります。
class BitLinear158b(BitLinear):
# 1. quantize_weightsを{-1, 1}の2値化から{-1, 0, 1}の3値化に修正
def quantize_weights(self):
# 式(3): betaの計算
beta = self.weight.abs().mean().clamp(min=self.epsilon)
# 式(1),(2): 重みの量子化(-1, 0, 1)とクリップ
# 各値は{-1, 0, +1}の中で最も近い整数に丸められます。
weight_scaled = self.weight / beta
weight_trinarized = torch.round(weight_scaled)
weight_trinarized = torch.clamp(weight_trinarized, -1, 1)
# STE
weight_trinarized = (weight_trinarized - self.weight).detach() + self.weight_scaled
return weight_trinarized, beta
3-2. [0, Qb]スケーリングの削除
BitLinearで非線形関数の前では入力を[0, Qb]にスケーリングしていました。これを上記に記載がある通り、常に[−Qb, Qb]にスケーリングするようにabsmax_quantize関数を修正します。
元々、パターン①/②に分岐する書き方をしていました。
def absmax_quantize(self, x):
if self.flg_before_linear:
# パターン①: 通常は[-Qb, Qb]にスケール: 式(4), (5)を適用
gamma = torch.abs(x).max().clamp(min=self.epsilon)
x_scaled = x * self.Qb / gamma
x_q = torch.round(x_scaled).clamp(-self.Qb, self.Qb - 1)
else:
# パターン②: Reluなどの非線形関数前の場合は[0, Qb]にスケール: 式(6)を適用
# 論文中には記載はないですが、スケールが異なるためスケーリングの基準として使っているgammaもetaを反映した値にすべきだと考えます。
eta = x.min()
gamma = torch.abs(x - eta).max().clamp(min=self.epsilon)
x_scaled = (x - eta) * self.Qb / gamma
x_q = torch.round(x_scaled).clamp(0, self.Qb - 1)
# STE
x_q = (x_q - x_scaled).detach() + x_scaled
return x_q, gamma
これをパターン①だけとします。また、コードの書き方をquantize_weights関数に寄せました。
# 2. BitLinear b158では、[0, Qb]のスケーリングは行わないません。
def absmax_quantize(self, x):
# スケールgammaの計算(absmax quantization)
gamma = torch.abs(x).max().clamp(min=self.epsilon)
# 重みの量子化とクリップ
x_scaled = x * self.Qb / gamma
x_q = torch.round(x_scaled)
x_q = torch.clamp(x_q, -self.Qb, self.Qb - 1)
# STE
x_q = (x_q - x_scaled).detach() + x_scaled
return x_q, gamma
さらに、不要となったflg_before_linearをモジュールから削除します。
class BitLinear158b(BitLinear):
def __init__(self, in_features, out_features, rms_norm_eps=1e-6, bias=True, bits=8):
super().__init__(in_features, out_features, bias, rms_norm_eps, bits)
# 2. BitLinear b158では、[0, Qb]のスケーリングは行わないため、flg_before_linearは使用しません。
del self.flg_before_linear
これによって、最終的なコードは以下の様になりました。
class BitLinear158b(BitLinear):
def __init__(self, in_features, out_features, bias=True, bits=8):
super().__init__(in_features, out_features, bias, bits)
# 2. BitLinear b158では、[0, Qb]のスケーリングは行わないため、flg_before_linearは使用しません。
del self.flg_before_linear
# 1. quantize_weightsを{-1, 1}の2値化から{-1, 0, 1}の3値化に修正
def quantize_weights(self):
# 式(3): betaの計算
beta = self.weight.abs().mean().clamp(min=self.epsilon)
# 式(1),(2): 重みの量子化(-1, 0, 1)とクリップ
# 各値は{-1, 0, +1}の中で最も近い整数に丸められます。
weight_scaled = self.weight / beta
weight_trinarized = torch.round(weight_scaled)
weight_trinarized = torch.clamp(weight_trinarized, -1, 1)
# STE
weight_trinarized = (weight_trinarized - self.weight).detach() + self.weight_scaled
return weight_trinarized, beta
# 2. BitLinear b158では、[0, Qb]のスケーリングは行わないません。
def absmax_quantize(self, x):
# スケールgammaの計算(absmax quantization)
gamma = torch.abs(x).max().clamp(min=self.epsilon)
# 重みの量子化とクリップ
x_scaled = x * self.Qb / gamma
x_q = torch.round(x_scaled)
x_q = torch.clamp(x_q, -self.Qb, self.Qb - 1)
# STE
x_q = (x_q - x_scaled).detach() + x_scaled
return x_q, gamma
これまで作成したコードはGitHubに置いてありますので、よければご参照・ご活用ください。
4. BitNet b158の検証
今後は、BitNetの場合と同様にBitLlamaをBitLinear b158で作り、Lossがある程度下がるかの検証を行います。
読みやすさの関係で別ページを作成する予定です。書けたらこちらに続きのリンクを貼りたいと思います。また、Xにて更新の告知をいたします。
もうしばらくお待ちいただけると幸いです。