
UnslothのGoogle Colabノートブックで手軽にLLMの蒸留ができる
DeepSeek-R1の1.58ビット量子化で途端に注目の的になったUnsloth。
彼らは元々、LLMのGGUFなどを作ってくれる親切なハッカーというイメージだった。
しかし、今回の動的量子化など、尖った技術をソフトウェアパッケージにしていて、Phi-4やCommande-Rなど、いろいろなオープンウェイトモデルの蒸留や量子化が簡単にできるGoogle Colab Notebookを配っていた。
量子化するやつ
蒸留するやつ
これがすごく簡単なので、蒸留をお手軽に試したいという人にうってつけだと思う。何よりすごいのは、彼らが実装した動的量子化のおかげで、4ビットLoRA(QLoRA)の性能劣化が抑えられ、無料アカウントでも7Bモデルくらいなら学習できてしまうということ。
学習そのものも数分から数十分で終わってしまうので、これまでにあったような「蒸留/ファインチューニングは凄く大変」というイメージでもない。目から鱗という感じ。
いろいろなLLMを簡単に高速化できるので試してみるのも面白いかもしれない。
Unslothが自分で用意した量子化モデルの一覧はこちら
自分でもPhi-4の蒸留を試してみたが、少しハマりどころがある。
まず、xformerのバージョンやら何やらを揃えないといけない。
Pythonは3.10で、CUDA12.1でやった。GPUはA100 80GBx1を使用。
pip install pip3-autoremove
pip-autoremove torch torchvision torchaudio -y
pip install torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu121
pip install unsloth
まず、騙されたと思って上のコマンドを順に実行しておく
その後、Colabノートブックから抜き出した以下のコードを適当な名前でセーブしておく
from unsloth import FastLanguageModel # FastVisionModel for LLMs
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
"unsloth/Meta-Llama-3.1-8B-bnb-4bit", # Llama-3.1 2x faster
"unsloth/Mistral-Small-Instruct-2409", # Mistral 22b 2x faster!
"unsloth/Phi-4", # Phi-4 2x faster!
"unsloth/Phi-4-unsloth-bnb-4bit", # Phi-4 Unsloth Dynamic 4-bit Quant
"unsloth/gemma-2-9b-bnb-4bit", # Gemma 2x faster!
"unsloth/Qwen2.5-7B-Instruct-bnb-4bit" # Qwen 2.5 2x faster!
"unsloth/Llama-3.2-1B-bnb-4bit", # NEW! Llama 3.2 models
"unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
"unsloth/Llama-3.2-3B-bnb-4bit",
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
] # More models at https://docs.unsloth.ai/get-started/all-our-models
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Phi-4",
max_seq_length = max_seq_length,
load_in_4bit = load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
model = FastLanguageModel.get_peft_model(
model,
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
tokenizer,
chat_template = "phi-4",
)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
return { "text" : texts, }
from datasets import load_dataset
dataset = load_dataset("mlabonne/FineTome-100k", split = "train")
pass
from datasets import load_dataset
dataset = load_dataset("mlabonne/FineTome-100k", split = "train")
from datasets import load_dataset
dataset = load_dataset("mlabonne/FineTome-100k", split = "train")
from unsloth.chat_templates import standardize_sharegpt
dataset = standardize_sharegpt(dataset)
dataset = dataset.map(
formatting_prompts_func,
batched=True,
)
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 2,
packing = False, # Can make training 5x faster for short sequences.
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 5,
# num_train_epochs = 1, # Set this for 1 full training run.
max_steps = 30,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none", # Use this for WandB etc
),
)
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
trainer,
instruction_part="<|im_start|>user<|im_sep|>",
response_part="<|im_start|>assistant<|im_sep|>",
)
trainer_stats = trainer.train()
これを実行すると、数分で蒸留が完了する。
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))== Unsloth 2025.1.8: Fast Llama patching. Transformers: 4.47.1.
\\ /| GPU: NVIDIA A100 80GB PCIe. Max memory: 79.256 GB. Platform: Linux.
O^O/ \_/ \ Torch: 2.5.1+cu121. CUDA: 8.0. CUDA Toolkit: 12.1. Triton: 3.1.0
\ / Bfloat16 = TRUE. FA [Xformers = 0.0.29.post1. FA2 = False]
"-____-" Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:03<00:00, 1.15s/it]
Unsloth 2025.1.8 patched 40 layers with 40 QKV layers, 40 O layers and 40 MLP layers.
==((====))== Unsloth - 2x faster free finetuning | Num GPUs = 1
\\ /| Num examples = 100,000 | Num Epochs = 1
O^O/ \_/ \ Batch size per device = 2 | Gradient Accumulation steps = 4
\ / Total batch size = 8 | Total steps = 30
"-____-" Number of trainable parameters = 65,536,000
0%| | 0/30 [00:00<?, ?it/s]
{'loss': 0.7482, 'grad_norm': 0.17042014002799988, 'learning_rate': 4e-05, 'epoch': 0.0}
{'loss': 0.6629, 'grad_norm': 0.14523565769195557, 'learning_rate': 8e-05, 'epoch': 0.0}
{'loss': 1.0388, 'grad_norm': 0.3025570511817932, 'learning_rate': 0.00012, 'epoch': 0.0}
{'loss': 0.7812, 'grad_norm': 0.22732926905155182, 'learning_rate': 0.00016, 'epoch': 0.0}
{'loss': 0.7112, 'grad_norm': 0.2056725025177002, 'learning_rate': 0.0002, 'epoch': 0.0}
{'loss': 0.8405, 'grad_norm': 0.24817104637622833, 'learning_rate': 0.000192, 'epoch': 0.0}
{'loss': 0.473, 'grad_norm': 0.12781551480293274, 'learning_rate': 0.00018400000000000003, 'epoch': 0.0}
{'loss': 0.8386, 'grad_norm': 0.1284501999616623, 'learning_rate': 0.00017600000000000002, 'epoch': 0.0}
{'loss': 0.6032, 'grad_norm': 0.09723766893148422, 'learning_rate': 0.000168, 'epoch': 0.0}
{'loss': 0.528, 'grad_norm': 0.12025056034326553, 'learning_rate': 0.00016, 'epoch': 0.0}
{'loss': 0.7025, 'grad_norm': 0.10142025351524353, 'learning_rate': 0.000152, 'epoch': 0.0}
{'loss': 0.7531, 'grad_norm': 0.6547413468360901, 'learning_rate': 0.000144, 'epoch': 0.0}
{'loss': 0.7773, 'grad_norm': 0.13310837745666504, 'learning_rate': 0.00013600000000000003, 'epoch': 0.0}
{'loss': 0.4542, 'grad_norm': 0.08028502762317657, 'learning_rate': 0.00012800000000000002, 'epoch': 0.0}
{'loss': 0.6358, 'grad_norm': 0.11321067065000534, 'learning_rate': 0.00012, 'epoch': 0.0}
{'loss': 0.3977, 'grad_norm': 0.09965520352125168, 'learning_rate': 0.00011200000000000001, 'epoch': 0.0}
{'loss': 0.7854, 'grad_norm': 0.11347711086273193, 'learning_rate': 0.00010400000000000001, 'epoch': 0.0}
{'loss': 0.6585, 'grad_norm': 0.1086692288517952, 'learning_rate': 9.6e-05, 'epoch': 0.0}
{'loss': 0.6072, 'grad_norm': 0.1224820464849472, 'learning_rate': 8.800000000000001e-05, 'epoch': 0.0}
{'loss': 0.7727, 'grad_norm': 0.1163286566734314, 'learning_rate': 8e-05, 'epoch': 0.0}
{'loss': 0.734, 'grad_norm': 0.07792412489652634, 'learning_rate': 7.2e-05, 'epoch': 0.0}
{'loss': 0.5374, 'grad_norm': 0.07801351696252823, 'learning_rate': 6.400000000000001e-05, 'epoch': 0.0}
{'loss': 0.9119, 'grad_norm': 0.14749355614185333, 'learning_rate': 5.6000000000000006e-05, 'epoch': 0.0}
{'loss': 0.7244, 'grad_norm': 0.17585600912570953, 'learning_rate': 4.8e-05, 'epoch': 0.0}
{'loss': 0.5003, 'grad_norm': 0.10795021057128906, 'learning_rate': 4e-05, 'epoch': 0.0}
{'loss': 0.5919, 'grad_norm': 0.08922930806875229, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.0}
{'loss': 0.6601, 'grad_norm': 0.10317248851060867, 'learning_rate': 2.4e-05, 'epoch': 0.0}
{'loss': 0.6584, 'grad_norm': 0.17016059160232544, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.0}
{'loss': 0.6996, 'grad_norm': 0.10287228971719742, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.0}
{'loss': 0.7713, 'grad_norm': 0.08957896381616592, 'learning_rate': 0.0, 'epoch': 0.0}
{'train_runtime': 138.8609, 'train_samples_per_second': 1.728, 'train_steps_per_second': 0.216, 'train_loss': 0.6853088637193044, 'epoch': 0.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [02:18<00:00, 4.63s/it]
手軽すぎてびっくりした。まあでもlossがそんなに下がってないのでこれでは実用性はないのかもしれないが。