見出し画像

BitNetLLMの罠(学習に失敗した話)

こないだ試した1bitllmが割と上手く行ってるようなのと、あまり日本語が下手なのでとりあえずファインチューニングでもするかと思ってやってみたらハマって数日無駄にしたという話。

BitNetは、よく知られているように推論と学習で動きを変えないといけない。

ところが1bitllmの実装では、そこいらへんが僕が前にやったBitLinearの実験で使ったコードとは微妙に違ったのでメモがてらご報告。

結論から言うと、一度でもoptimizer.step()すると勾配が爆発して死ぬ

それを確かめるために、便利なTRLを捨てて生実装を書いた。


from torch.utils.data import DataLoader,Dataset

import copy
class CompletionDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data=copy.deepcopy(self.data[idx])
        label=torch.LongTensor(data["input_ids"][1:]).to(device)
        data["input_ids"]=torch.LongTensor(data["input_ids"][:-1]).to(device)
        data["attention_mask"]=torch.LongTensor(data["attention_mask"][:-1]).to(device)
        return data,label

dataset = CompletionDataset(tokenized_dataset)   

import numpy as np

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001, alpha=0.99, eps=1e-08)

# grad_clipした方が爆発しないというClaude-3の助言を信じて入れてみたがダメ
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

train_dataloader = DataLoader(dataset, batch_size=16)
criterion = torch.nn.CrossEntropyLoss()

#model, optimizer, train_dataloader  = accelerator.prepare(model, optimizer, train_dataloader)

num_epochs=4
print(model)

for epoch in range(num_epochs):
    model.train()
    for batch,labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(**batch,labels=labels)
        print(outputs.logits[:10])
        loss = outputs.loss
        print(loss)
        if torch.isnan(loss):
            exit()
        loss.backward()
        optimizer.step()

笑えたのは、最初の一回は学習できる。
optimizer.step()をコメントアウトするとlossは少しずつだが減っていく。

しかし一度でもoptimizer.step()をするとembedingの段階で「-inf(-♾️)」と「inf(♾️)」が出まくって、lossがNaN(Not a Number)になってしまう。

BItnetCalualLMの本体はmodeling_bitnet.pyに入っていて、ここには論文と全く同じBItnetRMSNormが書かれていた。

class BitnetRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        BitnetRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

これは僕が自分でMNISTをBitNetに学習させたのに使ったのと同じコード。
違うのは、BitnetRMSNormの適用のさせ方で、1bitllmではこんな感じで適用されている。

ALL_LAYERNORM_LAYERS.append(BitnetRMSNorm)

そのせいで、BItLinearの実装はこうなっている

class BitLinear(nn.Linear):

    def __init__(self,
            *kargs,
            weight_bits=1,
            input_bits=8,
            **kwargs
        ):
        super(BitLinear, self).__init__(*kargs, **kwargs)
        #RMSNorm is placed outside BitLinear
        self.weight_bits = weight_bits
        self.input_bits = input_bits

    def forward(self, input):
        
        quant_input = input + (activation_quant(input, self.input_bits) - input).detach()
        quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach()

        out = nn.functional.linear(quant_input, quant_weight)
        if not self.bias is None:
            out += self.bias.view(1, -1).expand_as(out)

        return out

ちなみに僕が実験して上手く学習できた方のBItLinearはこう

class BitLinear(nn.Linear):
   def __init__(self,in_features,out_features,bias=False,flg_before_linear=True,input_bits=8,weight_bits=1,):
       super(BitLinear, self).__init__(in_features, out_features, bias)
       self.layernorm = nn.LayerNorm(in_features)
       self.RMSNorm = BitRMSNorm(in_features) #自分のRMSNormを持つ
       self.bits = input_bits

   def forward(self,x):
       w=self.weight
       x_norm = self.RMSNorm(x) #自分でRMSNormを適用する
       x_quant = x_norm + (activation_quant(x_norm)-x_norm).detach()
       w_quant = w+(weight_quant(w)-w).detach()
       y = F.linear(x_quant,w_quant)
       return y

だから多分モデルのパラメータ構造が根本的に違うはず。
forwardの中でやってる計算は、基本的にRMSNormに関係ないところはほぼ一緒に見える。

この二つが微妙に似て非なるものなので、元の1bitllmの学習にどのくらいの時間がかかったのかわからない状態では怖くて手を出せない。

ALL_LAYERNORM_LAYERS.appendという書き方はLlamaの元のコードがそうなっているのでそうなんだろう。

もしかして勾配爆発を恐れて学習率を下げすぎたか?
とここまで書いて思った。
でも学習率が高ければ勾配爆発しないというものでもない気がするが
どうなんだろう。

とりあえず副産物的にAccelerateの使い方をまとめたものをFree-AIのブログに上げておいた。

こっちのブログとあっちのブログの使い分けは、あっちは継之助でないとできないこと、かつ、ある程度ちゃんと成功した話やチュートリアルに近い話を乗っけることを目的としている。というのも、継之助のレンタル事業はFreeAIの仕事の一つだからね。継之助を借りてみたい人にとって役立つ話を載せるのは理にかなっている。

こっちのブログは基本的に愚痴、妄言、戯言、失敗談、しょうもない実験を載せる。もっと柔らかくてどうでもいい話がメインとなる。

継之助についてだが、いきなりA100 80GBが11枚も手元に来てしまったのでとりあえず動かさないと、という気持ちで発注したシャーシだったが、「まあメモリは256GBもあれば十分だろう」と軽自動車感覚で安易に発注したのが失敗だった。

よく考えると、継之助は80GBのVRAMを積んだGPUを8つ持ってるわけだからVRAMだけで80x8=640GB。それに対してメインメモリが256GBでは少なすぎるのである。

最初は実用上の問題がなかったので無視していたのだが、最近出てきたLISAなどの「メインメモリに盛大にオフロードするタイプのファインチューニング手法」が出てきたことで、「256GBじゃ全然足りないじゃん」ということがわかった。

そこでとりま少なくとも1.5TBくらいのメインメモリを買おうということになったのだが、これがなかなか難しい。

何が難しいかというと、継之助のシャーシは、RDIMMスロットが32スロットあって、今は16GBx16が刺さっていて256GBということになっている。
仮に16GBをもう16個買ったとしても、合計512GBにしかならない。

ということはまだ納品されて半年も経ってないのに16GBx16は全部外すことになる。

で、64GBx32(=2TB)にするか、96GBx16=(1.5TB)から始めるか、それとも128GBx16(=2TB)にするか悩むわけだ。

しかし、メモリ価格は容量と必ずしも比例しない。
例えば24GBのメモリは3万円だが、32GBのメモリは26000円だったりと逆転現象があったり、48GBは5万円で64GBは54千円とあまり変わらないことになる。

いろいろ悩んだ挙句、96GBで16枚、とりあえず買うことにした。これなら万が一足りなくてもまだスロットに余裕があるので買い増して2TBにできるし、最終的には3TBまで拡張できる。

何より、96GBは9万円くらいなのに対し、128GBは46万円もするというコストバランスの悪さだ。今のところはここがボリュームゾーンということで発注することにした。

そもそも70Bモデルは、32ビット(4バイト)で約280GBなので、256GBしかメモリがなかったら乗るわけがないのだ。オフロードができない。

ちなみに実際にどのくらいのメモリを使うのか、スワップを1.5TBにして調べたら、スワップが1TBくらいの時にちゃんと乗った(あくまでもVirtual Memoryに乗っただけだが)ので、1.5TBあればとりあえずは70BのLISAは動くだろうと思う。

最近H100のカードが市場に出回り始めてきているが、H200のスペックやいろんなメーカーが対抗策を出してきているところを見ると少し様子見した方がいいかもしれない。一枚も持ってないなら買った方がいいかもしれないけど、割高(大体550万円くらい。定価の二倍)だし、カードだけあってもシャーシがないとどうにもならず、シャーシそのものだって200万円以上する(1.5TBのメモリ単体でも200万円する)。ということは最低1000万は用意しないと入り口にすら立てないということになる。

今思い返すと、去年の9月のタイミングでA100買ったのは正解だったなあ。
まさかその日の夕方に海老根さんがもう10個買うとは思わなかったけど。

やっぱり研究には最低限必要なハードウェアというのがあって、西田先生もCRTが日本で三番目くらいに広島大学に来た時に大学に戻ったからCGの世界的権威になったんだしなあ。