BERTベースモデルのFine-TuningにTrainerクラスを利用する
こんにちは、エンジニアのすずきです。
以前、IBM論文の参考コードでTabBERTモデルの事前学習を行い、Fine-Tuningについては自作コードを実装しました。
自作コードで一応Fine-Tuningをできるようになったのですが、F1スコアなどのメトリクスを計算するだけでも面倒さを感じていました。
事前学習のときと同様にTransformersのTrainerクラスを使えればメトリクスも簡単に出せるのに...といろいろ調べてみたところ、下流タスク用のヘッドをボディ(事前学習済モデル)に付加したものについても、普通にTrainerクラスが使えることがわかりました。
そんなわけで、今回は以前の自作コードをTrainerクラスを使ってリファクタリングしました。
また、今回のリファクタリングのついでに、WandBの導入やオーバーサンプリング処理の追加も行ったので、最後の方におまけで書いています。
実装
ポイントとなる部分だけ書きます。
モデルの作成
Trainerクラスでモデルを扱うためには以下がポイントとなります。
PreTrainedModelを継承する
出力をlossとlogitsのタプルで返す
公式Docsに記載されているように、TrainerクラスではPreTrainedModelで動作するように最適化されるようです。
nn.Moduleでも大丈夫とのことでしたが、実際にこちらを継承したらエラーがでました。
また、自作コードではforwardでlogits(予測結果)のみを返すようにしていたのですが、lossとlogitsのタプルで返すようにしました。
lossを返すためにinitで損失関数loss_fnを指定し、推論時にもモデルを使用することを想定して、損失関数loss_fnを指定しない場合はloss=Noneを返すようにしました。
なお、下流タスク(分類)のヘッドとして、事前学習済モデルにLSTM層とLinear層を付加しています。
from transformers import BertConfig, BertModel, PreTrainedModel
import torch.nn as nn
import torch
class VisitorReactionModel(PreTrainedModel):
def __init__(self,
config='./output_pretraining/action_history/checkpoint-500/config.json',
num_categories=2,
loss_fn=None,
pretrained_model='./output_pretraining/action_history/checkpoint-500/pytorch_model.bin'):
super().__init__(config=config, num_categories=num_categories, loss_fn=loss_fn, pretrained_model=pretrained_model)
self.model = BertModel.from_pretrained(pretrained_model, config=config)
self.lstm = nn.LSTM(self.config.hidden_size, self.config.hidden_size, batch_first=True)
self.regressor = nn.Linear(self.config.hidden_size, num_categories)
self.loss_fn = loss_fn
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
output_attentions=False,
output_hidden_states=False,
labels=None):
outputs = self.model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
out, _ = self.lstm(outputs[0], None)
sequence_output = out[:, -1, :]
logits = self.regressor(sequence_output)
loss=None
if labels is not None and self.loss_fn is not None:
loss = self.loss_fn(logits, torch.max(labels, 1)[1])
# ModelOutputだとlossがスカラーじゃないというエラーが出るためTupleで返す
return loss, logits
compute_metricsの作成
モデルをTrainerクラスに適用できるようになったのですが、これだけだとprecision, recall, f1といったメトリクスを導出することができません。
そんなときに使用するのがcompute_metricsとなります。
※Transformersの3系バージョンだと使用できなかったので、4.26.0へ事前にバージョンアップしています。
引数の型はEvalPrediction、戻り値の型はOptional[Dict[str, float]]となります。
def compute_metrics(res: EvalPrediction):
logits = res.predictions.argmax(axis=1)
labels = res.label_ids.argmax(axis=1)
precision = precision_score(labels, logits, average='macro')
recall = recall_score(labels, logits, average='macro')
f1 = f1_score(labels, logits, average='macro')
return {
'precision': precision,
'recall': recall,
'f1': f1
}
あとは、Trainerの引数にmodelとcompute_metricsを与えれば、メトリクスが計算されます。
分類タスクでFine-Tuningを行うため、損失関数にはCrossEntropyを用いています。
loss_fn = CrossEntropyLoss()
model = VisitorReactionModel(config=config, pretrained_model=pretrained_model, loss_fn=loss_fn)
training_args = TrainingArguments(
output_dir=args.output_dir, # output directory
num_train_epochs=args.num_train_epochs, # total number of training epochs
per_device_train_batch_size=args.num_train_batch_size,
per_device_eval_batch_size=args.num_eval_batch_size,
save_steps=args.save_steps,
do_train=True,
do_eval=True,
evaluation_strategy="epoch", # epochかsteps(デフォルト500)ごとに評価
overwrite_output_dir=True,
save_total_limit=1,
report_to="wandb"
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics
)
その他
今回のリファクタリングでいくつか改善も行ったので、おまけで記載します。
WandBの導入
Trainerの利用でWandBの導入が楽になったので、リファクタリングにあわせて実装しました。
Fine-Tuningのコード内に、WandBのログインと初期化のコードを追加します。
このとき、dotenvを利用して、ローカル学習の際は.envから、SageMaker Training Jobsの際はEstimatorsのhyperparametersからAPIキーを読み込むようにします。
load_dotenv()
WANDB_API_KEY = os.getenv('SM_HP_WANDB_API_KEY')
wandb.login(key=WANDB_API_KEY) # Pass your W&B API key here
wandb.init(project="tabformer-opt") # Add your W&B project name
estimator = Estimator(
image_uri="",
role=role,
instance_type="ml.g4dn.2xlarge",
instance_count=1,
base_job_name="tabformer-opt-fine-tuning",
output_path="",
code_location="",
sagemaker_session=session,
entry_point="fine-tuning.sh",
dependencies=["tabformer-opt"],
hyperparameters={
"data_root": "/opt/ml/input/data/input_data/",
"data_fname": "",
"output_dir": "/opt/ml/model/",
"model_path": "/opt/ml/input/data/input_model/",
"wandb_api_key": <APIキー>
}
)
あとは、TrainingArgumentsにreport_to="wandb"を追加するだけで、学習結果が記録されるようになります。
オーバーサンプリング
今回、ポジティブラベルが10 %以下の不均衡データを使用しており、そのままモデルで学習を行ってもprecisionやrecallが低い結果となってしまいます。
そのため、少数派のポジティブラベルデータをオーバーサンプリングで増やすようにしました。
この際、単純なデータ複製で過学習を起こさないために、少数派のデータからランダムでデータを選択し、そのデータからランダムで選択された近傍点を用いて、両者の合成データを作成する、SMOTEという手法を用いました。
SMOTE処理を以下の関数にまとめ、データ前処理のコードに加えました。
def overSampling(data):
sm = SMOTE(random_state=42)
X = data.drop(columns='reaction', axis=1)
y = data['reaction']
X_sample, Y_sample = sm.fit_resample(X, y)
over_sampling = pd.DataFrame()
over_sampling = X_sample
over_sampling['reaction'] = Y_sample
return over_sampling
参考資料
採用情報
バックエンドが得意な方を募集中です。
AWSやバックエンドの経験があれば、インフラ設計やパフォーマンスチューニングなどなんでもお任せします。
もしご興味があれば、採用情報ページの画面左下のボタンからチャット(かWeb通話)でお声がけいただけると幸いです。
最近、YOUTRUSTにも登録しました。
カジュアル面談に興味がある方はぜひ…!