見出し画像

BitNet&BitNet b158の実装③

はじめに

BitNetおよびBitNet b158の実装を続けていこうと思います。
ボリュームが大きくなってきたため、記事を分けることとしました。前回までの内容は以下をご参照ください。
2日連続での投稿となるので前後関係をお気をつけください。


3. BitNet b158

これまでに作成したBitLinearを修正していく形でBitNet b158用のBitLinear b158を作成していきます。
BitLinearとBitLinear b158の変更点は以下の2つです。

  1. 重みwの量子化手法の変更({-1, 0, 1}の3値化)

  2. 非線形関数の前の入力の[0, Qb]スケーリングの削除

それ以外はBitLinearと同等となるため、BitLinearを継承する形でBitLinear b158を作成します。

class BitLinear158b(BitLinear):

3-1. {-1, 0, 1}の3値化

To constrain the weights to -1, 0, or +1, we adopt an absmean quantization function. It first scales the weight matrix by its average absolute value, and then round each value to the nearest integer among {-1, 0, +1}:

論文参照『The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits』
量子化関数

BitLinearにおけるquantize_weights関数を修正していきます。
元々、BitLinearでは重みを中心化してからcustom_signで{-1, 1}に振り分けていました。BitLinearではabsmean quantizationという方法を採用しています。
ここでわかりづらいことに、式(3)のγ(gamma)はBitNetの論文中ではβ(beta)と名付けられていました。なので私のコードではBitNetに引き続きこれをβとして取り扱います。
まずはこのβを計算します。

# 式(3): betaの計算
beta = self.weight.abs().mean().clamp(min=self.epsilon)
スケールの計算

後にこの値でweightを割ることになるので、オーバーフロー防止のため極小値ε(epsilon)でクリッピングします。式(1)ではεを足していますが、公式FAQを正としてクリッピングを採用します。
次に、式(1), (2)の計算です。これも式の通り、①重みWをβで割る、②roundをとる、③クリッピングの実施のみです。
これによって、重みの各値は{-1, 0, +1}の中で最も近い整数に丸められます。

# 式(1),(2): 重みの量子化(-1, 0, 1)とクリップ
# 各値は{-1, 0, +1}の中で最も近い整数に丸められます。
weight_scaled = self.weight / beta
weight_trinarized = torch.round(weight_scaled)
weight_trinarized = torch.clamp(weight_trinarized, -1, 1)

最後に、STEの実装を行なってquantize_weightsの変更は完了です。STEの詳細については1-7.に記載しました。また、BitLinear b158において複数パターンを調べた記事もあるのでよければそちらもご確認ください。

# STE
weight_trinarized = (weight_trinarized - self.weight).detach() + self.weight

これによって全体像は以下の様になります。

class BitLinear158b(BitLinear):
 
    # 1. quantize_weightsを{-1, 1}の2値化から{-1, 0, 1}の3値化に修正
    def quantize_weights(self):
        # 式(3): betaの計算
        beta = self.weight.abs().mean().clamp(min=self.epsilon)

        # 式(1),(2): 重みの量子化(-1, 0, 1)とクリップ
        # 各値は{-1, 0, +1}の中で最も近い整数に丸められます。
        weight_scaled = self.weight / beta
        weight_trinarized = torch.round(weight_scaled)
        weight_trinarized = torch.clamp(weight_trinarized, -1, 1)

        # STE
        weight_trinarized = (weight_trinarized - self.weight).detach() + self.weight_scaled

        return weight_trinarized, beta

3-2. [0, Qb]スケーリングの削除

The quantization function for activations follows the same implementation in BitNet, except that we do not scale the activations before the non-linear functions to the range [0,Qb]. Instead, the activations are all scaled to [−Qb, Qb] per token to get rid of the zero-point quantization.

論文参照『The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits』

BitLinearで非線形関数の前では入力を[0, Qb]にスケーリングしていました。これを上記に記載がある通り、常に[−Qb, Qb]にスケーリングするようにabsmax_quantize関数を修正します。
元々、パターン①/②に分岐する書き方をしていました。

def absmax_quantize(self, x):
    if self.flg_before_linear:
        # パターン①: 通常は[-Qb, Qb]にスケール: 式(4), (5)を適用
        gamma = torch.abs(x).max().clamp(min=self.epsilon)
        x_scaled = x * self.Qb / gamma
        x_q = torch.round(x_scaled).clamp(-self.Qb, self.Qb - 1)
    else:
        # パターン②: Reluなどの非線形関数前の場合は[0, Qb]にスケール: 式(6)を適用
        # 論文中には記載はないですが、スケールが異なるためスケーリングの基準として使っているgammaもetaを反映した値にすべきだと考えます。
        eta = x.min()
        gamma = torch.abs(x - eta).max().clamp(min=self.epsilon)
        x_scaled = (x - eta) * self.Qb / gamma
        x_q = torch.round(x_scaled).clamp(0, self.Qb - 1)
    # STE
    x_q = (x_q - x_scaled).detach() + x_scaled
    return x_q, gamma

これをパターン①だけとします。また、コードの書き方をquantize_weights関数に寄せました。

# 2. BitLinear b158では、[0, Qb]のスケーリングは行わないません。
def absmax_quantize(self, x):
    # スケールgammaの計算(absmax quantization)
    gamma = torch.abs(x).max().clamp(min=self.epsilon)

    # 重みの量子化とクリップ
    x_scaled = x * self.Qb / gamma
    x_q = torch.round(x_scaled)
    x_q = torch.clamp(x_q, -self.Qb, self.Qb - 1)
    
    # STE
    x_q = (x_q - x_scaled).detach() + x_scaled
    
    return x_q, gamma

さらに、不要となったflg_before_linearをモジュールから削除します。

class BitLinear158b(BitLinear):
    def __init__(self, in_features, out_features, rms_norm_eps=1e-6, bias=True, bits=8):
        super().__init__(in_features, out_features, bias, rms_norm_eps, bits)
        # 2. BitLinear b158では、[0, Qb]のスケーリングは行わないため、flg_before_linearは使用しません。
        del self.flg_before_linear

これによって、最終的なコードは以下の様になりました。

class BitLinear158b(BitLinear):
    def __init__(self, in_features, out_features, bias=True, bits=8):
        super().__init__(in_features, out_features, bias, bits)
        # 2. BitLinear b158では、[0, Qb]のスケーリングは行わないため、flg_before_linearは使用しません。
        del self.flg_before_linear
        
    # 1. quantize_weightsを{-1, 1}の2値化から{-1, 0, 1}の3値化に修正
    def quantize_weights(self):
        # 式(3): betaの計算
        beta = self.weight.abs().mean().clamp(min=self.epsilon)

        # 式(1),(2): 重みの量子化(-1, 0, 1)とクリップ
        # 各値は{-1, 0, +1}の中で最も近い整数に丸められます。
        weight_scaled = self.weight / beta
        weight_trinarized = torch.round(weight_scaled)
        weight_trinarized = torch.clamp(weight_trinarized, -1, 1)

        # STE
        weight_trinarized = (weight_trinarized - self.weight).detach() + self.weight_scaled

        return weight_trinarized, beta
    
    # 2. BitLinear b158では、[0, Qb]のスケーリングは行わないません。
    def absmax_quantize(self, x):
        # スケールgammaの計算(absmax quantization)
        gamma = torch.abs(x).max().clamp(min=self.epsilon)

        # 重みの量子化とクリップ
        x_scaled = x * self.Qb / gamma
        x_q = torch.round(x_scaled)
        x_q = torch.clamp(x_q, -self.Qb, self.Qb - 1)
        
        # STE
        x_q = (x_q - x_scaled).detach() + x_scaled
        
        return x_q, gamma

これまで作成したコードはGitHubに置いてありますので、よければご参照・ご活用ください。

4. BitNet b158の検証

今後は、BitNetの場合と同様にBitLlamaをBitLinear b158で作り、Lossがある程度下がるかの検証を行います。
読みやすさの関係で別ページを作成する予定です。書けたらこちらに続きのリンクを貼りたいと思います。また、Xにて更新の告知をいたします。
もうしばらくお待ちいただけると幸いです。


参照


いいなと思ったら応援しよう!