text2text-japaneseによる要約のファインチューニング
「text2text-japanese」が公開されたので、日本語による要約のファインチューニングを試してみました。
1. text2text-japanese
「text2text-japanese」は、「gpt2-japanese」のモデルを使って、文章から文章に変換するタスクのファインチューニングを行うためのプログラムです。
「質問回答」「チャットボット」「要約 」などに利用できます。
・質問回答 : 入力「質問」 → 出力「回答」
・チャットボット : 入力「発話」 → 出力「応答」
・要約 : 入力「文章」 → 出力「要約」
2. 要約データセットの準備
はじめに、「要約データセット」を準備します。
「ThreeLineSummaryDataset」の「train.csv」にLivedoorニュースのIDのリストがあるので、それを使います。
Livedoorニュースの「3行要約」と「本文」をスクレイピングで取得します。bs4でスクレイピングするコードは次のとおりです。サーバーに負荷をかけないように、10秒に1回だけ通信するようにしています。
from urllib.request import urlopen
from bs4 import BeautifulSoup
from bs4.element import NavigableString
from pprint import pprint
import time
# コンテンツの取得
def get_content(id):
time.sleep(10)
URL = 'https://news.livedoor.com/article/detail/'+id+'/'
print(URL)
try:
with urlopen(URL) as res:
# 本文
output1 = ''
html = res.read().decode('euc_jp', 'ignore')
soup = BeautifulSoup(html, 'html.parser')
lineList = soup.select('.articleBody p')
for line in lineList:
if len(line.contents) > 0 and type(line.contents[0]) == NavigableString:
output1 += line.contents[0].strip()
if output1 == '': # 記事がない
return
output1 += '\n'
# 3行要約
output0 = ''
summaryList = soup.select('.summaryList li')
for summary in summaryList:
output0 += summary.contents[0].strip()+'\t'
if output0 == '': # 記事がない
return
# 出力
print(output0+output1)
with open('output.tsv', mode='a') as f:
f.writelines(output0+output1)
except Exception:
print('Exception')
# IDリストの生成の取得
idList = []
with open('train.csv', mode='r') as f:
lines = f.readlines()
for line in lines:
id = line.strip().split(',')[3].split('.')[0]
idList.append(id)
# コンテンツの取得
for i in range(0, 10): # 取得したい記事のINDEXをここで指定
print('index:', i)
get_content(idList[i])
「output.tsv」に以下のようなtsvのデータ形式で取得できます。<tab>はタブ(\t)になります。
要約1 <tab> 要約2 <tab> 要約3 <tab> 本文
:
今回は練習のため1000件ほど取得しました。
3. 要約データセットをtext2text-japaneseの書式に変換
「text2text-japanese」の学習データは、「train_data」フォルダ下に、1ファイルに1ペア(本文と要約)ずつ、以下の書式で記述します。
本文<|SEP|>要約
要約データセットをtext2text-japaneseの書式に変換するコードは、次のとおりです。要約は3行ありますが、練習なので1行目だけを使います。前処理はURL表記の削除のみ行います。
import re
# URL表記の削除
def del_url(str):
return re.sub(r"(https?|ftp)(:\/\/[-_\.!~*\'()a-zA-Z0-9;\/?:\@&=\+$,%#]+)", "" , str)
# text2text-japaneseの書式に変換
dataList = []
with open('output.tsv', mode='r') as f:
lines = f.readlines()
for line in lines:
strs = line.strip().split('\t')
dataList.append(del_url(strs[3])+'<|SEP|>'+del_url(strs[0])+'\n')
# 出力
index = 0
for data in dataList:
with open('train_data/'+str(index)+'.txt', mode='w') as f:
f.write(data)
index += 1
train_dataフォルダ下に0.txt〜999.txtが生成されます。
【例】 0.txt
新橋と秋葉原に店を構える「岡むら屋」。この組み合わせではすでに「肉じゃが」という殿堂入りメニューがありますが、〜省略〜 <|SEP|>岡むら屋から、期間限定の新メニュー「じゃが肉めし」が登場する
4. 要約のファインチューニング
「Google Colab」で要約のファインチューニングを行います。
(1) Googleドライブのフォルダの準備
学習結果を永続化したいので、Googleドライブのフォルダを準備します。
# Googleドライブのフォルダの準備
from google.colab import drive
drive.mount('/content/drive')
!mkdir -p /content/drive/'My Drive'/gpt/
%cd /content/drive/'My Drive'/gpt/
(2) text2text-japaneseのインストール
# インストール
!git clone https://github.com/tanreinama/text2text-japanese.git
%cd text2text-japanese
!pip install jaconv
(3) 「text2text-japanese」フォルダに先程作成した「train_data」を配置。
(4) 学習用ファイルの生成。
「pairtext_1.pkl」〜「pairtext_7.pkl」が生成されます。
%%time
# 学習用ファイルの生成
!python make_data.py --src_dir train_data --dst_file pairtext --split_tag '<|SEP|>'
Wall time: 4min 36s
「make_data.py」の引数は、次のとおりです。
・--src_dir : 入力フォルダ
・--dst_file : 出力フォルダ
・--num_process : プロセス数 (デフォルト:8)
・--split_tag : 分割タグ
(5) モデルのダウンロード。
「medium」を使いたいところですが、「Google Colab」ではメモリが足りなかったので「small」を使います(「Google Colab Pro 」使いたい)。
%%time
# モデルのダウンロード
!wget https://www.nama.ne.jp/models/gpt2ja-small.tar.bz2
!tar xvfj gpt2ja-small.tar.bz2
Wall time: 3min 1s
(6) ファインチューニングの実行。
練習なので10000ステップ学習します。checkpoint/run1/にモデルが生成されます。
%%time
# 学習の実行
!python training.py --base_model gpt2ja-small --run_name run1 --dataset "pairtext_*.pkl" --max_train_steps 10000
[1 | 2.95] loss=1.57 avg=1.57
[2 | 3.33] loss=3.19 avg=2.39
[3 | 3.70] loss=2.48 avg=2.42
:
[9997 | 3994.64] loss=0.03 avg=0.06
[9998 | 3995.05] loss=0.03 avg=0.06
[9999 | 3995.44] loss=0.03 avg=0.06
Saving checkpoint/run2/model-10000
CPU times: user 13.8 s, sys: 2.85 s, total: 16.6 s
Wall time: 1h 6min 49s
「training.py」の引数は、次のとおりです。
・--dataset : 入力pklファイル
・--base_model : モデルファイル
・--batch_size : バッチサイズ (デフォルト:1)
・--optim : オプティマイザ (adam,adagrad,sgd, デフォルト:adam)
・--learning_rate : 学習率 (デフォルト:3e-6)
・--warmup_steps : 学習率のウォーミングアップステップ (デフォルト:0)
・--max_train_steps : 最大学習ステップ数 (デフォルト:-1)
・--run_name : 実行ID (checkpoint内のサブフォルダ名)
・--save_every : Nステップ毎のモデルの保存 (デフォルト:10000)
・--max_answer_len : 最大回答トークン数
・--gpu : GPUのID (デフォルト:0)
・--train_type : 学習種別 (QtoA,AtoQ, デフォルト:QtoA)
(7) 推論の実行...の前にmodel.pyの修正。
推論をそのまま実行すると、以下のエラーがでました。
Traceback (most recent call last):
File "text2text.py", line 42, in <module>
hparams = HParams(**params)
TypeError: __init__() got an unexpected keyword argument 'n_prediction'
そこで、「model.py」を以下のように修正して、「n_prediction」を受け取る(けど何もしない)ようにしました。
class HParams:
def __init__(self,
n_vocab=0,
n_ctx=1024,
n_embd=768,
n_head=12,
n_layer=12):
↓
class HParams:
def __init__(self,
n_vocab=0,
n_ctx=1024,
n_embd=768,
n_head=12,
n_layer=12,
n_prediction=0):
(8) 推論の実行。
ドラゴンボールの序盤のあらすじを要約してもらいました。
!python text2text.py --model checkpoint/run1 --context "地球の人里離れた山奥に住む尻尾の生えた少年・孫悟空はある日、西の都からやって来た少女ブルマと出会う。そこで、7つ集めると神龍が現れ、どんな願いでも一つだけ叶えてくれるというドラゴンボールの存在を、さらに育ての親である孫悟飯の形見として大切に持っていた球がその1つ「四星球」であることを知り、ブルマと共に残りのドラゴンボールを探す旅に出る。 人さらいのウーロンや盗賊のヤムチャなどを巻き込んだボール探しの末、世界征服を企むピラフ一味にボールを奪われ神龍を呼び出されるが、ウーロンがとっさに言い放った下らない願いを叶えてもらうことで一味の野望を阻止する。その後、悟空は旅の途中に知り合った武術の達人・亀仙人の下で、後に親友となるクリリンと共に8か月間にわたる修行を積み、その成果を確かめるために世界一の武術の達人を決める天下一武道会に出場し、変装して出場していた亀仙人に敗れるも準優勝を果たす。 悟空は再び修行の旅へと出発し、ドラゴンボールの悪用を企むレッドリボン軍との闘いや、孫悟飯との再会などを経てさらに強さを増していく。さらに3年後の天下一武道会では、亀仙流のライバルである鶴仙流の天津飯と闘うが、あと一歩のところで敗れ、前回と同じく準優勝に終わる。"
「ドラゴンボール超」の孫悟空の活躍を紹介している。
データセットと学習の量が少なく、モデルもsmallを使ったので、うまくいくか不安でしたが、それっぽい要約がでました。
ランダムさを増やすと、次のようになりました。
!python text2text.py --num_generate 10 --top_k 2 --model checkpoint/run1 --context "地球の人里離れた山奥に住む尻尾の生えた少年・孫悟空はある日、西の都からやって来た少女ブルマと出会う。そこで、7つ集めると神龍が現れ、どんな願いでも一つだけ叶えてくれるというドラゴンボールの存在を、さらに育ての親である孫悟飯の形見として大切に持っていた球がその1つ「四星球」であることを知り、ブルマと共に残りのドラゴンボールを探す旅に出る。 人さらいのウーロンや盗賊のヤムチャなどを巻き込んだボール探しの末、世界征服を企むピラフ一味にボールを奪われ神龍を呼び出されるが、ウーロンがとっさに言い放った下らない願いを叶えてもらうことで一味の野望を阻止する。その後、悟空は旅の途中に知り合った武術の達人・亀仙人の下で、後に親友となるクリリンと共に8か月間にわたる修行を積み、その成果を確かめるために世界一の武術の達人を決める天下一武道会に出場し、変装して出場していた亀仙人に敗れるも準優勝を果たす。 悟空は再び修行の旅へと出発し、ドラゴンボールの悪用を企むレッドリボン軍との闘いや、孫悟飯との再会などを経てさらに強さを増していく。さらに3年後の天下一武道会では、亀仙流のライバルである鶴仙流の天津飯と闘うが、あと一歩のところで敗れ、前回と同じく準優勝に終わる。"
孫悟飯の活躍を目の当たりにした悟空は、「7つ集めるとドラゴンボール」を育てるため、世界征服を企むピラフ一味と闘う。 映画「7つ集めると決めるな」のポスタービジュアルが話題の「7つ子」と孫悟空の「7つ子プロジェクト」が始動する
========
「ドラゴンボール超」の「孫悟空」の存在を知った悟空は、孫悟空の力になりたいと思い、「孫悟空」の力になりたいと思う。
========
「ドラゴンボール超」の孫悟飯に憧れるが「ドラゴンボール超」の孫悟空に勝てず2016年の世界最強の武道会では、孫悟飯に敗れ2017年の世界最強の武道会には出られなくなっていった
========
「ドラゴンボール超」の「孫悟空」の活躍に注目
========
「ドラゴンボール超」の孫悟空の活躍を紹介している
========
孫悟飯の活躍を目の当たりにした悟空は、孫悟飯が「7つ集めると神龍が現れ、どんな願でも一つだけ叶えてくれるという』という「ドラゴンボールの存在」を発見する
========
「亀の子」を育てた孫悟飯に憧れ、7つ集めると神龍が現れ、どんな願いででも一つだけ叶えてくれるという「ドラゴンボールの名もない少年」が完成する
========
「亀の子のプロレス」を終えた悟空は孫悟飯と再会を果たし、孫悟飯と2人で世界征服を企む「7つ集めるとドラゴンボール」を育てる
========
「亀の子」と「ドラゴンボール」の隠された力を紹介している
========
「ドラゴンボール超」の「孫悟空」の存在を知った悟空は、孫悟空の力になりたいと思うようになるが、「ドラゴンボール超」の存在を知り、悟空の力になりたいと思うようになる
「text2text.py」の引数は、次のとおりです。
・--model : モデル (デフォルト:gpt2ja-medium)
・--output_file : 出力ファイル
・--context : コンテキスト
・--num_generate : 生成数 (デフォルト:1)
・--top_k : 各ステップでk個の単語からランダム選択 (デフォルト :1)
・--top_p : 生成テキストを累積確率に制限 (デフォルト:0(制限なし))
・--temperature : 温度 (0.0〜1.0, デフォルト:1.0)
・--allow_duplicate_line : 重複した行を自動削除しない
・--full_sentences : 複数回のモデル実行に入力文章を分割して、全ての出力文章を繋げて出力
・--gpu : GPUのID (デフォルト:0)
5. ROUGE1-F1による評価
「ROUGE1-F1」は、文章を適切に要約できているかを評価するための指標です。「モデルが生成した要約」と「人間が生成した要約」を比較します。
「valid_data」フォルダに学習データと同じ書式で、評価データを配置し、以下のコマンドを実行します。
!python make_score.py --model checkpoint/run1 --src_dir valid_data --split_tag '<|SEP|>'
score = 0.043855953613521
これが増えれば、より適切に要約ができていると判定できます。
「make_score.py」の引数は、次のとおりです。
・--src_dir : 評価データのフォルダ
・--model : モデル
・--max_answer_len : 回答トークンの最大数
・--min_answer_len : 回答トークンの最小数
・--train_type : 学習種別 (QtoA,AtoQ, デフォルト:QtoA)
・--dataset_type : データセット種別 (デフォルト:'split')
・--split_tag : 分割タグ (デフォルト:'<|SP_QA|>')
・--top_k : 各ステップでk個の単語からランダム選択 (デフォルト :1)
・--top_p : 生成テキストを累積確率に制限 (デフォルト:0(制限なし))
・--temperature : 温度 (0.0〜1.0, デフォルト:1.0)
・--gpu : GPUのID (デフォルト:0)
・--verbose : 詳細出力