見出し画像

マルチモーダルembeddigモデルE5-Vを試してみる

画像と文字列を同じ埋め込みベクトル化できるマルチモーダルなembeddingモデルE5-Vというものを知ったので、試してみました。

画像と文字列を共にベクトル化できるとなると、先行するものとしてCLIPやSigLIPがありますが、このE5-Vは画像も理解するLLMであるLLaVA-NeXT-8Bをベースにしていることから、文章理解力が上がっているようです(上記論文参考)。

画像と文字列とでモダリティギャップがない、つまりは、似た意味なら画像だろうと文字列だろうと近いベクトルになる、という主張のようです。右側がE5-Vのイメージ

論文より https://arxiv.org/abs/2407.12580

とりあえず、サンプルプログラムを元に、動かしてみました。

動作環境
Windows 11, Python 3.12.4, GeForce RTX 4090 (VRAM 24GB)

画像はWikipediaから借りることにしました。

リャマの画像2つと、ニホンザルと、折鶴

文字列としては、'高原で暮らすリャマ', 'ボリビアのリャマ', '日本の温泉地にいるサル', '日本の伝統工芸'をチョイス。あえて「折鶴」や「折り紙」と言わずに、「日本の伝統工芸」として、意味を近く解釈するかどうかを見てみたかったからです。

これらで類似度マトリックスをとってみると、こんな感じでした。

文字列「高原で暮らすリャマ」「ボリビアのリャマ」と、画像の「リャマ1」「リャマ2」はそれぞれ、色が薄い、つまり類似度が近いと判定されています。

文字列「日本の伝統工芸」と画像「折り鶴」の組み合わせも、他に比べれば色が薄い。リャマよりは近いと判断できていそうです。

これはなかなか面白い。より長文で試してみたいところです。

今回試したコード

必要そうなライブラリをインストールします。

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install transformers accelerate matplotlib bitsandbytes requests pillow matplotlib-fontja

ほぼサンプルのままですが、一応コードを。

import torch
import torch.nn.functional as F
import requests
from PIL import Image, UnidentifiedImageError
from io import BytesIO
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import numpy as np
import matplotlib.pyplot as plt
import matplotlib_fontja

# テンプレートの定義
llama3_template = ('<|start_header_id|>user<|end_header_id|>\n\n{'
                   '}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n')

# プロセッサとモデルの初期化
processor = LlavaNextProcessor.from_pretrained('royokong/e5-v')
model = LlavaNextForConditionalGeneration.from_pretrained('royokong/e5-v', torch_dtype=torch.float16, load_in_4bit=True)

# プロンプトの定義
img_prompt = llama3_template.format('<image>\nSummary above image in one word: ')
text_prompt = llama3_template.format('<sent>\nSummary above sentence in one word: ')

# テキストと画像URLのリスト
texts = ['高原で暮らすリャマ',
         'ボリビアのリャマ',
         '日本の温泉地にいるサル',
         '日本の伝統工芸']

images_urls = ['https://upload.wikimedia.org/wikipedia/commons/1/1a/Dos_llamas_en_Bolivia_%28deciembre_2001%29.jpg',
               'https://upload.wikimedia.org/wikipedia/commons/f/f0/Llama_La_Paz_Bolivia.jpg',
               'https://upload.wikimedia.org/wikipedia/commons/8/8c/Jigokudani_hotspring_in_Nagano_Japan_001.jpg',
               'https://upload.wikimedia.org/wikipedia/commons/f/f8/%E6%8A%98%E9%B6%B4_WUXGA.jpg']

# テキストの処理
text_inputs = processor([text_prompt.replace('<sent>', text) for text in texts], return_tensors="pt", padding=True).to(
    'cuda')

# 画像のダウンロードと処理
images = []
# ブラウザのユーザーエージェントを偽装するためのヘッダー
headers = {
    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 '
                  'Safari/537.36'
}
for url in images_urls:
    response = requests.get(url, headers=headers)
    if response.status_code == 200:
        img_data = BytesIO(response.content)
        try:
            img = Image.open(img_data)
            images.append(img)
        except UnidentifiedImageError:
            print(f"画像を開けませんでした: {url}")
    else:
        print(f"画像のダウンロードに失敗しました: {url} ステータスコード: {response.status_code}")

img_inputs = processor([img_prompt] * len(images), images, return_tensors="pt", padding=True).to('cuda')

# 埋め込みの生成
with torch.no_grad():
    text_embs = model(**text_inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :]
    img_embs = model(**img_inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :]

# 埋め込みの正規化
text_embs = F.normalize(text_embs, dim=-1)
img_embs = F.normalize(img_embs, dim=-1)

# Combine text and image embeddings
combined_embs = torch.cat((text_embs, img_embs), dim=0)

# Calculate similarity matrix
similarity_matrix = combined_embs @ combined_embs.t()

# Convert to numpy for visualization
similarity_matrix_np = similarity_matrix.cpu().numpy()

# Mask the lower triangle of the matrix
mask = np.triu(np.ones_like(similarity_matrix_np, dtype=bool))
masked_similarity_matrix = np.where(mask, similarity_matrix_np, np.nan)  # Use np.nan to mask

# Define the labels for the x and y axes
labels = ['高原で暮らすリャマ', 'ボリビアのリャマ', '日本の温泉地にいるサル', '日本の伝統工芸',
          '画像:リャマ1', '画像:リャマ2', '画像:サル', '画像:折り鶴']


# Visualize the masked similarity matrix
plt.figure(figsize=(8, 8))
plt.imshow(masked_similarity_matrix, cmap='viridis')
plt.colorbar()
plt.title('文字列と画像を一緒にした類似度マトリックス')
# Set the custom labels for the x and y axes
plt.xticks(ticks=np.arange(len(labels)), labels=labels, rotation=45, ha="right")
plt.yticks(ticks=np.arange(len(labels)), labels=labels)
plt.show()

ライセンスについて

LLaVA-NeXT-8BはLLaMA-3 8Bを元にした、画像も解釈できるLLMです。商用利用はNGの研究目的モデルなので、このE5-Vもそのライセンスを引き継いでいると思います。また、LLaMA-3 8Bのライセンスなども引き継いでいるので、ライセンス回りは複雑です。E5-Vはあくまでも研究目的と理解するのが良さそうです。

画像について

クリエイティブ・コモンズ 表示-継承 3.0 非移植 1992, photo by SHIBUYA K.

クリエイティブ・コモンズ 表示-継承 3.0 非移植

クリエイティブ・コモンズ 表示-継承 3.0 非移植 作者Yosemite

パブリックドメイン


いいなと思ったら応援しよう!