0.1bのtransformerのperplexityで文章フィルタリングができるかどうかの検証
はじめに
最近は大規模言語モデルを作っています。
大規模言語モデルの学習において、事前学習データのフィルタリングは重要と言われています。
フィルタリングにはいくつもの方法がありますが、ルールベースでノイズを取り除くのは、処理速度の面でも有利です。
一方、最近は機械学習ベースのフィルタリングも注目されています。
特に、処理速度と精度を両立した(?)手法として、言語モデルのperplexity(困惑さ: モデルが入力した文章をどれくらい予測できるかの主要)を計算するアイデアが有効とされています。
例えば、2023年に発表されたPFNの言語モデルでも、0.1bクラスのtransformerが使われたそうです。
NLP2024での岡之原さんの講演でも、perplexity filterはなぜか有効、とのコメントがありました。
https://hillbig.github.io/NLP2024_WS_okanohara.pdf
というわけで、0.1bモデルでフィルタリングを試みます。
モデル学習
詳細は割愛しますが、GPT-2系の0.1 bモデルを、約30bトークンで学習させました。
最近はカリキュラム学習にハマっているので、英語wikipedia(360万件)→ CommonCrawl系日本語(1000万件)→日本語wikipedia(120万件)の順番で学習させました。
工コードは以下のリポジトリを用いています(まだ工事中)。
モデルも一応、公開中です。
Perplexityの計算コード
関数を定義します。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
import numpy as np
def perplexity(model, tokenizer, text) -> torch.Tensor:
tokenized_input = tokenizer.encode(
text, add_special_tokens=False, return_tensors="pt"
).to(model.device)
with torch.inference_mode():
output = model(tokenized_input, labels=tokenized_input)
ppl = torch.exp(output.loss)
return ppl.item()
model_name="kanhatakeyama/01b_model_30b_token"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto",
torch_dtype=torch.float16
)
適当な文章で評価した結果
適当な文章を入れてみます。
text_list=[
"吾輩は猫である。",
"最高の気分ですよ!",
"こんいちは、元気?",
"こんにちは、元気?",
"fjalkjfepiwofe",
"This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.",
]
for txt in text_list:
ppl=perplexity(model,tokenizer,txt)
print(ppl,txt.replace("\n",""))
計算結果は以下の通り。
ランダム文字列(fjalkjfepiwofe)はうまく弾けそうです。
一方で、打ち間違いの「こんいちは、元気?」のperplexityが、正しい文章である「こんにちは、元気?」を下回るという事態が発生しました。
mc4-jaで評価
一般的なwebコーパスである、mc4-jaから文章を抜き出して評価してみます。
from datasets import load_dataset
#mc4の読み込み
mc4_dataset = load_dataset('mc4', 'ja',split='train', streaming=True)
dataset=mc4_dataset
import re
line_list=[]
ppl_list=[]
for i,record in enumerate(dataset):
txt=record["text"]
print("-----")
for line in re.split(r'[。\\n]+', txt):
if line=="":
continue
line=line[:2000]
line=line.strip()
ppl=perplexity(model,tokenizer,line)
print(ppl,line.replace("\n",""))
line_list.append(line)
ppl_list.append(ppl)
if i>15:
break
#結果表示
import pandas as pd
df=pd.DataFrame({"text":line_list,"ppl":ppl_list}).sort_values("ppl")
df=df.drop_duplicates("text")
df[-30:]
perplexityが低かった文章
perplexityが高めの文章
4/5追記
きれいなテキストだけで学習させた、箱入り娘的なモデルを作りました(wikipedia, 青空文庫, NHKニュースを学習)。雑多なweb記事を読んでないので、文法的にオカシイ文章などを弾けるかも?というモチベーションです。
結果
あなた「わ」のような、崩し系の文章はpplで弾けそうなことがわかりました。
逆にいうと、今回のフィルターでは、その程度の選択性しかない印象でした。ルールベースでもよいかも?という感想です。
感想
テキストの良し悪しを、perplexityでうまく判定出来たとは言えない結果となりました。
モデルサイズを上げれば、精度が向上する可能性はありますが、計算コストも激増するので、悩みどころです(1TBクラスのテキストを処理する必要があります)。
また、実際のchatでは、ユーザーの入力が雑(≒perplexityが高い)可能性もあるので、清掃のしすぎも問題かもしれません。
まとめ
perplexity filterについては、色々とノウハウが必要そうなことがわかりました。