CLIPモデルで画像特徴点の抽出とElasticsearchで類似画像検索
類似画像検索システムを検討するにあたってCLIP(2021年2月にOpenAIによって公開された,言語と画像のマルチモーダルモデル)を試してみました。
1.Elasticsearchのマッピング定義
import json
from elasticsearch import Elasticsearch
es = Elasticsearch("http://0.0.0.0:9200")
# インデックス名
index_name = "test_index"
# インデックスを削除
#response = es.indices.delete(index=index_name)
mapping = {
"mappings": {
"properties": {
"metadata": {
"properties": {
"image_code": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
}
},
"vector": {
"type": "dense_vector",
"dims": 512,
"index": True,
"similarity": "cosine"
}
}
}
}
# マッピングを作成
response = es.indices.create(index=index_name, body=mapping)
2.CLIPで画像のベクトル化とElasticsearchへのデータ投入
%%time
import os
import numpy as np
from PIL import Image
import torch
from clip import clip
import json
from elasticsearch import Elasticsearch
# CLIPモデルのロード
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# 画像のベクトル化
def extract_features(image_path):
# 画像の前処理
image = Image.open(image_path)
image = preprocess(image).unsqueeze(0).to(device)
# 特徴量の抽出
with torch.no_grad():
image_features = model.encode_image(image).cpu().numpy()
return image_features / np.linalg.norm(image_features)
es = Elasticsearch("http://0.0.0.0:9200")
# rowsは画像パスのコードが格納された配列
for row in rows:
image_path = "../image/"+row["image_code"]+".jpg"
if os.path.isfile(image_path):
features = extract_features(image_path)
# 登録するデータ
data = {
"metadata":{
"image_code": str(row["image_code"]),
},
"vector":features[0]
}
response = es.index(index="test_index", body=data)
3.類似画像検索
%%time
import pprint
import sys
import os
import numpy as np
from PIL import Image
import torch
from clip import clip
import json
from elasticsearch import Elasticsearch
# CLIPモデルのロード
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# 画像のベクトル化
def extract_features(image_path):
# 画像の前処理
image = Image.open(image_path)
image = preprocess(image).unsqueeze(0).to(device)
# 特徴量の抽出
with torch.no_grad():
image_features = model.encode_image(image).cpu().numpy()
return image_features / np.linalg.norm(image_features)
image_path = "./test.jpg"
query_features = extract_features(image_path)
#==========================================================
es = Elasticsearch("http://0.0.0.0:9200")
# ドキュメントを検索するためのクエリ
# Elasticsearchが負のスコアを許容しないため+1
q = {
"size": 20,
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0 ",
"params": {"query_vector": query_features[0]}
}
}
}
}
# ドキュメントを検索
result = es.search(index="test_index", body=q)
# 検索結果からドキュメントの内容のみ表示
docs=[]
content = []
for document in result["hits"]["hits"]:
content.append([document["_score"], document["_source"]["metadata"]["image_code"]])
es.close()
print(content)