画像系マルチモーダルLLMであるQwen2-VLのファインチューニングの練習
はじめに
Qwen2-VLは、高性能なマルチモーダルLLMです。
本記事では、モデルのファインチューニングを試みてみます。
マルチモーダルLLMを学習するのは初めてなので、色々と試行錯誤がありそうです。
基本的には、以下のマニュアルを真似するだけ作業が完了しました。
(一部、コードを追記修正)
マシン
A100(80GB)x2、ubuntuを使います。
ファインチューニングの最低スペックは不明ですが、とりあえず試してみます。
環境構築
適当にcondaで仮想環境を作ります。pythonのバージョンは適当です。
conda create -n qwenft python=3.10 -y
conda activate qwenft
マニュアルの通りに学習関連のモジュールをpipします。
git clone https://github.com/modelscope/swift.git
cd swift
pip install -e .[llm]
pip install git+https://github.com/huggingface/transformers.git
pip install pyav qwen_vl_utils
#追加で必要だったモジュール
pip install qwen_vl_utils #推論
pip install torchvision #推論
pip install deepspeed #ファインチューニング
ここで、swiftというのは、Scalable lightWeight Infrastructure for Fine-Tuningの略で、ファインチューニング用のライブラリのようです。
iOS用のプログラミング言語は関係ないとのことです。安心しました。
推論をしてみる
とりあえず2bで推論してみます。
以下のwikipediaの写真を推論してみます。
推論コード
import torch
from swift.llm import (
get_model_tokenizer, get_template, inference, ModelType,
get_default_template_type, inference_stream
)
from swift.utils import seed_everything
#モデル初期化
model_type = ModelType.qwen2_vl_2b_instruct
template_type = get_default_template_type(model_type)
model, tokenizer = get_model_tokenizer(model_type, torch.bfloat16,
model_kwargs={'device_map': 'auto'})
model.generation_config.max_new_tokens = 256
template = get_template(template_type, tokenizer)
seed_everything(42)
#クエリを投げる
target_url="https://upload.wikimedia.org/wikipedia/commons/thumb/0/04/Cyanocitta-cristata-004.jpg/1280px-Cyanocitta-cristata-004.jpg"
query = f"""<img>{target_url}</img>これはなんだい?"""
response, history = inference(model, template, query)
print(f'query: {query}')
print(f'response: {response}')
出力
これはブルーベリーの写真です。ブルーベリーは、冬の雪に覆われた枝に静かに座っています。
なかなかおもしろい回答が返ってきました。
既存のデータセットでファインチューニング
サンプルコードをもとに、とりあえずファインチューニングさせてみます。
LoRA, QLoRAなども対応しているようです。
CUDA_VISIBLE_DEVICES=0,1 NPROC_PER_NODE=2 swift sft \
--model_type qwen2-vl-2b-instruct \
--model_id_or_path qwen/Qwen2-VL-2B-Instruct \
--sft_type full \
--freeze_vit true \
--deepspeed default-zero2 \
--dataset latex-ocr-print#20000
無事に学習が始まりました。vramは合計で52 GBほど使うようです。20minほどで学習が終わる雰囲気でした。
(ramに余裕がありそうだったので、deepspeedの設定をdeepspeed default-zero2からdeepspeed default-zero1に変更してみたのですが、binascii系のよくわからないエラーが出たので諦めました)
学習したモデルでの推論
モデルとトークナイザの読み込みに、やや苦労しました。
とりあえず、以下のコードで動きました。
import torch
from swift.llm import get_model_tokenizer, get_template, inference, ModelType,get_default_template_type
from swift.utils import seed_everything
from transformers import Qwen2VLForConditionalGeneration
#モデルのパス
local_model_path = "/data/2024/1008qwenft/swift/output/qwen2-vl-2b-instruct/v1-20241008-163630/checkpoint-1238"
#モデル初期化
model_type = ModelType.qwen2_vl_2b_instruct
template_type = get_default_template_type(model_type)
#この方式でtokenizerのみ読み込み
_, tokenizer = get_model_tokenizer(model_type, torch.bfloat16,
load_model=False,
model_kwargs={'device_map': 'auto',
})
# ローカルモデルの読み込み
model = Qwen2VLForConditionalGeneration.from_pretrained(
local_model_path,
torch_dtype=torch.bfloat16,
device_map='auto',
)
model.generation_config.max_new_tokens = 256
template = get_template(template_type, tokenizer)
seed_everything(42)
#推論
#クエリを投げる
target_url="https://upload.wikimedia.org/wikipedia/commons/thumb/0/04/Cyanocitta-cristata-004.jpg/1280px-Cyanocitta-cristata-004.jpg"
query = f"""<img>{target_url}</img>これはなんだい?"""
response, history = inference(model, template, query)
print(f'query: {query}')
print(f'response: {response}')
出力
これは青いヤブです。
ファインチューニングに使ったデータがOCR系?のものなので、回答精度が上がる必然性もありませんが、なんとも言えない回答になりました。
(青いヤブとは一体、、?)
ちなみに、英語で聞いたら正しい回答が返ってきました。
input
query = f"""<img>{target_url}</img>what is this?"""
Output(ファインチューニング前)
This is a Blue Jay (Cyanocitta cristata). Blue Jays are known for their vibrant blue and white plumage, which helps them blend into their snowy surroundings. They are commonly found in North America and are known for their distinctive call, which is often described as a "pee-ew" sound.
(これはアオカケス (Cyanocitta cristata) です。アオカケスは鮮やかな青と白の羽毛で知られ、雪に覆われた周囲に溶け込みます。北米によく見られ、独特の鳴き声で知られています。この鳴き声は「ピーーー」という音と表現されることが多いです。)
Output (ファインチューニング後)
A blue jay perched on a snowy branch.
(雪の積もった枝に止まったアオカケス。)
感想・まとめ
ファインチューニング前のモデルは、かなり注意深くにinstruction tuningされている印象です。それをOCR用のデータセットで学習させたことで、回答の丁寧さが失われた印象です。
ファインチューニング前のモデルが、マニアックな鳥の名前も覚えていたので、やや驚きました。wikipedia系のデータは概ね学習済みなのかもしれません。
英語では答えられるけれども、日本語では回答できないケースがあることがわかりました。
日本語のデータで追加学習することで、かなり性能を上げる余地がありそうです
今後はオリジナルなデータセットでファインチューニングの練習をしていきたいと思います