見出し画像

Google Colab で RedPajama-INCITE のLoRA ファインチューニングを試す

「Google Colab」で「RedPajama-INCITE-Base-3B」のLoRA ファインチューニングを試したので、まとめました。

【注意】Google Colab Pro/Pro+のプレミアム(A100)で試しました。VRAMは11.5GB必要でした。

1. lora-instruct

以下のリポジトリの「PEFT」で「RedPajama-INCITE」を学習させるコード「finetuning.py」を使わせてもらいました。

2. Colabでの実行

Colabでの実行手順は、次のとおりです。

(1) メニュー「編集→ノートブックの設定」で、「ハードウェアアクセラレータ」で「GPU」で「A100」を選択。

(2) Googleドライブのマウント

# Googleドライブのマウント
from google.colab import drive
drive.mount('/content/drive')

(3) 作業フォルダへの移動

# 作業フォルダへの移動
import os
os.makedirs("/content/drive/My Drive/work", exist_ok=True)
%cd '/content/drive/My Drive/work'

(3) リポジトリのクローン。

# リポジトリのクローン
!git clone https://github.com/leehanchung/lora-instruct
%cd lora-instruct

(4) poetryをrequirements.txtで出力。

# poetryをrequirements.txtで出力
#!pip install poetry
#!poetry export --without-hashes --with dev --output requirements.txt

(5) requirements.txtを開き、「poetry @git+XXXX」を「poetry」に変更。
(6) requirements.txtの実行。

!pip install -r requirements.txt

(7) LoRA ファインチューニングの実行。
デフォルトのデータセットは「yahma/alpaca-cleaned」が設定されています。

# LoRA ファインチューニングの実行
!python finetune.py \
    --base_model 'togethercomputer/RedPajama-INCITE-Base-3B-v1' \
    --output_dir './lora-redpajama'

W&Bの使用を聞かれたら回答します。(以下では使用しない"3"を選択)

wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3

2時間ほどで、「./lora-redpajama」にLoRAモデルが出力されました。

・lora-redpajama フォルダ
 ・checkpoint-1000 フォルダ
 ・checkpoint-800 フォルダ
 ・checkpoint-600 フォルダ
 ・runs フォルダ
 ・adapter_config.json
 ・adapter_model.bin

3. Colabでの推論

Colabでの推論手順は、次のとおりです。

(1) トークナイザーとモデルの準備

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained(
    "togethercomputer/RedPajama-INCITE-Base-3B-v1"
)
model = AutoModelForCausalLM.from_pretrained(
    "togethercomputer/RedPajama-INCITE-Base-3B-v1", 
    torch_dtype=torch.float16
)

(2) PEFTモデルの追加。

from peft import PeftModel

# PEFTモデルの追加
model = PeftModel.from_pretrained(model, "./lora-redpajama", device_map="auto")
model = model.to("cuda:0")

(3) Instruction用のプロンプトテンプレートの準備。

# Instruction用のプロンプトテンプレートの準備
def generate_prompt(instruction, input=None):
    if input:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:
"""
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
"""

(4) 推論の実行。

# プロンプトの準備
prompt = generate_prompt("Give three tips for staying healthy.")
print("--[prompt]--\n" + prompt + "----")

# 推論の実行
inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
input_length = inputs.input_ids.shape[1]
outputs = model.generate(
    **inputs, 
    max_new_tokens=128, 
    do_sample=True, 
    temperature=0.7, 
    top_p=0.7, 
    top_k=50, 
    return_dict_in_generate=True
)
token = outputs.sequences[0, input_length:]
output_str = tokenizer.decode(token)

# 確認
print(output_str)
--[prompt]--
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Give three tips for staying healthy.

### Response:
----

1. Make sure you get enough sleep.
2. Exercise daily.
3. Eat a healthy diet.

### Explanation:

1. Sleep is important for maintaining good health.
2. Exercise is important for maintaining good health.
3. Eating a healthy diet is important for maintaining good health.

回答後に、不要なテキスト(### Explanation:…)がついてますが、学習できてそうなことがわかります。

【おまけ】 LoRAを未設定の場合の確認

LoRAを未設定の場合の結果も確認してみます。

--[prompt]--
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Give three tips for staying healthy.

### Response:
----
I am a fan of the [Green Smoothie Diet](http://www.greensmoothiediet.com/).

I like to drink a green smoothie every morning.

Alpacaデータセット内の同一質問の回答は、次のとおりです。

    {
        "instruction": "Give three tips for staying healthy.",
        "input": "",
        "output": "1. Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
    },

LoRA ファインチューニングが影響を与えているこは、確認できました。



いいなと思ったら応援しよう!