0.1bのtransformerのperplexityで文章フィルタリングができるかどうかの検証


はじめに

最近は大規模言語モデルを作っています。

大規模言語モデルの学習において、事前学習データのフィルタリングは重要と言われています。

フィルタリングにはいくつもの方法がありますが、ルールベースでノイズを取り除くのは、処理速度の面でも有利です。

一方、最近は機械学習ベースのフィルタリングも注目されています。

特に、処理速度と精度を両立した(?)手法として、言語モデルのperplexity(困惑さ: モデルが入力した文章をどれくらい予測できるかの主要)を計算するアイデアが有効とされています。

例えば、2023年に発表されたPFNの言語モデルでも、0.1bクラスのtransformerが使われたそうです。

NLP2024での岡之原さんの講演でも、perplexity filterはなぜか有効、とのコメントがありました。

https://hillbig.github.io/NLP2024_WS_okanohara.pdf

というわけで、0.1bモデルでフィルタリングを試みます。

モデル学習

詳細は割愛しますが、GPT-2系の0.1 bモデルを、約30bトークンで学習させました。
最近はカリキュラム学習にハマっているので、英語wikipedia(360万件)→ CommonCrawl系日本語(1000万件)→日本語wikipedia(120万件)の順番で学習させました。

工コードは以下のリポジトリを用いています(まだ工事中)。

モデルも一応、公開中です。

Perplexityの計算コード

関数を定義します。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
import numpy as np

def perplexity(model, tokenizer, text) -> torch.Tensor:
    tokenized_input = tokenizer.encode(
        text, add_special_tokens=False, return_tensors="pt"
    ).to(model.device)
    with torch.inference_mode():
        output = model(tokenized_input, labels=tokenized_input)
    ppl = torch.exp(output.loss)
    return ppl.item()


model_name="kanhatakeyama/01b_model_30b_token"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="auto",
    torch_dtype=torch.float16
)

適当な文章で評価した結果

適当な文章を入れてみます。

text_list=[
    "吾輩は猫である。",
    "最高の気分ですよ!",
    "こんいちは、元気?",
    "こんにちは、元気?",
    "fjalkjfepiwofe",
    "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.",
]
for txt in text_list:
    ppl=perplexity(model,tokenizer,txt)
    print(ppl,txt.replace("\n",""))

計算結果は以下の通り。

ランダム文字列(fjalkjfepiwofe)はうまく弾けそうです。
一方で、打ち間違いの「こんいちは、元気?」perplexityが、正しい文章である「こんにちは、元気?」を下回るという事態が発生しました。

mc4-jaで評価

一般的なwebコーパスである、mc4-jaから文章を抜き出して評価してみます。

from datasets import load_dataset
#mc4の読み込み
mc4_dataset = load_dataset('mc4', 'ja',split='train', streaming=True)

dataset=mc4_dataset

import re
line_list=[]
ppl_list=[]
for i,record in enumerate(dataset):
    txt=record["text"]
    print("-----")

    for line in re.split(r'[。\\n]+', txt):
        if line=="":
            continue
        line=line[:2000]
        line=line.strip()
        ppl=perplexity(model,tokenizer,line)
        print(ppl,line.replace("\n",""))
        line_list.append(line)
        ppl_list.append(ppl)

    if i>15:
        break

#結果表示
import pandas as pd
df=pd.DataFrame({"text":line_list,"ppl":ppl_list}).sort_values("ppl")
df=df.drop_duplicates("text")
df[-30:]

perplexityが低かった文章

perplexityが高めの文章


4/5追記

きれいなテキストだけで学習させた、箱入り娘的なモデルを作りました(wikipedia, 青空文庫, NHKニュースを学習)。雑多なweb記事を読んでないので、文法的にオカシイ文章などを弾けるかも?というモチベーションです。

結果

pplが小さい文章
pplが高い文章

あなた「わ」のような、崩し系の文章はpplで弾けそうなことがわかりました。
逆にいうと、今回のフィルターでは、その程度の選択性しかない印象でした。ルールベースでもよいかも?という感想です。

感想

テキストの良し悪しを、perplexityでうまく判定出来たとは言えない結果となりました。
モデルサイズを上げれば、精度が向上する可能性はありますが、計算コストも激増するので、悩みどころです(1TBクラスのテキストを処理する必要があります)。

また、実際のchatでは、ユーザーの入力が雑(≒perplexityが高い)可能性もあるので、清掃のしすぎも問題かもしれません。

まとめ

perplexity filterについては、色々とノウハウが必要そうなことがわかりました。



いいなと思ったら応援しよう!