見出し画像

unslothを使って、小型LLMを「継続事前学習」する方法

LLM(大規模言語モデル)の性能を向上させる方法の一つに、
継続事前学習があります。

継続事前学習とは、
事前学習済みのLLMに対し、
さらに知識を習得させることを言います。

例えが適切ではないかもしれませんが、
社会人がリスキリングするみたいなものですね。

継続事前学習のイメージ


今回、unslothというライブラリーを使って
継続事前学習を実装してみましたので、
その詳細についてご紹介していきたいと思います。

ちなみに、自分が調べた限り、
ネット上に「継続事前学習」に関する記事は、
あまり存在しません。



改めて継続事前学習とは何か?


その名の通り、事前学習の継続です。

通常、LLMはインターネット上の膨大なテキストデータを使って事前学習をしています。
事前学習によって、LLMは豊富な知識を身につけていますが、
それはあくまで、一般的なものにすぎません。
また、過去の一時点のものにすぎません。

例えば、
LLMが事前学習していない
専門分野の知識だったり、最新の知識だったりを
LLMが習得するためには、
LLMにもう一度勉強させるしかないのです。

わかりやすいところで、
医療や金融といった専門分野を想像してみてください。

LLMは、どの程度、これらの知識を習得しているのでしょうか?

確かにインターネットには、そこそこ情報が流れているかもしれません。
ただ、専門分野の場合、ネットにある情報は全体のごく一部である可能性が高いです。
現に、自分は金融機関に勤めていますが、金融に関して、Google検索から得られる情報はまだまだ少ないと思っています。

特定の専門分野に特化したLLMを作ろうと思ったら、
特定の専門分野にかかる知識をLLMに再学習させる必要があるのです。

これが継続事前学習です。

特定の専門分野に限らず、
LLMが事前学習していない分野を学習させるのに
継続事前学習は有効な手段になります。

まとめると、継続事前学習によって、
既存のLLMに新たなデータを加えて再学習させることで、
モデルを目的に応じてカスタマイズできるのです。

ちなみに、この手法は、
モデルの中身が公開されているオープンソースのLLM(例:LLaMAやQwen)でないとできません。


継続事前学習の実装方法は?


unslothというライブラリーを使って
継続事前学習を実装する方法について解説していきます。


unslothは、継続事前学習に限らず、
SFT(教師ありファインチューニング)やDPO(選好チューニング)など、
LLMの学習全般について、
大幅に高速化できるライブラリです。

従来の方法と比較して、最大で2倍の速さ、メモリも最大で70%少ない使用量での実装が可能と言われています。

具体的なコードを解説しながら見ていきます。
なお、こちらの記事を参考にしています。

https://unsloth.ai/blog/contpretraining

ライブラリーのインストール

unslothをインストールします。

%%capture
!pip install unsloth

!pip uninstall unsloth -y
!pip install --upgrade --no-cache-dir --no-deps \
    git+https://github.com/unslothai/unsloth.git

既存パッケージをアップグレードします。

!pip install --upgrade torch
!pip install --upgrade xformers


HuggingFaceのトークンをセットしておきます。

from google.colab import userdata
HF_TOKEN=userdata.get('HF_TOKEN')


モデルの設定・ダウンロード

モデルは、Llama、Mistral、Qwen、gemmaなど、
あらゆるオープンソースのLLMが使えそうです。

この記事では、大規模言語モデル研究開発センター(LLMC)が開発したllm-jp-3-13Bモデルを使っていきます。


from unsloth import FastLanguageModel
import torch
if torch.cuda.get_device_capability()[0] >= 8:
    !pip install --no-deps packaging ninja einops "flash-attn>=2.6.3"

max_seq_length = 2048  # モデルが処理できるシーケンスの最大長
dtype = None  # 適切なデータ型を自動的に検出
load_in_4bit = True  # 4ビット量子化でメモリ効率を向上


# モデルとトークナイザをロード
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "llm-jp/llm-jp-3-13b",  # 使用するモデルを選択。
    max_seq_length = max_seq_length,        
    dtype = dtype,                          
    load_in_4bit = load_in_4bit,            
    # token = "hf_...",                     # 特定の制限付きモデル(例: Meta-Llama)用のアクセストークンを指定
)

【パラメータの補足説明】
max_seq_length:

モデルが処理できるシーケンスの最大長 。任意の値を設定可能です!
dtype:
データ型を指定。自動検出にはNoneを指定。
load_in_4bit:
Trueに設定することで、4ビット量子化を行うことができ、メモリ効率が向上します。


LoRAの設定

学習を効率的に行うためLoRAを使います。

LoRAが対象とする線型層は、
ファインチューニングの場合、
"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"
ですが、
継続事前学習の場合、
"embed_tokens", "lm_head"を追加するのが特徴です。

model = FastLanguageModel.get_peft_model(
    model,
    r = 128,  # 任意の値を指定可能(0以上)。推奨値: 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",
                      "embed_tokens", "lm_head",], # 継続的な事前学習を行う場合に追加
    lora_alpha = 32,      # LoRAのスケールファクター
    lora_dropout = 0,     # 任意の値を設定可能
    bias = "none",        # 任意の値を設定可能
    # [新機能] "unsloth" を使用するとVRAM消費が30%削減され、バッチサイズを2倍に拡張可能!
    use_gradient_checkpointing = "unsloth", # Trueまたは"unsloth"を指定して超長文のコンテキストに対応
    random_state = 3407,  # 再現性を確保するための乱数シード
    use_rslora = True,    # ランク安定化LoRAをサポート
    loftq_config = None,  # LoftQもサポート
)


データセットの準備

継続事前学習用のデータセットを用意します。
こちらのデータセットを使わせてもらいました。

from datasets import load_dataset

dataset = load_dataset("kajuma/CC-news-2024-July-October-cleaned", split = "train",)

dataset = dataset.train_test_split(train_size = 0.30)["train"]

dataset.train_test_split():
そのままだとデータ量が多いので、
train_test_splitでランダムに30%のデータに絞っています。

次に、モデルが読み込めるように、
datasetのフォーマットを整えます。

news_prompt = """
###text:{}
###
"""

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
    charsets = examples["charset"]
    texts  = examples["text"]
    outputs = []
    for charset,text in zip(charsets, texts):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = news_prompt.format(text) + EOS_TOKEN
        outputs.append(text)
    return { "text" : outputs, }
pass

# フォーマット関数を適用してデータを整形
dataset = dataset.map(formatting_prompts_func, batched = True,)


学習パラメータの設定

from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from unsloth import UnslothTrainer, UnslothTrainingArguments

trainer = UnslothTrainer(
    model = model,                    # トレーニング対象のモデル
    tokenizer = tokenizer,            # モデル用トークナイザー
    train_dataset = dataset,          # トレーニングデータセット
    dataset_text_field = "text",      # データセット内のテキストフィールド名
    max_seq_length = max_seq_length,  # 最大シーケンス長
    dataset_num_proc = 2,             # データセット処理に使用するプロセス数

    args = UnslothTrainingArguments(
        per_device_train_batch_size = 2,      # 各デバイスごとのバッチサイズ
        gradient_accumulation_steps = 8,     # 勾配の累積ステップ数

        # 長時間のトレーニングに使用可能な設定
        max_steps = 120,                     # トレーニングの最大ステップ数
        warmup_steps = 10,                   # ウォームアップステップ数
        # warmup_ratio = 0.1,                # ウォームアップ比率(オプション)
        # num_train_epochs = 1,              # トレーニングのエポック数(オプション)

        # 埋め込み行列には通常より2~10倍小さい学習率を選択
        learning_rate = 5e-5,                # 全体の学習率
        embedding_learning_rate = 1e-5,      # 埋め込み層の学習率

        fp16 = not is_bfloat16_supported(),  # FP16を使用(bfloat16がサポートされていない場合)
        bf16 = is_bfloat16_supported(),      # bfloat16を使用(サポートされている場合)
        logging_steps = 1,                   # ログを記録するステップ間隔
        optim = "adamw_8bit",                # 8ビット版AdamWオプティマイザーを使用
        weight_decay = 0.01,                 # 重み減衰率
        lr_scheduler_type = "linear",        # 学習率スケジューラのタイプ
        seed = 3407,                         # 再現性のための乱数シード
        output_dir = "outputs",              # 出力ディレクトリ
        report_to = "none",                  # ログ出力先(例: "wandb"などを指定可能)
    ),
)


学習の実行

trainer_stats = trainer.train()

こんな感じで学習が進みます。

モデルの保存

HuggingFaceに保存します。

# LoRAアダプタだけ保存
new_model_id = "****/llm-jp-3-13b-CP"
model.push_to_hub_merged(
    new_model_id+"_lora",
    tokenizer=tokenizer,
    save_method="lora",
    token=HF_TOKEN,
    private=True
)



さいごに

継続事前学習をしたモデルは、
この後、教師ありファインチューニング(指示チューニング)することで、
人間の意図を汲んだ回答を生成することができるようになります。
指示チューニングについては、こちらの記事を参考にしてください。

最後までお読みいただきありがとうございました。


いいなと思ったら応援しよう!

Non
よろしければサポートお願いします! いただいたサポートはクリエイターとしての活動費に使わせていただきます!