見出し画像

Unslothで継続事前学習をやってみた


はじめに

Unslothというライブラリを使って継続事前学習をやってみました。高速チューニングに行えるということなので、以前からトライしてみたいと思っていました。
今回は、継続事前学習で新しい知識を埋め込めるかどうかを確認してみることにしました。

結論から言えば、新しい知識を埋め込めたという結果を得ることはできませんでした。

使用したモデル

使用するモデルはLlama3.2 1Bを使用しました。新しい知識を埋め込むにはモデルが小さすぎたということかもしれません。

実装

早速実装を見ていくことにします。

準備

必要ライブラリとパラメータをいくつか先に決めておきます。また、モデルとトークナイザーもこの時点で読み込んでおきます。

import os
from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported
from unsloth import UnslothTrainer, UnslothTrainingArguments
from datasets import load_dataset
import wandb
os.environ["WANDB_PROJECT"] = "llama3.2-continued-pretraining"

# 基本設定
max_seq_length = 512
dtype = None
load_in_4bit = True
num_proc = 4
pretrain_output = 'pretrain_output'
random_seed = 3407

# モデルとトークナイザーの読み込み
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/Llama-3.2-1B",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

4bitのQLoRAで学習を実行する設定になっています。また、再現性を高めるためにランダムシードも設定します。
学習時の最大トークンサイズは512と小さい値にしています。学習に使用するマシンスペックの問題で小さい値になっています。非常に限られたリソースで実行しています。

LoRAの設定

継続事前学習を行いたいので、lm_headとembed_tokensをLoRAのターゲットに追加をします。

model = FastLanguageModel.get_peft_model(
    model,
    r=32,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
        "embed_tokens", "lm_head",
    ],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=random_seed,
    use_rslora=True
)

Unslothのサンプルnotebookに従うと、dropoutは0でbiasはnoneにするのが最適とのことです。
r や alphaはもっと大きな値の方が良かったかもしれないです。

データセットの読み込み

データセットの読み込みは以下のように行います。

dataset = load_dataset("kajuma/CC-news-2024-July-October-cleaned", split = "train",)
dataset = dataset.train_test_split(train_size=0.8)["train"]
dataset[0]

使用するデータセットは以下を使用させていただきました。

知識が正しく追加できているかどうかを判断するために、判定用のテキストを1つ追加しました。


dataset = dataset.add_item({
    'docId': 'b0b8df9a-6c3c-4331-81f8-bf01fbfb9d1f',
    'url': '',
    'charset': 'EUC-JP',
    'date': '2025-01-10T09:00:00',
    'language': 'ja',
    'text': open('dataset/knowledges.txt', 'r').read()
})
dataset[-1]

判定用のテキスト

2020年以降の流行語大賞を振り返ると、日本社会の重要な出来事や話題が反映されています:
2020年の「3密」は、新型コロナウイルス感染症対策として小池百合子東京都知事が提唱した「密閉・密集・密接」を避けるという概念を表し、パンデミック初期の日本社会を象徴する言葉となりました。
2021年は、メジャーリーグで投打の「二刀流」として歴史的な活躍を見せた大谷翔平選手の「リアル二刀流/ショータイム」が選ばれ、日本人選手の世界での躍進を表しました。
2022年の「村神様」は、プロ野球で史上初となる高校生ドラフト1位から4年連続本塁打王を達成した村上宗隆選手のニックネームで、若き天才打者の台頭を示しています。
2023年の「アレ(A.R.E)」は、阪神タイガースの岡田彰布監督が頻繁に使用した言葉で、44年ぶりのリーグ優勝を果たした阪神の快進撃を象徴しました。
2024年は、TBSドラマ『不適切にもほどがある!』から生まれた「ふてほど」が選ばれ、厳しい社会規範の中で生きる現代人の共感を呼んだことを示しています。
この5年間の流行語は、感染症対策から、スポーツ界の活躍、そして時代を反映したドラマまで、その時々の日本社会の関心事を端的に表現しています。

データの整形

データの整形といっても学習に正しく使えるようにキー名を変更するだけです。

news_prompt = """
{}
"""

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
    charsets = examples["charset"]
    texts  = examples["text"]
    outputs = []
    for charset,text in zip(charsets, texts):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = news_prompt.format(text) + EOS_TOKEN
        outputs.append(text)
    return { "text" : outputs, }

# フォーマット関数を適用してデータを整形
dataset = dataset.map(formatting_prompts_func, batched = True,)

学習設定

trainer = UnslothTrainer(
    model = model,                    # トレーニング対象のモデル
    tokenizer = tokenizer,            # モデル用トークナイザー
    train_dataset = dataset,          # トレーニングデータセット
    dataset_text_field = "text",      # データセット内のテキストフィールド名
    max_seq_length = max_seq_length,  # 最大シーケンス長
    dataset_num_proc = num_proc,             # データセット処理に使用するプロセス数
    neftune_noise_alpha=5,
    packing=False,
    
    args = UnslothTrainingArguments(
        per_device_train_batch_size = 8,      # 各デバイスごとのバッチサイズ
        gradient_accumulation_steps = 1,     # 勾配の累積ステップ数

        # 長時間のトレーニングに使用可能な設定
        num_train_epochs = 1,

        # 埋め込み行列には通常より2~10倍小さい学習率を選択
        learning_rate = 5e-5,                # 全体の学習率
        embedding_learning_rate = 1e-5,      # 埋め込み層の学習率

        fp16 = not is_bfloat16_supported(),  # FP16を使用(bfloat16がサポートされていない場合)
        bf16 = is_bfloat16_supported(),      # bfloat16を使用(サポートされている場合)
        logging_steps = 10,                   # ログを記録するステップ間隔
        save_total_limit = 2,
        
        optim="adamw_8bit",
        lr_scheduler_type="constant", 
        seed = random_seed,                         # 再現性のための乱数シード
        output_dir = pretrain_output,              # 出力ディレクトリ
        report_to = "wandb",                  # ログ出力先(例: "wandb"などを指定可能)
    ),
)

学習率のスケジューラはconstantで実行しました。linearなどのスケジューラで学習率を減衰した方が学習時の破綻のリスクが少ないと思いますが、モデルもデータも規模は小さいので、損失が早く減る方法をとりました。

学習の実行

trainer_stats = trainer.train()
wandb.finish()

学習の実行はこれだけ。wandb.finish()を実行しないとwandbが正しく終了しない。

モデル保存

model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
del model

LoRAで学習した結果をAdapterとして保存するのではなく、モデル本体にマージして保存します。

結果の確認

以下のコードで結果を確認します。

import os
from unsloth import FastLanguageModel
from transformers import TextStreamer

# 基本設定
max_seq_length = 2048
dtype = None
load_in_4bit = False
num_proc = 4
pretrain_output = 'pretrain_output'
sft_output = 'sft_output'
random_seed = 3407

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = './model',
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

inputs = tokenizer(
[
    "2020年以降の流行語大賞を振り返ると、"
]*1, return_tensors = "pt").to("cuda")

FastLanguageModel.for_inference(model) # Enable native 2x faster inference
text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, 
        max_new_tokens = 512,
        use_cache=True,
        do_sample=False,
        repetition_penalty=1.2)

出力

<|begin_of_text|>2020年以降の流行語大賞を振り返ると、上位に「アーカイブ」が選ばれた。いったいなぜかと。

## 「アーカイブ」とはどういうこと?

もともとは、「記憶や知見から事実への信じ方」をしたりすることを指す言葉でしたが、現在ではその意義が広げてきました。
例えば、ある日、ある人に「今までのあらゆることがこのようにしてきた」と話し始められたら、その人の発言にはどれだけの証拠があるのかというものがわかりますよ。

たとえそれが「アーカイブ」のようでもうまく伝えることはできませんが、より強力な表現として使われることもありますね。

ちなみに、アーカイブの前に「レポート」「調査」などが用いられるのは、前者が情報の収集・まとめる作業で後者のほうが検討や評価を行うためです。

また、つまりアーカイブは、事前の予想と合致しない結果が出ない場合にも、事後に予期せぬ事態になっている場合は、事後の対処法を決めておけばいいとされているのですが、そもそもアーカイブってのことなのでしょうか?
(出典:『日本経済トピックス』)

「アーカイブ」
「アーキテクチャー」

これらは、もともと別の意味で使われていたのですが、近年の活用が増えていますからね。

「アーカイブ」は、もともと「記録の長さ」についてのことだったのですが、特に「アーカイブの歴史」に関連しているのではないかと考えた人がいるのです。

そうすると、次のように解釈できるのです。「アーカイブの歴史」は、もともと「記録の長さ」ではなく、もしくは「記述の多様性」であると考えられます。

「アーカイブの歴史」は、もともと「記述の多様性」であると考えられているのと同一の物体のうちのひとつの分野のことであり、もともと「記述の多様性」が続いているとされる

いいなと思ったら応援しよう!