見出し画像

WSL2でMatMul-Free LMを試してみる

「LLM 1Bパラメータで行列計算を完全に排除できた。メモリ消費量を学習時10倍、推論時61%Max削減」らしいMatMul-Free LMを試してみます。

使用するPCはドスパラさんの「GALLERIA UL9C-R49」。スペックは
・CPU: Intel® Core™ i9-13900HX Processor
・Mem: 64 GB
・GPU: NVIDIA® GeForce RTX™ 4090 Laptop GPU(16GB)
・GPU: NVIDIA® GeForce RTX™ 4090 (24GB)
・OS: Ubuntu22.04 on WSL2(Windows 11)
です。


1. 準備

python3 -m venv matmulfree
cd $_
source bin/activate

パッケージのインストールです。

pip install torch wheel packaging
pip install -U git+https://github.com/ridgerchu/matmulfreellm

2. 使用するモデル

Hugging Faceに事前学習モデルが3つ提供されています。

今回は、2.7Bモデルを使用して試してみます。

3. 試してみる

pythonのプロンプトに流し込むコードはこちらです。name変数にHugging Faceのモデル名を指定します。

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import mmfreelm
from transformers import AutoModelForCausalLM, AutoTokenizer
#Change here to our open-sourced model
name = 'ridger/MMfreeLM-2.7B'
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name).cuda().half()

def q(input_prompt):
    input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda()
    outputs = model.generate(input_ids, max_length=32,  do_sample=True, top_p=0.4, temperature=0.6)
    print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])

では聞いてみましょう。

>>> q("Japan is ")
Japan is 25% of the world's population and the world's largest economy in terms of GDP and the largest economy in East Asia

ちゃんと読める推論結果になっています。内容はさておき。

VRAMの使用量は10.9GBでした。

4. 学習

論文を読むに、学習に要した時間はH100 8枚で

  • 370M: 5h

  • 1.3B: 84h

  • 2.7B: 173h

らしく4090を1枚では厳しそう…。

5. まとめ

100B+のようなLLMでの学習が未検証とのことですが、日進月歩のAI界隈ですから、あっという間に「できてますが?」となるんでしょうね。

3値、すごいなぁー。

関連

X(旧Twitter)紹介いただきありがとうございます。

いいなと思ったら応援しよう!