見出し画像

UnslothのGoogle Colabノートブックで手軽にLLMの蒸留ができる

DeepSeek-R1の1.58ビット量子化で途端に注目の的になったUnsloth。
彼らは元々、LLMのGGUFなどを作ってくれる親切なハッカーというイメージだった。

しかし、今回の動的量子化など、尖った技術をソフトウェアパッケージにしていて、Phi-4やCommande-Rなど、いろいろなオープンウェイトモデルの蒸留や量子化が簡単にできるGoogle Colab Notebookを配っていた。

量子化するやつ

蒸留するやつ

これがすごく簡単なので、蒸留をお手軽に試したいという人にうってつけだと思う。何よりすごいのは、彼らが実装した動的量子化のおかげで、4ビットLoRA(QLoRA)の性能劣化が抑えられ、無料アカウントでも7Bモデルくらいなら学習できてしまうということ。

学習そのものも数分から数十分で終わってしまうので、これまでにあったような「蒸留/ファインチューニングは凄く大変」というイメージでもない。目から鱗という感じ。

いろいろなLLMを簡単に高速化できるので試してみるのも面白いかもしれない。

https://unsloth.ai/

Unslothが自分で用意した量子化モデルの一覧はこちら

自分でもPhi-4の蒸留を試してみたが、少しハマりどころがある。
まず、xformerのバージョンやら何やらを揃えないといけない。

Pythonは3.10で、CUDA12.1でやった。GPUはA100 80GBx1を使用。

pip install pip3-autoremove
pip-autoremove torch torchvision torchaudio -y
pip install torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu121
pip install unsloth

まず、騙されたと思って上のコマンドを順に実行しておく
その後、Colabノートブックから抜き出した以下のコードを適当な名前でセーブしておく

from unsloth import FastLanguageModel  # FastVisionModel for LLMs                                                                                                               
import torch                                                                                                                                                                    
max_seq_length = 2048  # Choose any! We auto support RoPE Scaling internally!                                                                                                   
load_in_4bit = True  # Use 4bit quantization to reduce memory usage. Can be False.                                                                                              
                                                                                                                                                                                
# 4bit pre quantized models we support for 4x faster downloading + no OOMs.                                                                                                     
fourbit_models = [                                                                                                                                                              
    "unsloth/Meta-Llama-3.1-8B-bnb-4bit",  # Llama-3.1 2x faster                                                                                                                
    "unsloth/Mistral-Small-Instruct-2409",  # Mistral 22b 2x faster!                                                                                                            
    "unsloth/Phi-4",  # Phi-4 2x faster!                                                                                                                                        
    "unsloth/Phi-4-unsloth-bnb-4bit",  # Phi-4 Unsloth Dynamic 4-bit Quant                                                                                                      
    "unsloth/gemma-2-9b-bnb-4bit",  # Gemma 2x faster!                                                                                                                          
    "unsloth/Qwen2.5-7B-Instruct-bnb-4bit"  # Qwen 2.5 2x faster!                                                                                                               
    "unsloth/Llama-3.2-1B-bnb-4bit",  # NEW! Llama 3.2 models                                                                                                                   
    "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",                                                                                                                                   
    "unsloth/Llama-3.2-3B-bnb-4bit",                                                                                                                                            
    "unsloth/Llama-3.2-3B-Instruct-bnb-4bit",                                                                                                                                   
]  # More models at https://docs.unsloth.ai/get-started/all-our-models                                                                                                          
                                                                                                                                                                                
model, tokenizer = FastLanguageModel.from_pretrained(                                                                                                                           
    model_name = "unsloth/Phi-4",                                                                                                                                               
    max_seq_length = max_seq_length,                                                                                                                                            
    load_in_4bit = load_in_4bit,                                                                                                                                                
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf                                                                                           
)                                                                                                                                                                               
                                                                                                                                                                                
model = FastLanguageModel.get_peft_model(                                                                                                                                       
    model,                                                                                                                                                                      
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128                                                                                                              
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",                                                                                                                   
                      "gate_proj", "up_proj", "down_proj",],                                                                                                                    
    lora_alpha = 16,                                                                                                                                                            
    lora_dropout = 0, # Supports any, but = 0 is optimized                                                                                                                      
    bias = "none",    # Supports any, but = "none" is optimized                                                                                                                 
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!                                                                                                           
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context                                                                                           
    random_state = 3407,                                                                                                                                                        
    use_rslora = False,  # We support rank stabilized LoRA                                                                                                                      
    loftq_config = None, # And LoftQ                                                                                                                                            
)                                                                                                                                                                               
                                                                                                                                                                                
from unsloth.chat_templates import get_chat_template                                                                                                                            
                                                                                                                                                                                
tokenizer = get_chat_template(                                                                                                                                                  
    tokenizer,                                                                                                                                                                  
    chat_template = "phi-4",
)

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [
        tokenizer.apply_chat_template(
            convo, tokenize = False, add_generation_prompt = False
        )
        for convo in convos
    ]
    return { "text" : texts, }


from datasets import load_dataset
dataset = load_dataset("mlabonne/FineTome-100k", split = "train")
pass

from datasets import load_dataset
dataset = load_dataset("mlabonne/FineTome-100k", split = "train")


from datasets import load_dataset
dataset = load_dataset("mlabonne/FineTome-100k", split = "train")

from unsloth.chat_templates import standardize_sharegpt

dataset = standardize_sharegpt(dataset)
dataset = dataset.map(
    formatting_prompts_func,
    batched=True,
)

from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps = 30,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none", # Use this for WandB etc
    ),
)

from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part="<|im_start|>user<|im_sep|>",
    response_part="<|im_start|>assistant<|im_sep|>",
)

trainer_stats = trainer.train()

これを実行すると、数分で蒸留が完了する。

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.                                                                                                       
🦥 Unsloth Zoo will now patch everything to make training faster!                                                                                                               
==((====))==  Unsloth 2025.1.8: Fast Llama patching. Transformers: 4.47.1.                                                                                                      
   \\   /|    GPU: NVIDIA A100 80GB PCIe. Max memory: 79.256 GB. Platform: Linux.                                                                                               
O^O/ \_/ \    Torch: 2.5.1+cu121. CUDA: 8.0. CUDA Toolkit: 12.1. Triton: 3.1.0                                                                                                  
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post1. FA2 = False]                                                                                                        
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth                                                                                                          
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!                                                                                           
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.15s/it]
Unsloth 2025.1.8 patched 40 layers with 40 QKV layers, 40 O layers and 40 MLP layers.                                                                                           
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1                                                                                                                
   \\   /|    Num examples = 100,000 | Num Epochs = 1                                                                                                                           
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4                                                                                                       
\        /    Total batch size = 8 | Total steps = 30                                                                                                                           
 "-____-"     Number of trainable parameters = 65,536,000                                                                                                                       
  0%|                                                                                                                                                    | 0/30 [00:00<?, ?it/s]
{'loss': 0.7482, 'grad_norm': 0.17042014002799988, 'learning_rate': 4e-05, 'epoch': 0.0}                                                                                        
{'loss': 0.6629, 'grad_norm': 0.14523565769195557, 'learning_rate': 8e-05, 'epoch': 0.0}                                                                                        
{'loss': 1.0388, 'grad_norm': 0.3025570511817932, 'learning_rate': 0.00012, 'epoch': 0.0}                                                                                       
{'loss': 0.7812, 'grad_norm': 0.22732926905155182, 'learning_rate': 0.00016, 'epoch': 0.0}                                                                                      
{'loss': 0.7112, 'grad_norm': 0.2056725025177002, 'learning_rate': 0.0002, 'epoch': 0.0}                                                                                        
{'loss': 0.8405, 'grad_norm': 0.24817104637622833, 'learning_rate': 0.000192, 'epoch': 0.0}                                                                                     
{'loss': 0.473, 'grad_norm': 0.12781551480293274, 'learning_rate': 0.00018400000000000003, 'epoch': 0.0}                                                                        
{'loss': 0.8386, 'grad_norm': 0.1284501999616623, 'learning_rate': 0.00017600000000000002, 'epoch': 0.0}                                                                        
{'loss': 0.6032, 'grad_norm': 0.09723766893148422, 'learning_rate': 0.000168, 'epoch': 0.0}                                                                                     
{'loss': 0.528, 'grad_norm': 0.12025056034326553, 'learning_rate': 0.00016, 'epoch': 0.0}                                                                                       
{'loss': 0.7025, 'grad_norm': 0.10142025351524353, 'learning_rate': 0.000152, 'epoch': 0.0}                                                                                     
{'loss': 0.7531, 'grad_norm': 0.6547413468360901, 'learning_rate': 0.000144, 'epoch': 0.0}                                                                                      
{'loss': 0.7773, 'grad_norm': 0.13310837745666504, 'learning_rate': 0.00013600000000000003, 'epoch': 0.0}                                                                       
{'loss': 0.4542, 'grad_norm': 0.08028502762317657, 'learning_rate': 0.00012800000000000002, 'epoch': 0.0}                                                                       
{'loss': 0.6358, 'grad_norm': 0.11321067065000534, 'learning_rate': 0.00012, 'epoch': 0.0}                                                                                      
{'loss': 0.3977, 'grad_norm': 0.09965520352125168, 'learning_rate': 0.00011200000000000001, 'epoch': 0.0}                                                                       
{'loss': 0.7854, 'grad_norm': 0.11347711086273193, 'learning_rate': 0.00010400000000000001, 'epoch': 0.0}                                                                       
{'loss': 0.6585, 'grad_norm': 0.1086692288517952, 'learning_rate': 9.6e-05, 'epoch': 0.0}                                                                                       
{'loss': 0.6072, 'grad_norm': 0.1224820464849472, 'learning_rate': 8.800000000000001e-05, 'epoch': 0.0}                                                                         
{'loss': 0.7727, 'grad_norm': 0.1163286566734314, 'learning_rate': 8e-05, 'epoch': 0.0}                                                                                         
{'loss': 0.734, 'grad_norm': 0.07792412489652634, 'learning_rate': 7.2e-05, 'epoch': 0.0}                                                                                       
{'loss': 0.5374, 'grad_norm': 0.07801351696252823, 'learning_rate': 6.400000000000001e-05, 'epoch': 0.0}                                                                        
{'loss': 0.9119, 'grad_norm': 0.14749355614185333, 'learning_rate': 5.6000000000000006e-05, 'epoch': 0.0}                                                                       
{'loss': 0.7244, 'grad_norm': 0.17585600912570953, 'learning_rate': 4.8e-05, 'epoch': 0.0}                                                                                      
{'loss': 0.5003, 'grad_norm': 0.10795021057128906, 'learning_rate': 4e-05, 'epoch': 0.0}                                                                                        
{'loss': 0.5919, 'grad_norm': 0.08922930806875229, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.0}                                                                       
{'loss': 0.6601, 'grad_norm': 0.10317248851060867, 'learning_rate': 2.4e-05, 'epoch': 0.0}                                                                                      
{'loss': 0.6584, 'grad_norm': 0.17016059160232544, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.0}                                                                       
{'loss': 0.6996, 'grad_norm': 0.10287228971719742, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.0}                                                                        
{'loss': 0.7713, 'grad_norm': 0.08957896381616592, 'learning_rate': 0.0, 'epoch': 0.0}                                                                                          
{'train_runtime': 138.8609, 'train_samples_per_second': 1.728, 'train_steps_per_second': 0.216, 'train_loss': 0.6853088637193044, 'epoch': 0.0}                                 
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [02:18<00:00,  4.63s/it]

手軽すぎてびっくりした。まあでもlossがそんなに下がってないのでこれでは実用性はないのかもしれないが。