BERTとFaissを使って、類似キーワードを出力してみた
はじめに
こんにちは!株式会社POLでエンジニアをやっている @mejihabenatawa です!
POLは「研究者の可能性を最大化するプラットフォームを創造する」をビジョンに、理系学生に特化した採用サービス、および研究開発者・技術者に特化した転職/採用サービスの2サービスを運営しています。
今回、GWに機械学習周りの勉強をしようと思い、いろんなものを触っていたのですが、そのうちの一つについてブログを書きます!
今回やったこと
・BERTで単語ベクトルを取得
・Faissを使って、ベクトルが近いキーワードを取得
BERTで単語ベクトルを取得
キーワードは技術系のキーワードを拾ってきました。いくつか例を示すとこんな感じです。
research_tag_df['tag'].head()
## 実行結果
0 タンパク質結晶構造学
1 カウンターパルセーター
2 気胸モニタリング
3 配向性
4 選択体系機能言語理論
Name: tag, dtype: object
これらの単語ベクトルを取得したいので、BERTを import して以下のように関数を定義する。
# BERT
import torch
from transformers import BertJapaneseTokenizer, BertForMaskedLM
from transformers import BertJapaneseTokenizer, BertModel
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
model_bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking', output_hidden_states=True)
model_bert.eval()
def calc_embedding(text):
bert_tokens = tokenizer.tokenize(text)
ids = tokenizer.convert_tokens_to_ids(["[CLS]"] + bert_tokens[:126] + ["[SEP]"])
tokens_tensor = torch.tensor(ids).reshape(1, -1)
with torch.no_grad():
output = model_bert(tokens_tensor)
return output[1].numpy()
実際にいれる。
value_np = research_tag_df['tag'].map(calc_embedding)
Faissを使ってベクトルが近いキーワードを取得
Faiss は Facebook社が開発している近傍探索のライブラリです。以下のように index にデータを追加することで検索することができるようになり、search関数の返り値として、value_np に各データに対して似ているデータの 配列の位置とベクトル同士の距離を取得することができる。
index = faiss.IndexFlatL2(value_np.shape[1])
index.add(value_np)
D, I = index.search(target_np, 10)
print(I[:5])
print(D[:5])
print(
## 実行結果
[
[ 0 6630 1499 9407 2992 2129 172 3741 7228 836]
[ 1 9153 7314 3096 8410 6710 9305 3890 8014 6145]
]
[
[ 0. 12.640289 12.712769 16.70929 17.302261 17.314362 17.566467
17.58229 17.605682 17.645508]
[ 0. 23.757477 23.898285 24.858276 25.150726 26.226501 26.32248
26.343475 26.71762 26.73346 ]
]
実際にどの単語と単語が似ているのかを知るために表示するとこうなる。
## 「タンパク質結晶構造学」に似ている単語を表示
research_tag_df['tag'][I[0]]
## 実行結果
0 タンパク質結晶構造学
6630 分子構造解析学 カルバ糖
1499 分子構造解析学
9407 骨組構造力学
2992 生物分類学
2129 細胞免疫学
172 遺伝子発現変動解析
3741 統合生理心理学
7228 神経科学一般生物物理学
836 認知作業療法学
Name: tag, dtype: object
## 「気胸モニタリング」に似ている単語を表示
research_tag_df['tag'][I[2]]
## 実行結果
2 気胸モニタリング
4363 損傷モニタリング
2298 質量イメージング
6753 河川モニタリング
1722 シナプス伝達
9519 圧ストレス
8950 文書管理
6404 リアルタイムレンダリング
3868 リアルタイムスケジューリング
3094 強制対流沸騰
Name: tag, dtype: object
## 「隣人祭り」に似ている単語を表示
research_tag_df['tag'][I[6]]
## 実行結果
6 隣人祭り
4537 社会福祉実習
6108 ウォーキングイベント
7207 伝承遊び
9512 農村コミュニティ
401 職場談話
4896 職場ストレス
8002 都市祭礼
7285 地域福祉と契約制度
8858 防災計画支援
Name: tag, dtype: object
(隣人祭りってなんだとは気になりつつ...)BERTは単語単位ではなく、toknizer でトークン単位に分けているので(タンパク質結晶構造学とかは ['タンパク質', '結晶', '構造', '学'] とかに分解される)、その影響で単語の末尾が強く反映されている気がする。
おわりに
BERTのEmbeddingは本来であれば、文脈を考慮することができるのがメリットなのであまり良い使い方はしていないが、自分のBERTの理解の浅さに気づく結果がでてきてよかった。もう少しBERTを理解して、使用する出力層とか考えてできるようにすると良さそう。続編書きたい。
そして、株式会社POLではエンジニア、デザイナー、プロダクトマネージャーを大募集してます!お話しだけでも構いませんのでお気軽にお声がけください!!!