WSL2でMatMul-Free LMを試してみる
「LLM 1Bパラメータで行列計算を完全に排除できた。メモリ消費量を学習時10倍、推論時61%Max削減」らしいMatMul-Free LMを試してみます。
使用するPCはドスパラさんの「GALLERIA UL9C-R49」。スペックは
・CPU: Intel® Core™ i9-13900HX Processor
・Mem: 64 GB
・GPU: NVIDIA® GeForce RTX™ 4090 Laptop GPU(16GB)・GPU: NVIDIA® GeForce RTX™ 4090 (24GB)
・OS: Ubuntu22.04 on WSL2(Windows 11)
です。
1. 準備
python3 -m venv matmulfree
cd $_
source bin/activate
パッケージのインストールです。
pip install torch wheel packaging
pip install -U git+https://github.com/ridgerchu/matmulfreellm
2. 使用するモデル
Hugging Faceに事前学習モデルが3つ提供されています。
今回は、2.7Bモデルを使用して試してみます。
3. 試してみる
pythonのプロンプトに流し込むコードはこちらです。name変数にHugging Faceのモデル名を指定します。
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import mmfreelm
from transformers import AutoModelForCausalLM, AutoTokenizer
#Change here to our open-sourced model
name = 'ridger/MMfreeLM-2.7B'
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name).cuda().half()
def q(input_prompt):
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda()
outputs = model.generate(input_ids, max_length=32, do_sample=True, top_p=0.4, temperature=0.6)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
では聞いてみましょう。
>>> q("Japan is ")
Japan is 25% of the world's population and the world's largest economy in terms of GDP and the largest economy in East Asia
ちゃんと読める推論結果になっています。内容はさておき。
VRAMの使用量は10.9GBでした。
4. 学習
論文を読むに、学習に要した時間はH100 8枚で
370M: 5h
1.3B: 84h
2.7B: 173h
らしく4090を1枚では厳しそう…。
5. まとめ
100B+のようなLLMでの学習が未検証とのことですが、日進月歩のAI界隈ですから、あっという間に「できてますが?」となるんでしょうね。
3値、すごいなぁー。
関連
X(旧Twitter)紹介いただきありがとうございます。