見出し画像

BERTとFaissを使って、類似キーワードを出力してみた


はじめに

こんにちは!株式会社POLでエンジニアをやっている @mejihabenatawa です!

POLは「研究者の可能性を最大化するプラットフォームを創造する」をビジョンに、理系学生に特化した採用サービス、および研究開発者・技術者に特化した転職/採用サービスの2サービスを運営しています。

今回、GWに機械学習周りの勉強をしようと思い、いろんなものを触っていたのですが、そのうちの一つについてブログを書きます!

今回やったこと

・BERTで単語ベクトルを取得

・Faissを使って、ベクトルが近いキーワードを取得

BERTで単語ベクトルを取得

キーワードは技術系のキーワードを拾ってきました。いくつか例を示すとこんな感じです。

research_tag_df['tag'].head()

## 実行結果

0     タンパク質結晶構造学
1    カウンターパルセーター
2       気胸モニタリング
3            配向性
4     選択体系機能言語理論
Name: tag, dtype: object

これらの単語ベクトルを取得したいので、BERTを import して以下のように関数を定義する。

# BERT
import torch
from transformers import BertJapaneseTokenizer, BertForMaskedLM
from transformers import BertJapaneseTokenizer, BertModel
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
model_bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking', output_hidden_states=True)
model_bert.eval()

def calc_embedding(text):
   bert_tokens = tokenizer.tokenize(text)
   ids = tokenizer.convert_tokens_to_ids(["[CLS]"] + bert_tokens[:126] + ["[SEP]"])
   tokens_tensor = torch.tensor(ids).reshape(1, -1)
   with torch.no_grad():
       output = model_bert(tokens_tensor)
   return output[1].numpy()

実際にいれる。

value_np = research_tag_df['tag'].map(calc_embedding)

Faissを使ってベクトルが近いキーワードを取得

Faiss は Facebook社が開発している近傍探索のライブラリです。以下のように index にデータを追加することで検索することができるようになり、search関数の返り値として、value_np に各データに対して似ているデータの 配列の位置とベクトル同士の距離を取得することができる。

index = faiss.IndexFlatL2(value_np.shape[1])
index.add(value_np)

D, I = index.search(target_np, 10)
print(I[:5])
print(D[:5])
print(

## 実行結果

[
    [   0 6630 1499 9407 2992 2129  172 3741 7228  836]
    [   1 9153 7314 3096 8410 6710 9305 3890 8014 6145]
]

[
    [ 0.       12.640289 12.712769 16.70929  17.302261 17.314362 17.566467
 17.58229  17.605682 17.645508]
    [ 0.       23.757477 23.898285 24.858276 25.150726 26.226501 26.32248
 26.343475 26.71762  26.73346 ]
]

実際にどの単語と単語が似ているのかを知るために表示するとこうなる。

## 「タンパク質結晶構造学」に似ている単語を表示

research_tag_df['tag'][I[0]]

## 実行結果

0         タンパク質結晶構造学
6630    分子構造解析学 カルバ糖
1499         分子構造解析学
9407          骨組構造力学
2992           生物分類学
2129           細胞免疫学
172        遺伝子発現変動解析
3741         統合生理心理学
7228     神経科学一般生物物理学
836          認知作業療法学
Name: tag, dtype: object

## 「気胸モニタリング」に似ている単語を表示

research_tag_df['tag'][I[2]]

## 実行結果

2             気胸モニタリング
4363          損傷モニタリング
2298          質量イメージング
6753          河川モニタリング
1722            シナプス伝達
9519             圧ストレス
8950              文書管理
6404      リアルタイムレンダリング
3868    リアルタイムスケジューリング
3094            強制対流沸騰
Name: tag, dtype: object

## 「隣人祭り」に似ている単語を表示

research_tag_df['tag'][I[6]]

## 実行結果

6             隣人祭り
4537        社会福祉実習
6108    ウォーキングイベント
7207          伝承遊び
9512      農村コミュニティ
401           職場談話
4896        職場ストレス
8002          都市祭礼
7285     地域福祉と契約制度
8858        防災計画支援
Name: tag, dtype: object

(隣人祭りってなんだとは気になりつつ...)BERTは単語単位ではなく、toknizer でトークン単位に分けているので(タンパク質結晶構造学とかは ['タンパク質', '結晶', '構造', '学'] とかに分解される)、その影響で単語の末尾が強く反映されている気がする。

おわりに

BERTのEmbeddingは本来であれば、文脈を考慮することができるのがメリットなのであまり良い使い方はしていないが、自分のBERTの理解の浅さに気づく結果がでてきてよかった。もう少しBERTを理解して、使用する出力層とか考えてできるようにすると良さそう。続編書きたい。


そして、株式会社POLではエンジニア、デザイナー、プロダクトマネージャーを大募集してます!お話しだけでも構いませんのでお気軽にお声がけください!!!


いいなと思ったら応援しよう!