BitNetでMNISTを学習させて見えてきた性質
かれこれ一ヶ月弱くらいBitNetと格闘している。BitNetは、Microsoftが発明したと主張している1-Bit(1.58ビットとも言われる)量子化ニューラルネットワークのことだ。
僕はその辺に落ちてるコードを使って最初の最初はlossが2くらいまで下がったのだが、そもそもLLMはlossが1を切らないと実用性がない。
それ以降は6とか良くて5とかなのでたまたま最初に試したのがうまく行ったようだ。
しかしいつまで経っても良くならないのでBitNetの性質を根本的に見直す必要があるのでは?と思い、初心に帰って論理回路を学習させようとした。
BitNetのコードベースははちさんのコードと、Microsoftの公式な論文の実装を併用した。
まず試したのはこのようなコード
from bitnet import *
import torch
from torch import optim
import torch.nn as nn
input = torch.tensor([ [0.,0.],[0.,1.],[1.,0.],[1.,1.] ])
output = torch.tensor([ [0.] ,[0.],[0.],[1.],])
#layer = BitLinear(2,1)
model = nn.Sequential(
#nn.Linear(2,1),
BitLinear(2,2),
nn.ReLU(),
BitLinear(2,1),
nn.Sigmoid()
)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
loss_fn = nn.BCELoss()
for i in range(90000):
#for x,t in zip(input,output):
optimizer.zero_grad()
y=model.forward(input)
loss = loss_fn(y,output)
print(loss)
loss.backward()
optimizer.step()
y=model(input)
print(y)
非常にシンプルな2入力1出力の三層パーセプトロンだ。
これを学習させようと何度も頑張ったが、全く学習できなかった。
昔の人工知能研究者は、論理回路が学習できないならもっと複雑なことは絶対に学習できないに違いないと考えて早々に諦めていた。おそらくBitNetが今日まで有効性を認められてこなかったのはこう言う性質にも原因があるのだろう。
ただ、僕としては「そもそも3状態で論理回路の状態を再現するのは運ゲーすぎない?」と言う疑問がある。初期状態はランダムだから、学習によって見つけると言うよりも運(初期状態)の要素が強い。
そこでもう少し複雑な問題に適用したらどうなるか実験してみた。
MNISTだ。
MNISTのコードを書くのは面倒だったのでClaude-3に書かせたものを、レイヤーだけBitLinearに変える。ちなみにコメントもClaude-3が書いたもの。便利すぎる。
from bitnet import *
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# MNISTデータセットのダウンロードと前処理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
# ニューラルネットワークの定義
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 =BitLinearOriginal(784, 128)
self.fc2 =BitLinearOriginal(128, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
model = Net()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# オプティマイザのパラメータを記録するためのリストを初期化
param_history = []
def calculate_accuracy(output, target):
_, predicted = torch.max(output, 1)
correct = (predicted == target).sum().item()
accuracy = correct / target.size(0)
return accuracy
# 学習ループ
for epoch in range(50):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# オプティマイザのパラメータを記録
param_history.append([p.clone().detach().numpy() for p in model.parameters()])
optimizer.step()
acc = calculate_accuracy(output, target)
print(f"loss:{loss} acc:{acc}")
if acc>0.99:
break
これで学習させたところ、確かに学習できてしまう。
ただ、あまりに一瞬で学習されてしまうので、やはりなんというかどこか嘘くさい気がする。
仕方がないので学習の初期状態を詳細に出力させて、ちゃんと「正答率が低い状態」から「正答率が高い状態」へ学習していることを確認した。
$ python mnist.py
loss:2.41108 acc:0.04688
loss:2.18548 acc:0.23438
loss:2.00562 acc:0.42188
loss:1.94573 acc:0.45312
loss:1.82240 acc:0.48438
loss:1.76757 acc:0.56250
loss:1.55644 acc:0.59375
loss:1.39945 acc:0.81250
loss:1.53707 acc:0.64062
loss:1.37297 acc:0.70312
loss:1.30705 acc:0.76562
loss:1.29431 acc:0.78125
loss:1.17070 acc:0.81250
loss:1.20207 acc:0.76562
loss:1.22176 acc:0.70312
loss:1.17043 acc:0.76562
loss:1.29147 acc:0.60938
loss:1.13530 acc:0.79688
loss:1.13138 acc:0.70312
loss:1.10515 acc:0.70312
loss:1.12286 acc:0.70312
loss:1.03146 acc:0.81250
loss:0.90347 acc:0.84375
loss:1.00460 acc:0.81250
loss:0.99084 acc:0.78125
loss:0.82783 acc:0.89062
loss:1.06856 acc:0.75000
loss:0.89284 acc:0.81250
loss:1.14347 acc:0.70312
loss:0.86533 acc:0.82812
loss:0.76226 acc:0.89062
loss:0.98805 acc:0.76562
loss:0.77720 acc:0.87500
loss:0.91726 acc:0.76562
loss:0.84186 acc:0.79688
loss:0.89622 acc:0.79688
loss:0.77819 acc:0.79688
loss:0.64624 acc:0.96875
loss:0.85572 acc:0.81250
loss:0.81650 acc:0.82812
loss:0.70959 acc:0.92188
loss:0.79852 acc:0.79688
loss:0.58357 acc:0.92188
loss:0.69148 acc:0.87500
loss:0.68034 acc:0.85938
loss:0.71641 acc:0.90625
loss:0.52498 acc:0.90625
loss:0.62418 acc:0.90625
loss:0.57807 acc:0.84375
loss:0.60706 acc:0.87500
loss:0.49764 acc:0.92188
loss:0.55832 acc:0.92188
loss:0.51869 acc:0.89062
loss:0.81124 acc:0.75000
loss:0.57983 acc:0.87500
loss:0.65688 acc:0.87500
loss:0.65231 acc:0.82812
loss:0.58111 acc:0.90625
loss:0.53314 acc:0.89062
loss:0.59246 acc:0.87500
loss:0.56978 acc:0.87500
loss:0.54524 acc:0.89062
loss:0.43568 acc:0.90625
loss:0.67516 acc:0.79688
loss:0.50416 acc:0.90625
loss:0.47741 acc:0.89062
loss:0.53961 acc:0.87500
loss:0.54430 acc:0.90625
loss:0.59080 acc:0.79688
loss:0.39868 acc:0.93750
loss:0.45935 acc:0.92188
loss:0.46753 acc:0.90625
loss:0.55835 acc:0.89062
loss:0.45517 acc:0.90625
loss:0.51455 acc:0.87500
loss:0.45387 acc:0.90625
loss:0.55507 acc:0.82812
loss:0.36060 acc:0.95312
loss:0.34529 acc:0.93750
loss:0.44868 acc:0.87500
loss:0.51357 acc:0.85938
loss:0.51937 acc:0.87500
loss:0.43396 acc:0.89062
loss:0.40336 acc:0.90625
loss:0.39574 acc:0.92188
loss:0.31379 acc:0.95312
loss:0.29613 acc:0.96875
loss:0.44163 acc:0.89062
loss:0.54892 acc:0.85938
loss:0.34996 acc:0.90625
loss:0.43126 acc:0.90625
loss:0.47235 acc:0.85938
loss:0.38899 acc:0.93750
loss:0.38725 acc:0.90625
loss:0.51207 acc:0.82812
loss:0.58700 acc:0.78125
loss:0.40778 acc:0.92188
loss:0.51206 acc:0.87500
loss:0.46178 acc:0.89062
loss:0.31953 acc:0.96875
loss:0.38844 acc:0.92188
loss:0.36591 acc:0.90625
loss:0.33379 acc:0.90625
loss:0.30978 acc:0.93750
loss:0.47275 acc:0.85938
loss:0.44161 acc:0.89062
loss:0.40364 acc:0.90625
loss:0.24938 acc:0.96875
loss:0.24971 acc:0.98438
loss:0.41120 acc:0.90625
loss:0.51807 acc:0.84375
loss:0.49253 acc:0.89062
loss:0.46641 acc:0.92188
loss:0.44445 acc:0.87500
loss:0.40190 acc:0.90625
loss:0.42140 acc:0.90625
loss:0.28105 acc:0.93750
loss:0.42497 acc:0.87500
loss:0.58897 acc:0.81250
loss:0.37514 acc:0.89062
loss:0.35815 acc:0.92188
loss:0.39548 acc:0.89062
loss:0.35628 acc:0.96875
loss:0.38870 acc:0.90625
loss:0.35709 acc:0.93750
loss:0.40669 acc:0.87500
loss:0.20514 acc:1.00000
loss:0.30028 acc:0.92188
loss:0.26510 acc:0.96875
loss:0.53790 acc:0.76562
loss:0.33812 acc:0.92188
loss:0.22588 acc:0.95312
loss:0.40264 acc:0.87500
loss:0.21730 acc:0.98438
loss:0.44845 acc:0.89062
loss:0.31983 acc:0.92188
loss:0.45905 acc:0.87500
loss:0.32539 acc:0.90625
loss:0.30357 acc:0.89062
loss:0.36034 acc:0.90625
loss:0.31498 acc:0.95312
loss:0.23514 acc:0.93750
loss:0.42586 acc:0.90625
loss:0.31543 acc:0.93750
loss:0.19443 acc:0.98438
loss:0.37593 acc:0.90625
loss:0.26972 acc:0.96875
loss:0.23483 acc:0.96875
loss:0.25687 acc:0.92188
loss:0.29415 acc:0.90625
loss:0.34700 acc:0.85938
loss:0.40503 acc:0.89062
loss:0.46366 acc:0.85938
loss:0.32815 acc:0.92188
loss:0.34704 acc:0.90625
loss:0.18917 acc:0.96875
loss:0.37811 acc:0.87500
loss:0.34091 acc:0.87500
loss:0.31946 acc:0.89062
loss:0.23625 acc:0.95312
loss:0.19402 acc:0.96875
loss:0.35543 acc:0.89062
loss:0.18609 acc:0.98438
loss:0.34984 acc:0.90625
loss:0.29880 acc:0.95312
loss:0.23121 acc:0.93750
loss:0.27589 acc:0.93750
loss:0.16535 acc:0.98438
loss:0.27656 acc:0.95312
loss:0.27613 acc:0.90625
loss:0.17787 acc:1.00000
loss:0.22559 acc:0.96875
loss:0.18576 acc:0.95312
loss:0.20565 acc:0.96875
loss:0.33294 acc:0.92188
loss:0.37279 acc:0.89062
loss:0.23660 acc:0.92188
loss:0.19797 acc:0.96875
loss:0.51400 acc:0.87500
loss:0.23547 acc:0.92188
loss:0.32005 acc:0.89062
loss:0.38082 acc:0.87500
loss:0.37463 acc:0.85938
loss:0.32102 acc:0.90625
loss:0.30536 acc:0.89062
loss:0.19124 acc:0.96875
loss:0.38712 acc:0.90625
loss:0.24493 acc:0.92188
loss:0.25722 acc:0.93750
loss:0.37526 acc:0.90625
loss:0.24581 acc:0.93750
loss:0.18983 acc:0.95312
loss:0.19335 acc:0.93750
loss:0.26011 acc:0.92188
loss:0.16785 acc:0.96875
loss:0.31089 acc:0.92188
loss:0.41377 acc:0.87500
loss:0.31412 acc:0.90625
loss:0.18055 acc:0.95312
loss:0.21896 acc:0.95312
loss:0.22426 acc:0.93750
loss:0.17735 acc:0.96875
loss:0.12969 acc:0.98438
loss:0.25509 acc:0.95312
loss:0.36201 acc:0.85938
loss:0.30848 acc:0.90625
loss:0.37621 acc:0.89062
loss:0.45327 acc:0.84375
loss:0.29644 acc:0.90625
loss:0.26313 acc:0.92188
loss:0.44245 acc:0.89062
loss:0.27033 acc:0.90625
loss:0.29708 acc:0.89062
loss:0.30118 acc:0.87500
loss:0.27827 acc:0.93750
loss:0.21405 acc:0.95312
loss:0.15615 acc:0.96875
loss:0.20493 acc:0.95312
loss:0.39926 acc:0.85938
loss:0.33024 acc:0.90625
loss:0.33339 acc:0.90625
loss:0.35240 acc:0.92188
loss:0.28695 acc:0.93750
loss:0.24316 acc:0.93750
loss:0.22253 acc:0.93750
loss:0.21479 acc:0.95312
loss:0.29799 acc:0.95312
loss:0.33985 acc:0.87500
loss:0.27656 acc:0.95312
loss:0.34378 acc:0.90625
loss:0.20759 acc:0.93750
loss:0.32314 acc:0.93750
loss:0.30703 acc:0.89062
loss:0.23211 acc:0.95312
loss:0.18815 acc:0.95312
loss:0.36385 acc:0.90625
loss:0.25207 acc:0.90625
loss:0.26671 acc:0.92188
loss:0.36773 acc:0.87500
loss:0.28102 acc:0.90625
loss:0.30742 acc:0.92188
loss:0.29756 acc:0.90625
loss:0.39121 acc:0.85938
loss:0.30177 acc:0.92188
loss:0.32003 acc:0.89062
loss:0.33404 acc:0.89062
loss:0.14668 acc:0.98438
loss:0.22110 acc:0.96875
loss:0.25015 acc:0.93750
loss:0.29098 acc:0.93750
loss:0.19943 acc:0.92188
loss:0.23941 acc:0.92188
loss:0.15813 acc:0.96875
loss:0.32510 acc:0.90625
loss:0.24515 acc:0.92188
loss:0.13650 acc:0.96875
loss:0.34347 acc:0.90625
loss:0.37198 acc:0.89062
loss:0.19716 acc:0.93750
loss:0.20936 acc:0.95312
loss:0.22594 acc:0.92188
loss:0.31214 acc:0.84375
loss:0.22110 acc:0.95312
loss:0.16838 acc:0.95312
loss:0.17118 acc:0.98438
loss:0.27213 acc:0.93750
loss:0.36451 acc:0.87500
loss:0.13542 acc:0.96875
loss:0.11854 acc:0.98438
loss:0.37582 acc:0.87500
loss:0.20354 acc:0.95312
loss:0.18940 acc:0.93750
loss:0.25606 acc:0.90625
loss:0.24460 acc:0.90625
loss:0.20000 acc:0.95312
loss:0.17487 acc:0.93750
loss:0.18902 acc:0.93750
loss:0.18991 acc:0.96875
loss:0.20206 acc:0.93750
loss:0.22814 acc:0.90625
loss:0.26283 acc:0.92188
loss:0.09305 acc:0.98438
loss:0.21910 acc:0.93750
loss:0.22599 acc:0.95312
loss:0.31945 acc:0.90625
loss:0.32415 acc:0.93750
loss:0.46490 acc:0.87500
loss:0.32415 acc:0.89062
loss:0.10184 acc:1.00000
loss:0.21687 acc:0.95312
loss:0.19604 acc:0.96875
loss:0.17721 acc:0.96875
loss:0.32351 acc:0.89062
loss:0.12123 acc:0.98438
loss:0.23676 acc:0.96875
loss:0.25150 acc:0.93750
loss:0.43190 acc:0.92188
loss:0.20894 acc:0.95312
loss:0.09201 acc:0.98438
loss:0.17339 acc:0.96875
loss:0.32940 acc:0.87500
loss:0.34605 acc:0.89062
loss:0.27641 acc:0.92188
loss:0.27213 acc:0.93750
loss:0.13855 acc:0.96875
loss:0.29624 acc:0.87500
loss:0.20382 acc:0.89062
loss:0.18497 acc:0.96875
loss:0.18618 acc:0.95312
loss:0.16166 acc:0.96875
loss:0.20708 acc:0.93750
loss:0.17241 acc:0.96875
loss:0.20468 acc:0.93750
loss:0.22756 acc:0.93750
loss:0.16883 acc:0.96875
loss:0.22368 acc:0.93750
loss:0.20490 acc:0.95312
loss:0.25694 acc:0.89062
loss:0.15531 acc:0.96875
loss:0.30559 acc:0.89062
loss:0.22385 acc:0.95312
loss:0.18270 acc:0.95312
loss:0.24967 acc:0.92188
loss:0.10997 acc:0.98438
loss:0.25074 acc:0.92188
loss:0.21005 acc:0.93750
loss:0.13791 acc:0.96875
loss:0.35392 acc:0.92188
loss:0.24487 acc:0.93750
loss:0.23756 acc:0.95312
loss:0.26027 acc:0.89062
loss:0.28201 acc:0.92188
loss:0.26195 acc:0.89062
loss:0.26110 acc:0.90625
loss:0.18067 acc:0.92188
loss:0.23617 acc:0.92188
loss:0.25224 acc:0.93750
loss:0.22404 acc:0.93750
loss:0.26276 acc:0.90625
loss:0.16056 acc:0.93750
loss:0.30658 acc:0.92188
loss:0.21245 acc:0.92188
loss:0.13786 acc:0.96875
loss:0.28125 acc:0.87500
loss:0.18568 acc:0.95312
loss:0.15106 acc:0.95312
loss:0.19052 acc:0.96875
loss:0.31077 acc:0.89062
loss:0.15633 acc:0.98438
loss:0.30772 acc:0.93750
loss:0.25433 acc:0.90625
loss:0.23351 acc:0.89062
loss:0.21224 acc:0.96875
loss:0.14788 acc:0.96875
loss:0.20176 acc:0.95312
loss:0.16864 acc:0.95312
loss:0.14577 acc:0.95312
loss:0.25367 acc:0.92188
loss:0.15585 acc:0.96875
loss:0.10192 acc:0.98438
loss:0.29795 acc:0.93750
loss:0.15300 acc:0.96875
loss:0.20459 acc:0.92188
loss:0.40997 acc:0.89062
loss:0.27964 acc:0.92188
loss:0.20234 acc:0.96875
loss:0.33700 acc:0.90625
loss:0.31243 acc:0.89062
loss:0.19571 acc:0.93750
loss:0.32722 acc:0.92188
loss:0.20196 acc:0.95312
loss:0.15080 acc:0.95312
loss:0.10359 acc:0.96875
loss:0.31100 acc:0.92188
loss:0.23306 acc:0.95312
loss:0.22770 acc:0.96875
loss:0.30542 acc:0.93750
loss:0.15798 acc:0.93750
loss:0.10705 acc:0.96875
loss:0.23852 acc:0.90625
loss:0.20319 acc:0.95312
loss:0.21662 acc:0.93750
loss:0.22358 acc:0.93750
loss:0.11667 acc:0.98438
loss:0.19136 acc:0.95312
loss:0.13301 acc:0.96875
loss:0.34246 acc:0.90625
loss:0.19936 acc:0.95312
loss:0.27872 acc:0.90625
loss:0.28035 acc:0.87500
loss:0.13166 acc:0.96875
loss:0.15417 acc:0.96875
loss:0.20961 acc:0.93750
loss:0.27052 acc:0.90625
loss:0.16499 acc:0.95312
loss:0.12835 acc:0.95312
loss:0.21926 acc:0.92188
loss:0.26808 acc:0.93750
loss:0.21866 acc:0.92188
loss:0.25416 acc:0.93750
loss:0.20989 acc:0.95312
loss:0.12020 acc:0.98438
loss:0.13833 acc:0.95312
loss:0.16170 acc:0.95312
loss:0.36952 acc:0.89062
loss:0.13549 acc:0.95312
loss:0.14136 acc:0.95312
loss:0.28028 acc:0.92188
loss:0.13040 acc:0.98438
loss:0.16315 acc:0.95312
loss:0.16745 acc:0.98438
loss:0.33258 acc:0.93750
loss:0.17818 acc:0.95312
loss:0.15952 acc:0.98438
loss:0.25364 acc:0.92188
loss:0.30436 acc:0.92188
loss:0.18246 acc:0.95312
loss:0.24509 acc:0.93750
loss:0.17627 acc:0.95312
loss:0.21730 acc:0.93750
loss:0.25020 acc:0.93750
loss:0.31416 acc:0.89062
loss:0.11529 acc:0.98438
loss:0.17905 acc:0.92188
loss:0.19555 acc:0.95312
loss:0.16704 acc:0.95312
loss:0.19937 acc:0.93750
loss:0.19374 acc:0.92188
loss:0.13501 acc:0.96875
loss:0.15464 acc:0.95312
loss:0.14569 acc:0.96875
loss:0.24573 acc:0.87500
loss:0.32774 acc:0.87500
loss:0.12643 acc:0.96875
loss:0.34496 acc:0.92188
loss:0.20144 acc:0.95312
loss:0.17076 acc:0.95312
loss:0.24230 acc:0.95312
loss:0.23473 acc:0.95312
loss:0.30967 acc:0.92188
loss:0.15197 acc:0.98438
loss:0.12085 acc:0.96875
loss:0.10847 acc:0.96875
loss:0.30175 acc:0.93750
loss:0.18536 acc:0.96875
loss:0.12876 acc:0.96875
loss:0.21572 acc:0.93750
loss:0.18642 acc:0.95312
loss:0.15870 acc:0.95312
loss:0.18767 acc:0.95312
loss:0.14790 acc:0.96875
loss:0.22414 acc:0.92188
loss:0.16376 acc:0.95312
loss:0.28045 acc:0.92188
loss:0.21378 acc:0.95312
loss:0.15277 acc:0.96875
loss:0.15975 acc:0.96875
loss:0.10392 acc:0.98438
loss:0.20615 acc:0.93750
loss:0.12076 acc:0.96875
loss:0.19137 acc:0.93750
loss:0.13918 acc:0.98438
loss:0.22372 acc:0.95312
loss:0.12778 acc:1.00000
loss:0.10982 acc:0.98438
loss:0.15405 acc:0.92188
loss:0.26379 acc:0.92188
loss:0.14794 acc:0.93750
loss:0.18582 acc:0.95312
loss:0.22060 acc:0.96875
loss:0.19452 acc:0.93750
loss:0.24615 acc:0.90625
loss:0.28249 acc:0.93750
loss:0.09747 acc:0.98438
loss:0.13903 acc:0.96875
loss:0.19606 acc:0.92188
loss:0.06221 acc:0.98438
loss:0.24209 acc:0.92188
loss:0.14922 acc:0.95312
loss:0.18266 acc:0.95312
loss:0.18010 acc:0.93750
loss:0.18483 acc:0.96875
loss:0.11312 acc:0.98438
loss:0.13887 acc:0.95312
loss:0.31176 acc:0.89062
loss:0.18163 acc:0.93750
loss:0.08486 acc:0.98438
loss:0.16347 acc:0.93750
loss:0.05822 acc:1.00000
loss:0.29235 acc:0.92188
loss:0.18129 acc:0.96875
loss:0.45740 acc:0.92188
loss:0.18626 acc:0.96875
loss:0.23226 acc:0.90625
loss:0.26856 acc:0.93750
loss:0.27806 acc:0.90625
loss:0.38523 acc:0.89062
loss:0.32276 acc:0.87500
loss:0.20932 acc:0.92188
loss:0.28150 acc:0.85938
loss:0.33946 acc:0.92188
loss:0.18500 acc:0.93750
loss:0.19951 acc:0.90625
loss:0.08960 acc:0.98438
loss:0.21694 acc:0.93750
loss:0.26588 acc:0.90625
loss:0.09444 acc:0.96875
loss:0.19150 acc:0.93750
loss:0.16699 acc:0.98438
loss:0.22975 acc:0.95312
loss:0.14448 acc:0.96875
loss:0.34983 acc:0.90625
loss:0.19037 acc:0.95312
loss:0.14290 acc:0.95312
loss:0.17222 acc:0.95312
loss:0.19792 acc:0.95312
loss:0.15288 acc:0.95312
loss:0.11856 acc:0.98438
loss:0.17556 acc:0.95312
loss:0.17508 acc:0.93750
loss:0.26778 acc:0.89062
loss:0.24382 acc:0.92188
loss:0.16512 acc:0.96875
loss:0.12264 acc:0.96875
loss:0.23912 acc:0.95312
loss:0.15832 acc:0.96875
loss:0.30289 acc:0.92188
loss:0.36022 acc:0.89062
loss:0.06129 acc:1.00000
loss:0.08585 acc:1.00000
loss:0.12552 acc:0.98438
loss:0.11478 acc:0.95312
loss:0.16231 acc:0.96875
loss:0.28971 acc:0.92188
loss:0.17992 acc:0.96875
loss:0.15360 acc:0.96875
loss:0.14996 acc:0.96875
loss:0.12951 acc:0.96875
loss:0.14277 acc:0.95312
loss:0.17858 acc:0.95312
loss:0.18334 acc:0.95312
loss:0.12234 acc:0.96875
loss:0.11103 acc:0.96875
loss:0.17434 acc:0.95312
loss:0.15722 acc:0.93750
loss:0.10310 acc:0.96875
loss:0.08445 acc:0.98438
loss:0.14952 acc:0.93750
loss:0.26554 acc:0.92188
loss:0.18645 acc:0.96875
loss:0.21273 acc:0.92188
loss:0.07908 acc:0.96875
loss:0.14561 acc:0.96875
loss:0.09148 acc:1.00000
loss:0.12898 acc:0.98438
loss:0.20631 acc:0.96875
loss:0.19616 acc:0.93750
loss:0.11032 acc:0.96875
loss:0.12006 acc:0.98438
loss:0.26047 acc:0.90625
loss:0.14672 acc:0.95312
loss:0.20060 acc:0.95312
loss:0.13894 acc:0.96875
loss:0.23488 acc:0.92188
loss:0.11752 acc:0.95312
loss:0.10076 acc:0.98438
loss:0.16476 acc:0.95312
loss:0.11161 acc:0.95312
loss:0.24382 acc:0.95312
loss:0.10563 acc:0.96875
loss:0.24099 acc:0.89062
loss:0.25432 acc:0.92188
loss:0.26607 acc:0.89062
loss:0.18009 acc:0.93750
loss:0.09952 acc:0.98438
loss:0.28454 acc:0.89062
loss:0.08948 acc:0.95312
loss:0.21783 acc:0.92188
loss:0.15621 acc:0.95312
loss:0.25727 acc:0.92188
loss:0.24574 acc:0.90625
loss:0.09447 acc:0.96875
loss:0.09729 acc:0.96875
loss:0.08467 acc:0.98438
loss:0.13028 acc:0.95312
loss:0.21107 acc:0.92188
loss:0.07938 acc:0.98438
loss:0.08832 acc:0.98438
loss:0.13598 acc:0.93750
loss:0.10205 acc:0.95312
loss:0.16719 acc:0.95312
loss:0.25444 acc:0.93750
loss:0.07656 acc:1.00000
loss:0.11371 acc:0.95312
loss:0.18071 acc:0.96875
loss:0.18862 acc:0.92188
loss:0.10793 acc:0.98438
loss:0.27224 acc:0.93750
loss:0.12655 acc:0.96875
loss:0.22090 acc:0.92188
loss:0.13526 acc:0.95312
loss:0.14960 acc:0.96875
loss:0.07072 acc:0.98438
loss:0.24943 acc:0.90625
loss:0.10602 acc:0.96875
loss:0.13111 acc:0.96875
loss:0.24879 acc:0.92188
loss:0.30222 acc:0.89062
loss:0.16071 acc:0.95312
loss:0.19176 acc:0.93750
loss:0.04011 acc:1.00000
loss:0.10362 acc:0.96875
loss:0.15975 acc:0.93750
loss:0.14810 acc:0.96875
loss:0.11138 acc:0.96875
loss:0.16518 acc:0.93750
loss:0.09359 acc:0.98438
loss:0.14622 acc:0.96875
loss:0.12750 acc:0.96875
loss:0.11546 acc:0.96875
loss:0.20611 acc:0.95312
loss:0.34404 acc:0.92188
loss:0.10114 acc:0.98438
loss:0.16506 acc:0.95312
loss:0.17545 acc:0.95312
loss:0.26379 acc:0.92188
loss:0.21496 acc:0.90625
loss:0.26072 acc:0.93750
loss:0.07525 acc:0.98438
loss:0.20520 acc:0.93750
loss:0.07231 acc:0.98438
loss:0.08859 acc:0.96875
loss:0.16149 acc:0.95312
loss:0.15149 acc:0.95312
loss:0.18882 acc:0.93750
loss:0.10414 acc:0.96875
loss:0.19076 acc:0.98438
loss:0.14612 acc:0.95312
loss:0.11762 acc:0.96875
loss:0.11453 acc:0.96875
loss:0.15372 acc:0.93750
loss:0.16458 acc:0.95312
loss:0.05676 acc:0.98438
loss:0.20118 acc:0.92188
loss:0.11675 acc:0.96875
loss:0.08715 acc:0.98438
loss:0.08951 acc:0.98438
loss:0.17705 acc:0.93750
loss:0.22871 acc:0.95312
loss:0.14870 acc:0.96875
loss:0.16718 acc:0.92188
loss:0.15694 acc:0.95312
loss:0.12892 acc:0.98438
loss:0.16018 acc:0.96875
loss:0.10527 acc:0.96875
loss:0.06917 acc:1.00000
loss:0.24590 acc:0.92188
loss:0.07912 acc:1.00000
loss:0.12666 acc:0.96875
loss:0.17376 acc:0.95312
loss:0.14857 acc:0.92188
loss:0.33249 acc:0.90625
loss:0.11874 acc:0.98438
loss:0.04404 acc:1.00000
loss:0.10323 acc:0.98438
loss:0.26786 acc:0.90625
loss:0.25149 acc:0.92188
loss:0.16042 acc:0.95312
loss:0.27602 acc:0.92188
loss:0.16729 acc:0.98438
loss:0.13113 acc:0.98438
loss:0.25255 acc:0.95312
loss:0.17137 acc:0.93750
loss:0.18273 acc:0.96875
loss:0.06922 acc:1.00000
loss:0.24624 acc:0.90625
loss:0.18232 acc:0.90625
loss:0.07420 acc:0.98438
loss:0.15953 acc:0.98438
loss:0.07206 acc:0.96875
loss:0.18828 acc:0.93750
loss:0.19701 acc:0.96875
loss:0.15299 acc:0.93750
loss:0.07865 acc:0.98438
loss:0.20735 acc:0.93750
loss:0.17156 acc:0.93750
loss:0.07509 acc:1.00000
loss:0.14986 acc:0.95312
loss:0.13248 acc:0.96875
loss:0.14388 acc:0.95312
loss:0.05571 acc:1.00000
loss:0.09408 acc:0.96875
loss:0.24787 acc:0.93750
loss:0.19877 acc:0.92188
loss:0.08191 acc:0.98438
loss:0.31069 acc:0.92188
loss:0.12103 acc:0.96875
loss:0.24595 acc:0.90625
loss:0.14003 acc:0.98438
loss:0.08255 acc:0.98438
loss:0.17961 acc:0.93750
loss:0.13665 acc:0.93750
loss:0.23543 acc:0.92188
loss:0.25226 acc:0.90625
loss:0.12286 acc:0.95312
loss:0.17261 acc:0.95312
loss:0.13196 acc:0.95312
loss:0.22068 acc:0.90625
loss:0.38338 acc:0.90625
loss:0.09904 acc:0.96875
loss:0.06626 acc:0.98438
loss:0.15176 acc:0.96875
loss:0.14975 acc:0.95312
loss:0.29266 acc:0.93750
loss:0.08018 acc:0.98438
loss:0.22623 acc:0.93750
loss:0.09631 acc:0.96875
loss:0.13066 acc:0.96875
loss:0.13228 acc:0.95312
loss:0.14615 acc:0.92188
loss:0.14123 acc:0.96875
loss:0.26964 acc:0.89062
loss:0.08345 acc:0.96875
loss:0.18795 acc:0.92188
loss:0.14019 acc:0.96875
loss:0.18857 acc:0.95312
loss:0.20577 acc:0.92188
loss:0.11165 acc:0.98438
loss:0.11152 acc:0.95312
loss:0.10814 acc:0.96875
loss:0.10788 acc:0.98438
loss:0.10583 acc:0.95312
loss:0.08795 acc:0.98438
loss:0.16285 acc:0.95312
loss:0.13543 acc:0.96875
loss:0.14316 acc:0.95312
loss:0.20426 acc:0.96875
loss:0.06097 acc:1.00000
loss:0.08930 acc:0.98438
loss:0.14915 acc:0.93750
loss:0.05893 acc:1.00000
loss:0.03859 acc:1.00000
loss:0.10800 acc:0.96875
loss:0.26009 acc:0.90625
loss:0.07896 acc:1.00000
loss:0.06844 acc:0.98438
loss:0.10040 acc:0.98438
loss:0.08292 acc:0.98438
loss:0.37432 acc:0.87500
loss:0.16169 acc:0.95312
loss:0.22969 acc:0.90625
loss:0.15961 acc:0.93750
loss:0.13284 acc:0.98438
loss:0.09883 acc:0.98438
loss:0.14311 acc:0.96875
loss:0.18110 acc:0.93750
loss:0.20608 acc:0.93750
loss:0.08434 acc:0.98438
loss:0.13159 acc:0.95312
loss:0.10497 acc:0.95312
loss:0.13571 acc:0.96875
loss:0.15379 acc:0.96875
loss:0.12703 acc:0.96875
loss:0.22704 acc:0.90625
loss:0.11205 acc:0.95312
loss:0.13118 acc:0.95312
loss:0.11921 acc:0.96875
loss:0.13128 acc:0.96875
loss:0.16961 acc:0.98438
loss:0.26039 acc:0.92188
loss:0.10821 acc:0.96875
loss:0.15362 acc:0.92188
loss:0.20221 acc:0.93750
loss:0.25353 acc:0.93750
loss:0.17650 acc:0.92188
loss:0.17608 acc:0.92188
loss:0.10742 acc:0.96875
loss:0.14668 acc:0.95312
loss:0.11982 acc:0.95312
loss:0.14160 acc:0.96875
loss:0.14920 acc:0.96875
loss:0.07769 acc:0.98438
loss:0.07435 acc:0.98438
loss:0.12596 acc:0.96875
loss:0.06209 acc:0.98438
loss:0.23671 acc:0.93750
loss:0.32712 acc:0.90625
loss:0.16227 acc:0.93750
loss:0.07742 acc:0.98438
loss:0.11433 acc:0.96875
loss:0.14617 acc:0.95312
loss:0.10022 acc:0.98438
loss:0.11036 acc:0.96875
loss:0.22448 acc:0.95312
loss:0.17712 acc:0.92188
loss:0.14171 acc:0.95312
loss:0.14837 acc:0.95312
loss:0.05253 acc:1.00000
loss:0.05911 acc:1.00000
loss:0.16225 acc:0.95312
loss:0.22476 acc:0.92188
loss:0.18999 acc:0.93750
loss:0.17341 acc:0.90625
loss:0.09191 acc:0.98438
loss:0.18701 acc:0.95312
loss:0.11330 acc:0.96875
loss:0.17069 acc:0.96875
loss:0.07430 acc:0.98438
loss:0.14603 acc:0.96875
loss:0.04871 acc:1.00000
loss:0.28533 acc:0.92188
loss:0.18542 acc:0.93750
loss:0.23417 acc:0.90625
loss:0.14049 acc:0.93750
loss:0.25800 acc:0.90625
loss:0.11138 acc:0.98438
loss:0.07213 acc:0.98438
loss:0.07436 acc:0.98438
loss:0.14025 acc:0.95312
loss:0.07013 acc:0.98438
loss:0.17246 acc:0.95312
loss:0.18267 acc:0.95312
loss:0.22772 acc:0.93750
loss:0.06264 acc:0.98438
loss:0.02746 acc:1.00000
loss:0.24368 acc:0.89062
loss:0.05390 acc:1.00000
loss:0.19132 acc:0.95312
loss:0.19019 acc:0.96875
loss:0.13184 acc:0.96875
loss:0.07464 acc:0.98438
loss:0.09115 acc:0.98438
loss:0.09529 acc:0.96875
loss:0.08990 acc:0.96875
loss:0.07885 acc:0.96875
loss:0.09846 acc:0.95312
loss:0.14663 acc:0.95312
loss:0.22904 acc:0.93750
loss:0.11402 acc:0.96875
loss:0.19972 acc:0.93750
loss:0.20436 acc:0.96875
loss:0.32500 acc:0.87500
loss:0.11743 acc:0.95312
loss:0.17370 acc:0.92188
loss:0.08371 acc:0.98438
loss:0.08599 acc:0.96875
loss:0.12539 acc:0.95312
loss:0.12278 acc:0.98438
loss:0.10975 acc:0.96875
loss:0.10831 acc:0.98438
loss:0.06411 acc:0.96875
loss:0.23990 acc:0.93750
loss:0.08271 acc:0.98438
loss:0.10399 acc:0.96875
loss:0.06397 acc:1.00000
loss:0.08254 acc:0.96875
loss:0.09192 acc:0.98438
loss:0.17281 acc:0.95312
loss:0.06051 acc:1.00000
loss:0.14834 acc:0.98438
loss:0.09404 acc:0.96875
loss:0.06473 acc:0.98438
loss:0.15557 acc:0.98438
loss:0.14979 acc:0.96875
loss:0.19210 acc:0.95312
loss:0.13649 acc:0.93750
loss:0.13135 acc:0.93750
loss:0.21046 acc:0.90625
loss:0.10882 acc:0.95312
loss:0.17906 acc:0.93750
loss:0.19661 acc:0.93750
loss:0.10849 acc:0.95312
loss:0.07449 acc:1.00000
loss:0.24108 acc:0.90625
loss:0.06498 acc:0.98438
loss:0.17111 acc:0.95312
loss:0.19536 acc:0.93750
loss:0.08626 acc:0.95312
loss:0.12976 acc:0.98438
loss:0.04587 acc:1.00000
loss:0.31982 acc:0.90625
loss:0.09512 acc:0.98438
loss:0.23773 acc:0.90625
loss:0.20744 acc:0.95312
loss:0.23242 acc:0.95312
loss:0.15539 acc:0.96875
loss:0.14239 acc:0.93750
loss:0.08568 acc:0.96875
loss:0.14298 acc:0.95312
loss:0.11364 acc:0.95312
loss:0.15121 acc:0.98438
loss:0.10850 acc:0.93750
loss:0.13976 acc:0.95312
loss:0.19362 acc:0.93750
loss:0.08078 acc:0.98438
loss:0.05183 acc:0.98438
loss:0.15727 acc:0.96875
loss:0.07899 acc:0.96875
loss:0.05916 acc:0.98438
loss:0.09187 acc:0.96875
loss:0.07978 acc:0.98438
loss:0.07401 acc:0.98438
loss:0.20396 acc:0.95312
loss:0.04685 acc:1.00000
loss:0.17921 acc:0.96875
loss:0.05774 acc:1.00000
loss:0.09466 acc:0.96875
loss:0.21201 acc:0.95312
loss:0.12821 acc:0.95312
loss:0.08764 acc:0.98438
loss:0.11809 acc:0.96875
loss:0.16199 acc:0.96875
loss:0.11255 acc:0.98438
loss:0.14778 acc:0.95312
loss:0.08199 acc:0.96875
loss:0.06982 acc:0.96875
loss:0.08677 acc:0.96875
loss:0.14956 acc:0.95312
loss:0.15698 acc:0.93750
loss:0.12428 acc:0.96875
loss:0.19111 acc:0.93750
loss:0.04491 acc:1.00000
loss:0.19884 acc:0.90625
loss:0.24804 acc:0.93750
loss:0.13953 acc:0.96875
loss:0.24288 acc:0.90625
loss:0.15888 acc:0.95312
loss:0.17232 acc:0.95312
loss:0.09322 acc:0.96875
loss:0.14902 acc:0.96875
loss:0.11222 acc:0.98438
loss:0.08734 acc:0.96875
loss:0.09247 acc:0.96875
loss:0.11926 acc:0.95312
loss:0.04608 acc:0.98438
loss:0.07352 acc:0.96875
loss:0.05912 acc:0.98438
loss:0.06844 acc:1.00000
loss:0.09879 acc:1.00000
loss:0.11328 acc:0.96875
loss:0.14035 acc:0.96875
loss:0.18120 acc:0.95312
loss:0.09994 acc:0.95312
loss:0.14785 acc:0.95312
loss:0.07812 acc:0.96875
loss:0.15447 acc:0.95312
loss:0.18226 acc:0.93750
loss:0.14916 acc:0.96875
loss:0.19192 acc:0.93750
loss:0.10521 acc:0.98438
loss:0.19849 acc:0.96875
loss:0.16188 acc:0.95312
loss:0.05045 acc:1.00000
loss:0.23136 acc:0.95312
loss:0.07092 acc:0.98438
loss:0.12442 acc:0.96875
loss:0.16773 acc:0.95312
loss:0.16274 acc:0.95312
loss:0.28012 acc:0.92188
loss:0.08197 acc:0.98438
loss:0.05306 acc:0.98438
loss:0.07457 acc:0.98438
loss:0.17700 acc:0.93750
loss:0.10168 acc:0.98438
loss:0.20213 acc:0.92188
loss:0.20351 acc:0.92188
loss:0.10901 acc:0.96875
loss:0.04361 acc:1.00000
loss:0.16461 acc:0.96875
loss:0.20337 acc:0.92188
loss:0.11774 acc:0.98438
loss:0.19589 acc:0.92188
loss:0.14479 acc:0.96875
loss:0.16974 acc:0.93750
loss:0.09703 acc:0.96875
loss:0.06660 acc:0.98438
loss:0.08469 acc:1.00000
loss:0.09053 acc:0.96875
loss:0.15431 acc:0.90625
loss:0.10061 acc:0.98438
loss:0.10092 acc:0.98438
loss:0.17791 acc:0.96875
loss:0.13418 acc:0.95312
loss:0.04396 acc:0.98438
loss:0.16642 acc:0.96875
loss:0.05089 acc:1.00000
loss:0.17533 acc:0.93750
loss:0.16705 acc:0.95312
loss:0.10573 acc:0.95312
loss:0.15486 acc:0.96875
loss:0.12395 acc:0.95312
loss:0.12970 acc:0.96875
loss:0.18159 acc:0.95312
loss:0.05818 acc:1.00000
loss:0.10698 acc:0.93750
loss:0.03983 acc:1.00000
loss:0.18349 acc:0.96875
loss:0.08221 acc:0.98438
loss:0.12925 acc:0.96875
loss:0.11189 acc:0.96875
loss:0.17911 acc:0.96875
loss:0.27267 acc:0.95312
loss:0.06801 acc:0.98438
loss:0.03988 acc:1.00000
loss:0.03477 acc:1.00000
loss:0.10255 acc:0.96875
loss:0.09664 acc:0.98438
loss:0.11505 acc:0.96875
loss:0.09369 acc:0.98438
loss:0.20973 acc:0.95312
loss:0.16301 acc:0.93750
loss:0.08222 acc:0.96875
loss:0.10991 acc:0.95312
loss:0.09927 acc:0.96875
loss:0.04367 acc:1.00000
loss:0.06952 acc:0.96875
loss:0.13133 acc:0.96875
loss:0.13914 acc:0.96875
loss:0.27232 acc:0.96875
loss:0.08188 acc:0.96875
loss:0.20095 acc:0.96875
loss:0.08710 acc:0.96875
loss:0.05236 acc:1.00000
loss:0.11694 acc:0.96875
loss:0.11446 acc:0.96875
loss:0.14520 acc:0.96875
loss:0.18993 acc:0.90625
loss:0.17676 acc:0.96875
loss:0.10977 acc:0.95312
loss:0.12015 acc:0.98438
loss:0.15122 acc:0.96875
loss:0.19218 acc:0.95312
loss:0.14699 acc:0.93750
loss:0.24232 acc:0.90625
loss:0.21631 acc:0.90625
loss:0.06756 acc:0.98438
loss:0.17421 acc:0.95312
loss:0.12664 acc:0.95312
loss:0.15569 acc:0.95312
loss:0.12884 acc:0.96875
loss:0.11202 acc:0.96875
loss:0.05302 acc:0.98438
loss:0.20850 acc:0.93750
loss:0.12652 acc:0.96875
loss:0.23816 acc:0.92188
loss:0.06526 acc:0.98438
loss:0.05521 acc:1.00000
loss:0.21977 acc:0.93750
loss:0.04914 acc:1.00000
loss:0.17506 acc:0.92188
loss:0.06785 acc:0.98438
loss:0.15795 acc:0.95312
loss:0.13381 acc:0.93750
loss:0.07008 acc:0.98438
loss:0.19596 acc:0.95312
loss:0.19804 acc:0.92188
loss:0.22154 acc:0.93750
loss:0.14441 acc:0.93750
loss:0.08902 acc:0.98438
loss:0.13566 acc:0.96875
loss:0.08600 acc:0.96875
loss:0.19083 acc:0.98438
loss:0.07337 acc:0.98438
loss:0.16578 acc:0.93750
loss:0.15782 acc:0.95312
loss:0.10420 acc:0.98438
loss:0.07731 acc:0.98438
loss:0.12828 acc:0.96875
loss:0.06208 acc:1.00000
loss:0.25428 acc:0.92188
loss:0.03976 acc:1.00000
loss:0.19712 acc:0.90625
loss:0.15139 acc:0.92188
loss:0.06922 acc:0.98438
loss:0.12168 acc:0.98438
loss:0.04660 acc:0.98438
loss:0.12327 acc:0.96875
loss:0.22629 acc:0.95312
loss:0.12471 acc:0.95312
loss:0.04671 acc:1.00000
loss:0.14425 acc:0.96875
loss:0.19596 acc:0.93750
loss:0.03496 acc:1.00000
lossの動きに注目してほしい。
lossの動きが明らかにこれまでの浮動小数点を使った学習と異なっているように見えるのだ。
BitNetでMNISTを学習させるとこうなる。
学習の際初期段階においては、BitNetのlossの方が通常のLinearよりも下がるのが遅い。
もう少し最初期段階に注目してみる。
普通のLinearに比べるとBitNetはやや遅い収束になっているが、これくらいなら誤差の範囲と言えそうだ。
学習の終盤に注目してみる。
学習の終盤では、ノーマルのLinearの方がlossのブレが激しい。
しかしこんなに簡単にLinearを入れ替えて使えるとは思わなかった。
それを踏まえて、なぜBItNetによるLLMがうまく学習できないのか考えてみると、単純にそもそもLLMをゼロから学習(事前学習)することは、誰にとっても困難であるから、と言うのがもしかすると正解なのかもしれない。
そして学習過程においては、BitNetは他の方法よりも時間がかかりそうだ。
また、学習上の計算においても、BItNetは従来の手法に比べて色んな処理が間に入り、計算量が多くなるため、現行ハードウェア上で得られる恩恵は限定的だ。推論においては計算量が激減するが、GPUを使う意味はなくなってくる。それがBItNet論文に書かれた「新しいハードウェアが必要」と言う提言の意味なのかもしれない。
import torch
from torch import nn
import torch.nn.functional as F
class BitRMSNorm(nn.Module): #はちさんによる実装
def __init__(self, hidden_size, eps=1e-6):
"""
BitRMSNorm is equivalent to LlamaRMSNorm and T5LayerNorm
refers: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L76C1-L90C59
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def activation_quant(x):
scale = 127.0 / x.abs().max(dim=-1,keepdim=True).values.clamp_(min=1e-5)
y = (x*scale).round().clamp_(-128,127)/scale
return y
def weight_quant(w):
scale = 1.0/w.abs().mean().clamp_(min=1e-5)
u = (w*scale).round().clamp_(-1,1) / scale
return u
class BitLinearOriginal(nn.Linear): #論文に基づく実装
def __init__(self,in_features,out_features,bias=False,flg_before_linear=True,bits=8):
super(BitLinearOriginal, self).__init__(in_features, out_features, bias)
self.layernorm = nn.LayerNorm(in_features)
self.RMSNorm = BitRMSNorm(in_features)
self.bits = bits
def forward(self,x):
w=self.weight
x_norm = self.RMSNorm(x)
x_quant = x_norm + (activation_quant(x_norm)-x_norm).detach()
w_quant = w+(weight_quant(w)-w).detach()
y = F.linear(x_quant,w_quant)
return y
完全なソースコードはGitHubに置いた。