見出し画像

BitNetでMNISTを学習させて見えてきた性質

かれこれ一ヶ月弱くらいBitNetと格闘している。BitNetは、Microsoftが発明したと主張している1-Bit(1.58ビットとも言われる)量子化ニューラルネットワークのことだ。

僕はその辺に落ちてるコードを使って最初の最初はlossが2くらいまで下がったのだが、そもそもLLMはlossが1を切らないと実用性がない。

それ以降は6とか良くて5とかなのでたまたま最初に試したのがうまく行ったようだ。

しかしいつまで経っても良くならないのでBitNetの性質を根本的に見直す必要があるのでは?と思い、初心に帰って論理回路を学習させようとした。

BitNetのコードベースははちさんのコードと、Microsoftの公式な論文の実装を併用した。

まず試したのはこのようなコード

from bitnet import *
import torch
from torch import optim
import torch.nn as nn


input = torch.tensor([ [0.,0.],[0.,1.],[1.,0.],[1.,1.] ])
output = torch.tensor([ [0.] ,[0.],[0.],[1.],])

#layer = BitLinear(2,1)
model = nn.Sequential(
    #nn.Linear(2,1),
    BitLinear(2,2),
    nn.ReLU(),
    BitLinear(2,1),
    nn.Sigmoid()
)

optimizer = optim.AdamW(model.parameters(), lr=0.001)
loss_fn = nn.BCELoss()
for i in range(90000):
        #for x,t in zip(input,output):
        optimizer.zero_grad()
        y=model.forward(input)
        loss = loss_fn(y,output)
        print(loss)
        loss.backward()
        optimizer.step()
y=model(input)
print(y)

非常にシンプルな2入力1出力の三層パーセプトロンだ。

これを学習させようと何度も頑張ったが、全く学習できなかった

昔の人工知能研究者は、論理回路が学習できないならもっと複雑なことは絶対に学習できないに違いないと考えて早々に諦めていた。おそらくBitNetが今日まで有効性を認められてこなかったのはこう言う性質にも原因があるのだろう。

ただ、僕としては「そもそも3状態で論理回路の状態を再現するのは運ゲーすぎない?」と言う疑問がある。初期状態はランダムだから、学習によって見つけると言うよりも運(初期状態)の要素が強い。

そこでもう少し複雑な問題に適用したらどうなるか実験してみた。
MNISTだ。

MNISTのコードを書くのは面倒だったのでClaude-3に書かせたものを、レイヤーだけBitLinearに変える。ちなみにコメントもClaude-3が書いたもの。便利すぎる。

from bitnet import *
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# MNISTデータセットのダウンロードと前処理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# ニューラルネットワークの定義
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 =BitLinearOriginal(784, 128)
        self.fc2 =BitLinearOriginal(128, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = Net()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# オプティマイザのパラメータを記録するためのリストを初期化
param_history = []

def calculate_accuracy(output, target):
    _, predicted = torch.max(output, 1)
    correct = (predicted == target).sum().item()
    accuracy = correct / target.size(0)
    return accuracy

# 学習ループ
for epoch in range(50):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        # オプティマイザのパラメータを記録
        param_history.append([p.clone().detach().numpy() for p in model.parameters()])

        optimizer.step()
        acc = calculate_accuracy(output, target)
        print(f"loss:{loss}  acc:{acc}")
        if acc>0.99:
          break

これで学習させたところ、確かに学習できてしまう。
ただ、あまりに一瞬で学習されてしまうので、やはりなんというかどこか嘘くさい気がする。

仕方がないので学習の初期状態を詳細に出力させて、ちゃんと「正答率が低い状態」から「正答率が高い状態」へ学習していることを確認した。

$ python mnist.py 
loss:2.41108	acc:0.04688
loss:2.18548	acc:0.23438
loss:2.00562	acc:0.42188
loss:1.94573	acc:0.45312
loss:1.82240	acc:0.48438
loss:1.76757	acc:0.56250
loss:1.55644	acc:0.59375
loss:1.39945	acc:0.81250
loss:1.53707	acc:0.64062
loss:1.37297	acc:0.70312
loss:1.30705	acc:0.76562
loss:1.29431	acc:0.78125
loss:1.17070	acc:0.81250
loss:1.20207	acc:0.76562
loss:1.22176	acc:0.70312
loss:1.17043	acc:0.76562
loss:1.29147	acc:0.60938
loss:1.13530	acc:0.79688
loss:1.13138	acc:0.70312
loss:1.10515	acc:0.70312
loss:1.12286	acc:0.70312
loss:1.03146	acc:0.81250
loss:0.90347	acc:0.84375
loss:1.00460	acc:0.81250
loss:0.99084	acc:0.78125
loss:0.82783	acc:0.89062
loss:1.06856	acc:0.75000
loss:0.89284	acc:0.81250
loss:1.14347	acc:0.70312
loss:0.86533	acc:0.82812
loss:0.76226	acc:0.89062
loss:0.98805	acc:0.76562
loss:0.77720	acc:0.87500
loss:0.91726	acc:0.76562
loss:0.84186	acc:0.79688
loss:0.89622	acc:0.79688
loss:0.77819	acc:0.79688
loss:0.64624	acc:0.96875
loss:0.85572	acc:0.81250
loss:0.81650	acc:0.82812
loss:0.70959	acc:0.92188
loss:0.79852	acc:0.79688
loss:0.58357	acc:0.92188
loss:0.69148	acc:0.87500
loss:0.68034	acc:0.85938
loss:0.71641	acc:0.90625
loss:0.52498	acc:0.90625
loss:0.62418	acc:0.90625
loss:0.57807	acc:0.84375
loss:0.60706	acc:0.87500
loss:0.49764	acc:0.92188
loss:0.55832	acc:0.92188
loss:0.51869	acc:0.89062
loss:0.81124	acc:0.75000
loss:0.57983	acc:0.87500
loss:0.65688	acc:0.87500
loss:0.65231	acc:0.82812
loss:0.58111	acc:0.90625
loss:0.53314	acc:0.89062
loss:0.59246	acc:0.87500
loss:0.56978	acc:0.87500
loss:0.54524	acc:0.89062
loss:0.43568	acc:0.90625
loss:0.67516	acc:0.79688
loss:0.50416	acc:0.90625
loss:0.47741	acc:0.89062
loss:0.53961	acc:0.87500
loss:0.54430	acc:0.90625
loss:0.59080	acc:0.79688
loss:0.39868	acc:0.93750
loss:0.45935	acc:0.92188
loss:0.46753	acc:0.90625
loss:0.55835	acc:0.89062
loss:0.45517	acc:0.90625
loss:0.51455	acc:0.87500
loss:0.45387	acc:0.90625
loss:0.55507	acc:0.82812
loss:0.36060	acc:0.95312
loss:0.34529	acc:0.93750
loss:0.44868	acc:0.87500
loss:0.51357	acc:0.85938
loss:0.51937	acc:0.87500
loss:0.43396	acc:0.89062
loss:0.40336	acc:0.90625
loss:0.39574	acc:0.92188
loss:0.31379	acc:0.95312
loss:0.29613	acc:0.96875
loss:0.44163	acc:0.89062
loss:0.54892	acc:0.85938
loss:0.34996	acc:0.90625
loss:0.43126	acc:0.90625
loss:0.47235	acc:0.85938
loss:0.38899	acc:0.93750
loss:0.38725	acc:0.90625
loss:0.51207	acc:0.82812
loss:0.58700	acc:0.78125
loss:0.40778	acc:0.92188
loss:0.51206	acc:0.87500
loss:0.46178	acc:0.89062
loss:0.31953	acc:0.96875
loss:0.38844	acc:0.92188
loss:0.36591	acc:0.90625
loss:0.33379	acc:0.90625
loss:0.30978	acc:0.93750
loss:0.47275	acc:0.85938
loss:0.44161	acc:0.89062
loss:0.40364	acc:0.90625
loss:0.24938	acc:0.96875
loss:0.24971	acc:0.98438
loss:0.41120	acc:0.90625
loss:0.51807	acc:0.84375
loss:0.49253	acc:0.89062
loss:0.46641	acc:0.92188
loss:0.44445	acc:0.87500
loss:0.40190	acc:0.90625
loss:0.42140	acc:0.90625
loss:0.28105	acc:0.93750
loss:0.42497	acc:0.87500
loss:0.58897	acc:0.81250
loss:0.37514	acc:0.89062
loss:0.35815	acc:0.92188
loss:0.39548	acc:0.89062
loss:0.35628	acc:0.96875
loss:0.38870	acc:0.90625
loss:0.35709	acc:0.93750
loss:0.40669	acc:0.87500
loss:0.20514	acc:1.00000
loss:0.30028	acc:0.92188
loss:0.26510	acc:0.96875
loss:0.53790	acc:0.76562
loss:0.33812	acc:0.92188
loss:0.22588	acc:0.95312
loss:0.40264	acc:0.87500
loss:0.21730	acc:0.98438
loss:0.44845	acc:0.89062
loss:0.31983	acc:0.92188
loss:0.45905	acc:0.87500
loss:0.32539	acc:0.90625
loss:0.30357	acc:0.89062
loss:0.36034	acc:0.90625
loss:0.31498	acc:0.95312
loss:0.23514	acc:0.93750
loss:0.42586	acc:0.90625
loss:0.31543	acc:0.93750
loss:0.19443	acc:0.98438
loss:0.37593	acc:0.90625
loss:0.26972	acc:0.96875
loss:0.23483	acc:0.96875
loss:0.25687	acc:0.92188
loss:0.29415	acc:0.90625
loss:0.34700	acc:0.85938
loss:0.40503	acc:0.89062
loss:0.46366	acc:0.85938
loss:0.32815	acc:0.92188
loss:0.34704	acc:0.90625
loss:0.18917	acc:0.96875
loss:0.37811	acc:0.87500
loss:0.34091	acc:0.87500
loss:0.31946	acc:0.89062
loss:0.23625	acc:0.95312
loss:0.19402	acc:0.96875
loss:0.35543	acc:0.89062
loss:0.18609	acc:0.98438
loss:0.34984	acc:0.90625
loss:0.29880	acc:0.95312
loss:0.23121	acc:0.93750
loss:0.27589	acc:0.93750
loss:0.16535	acc:0.98438
loss:0.27656	acc:0.95312
loss:0.27613	acc:0.90625
loss:0.17787	acc:1.00000
loss:0.22559	acc:0.96875
loss:0.18576	acc:0.95312
loss:0.20565	acc:0.96875
loss:0.33294	acc:0.92188
loss:0.37279	acc:0.89062
loss:0.23660	acc:0.92188
loss:0.19797	acc:0.96875
loss:0.51400	acc:0.87500
loss:0.23547	acc:0.92188
loss:0.32005	acc:0.89062
loss:0.38082	acc:0.87500
loss:0.37463	acc:0.85938
loss:0.32102	acc:0.90625
loss:0.30536	acc:0.89062
loss:0.19124	acc:0.96875
loss:0.38712	acc:0.90625
loss:0.24493	acc:0.92188
loss:0.25722	acc:0.93750
loss:0.37526	acc:0.90625
loss:0.24581	acc:0.93750
loss:0.18983	acc:0.95312
loss:0.19335	acc:0.93750
loss:0.26011	acc:0.92188
loss:0.16785	acc:0.96875
loss:0.31089	acc:0.92188
loss:0.41377	acc:0.87500
loss:0.31412	acc:0.90625
loss:0.18055	acc:0.95312
loss:0.21896	acc:0.95312
loss:0.22426	acc:0.93750
loss:0.17735	acc:0.96875
loss:0.12969	acc:0.98438
loss:0.25509	acc:0.95312
loss:0.36201	acc:0.85938
loss:0.30848	acc:0.90625
loss:0.37621	acc:0.89062
loss:0.45327	acc:0.84375
loss:0.29644	acc:0.90625
loss:0.26313	acc:0.92188
loss:0.44245	acc:0.89062
loss:0.27033	acc:0.90625
loss:0.29708	acc:0.89062
loss:0.30118	acc:0.87500
loss:0.27827	acc:0.93750
loss:0.21405	acc:0.95312
loss:0.15615	acc:0.96875
loss:0.20493	acc:0.95312
loss:0.39926	acc:0.85938
loss:0.33024	acc:0.90625
loss:0.33339	acc:0.90625
loss:0.35240	acc:0.92188
loss:0.28695	acc:0.93750
loss:0.24316	acc:0.93750
loss:0.22253	acc:0.93750
loss:0.21479	acc:0.95312
loss:0.29799	acc:0.95312
loss:0.33985	acc:0.87500
loss:0.27656	acc:0.95312
loss:0.34378	acc:0.90625
loss:0.20759	acc:0.93750
loss:0.32314	acc:0.93750
loss:0.30703	acc:0.89062
loss:0.23211	acc:0.95312
loss:0.18815	acc:0.95312
loss:0.36385	acc:0.90625
loss:0.25207	acc:0.90625
loss:0.26671	acc:0.92188
loss:0.36773	acc:0.87500
loss:0.28102	acc:0.90625
loss:0.30742	acc:0.92188
loss:0.29756	acc:0.90625
loss:0.39121	acc:0.85938
loss:0.30177	acc:0.92188
loss:0.32003	acc:0.89062
loss:0.33404	acc:0.89062
loss:0.14668	acc:0.98438
loss:0.22110	acc:0.96875
loss:0.25015	acc:0.93750
loss:0.29098	acc:0.93750
loss:0.19943	acc:0.92188
loss:0.23941	acc:0.92188
loss:0.15813	acc:0.96875
loss:0.32510	acc:0.90625
loss:0.24515	acc:0.92188
loss:0.13650	acc:0.96875
loss:0.34347	acc:0.90625
loss:0.37198	acc:0.89062
loss:0.19716	acc:0.93750
loss:0.20936	acc:0.95312
loss:0.22594	acc:0.92188
loss:0.31214	acc:0.84375
loss:0.22110	acc:0.95312
loss:0.16838	acc:0.95312
loss:0.17118	acc:0.98438
loss:0.27213	acc:0.93750
loss:0.36451	acc:0.87500
loss:0.13542	acc:0.96875
loss:0.11854	acc:0.98438
loss:0.37582	acc:0.87500
loss:0.20354	acc:0.95312
loss:0.18940	acc:0.93750
loss:0.25606	acc:0.90625
loss:0.24460	acc:0.90625
loss:0.20000	acc:0.95312
loss:0.17487	acc:0.93750
loss:0.18902	acc:0.93750
loss:0.18991	acc:0.96875
loss:0.20206	acc:0.93750
loss:0.22814	acc:0.90625
loss:0.26283	acc:0.92188
loss:0.09305	acc:0.98438
loss:0.21910	acc:0.93750
loss:0.22599	acc:0.95312
loss:0.31945	acc:0.90625
loss:0.32415	acc:0.93750
loss:0.46490	acc:0.87500
loss:0.32415	acc:0.89062
loss:0.10184	acc:1.00000
loss:0.21687	acc:0.95312
loss:0.19604	acc:0.96875
loss:0.17721	acc:0.96875
loss:0.32351	acc:0.89062
loss:0.12123	acc:0.98438
loss:0.23676	acc:0.96875
loss:0.25150	acc:0.93750
loss:0.43190	acc:0.92188
loss:0.20894	acc:0.95312
loss:0.09201	acc:0.98438
loss:0.17339	acc:0.96875
loss:0.32940	acc:0.87500
loss:0.34605	acc:0.89062
loss:0.27641	acc:0.92188
loss:0.27213	acc:0.93750
loss:0.13855	acc:0.96875
loss:0.29624	acc:0.87500
loss:0.20382	acc:0.89062
loss:0.18497	acc:0.96875
loss:0.18618	acc:0.95312
loss:0.16166	acc:0.96875
loss:0.20708	acc:0.93750
loss:0.17241	acc:0.96875
loss:0.20468	acc:0.93750
loss:0.22756	acc:0.93750
loss:0.16883	acc:0.96875
loss:0.22368	acc:0.93750
loss:0.20490	acc:0.95312
loss:0.25694	acc:0.89062
loss:0.15531	acc:0.96875
loss:0.30559	acc:0.89062
loss:0.22385	acc:0.95312
loss:0.18270	acc:0.95312
loss:0.24967	acc:0.92188
loss:0.10997	acc:0.98438
loss:0.25074	acc:0.92188
loss:0.21005	acc:0.93750
loss:0.13791	acc:0.96875
loss:0.35392	acc:0.92188
loss:0.24487	acc:0.93750
loss:0.23756	acc:0.95312
loss:0.26027	acc:0.89062
loss:0.28201	acc:0.92188
loss:0.26195	acc:0.89062
loss:0.26110	acc:0.90625
loss:0.18067	acc:0.92188
loss:0.23617	acc:0.92188
loss:0.25224	acc:0.93750
loss:0.22404	acc:0.93750
loss:0.26276	acc:0.90625
loss:0.16056	acc:0.93750
loss:0.30658	acc:0.92188
loss:0.21245	acc:0.92188
loss:0.13786	acc:0.96875
loss:0.28125	acc:0.87500
loss:0.18568	acc:0.95312
loss:0.15106	acc:0.95312
loss:0.19052	acc:0.96875
loss:0.31077	acc:0.89062
loss:0.15633	acc:0.98438
loss:0.30772	acc:0.93750
loss:0.25433	acc:0.90625
loss:0.23351	acc:0.89062
loss:0.21224	acc:0.96875
loss:0.14788	acc:0.96875
loss:0.20176	acc:0.95312
loss:0.16864	acc:0.95312
loss:0.14577	acc:0.95312
loss:0.25367	acc:0.92188
loss:0.15585	acc:0.96875
loss:0.10192	acc:0.98438
loss:0.29795	acc:0.93750
loss:0.15300	acc:0.96875
loss:0.20459	acc:0.92188
loss:0.40997	acc:0.89062
loss:0.27964	acc:0.92188
loss:0.20234	acc:0.96875
loss:0.33700	acc:0.90625
loss:0.31243	acc:0.89062
loss:0.19571	acc:0.93750
loss:0.32722	acc:0.92188
loss:0.20196	acc:0.95312
loss:0.15080	acc:0.95312
loss:0.10359	acc:0.96875
loss:0.31100	acc:0.92188
loss:0.23306	acc:0.95312
loss:0.22770	acc:0.96875
loss:0.30542	acc:0.93750
loss:0.15798	acc:0.93750
loss:0.10705	acc:0.96875
loss:0.23852	acc:0.90625
loss:0.20319	acc:0.95312
loss:0.21662	acc:0.93750
loss:0.22358	acc:0.93750
loss:0.11667	acc:0.98438
loss:0.19136	acc:0.95312
loss:0.13301	acc:0.96875
loss:0.34246	acc:0.90625
loss:0.19936	acc:0.95312
loss:0.27872	acc:0.90625
loss:0.28035	acc:0.87500
loss:0.13166	acc:0.96875
loss:0.15417	acc:0.96875
loss:0.20961	acc:0.93750
loss:0.27052	acc:0.90625
loss:0.16499	acc:0.95312
loss:0.12835	acc:0.95312
loss:0.21926	acc:0.92188
loss:0.26808	acc:0.93750
loss:0.21866	acc:0.92188
loss:0.25416	acc:0.93750
loss:0.20989	acc:0.95312
loss:0.12020	acc:0.98438
loss:0.13833	acc:0.95312
loss:0.16170	acc:0.95312
loss:0.36952	acc:0.89062
loss:0.13549	acc:0.95312
loss:0.14136	acc:0.95312
loss:0.28028	acc:0.92188
loss:0.13040	acc:0.98438
loss:0.16315	acc:0.95312
loss:0.16745	acc:0.98438
loss:0.33258	acc:0.93750
loss:0.17818	acc:0.95312
loss:0.15952	acc:0.98438
loss:0.25364	acc:0.92188
loss:0.30436	acc:0.92188
loss:0.18246	acc:0.95312
loss:0.24509	acc:0.93750
loss:0.17627	acc:0.95312
loss:0.21730	acc:0.93750
loss:0.25020	acc:0.93750
loss:0.31416	acc:0.89062
loss:0.11529	acc:0.98438
loss:0.17905	acc:0.92188
loss:0.19555	acc:0.95312
loss:0.16704	acc:0.95312
loss:0.19937	acc:0.93750
loss:0.19374	acc:0.92188
loss:0.13501	acc:0.96875
loss:0.15464	acc:0.95312
loss:0.14569	acc:0.96875
loss:0.24573	acc:0.87500
loss:0.32774	acc:0.87500
loss:0.12643	acc:0.96875
loss:0.34496	acc:0.92188
loss:0.20144	acc:0.95312
loss:0.17076	acc:0.95312
loss:0.24230	acc:0.95312
loss:0.23473	acc:0.95312
loss:0.30967	acc:0.92188
loss:0.15197	acc:0.98438
loss:0.12085	acc:0.96875
loss:0.10847	acc:0.96875
loss:0.30175	acc:0.93750
loss:0.18536	acc:0.96875
loss:0.12876	acc:0.96875
loss:0.21572	acc:0.93750
loss:0.18642	acc:0.95312
loss:0.15870	acc:0.95312
loss:0.18767	acc:0.95312
loss:0.14790	acc:0.96875
loss:0.22414	acc:0.92188
loss:0.16376	acc:0.95312
loss:0.28045	acc:0.92188
loss:0.21378	acc:0.95312
loss:0.15277	acc:0.96875
loss:0.15975	acc:0.96875
loss:0.10392	acc:0.98438
loss:0.20615	acc:0.93750
loss:0.12076	acc:0.96875
loss:0.19137	acc:0.93750
loss:0.13918	acc:0.98438
loss:0.22372	acc:0.95312
loss:0.12778	acc:1.00000
loss:0.10982	acc:0.98438
loss:0.15405	acc:0.92188
loss:0.26379	acc:0.92188
loss:0.14794	acc:0.93750
loss:0.18582	acc:0.95312
loss:0.22060	acc:0.96875
loss:0.19452	acc:0.93750
loss:0.24615	acc:0.90625
loss:0.28249	acc:0.93750
loss:0.09747	acc:0.98438
loss:0.13903	acc:0.96875
loss:0.19606	acc:0.92188
loss:0.06221	acc:0.98438
loss:0.24209	acc:0.92188
loss:0.14922	acc:0.95312
loss:0.18266	acc:0.95312
loss:0.18010	acc:0.93750
loss:0.18483	acc:0.96875
loss:0.11312	acc:0.98438
loss:0.13887	acc:0.95312
loss:0.31176	acc:0.89062
loss:0.18163	acc:0.93750
loss:0.08486	acc:0.98438
loss:0.16347	acc:0.93750
loss:0.05822	acc:1.00000
loss:0.29235	acc:0.92188
loss:0.18129	acc:0.96875
loss:0.45740	acc:0.92188
loss:0.18626	acc:0.96875
loss:0.23226	acc:0.90625
loss:0.26856	acc:0.93750
loss:0.27806	acc:0.90625
loss:0.38523	acc:0.89062
loss:0.32276	acc:0.87500
loss:0.20932	acc:0.92188
loss:0.28150	acc:0.85938
loss:0.33946	acc:0.92188
loss:0.18500	acc:0.93750
loss:0.19951	acc:0.90625
loss:0.08960	acc:0.98438
loss:0.21694	acc:0.93750
loss:0.26588	acc:0.90625
loss:0.09444	acc:0.96875
loss:0.19150	acc:0.93750
loss:0.16699	acc:0.98438
loss:0.22975	acc:0.95312
loss:0.14448	acc:0.96875
loss:0.34983	acc:0.90625
loss:0.19037	acc:0.95312
loss:0.14290	acc:0.95312
loss:0.17222	acc:0.95312
loss:0.19792	acc:0.95312
loss:0.15288	acc:0.95312
loss:0.11856	acc:0.98438
loss:0.17556	acc:0.95312
loss:0.17508	acc:0.93750
loss:0.26778	acc:0.89062
loss:0.24382	acc:0.92188
loss:0.16512	acc:0.96875
loss:0.12264	acc:0.96875
loss:0.23912	acc:0.95312
loss:0.15832	acc:0.96875
loss:0.30289	acc:0.92188
loss:0.36022	acc:0.89062
loss:0.06129	acc:1.00000
loss:0.08585	acc:1.00000
loss:0.12552	acc:0.98438
loss:0.11478	acc:0.95312
loss:0.16231	acc:0.96875
loss:0.28971	acc:0.92188
loss:0.17992	acc:0.96875
loss:0.15360	acc:0.96875
loss:0.14996	acc:0.96875
loss:0.12951	acc:0.96875
loss:0.14277	acc:0.95312
loss:0.17858	acc:0.95312
loss:0.18334	acc:0.95312
loss:0.12234	acc:0.96875
loss:0.11103	acc:0.96875
loss:0.17434	acc:0.95312
loss:0.15722	acc:0.93750
loss:0.10310	acc:0.96875
loss:0.08445	acc:0.98438
loss:0.14952	acc:0.93750
loss:0.26554	acc:0.92188
loss:0.18645	acc:0.96875
loss:0.21273	acc:0.92188
loss:0.07908	acc:0.96875
loss:0.14561	acc:0.96875
loss:0.09148	acc:1.00000
loss:0.12898	acc:0.98438
loss:0.20631	acc:0.96875
loss:0.19616	acc:0.93750
loss:0.11032	acc:0.96875
loss:0.12006	acc:0.98438
loss:0.26047	acc:0.90625
loss:0.14672	acc:0.95312
loss:0.20060	acc:0.95312
loss:0.13894	acc:0.96875
loss:0.23488	acc:0.92188
loss:0.11752	acc:0.95312
loss:0.10076	acc:0.98438
loss:0.16476	acc:0.95312
loss:0.11161	acc:0.95312
loss:0.24382	acc:0.95312
loss:0.10563	acc:0.96875
loss:0.24099	acc:0.89062
loss:0.25432	acc:0.92188
loss:0.26607	acc:0.89062
loss:0.18009	acc:0.93750
loss:0.09952	acc:0.98438
loss:0.28454	acc:0.89062
loss:0.08948	acc:0.95312
loss:0.21783	acc:0.92188
loss:0.15621	acc:0.95312
loss:0.25727	acc:0.92188
loss:0.24574	acc:0.90625
loss:0.09447	acc:0.96875
loss:0.09729	acc:0.96875
loss:0.08467	acc:0.98438
loss:0.13028	acc:0.95312
loss:0.21107	acc:0.92188
loss:0.07938	acc:0.98438
loss:0.08832	acc:0.98438
loss:0.13598	acc:0.93750
loss:0.10205	acc:0.95312
loss:0.16719	acc:0.95312
loss:0.25444	acc:0.93750
loss:0.07656	acc:1.00000
loss:0.11371	acc:0.95312
loss:0.18071	acc:0.96875
loss:0.18862	acc:0.92188
loss:0.10793	acc:0.98438
loss:0.27224	acc:0.93750
loss:0.12655	acc:0.96875
loss:0.22090	acc:0.92188
loss:0.13526	acc:0.95312
loss:0.14960	acc:0.96875
loss:0.07072	acc:0.98438
loss:0.24943	acc:0.90625
loss:0.10602	acc:0.96875
loss:0.13111	acc:0.96875
loss:0.24879	acc:0.92188
loss:0.30222	acc:0.89062
loss:0.16071	acc:0.95312
loss:0.19176	acc:0.93750
loss:0.04011	acc:1.00000
loss:0.10362	acc:0.96875
loss:0.15975	acc:0.93750
loss:0.14810	acc:0.96875
loss:0.11138	acc:0.96875
loss:0.16518	acc:0.93750
loss:0.09359	acc:0.98438
loss:0.14622	acc:0.96875
loss:0.12750	acc:0.96875
loss:0.11546	acc:0.96875
loss:0.20611	acc:0.95312
loss:0.34404	acc:0.92188
loss:0.10114	acc:0.98438
loss:0.16506	acc:0.95312
loss:0.17545	acc:0.95312
loss:0.26379	acc:0.92188
loss:0.21496	acc:0.90625
loss:0.26072	acc:0.93750
loss:0.07525	acc:0.98438
loss:0.20520	acc:0.93750
loss:0.07231	acc:0.98438
loss:0.08859	acc:0.96875
loss:0.16149	acc:0.95312
loss:0.15149	acc:0.95312
loss:0.18882	acc:0.93750
loss:0.10414	acc:0.96875
loss:0.19076	acc:0.98438
loss:0.14612	acc:0.95312
loss:0.11762	acc:0.96875
loss:0.11453	acc:0.96875
loss:0.15372	acc:0.93750
loss:0.16458	acc:0.95312
loss:0.05676	acc:0.98438
loss:0.20118	acc:0.92188
loss:0.11675	acc:0.96875
loss:0.08715	acc:0.98438
loss:0.08951	acc:0.98438
loss:0.17705	acc:0.93750
loss:0.22871	acc:0.95312
loss:0.14870	acc:0.96875
loss:0.16718	acc:0.92188
loss:0.15694	acc:0.95312
loss:0.12892	acc:0.98438
loss:0.16018	acc:0.96875
loss:0.10527	acc:0.96875
loss:0.06917	acc:1.00000
loss:0.24590	acc:0.92188
loss:0.07912	acc:1.00000
loss:0.12666	acc:0.96875
loss:0.17376	acc:0.95312
loss:0.14857	acc:0.92188
loss:0.33249	acc:0.90625
loss:0.11874	acc:0.98438
loss:0.04404	acc:1.00000
loss:0.10323	acc:0.98438
loss:0.26786	acc:0.90625
loss:0.25149	acc:0.92188
loss:0.16042	acc:0.95312
loss:0.27602	acc:0.92188
loss:0.16729	acc:0.98438
loss:0.13113	acc:0.98438
loss:0.25255	acc:0.95312
loss:0.17137	acc:0.93750
loss:0.18273	acc:0.96875
loss:0.06922	acc:1.00000
loss:0.24624	acc:0.90625
loss:0.18232	acc:0.90625
loss:0.07420	acc:0.98438
loss:0.15953	acc:0.98438
loss:0.07206	acc:0.96875
loss:0.18828	acc:0.93750
loss:0.19701	acc:0.96875
loss:0.15299	acc:0.93750
loss:0.07865	acc:0.98438
loss:0.20735	acc:0.93750
loss:0.17156	acc:0.93750
loss:0.07509	acc:1.00000
loss:0.14986	acc:0.95312
loss:0.13248	acc:0.96875
loss:0.14388	acc:0.95312
loss:0.05571	acc:1.00000
loss:0.09408	acc:0.96875
loss:0.24787	acc:0.93750
loss:0.19877	acc:0.92188
loss:0.08191	acc:0.98438
loss:0.31069	acc:0.92188
loss:0.12103	acc:0.96875
loss:0.24595	acc:0.90625
loss:0.14003	acc:0.98438
loss:0.08255	acc:0.98438
loss:0.17961	acc:0.93750
loss:0.13665	acc:0.93750
loss:0.23543	acc:0.92188
loss:0.25226	acc:0.90625
loss:0.12286	acc:0.95312
loss:0.17261	acc:0.95312
loss:0.13196	acc:0.95312
loss:0.22068	acc:0.90625
loss:0.38338	acc:0.90625
loss:0.09904	acc:0.96875
loss:0.06626	acc:0.98438
loss:0.15176	acc:0.96875
loss:0.14975	acc:0.95312
loss:0.29266	acc:0.93750
loss:0.08018	acc:0.98438
loss:0.22623	acc:0.93750
loss:0.09631	acc:0.96875
loss:0.13066	acc:0.96875
loss:0.13228	acc:0.95312
loss:0.14615	acc:0.92188
loss:0.14123	acc:0.96875
loss:0.26964	acc:0.89062
loss:0.08345	acc:0.96875
loss:0.18795	acc:0.92188
loss:0.14019	acc:0.96875
loss:0.18857	acc:0.95312
loss:0.20577	acc:0.92188
loss:0.11165	acc:0.98438
loss:0.11152	acc:0.95312
loss:0.10814	acc:0.96875
loss:0.10788	acc:0.98438
loss:0.10583	acc:0.95312
loss:0.08795	acc:0.98438
loss:0.16285	acc:0.95312
loss:0.13543	acc:0.96875
loss:0.14316	acc:0.95312
loss:0.20426	acc:0.96875
loss:0.06097	acc:1.00000
loss:0.08930	acc:0.98438
loss:0.14915	acc:0.93750
loss:0.05893	acc:1.00000
loss:0.03859	acc:1.00000
loss:0.10800	acc:0.96875
loss:0.26009	acc:0.90625
loss:0.07896	acc:1.00000
loss:0.06844	acc:0.98438
loss:0.10040	acc:0.98438
loss:0.08292	acc:0.98438
loss:0.37432	acc:0.87500
loss:0.16169	acc:0.95312
loss:0.22969	acc:0.90625
loss:0.15961	acc:0.93750
loss:0.13284	acc:0.98438
loss:0.09883	acc:0.98438
loss:0.14311	acc:0.96875
loss:0.18110	acc:0.93750
loss:0.20608	acc:0.93750
loss:0.08434	acc:0.98438
loss:0.13159	acc:0.95312
loss:0.10497	acc:0.95312
loss:0.13571	acc:0.96875
loss:0.15379	acc:0.96875
loss:0.12703	acc:0.96875
loss:0.22704	acc:0.90625
loss:0.11205	acc:0.95312
loss:0.13118	acc:0.95312
loss:0.11921	acc:0.96875
loss:0.13128	acc:0.96875
loss:0.16961	acc:0.98438
loss:0.26039	acc:0.92188
loss:0.10821	acc:0.96875
loss:0.15362	acc:0.92188
loss:0.20221	acc:0.93750
loss:0.25353	acc:0.93750
loss:0.17650	acc:0.92188
loss:0.17608	acc:0.92188
loss:0.10742	acc:0.96875
loss:0.14668	acc:0.95312
loss:0.11982	acc:0.95312
loss:0.14160	acc:0.96875
loss:0.14920	acc:0.96875
loss:0.07769	acc:0.98438
loss:0.07435	acc:0.98438
loss:0.12596	acc:0.96875
loss:0.06209	acc:0.98438
loss:0.23671	acc:0.93750
loss:0.32712	acc:0.90625
loss:0.16227	acc:0.93750
loss:0.07742	acc:0.98438
loss:0.11433	acc:0.96875
loss:0.14617	acc:0.95312
loss:0.10022	acc:0.98438
loss:0.11036	acc:0.96875
loss:0.22448	acc:0.95312
loss:0.17712	acc:0.92188
loss:0.14171	acc:0.95312
loss:0.14837	acc:0.95312
loss:0.05253	acc:1.00000
loss:0.05911	acc:1.00000
loss:0.16225	acc:0.95312
loss:0.22476	acc:0.92188
loss:0.18999	acc:0.93750
loss:0.17341	acc:0.90625
loss:0.09191	acc:0.98438
loss:0.18701	acc:0.95312
loss:0.11330	acc:0.96875
loss:0.17069	acc:0.96875
loss:0.07430	acc:0.98438
loss:0.14603	acc:0.96875
loss:0.04871	acc:1.00000
loss:0.28533	acc:0.92188
loss:0.18542	acc:0.93750
loss:0.23417	acc:0.90625
loss:0.14049	acc:0.93750
loss:0.25800	acc:0.90625
loss:0.11138	acc:0.98438
loss:0.07213	acc:0.98438
loss:0.07436	acc:0.98438
loss:0.14025	acc:0.95312
loss:0.07013	acc:0.98438
loss:0.17246	acc:0.95312
loss:0.18267	acc:0.95312
loss:0.22772	acc:0.93750
loss:0.06264	acc:0.98438
loss:0.02746	acc:1.00000
loss:0.24368	acc:0.89062
loss:0.05390	acc:1.00000
loss:0.19132	acc:0.95312
loss:0.19019	acc:0.96875
loss:0.13184	acc:0.96875
loss:0.07464	acc:0.98438
loss:0.09115	acc:0.98438
loss:0.09529	acc:0.96875
loss:0.08990	acc:0.96875
loss:0.07885	acc:0.96875
loss:0.09846	acc:0.95312
loss:0.14663	acc:0.95312
loss:0.22904	acc:0.93750
loss:0.11402	acc:0.96875
loss:0.19972	acc:0.93750
loss:0.20436	acc:0.96875
loss:0.32500	acc:0.87500
loss:0.11743	acc:0.95312
loss:0.17370	acc:0.92188
loss:0.08371	acc:0.98438
loss:0.08599	acc:0.96875
loss:0.12539	acc:0.95312
loss:0.12278	acc:0.98438
loss:0.10975	acc:0.96875
loss:0.10831	acc:0.98438
loss:0.06411	acc:0.96875
loss:0.23990	acc:0.93750
loss:0.08271	acc:0.98438
loss:0.10399	acc:0.96875
loss:0.06397	acc:1.00000
loss:0.08254	acc:0.96875
loss:0.09192	acc:0.98438
loss:0.17281	acc:0.95312
loss:0.06051	acc:1.00000
loss:0.14834	acc:0.98438
loss:0.09404	acc:0.96875
loss:0.06473	acc:0.98438
loss:0.15557	acc:0.98438
loss:0.14979	acc:0.96875
loss:0.19210	acc:0.95312
loss:0.13649	acc:0.93750
loss:0.13135	acc:0.93750
loss:0.21046	acc:0.90625
loss:0.10882	acc:0.95312
loss:0.17906	acc:0.93750
loss:0.19661	acc:0.93750
loss:0.10849	acc:0.95312
loss:0.07449	acc:1.00000
loss:0.24108	acc:0.90625
loss:0.06498	acc:0.98438
loss:0.17111	acc:0.95312
loss:0.19536	acc:0.93750
loss:0.08626	acc:0.95312
loss:0.12976	acc:0.98438
loss:0.04587	acc:1.00000
loss:0.31982	acc:0.90625
loss:0.09512	acc:0.98438
loss:0.23773	acc:0.90625
loss:0.20744	acc:0.95312
loss:0.23242	acc:0.95312
loss:0.15539	acc:0.96875
loss:0.14239	acc:0.93750
loss:0.08568	acc:0.96875
loss:0.14298	acc:0.95312
loss:0.11364	acc:0.95312
loss:0.15121	acc:0.98438
loss:0.10850	acc:0.93750
loss:0.13976	acc:0.95312
loss:0.19362	acc:0.93750
loss:0.08078	acc:0.98438
loss:0.05183	acc:0.98438
loss:0.15727	acc:0.96875
loss:0.07899	acc:0.96875
loss:0.05916	acc:0.98438
loss:0.09187	acc:0.96875
loss:0.07978	acc:0.98438
loss:0.07401	acc:0.98438
loss:0.20396	acc:0.95312
loss:0.04685	acc:1.00000
loss:0.17921	acc:0.96875
loss:0.05774	acc:1.00000
loss:0.09466	acc:0.96875
loss:0.21201	acc:0.95312
loss:0.12821	acc:0.95312
loss:0.08764	acc:0.98438
loss:0.11809	acc:0.96875
loss:0.16199	acc:0.96875
loss:0.11255	acc:0.98438
loss:0.14778	acc:0.95312
loss:0.08199	acc:0.96875
loss:0.06982	acc:0.96875
loss:0.08677	acc:0.96875
loss:0.14956	acc:0.95312
loss:0.15698	acc:0.93750
loss:0.12428	acc:0.96875
loss:0.19111	acc:0.93750
loss:0.04491	acc:1.00000
loss:0.19884	acc:0.90625
loss:0.24804	acc:0.93750
loss:0.13953	acc:0.96875
loss:0.24288	acc:0.90625
loss:0.15888	acc:0.95312
loss:0.17232	acc:0.95312
loss:0.09322	acc:0.96875
loss:0.14902	acc:0.96875
loss:0.11222	acc:0.98438
loss:0.08734	acc:0.96875
loss:0.09247	acc:0.96875
loss:0.11926	acc:0.95312
loss:0.04608	acc:0.98438
loss:0.07352	acc:0.96875
loss:0.05912	acc:0.98438
loss:0.06844	acc:1.00000
loss:0.09879	acc:1.00000
loss:0.11328	acc:0.96875
loss:0.14035	acc:0.96875
loss:0.18120	acc:0.95312
loss:0.09994	acc:0.95312
loss:0.14785	acc:0.95312
loss:0.07812	acc:0.96875
loss:0.15447	acc:0.95312
loss:0.18226	acc:0.93750
loss:0.14916	acc:0.96875
loss:0.19192	acc:0.93750
loss:0.10521	acc:0.98438
loss:0.19849	acc:0.96875
loss:0.16188	acc:0.95312
loss:0.05045	acc:1.00000
loss:0.23136	acc:0.95312
loss:0.07092	acc:0.98438
loss:0.12442	acc:0.96875
loss:0.16773	acc:0.95312
loss:0.16274	acc:0.95312
loss:0.28012	acc:0.92188
loss:0.08197	acc:0.98438
loss:0.05306	acc:0.98438
loss:0.07457	acc:0.98438
loss:0.17700	acc:0.93750
loss:0.10168	acc:0.98438
loss:0.20213	acc:0.92188
loss:0.20351	acc:0.92188
loss:0.10901	acc:0.96875
loss:0.04361	acc:1.00000
loss:0.16461	acc:0.96875
loss:0.20337	acc:0.92188
loss:0.11774	acc:0.98438
loss:0.19589	acc:0.92188
loss:0.14479	acc:0.96875
loss:0.16974	acc:0.93750
loss:0.09703	acc:0.96875
loss:0.06660	acc:0.98438
loss:0.08469	acc:1.00000
loss:0.09053	acc:0.96875
loss:0.15431	acc:0.90625
loss:0.10061	acc:0.98438
loss:0.10092	acc:0.98438
loss:0.17791	acc:0.96875
loss:0.13418	acc:0.95312
loss:0.04396	acc:0.98438
loss:0.16642	acc:0.96875
loss:0.05089	acc:1.00000
loss:0.17533	acc:0.93750
loss:0.16705	acc:0.95312
loss:0.10573	acc:0.95312
loss:0.15486	acc:0.96875
loss:0.12395	acc:0.95312
loss:0.12970	acc:0.96875
loss:0.18159	acc:0.95312
loss:0.05818	acc:1.00000
loss:0.10698	acc:0.93750
loss:0.03983	acc:1.00000
loss:0.18349	acc:0.96875
loss:0.08221	acc:0.98438
loss:0.12925	acc:0.96875
loss:0.11189	acc:0.96875
loss:0.17911	acc:0.96875
loss:0.27267	acc:0.95312
loss:0.06801	acc:0.98438
loss:0.03988	acc:1.00000
loss:0.03477	acc:1.00000
loss:0.10255	acc:0.96875
loss:0.09664	acc:0.98438
loss:0.11505	acc:0.96875
loss:0.09369	acc:0.98438
loss:0.20973	acc:0.95312
loss:0.16301	acc:0.93750
loss:0.08222	acc:0.96875
loss:0.10991	acc:0.95312
loss:0.09927	acc:0.96875
loss:0.04367	acc:1.00000
loss:0.06952	acc:0.96875
loss:0.13133	acc:0.96875
loss:0.13914	acc:0.96875
loss:0.27232	acc:0.96875
loss:0.08188	acc:0.96875
loss:0.20095	acc:0.96875
loss:0.08710	acc:0.96875
loss:0.05236	acc:1.00000
loss:0.11694	acc:0.96875
loss:0.11446	acc:0.96875
loss:0.14520	acc:0.96875
loss:0.18993	acc:0.90625
loss:0.17676	acc:0.96875
loss:0.10977	acc:0.95312
loss:0.12015	acc:0.98438
loss:0.15122	acc:0.96875
loss:0.19218	acc:0.95312
loss:0.14699	acc:0.93750
loss:0.24232	acc:0.90625
loss:0.21631	acc:0.90625
loss:0.06756	acc:0.98438
loss:0.17421	acc:0.95312
loss:0.12664	acc:0.95312
loss:0.15569	acc:0.95312
loss:0.12884	acc:0.96875
loss:0.11202	acc:0.96875
loss:0.05302	acc:0.98438
loss:0.20850	acc:0.93750
loss:0.12652	acc:0.96875
loss:0.23816	acc:0.92188
loss:0.06526	acc:0.98438
loss:0.05521	acc:1.00000
loss:0.21977	acc:0.93750
loss:0.04914	acc:1.00000
loss:0.17506	acc:0.92188
loss:0.06785	acc:0.98438
loss:0.15795	acc:0.95312
loss:0.13381	acc:0.93750
loss:0.07008	acc:0.98438
loss:0.19596	acc:0.95312
loss:0.19804	acc:0.92188
loss:0.22154	acc:0.93750
loss:0.14441	acc:0.93750
loss:0.08902	acc:0.98438
loss:0.13566	acc:0.96875
loss:0.08600	acc:0.96875
loss:0.19083	acc:0.98438
loss:0.07337	acc:0.98438
loss:0.16578	acc:0.93750
loss:0.15782	acc:0.95312
loss:0.10420	acc:0.98438
loss:0.07731	acc:0.98438
loss:0.12828	acc:0.96875
loss:0.06208	acc:1.00000
loss:0.25428	acc:0.92188
loss:0.03976	acc:1.00000
loss:0.19712	acc:0.90625
loss:0.15139	acc:0.92188
loss:0.06922	acc:0.98438
loss:0.12168	acc:0.98438
loss:0.04660	acc:0.98438
loss:0.12327	acc:0.96875
loss:0.22629	acc:0.95312
loss:0.12471	acc:0.95312
loss:0.04671	acc:1.00000
loss:0.14425	acc:0.96875
loss:0.19596	acc:0.93750
loss:0.03496	acc:1.00000

lossの動きに注目してほしい。
lossの動きが明らかにこれまでの浮動小数点を使った学習と異なっているように見えるのだ。

BitNetでMNISTを学習させるとこうなる。

学習の際初期段階においては、BitNetのlossの方が通常のLinearよりも下がるのが遅い。

もう少し最初期段階に注目してみる。

普通のLinearに比べるとBitNetはやや遅い収束になっているが、これくらいなら誤差の範囲と言えそうだ。

学習の終盤に注目してみる。

学習の終盤では、ノーマルのLinearの方がlossのブレが激しい。
しかしこんなに簡単にLinearを入れ替えて使えるとは思わなかった。

それを踏まえて、なぜBItNetによるLLMがうまく学習できないのか考えてみると、単純にそもそもLLMをゼロから学習(事前学習)することは、誰にとっても困難であるから、と言うのがもしかすると正解なのかもしれない。

そして学習過程においては、BitNetは他の方法よりも時間がかかりそうだ。
また、学習上の計算においても、BItNetは従来の手法に比べて色んな処理が間に入り、計算量が多くなるため、現行ハードウェア上で得られる恩恵は限定的だ。推論においては計算量が激減するが、GPUを使う意味はなくなってくる。それがBItNet論文に書かれた「新しいハードウェアが必要」と言う提言の意味なのかもしれない。

import torch
from torch import nn
import torch.nn.functional as F

class BitRMSNorm(nn.Module): #はちさんによる実装
    def __init__(self, hidden_size, eps=1e-6):
        """
        BitRMSNorm is equivalent to LlamaRMSNorm and T5LayerNorm
        refers: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L76C1-L90C59
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

def activation_quant(x):
   scale = 127.0 / x.abs().max(dim=-1,keepdim=True).values.clamp_(min=1e-5)
   y = (x*scale).round().clamp_(-128,127)/scale
   return y

def weight_quant(w):
   scale = 1.0/w.abs().mean().clamp_(min=1e-5)
   u = (w*scale).round().clamp_(-1,1) / scale
   return u

class BitLinearOriginal(nn.Linear): #論文に基づく実装
   def __init__(self,in_features,out_features,bias=False,flg_before_linear=True,bits=8):
       super(BitLinearOriginal, self).__init__(in_features, out_features, bias)
       self.layernorm = nn.LayerNorm(in_features)
       self.RMSNorm = BitRMSNorm(in_features)
       self.bits = bits
   def forward(self,x):
       w=self.weight
       x_norm = self.RMSNorm(x)
       x_quant = x_norm + (activation_quant(x_norm)-x_norm).detach()
       w_quant = w+(weight_quant(w)-w).detach()
       y = F.linear(x_quant,w_quant)
       return y

完全なソースコードはGitHubに置いた。