LLMのembeddingベクトルを使ってテキスト分類のフレームワークを提案したClusterLLMの疑似コードを公開しました!
LLMとクラスタリングを組み合わせた手法 ClusterLLM が公開されました!
ClusterLLMのメリット
従来の機械学習の組み合わせ: ClusterLLMは、従来の機械学習の手法を効果的に組み合わせることで、より高度なテキスト分類を実現しています。
継続的なチューニング: ClusterLLMのフレームワークは、継続的なpromptチューニングフローを持っています。これにより、分類の精度を向上させることができます。
ChatGPTは中立的なラベル生成能力を持っており、偏りのないラベルを持つことは、データセット全体を公平に評価・分析する上で重要なことです。
偏見を持たないラベルを使用することで、真のデータの特性やトレンドをより正確に捉えることができ、特定のクラスやカテゴリの特徴を正確に理解し、それに基づいてhard negative sampleを効果的に抽出することが可能になります。
つまり、ChatGPTは、似ているが微妙に異なるサンプルを正確に識別するのに役立つということです。ClusterLLM はこの特徴を生かした手法です。
※"Hard negative sample"は、機械学習の文脈で、正しいカテゴリと非常に似ているが、実際には異なるカテゴリに属するサンプルを指します。これらのサンプルは、分類器にとって「困難なネガティブ例」となります。
[論文]ClusterLLM: Large Language Models as a Guide for Text Clustering
著者: Yuwei Zhang, Zihan Wang, Jingbo Shang
概要
ClusterLLMは、ChatGPTのような指示調整された大規模言語モデルからのフィードバックを活用する新しいテキストクラスタリングフレームワークを紹介します。伝統的な教師なし方法と比較して、ClusterLLMには2つの魅力的な利点があります。
LLMの能力: その埋め込みがアクセス不可能であっても、LLMの能力を享受できます。
ユーザーの好みの理解: テキスト指示や少量の注釈付きデータを通じて、クラスタリングに関するユーザーの好みを理解します。
主なポイント
ChatGPTにクラスタリングの視点に関する洞察を求めるために、難しい3つ組の質問を構築します。この戦略は、小さな埋め込みの微調整に効果的であり、ChatGPTに問い合わせるのにコスト効率的です。
クラスタリングの粒度に関するChatGPTの助けを求めるために、慎重に設計されたペアワイズの質問を使用します。
14のデータセットでの広範な実験により、ClusterLLMがクラスタリングの品質を一貫して向上させることが示されました。平均コストはデータセットあたり約$0.6です。
ClusterLLMの実装ガイドライン
データの準備:
テキストデータを収集し、前処理を行います(例: トークン化、ストップワードの削除など)。
小さな埋め込みの作成:
既存のテキスト埋め込みモデル(例: Word2Vec, FastText, BERTなど)を使用して、テキストデータの埋め込みを生成します。
ChatGPTへの質問:
似ているが異なるクラスタに属するデータポイントを選択し、ChatGPTに3つ組の質問を行います。
ChatGPTの回答を使用して、小さな埋め込みを微調整します。
クラスタリングの粒度の調整:
ペアワイズの質問を設計し、ChatGPTにクラスタリングの粒度に関する質問を行います。
ChatGPTの回答に基づいて、クラスタの階層の粒度を調整します。
クラスタリングの実行:
調整された埋め込みを使用して、テキストデータのクラスタリングを行います。
評価:
クラスタリングの結果を評価するための適切なメトリクス(例: シルエット係数)を使用して、クラスタリングの品質を評価します。
ちゃんとした動くコードにしてくださるのは
きっとどなたがやってくださると信じて……Pythonの疑似コードを公開します。
Python風疑似コード
# CLUSTER LLM 疑似コード
import clustering_algorithm
import large_language_model
import numpy as np
from evaluation_metrics import silhouette_coefficient
corpus = load_corpus()
embedding_space = pre_trained_embedding(corpus)
def entropy_based_triplet_sampling(embeddings, num_triplets, cluster_assignments):
"""
2.1.1 エントロピーに基づく3つ組サンプリング
3つ組をサンプリングします。アンカーはエントロピーに基づいて選ばれ、最も近いクラスタからの2つの選択肢が付随します。
"""
triplets = []
for i in range(num_triplets):
anchor = select_instance_based_on_entropy(cluster_assignments)
close_clusters = select_closest_clusters(anchor, cluster_assignments)
c1, c2 = random.sample(close_clusters, 2)
triplet = (anchor, c1, c2)
triplets.append(triplet)
return triplets
def prompt_LLM_for_triplets(triplets, task_instruction):
"""
2.1 3つ組タスクの視点
3つ組をLLMに提示して、アンカーに最も近い選択肢を取得します。
"""
answers = []
for triplet in triplets:
query = f"{task_instruction} Which is closer to {triplet[0]}: {triplet[1]} or {triplet[2]}?"
answer = large_language_model.predict(query)
answers.append(answer)
return answers
def fine_tune_embedding_space(embeddings, answers_from_LLM, triplets):
"""
2.1.2 埋め込みの微調整
LLMからの回答に基づいてembeddingsを微調整します。
"""
for i, answer in enumerate(answers_from_LLM):
anchor, choice = triplets[i][0], triplets[i][answer]
adjust_distance(embeddings, anchor, choice)
return embeddings
def pairwise_hierarchical_sampling(embeddings, granularity_range):
"""
2.2.1 ペアワイズ階層サンプリングを使用した粒度の決定
微調整された埋め込みで階層的クラスタリングを使用して、クラスタリング階層を生成します。
"""
hierarchical_clustering = perform_hierarchical_clustering(embeddings)
pairs = []
for granularity in granularity_range:
cluster_assignments = get_cluster_assignments_at_granularity(hierarchical_clustering, granularity)
pair = sample_pair_from_clusters(cluster_assignments)
pairs.append(pair)
return pairs
def prompt_LLM_for_pairs(pairs, demo_pairs):
"""
2.2 ペアワイズタスクの粒度
インスタンスのペアをLLMに提示して、それらが同じクラスタに属しているかどうかを取得します。
"""
answers = []
for pair in pairs:
query = f"Given the demos {demo_pairs}, do {pair[0]} and {pair[1]} belong to the same category?"
answer = large_language_model.predict(query)
answers.append(answer)
return answers
def determine_granularity(answers_from_LLM):
"""
2.2.1 ペアワイズ階層サンプリングを使用した粒度の決定 (続き)
LLMからの回答を元に、最も一貫性のある粒度を選択します。
"""
consistency_scores = calculate_consistency(answers_from_LLM)
optimal_granularity = np.argmax(consistency_scores)
return optimal_granularity
## メインのロジック ##
# 1. データ準備
corpus = load_corpus()
preprocessed_corpus = preprocess_corpus(corpus) # Tokenization, stop-word removal, etc.
# 2. 小さな埋め込みを作成する
embedding_space = pre_trained_embedding(preprocessed_corpus) # Use models like Word2Vec, FastText, BERT
# 3. ChatGPTに質問する
initial_clusters = clustering_algorithm.initial_clustering(embedding_space)
triplets = entropy_based_triplet_sampling(embedding_space, num_triplets, initial_clusters)
answers_for_triplets = prompt_LLM_for_triplets(triplets, user_specified_instruction)
embedding_space = fine_tune_embedding_space(embedding_space, answers_for_triplets, triplets)
# 4. クラスタリングの粒度を調整する
pairs = pairwise_hierarchical_sampling(embedding_space, granularity_range)
answers_for_pairs = prompt_LLM_for_pairs(pairs, user_demo_pairs)
granularity = determine_granularity(answers_for_pairs)
# 5. クラスタリングの実行
clusters = clustering_algorithm(embedding_space, granularity)
# 6. 評価
score = silhouette_coefficient(clusters, embedding_space)
# クラスターとその評価スコアをチェック
clusters, score