(備忘)保存したFAISS indexを使ってTimeWeightedVectorStoreRetrieverを動かす【LangChain】
Time Weighted VectorStore Retriever
最近、LangChainにTimeWeightedVectorStoreRetrieverというRetrieverが実装されました。これは、意味の類似性とドキュメントのアクセス時刻に基づいて検索結果をランク付けすることができます。これにより、関連性と新鮮さを同時に考慮できます。
例えば、キャラクターが以前に学んだ情報が時間経過とともに徐々に薄れる場合や、新しい情報を優先して取得する場合に役立ちます。また、キャラクター同士の対話や物語の進行に応じて、適切な情報を提供するために、検索結果の優先度を時間や関連性に基づいて調整することもできます。
他にも、トレンド情報やニュースなどの更新が早いドキュメントの検索において、古い情報よりも最新の情報を優先して取得する場合などにも役立ちそうです。
TimeWeightedVectorStoreRetrieverは、FAISSインデックスを使用して効率的に検索できるように設計されています。FAISSインデックスはディスクに保存できますが、TimeWeightedVectorStoreRetrieverにはself.memory_streamという変数にドキュメントのリストが格納されるため、単に保存したインデックスを読み込むだけでは動作しません。
memory_streamの保存
TimeWeightedVectorStoreRetrieverクラスのadd_documentsメソッドを見てみましょう。このメソッドは、ドキュメントのリストを受け取り、それらをself.memory_streamに追加します。しかし、このself.memory_streamはFAISSインデックスとは別に格納されており、プログラムの実行が終了すると失われてしまいます。そのため、FAISSインデックスを保存・読み込みするだけでは、self.memory_streamの情報が失われてしまい、正常に動作しなくなります。
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
current_time = kwargs.get("current_time", datetime.now())
dup_docs = [deepcopy(d) for d in documents]
for i, doc in enumerate(dup_docs):
if "last_accessed_at" not in doc.metadata:
doc.metadata["last_accessed_at"] = current_time
if "created_at" not in doc.metadata:
doc.metadata["created_at"] = current_time
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
self.memory_stream.extend(dup_docs)
return self.vectorstore.add_documents(dup_docs, **kwargs)
そこで、TimeWeightedVectorStoreRetrieverを継承し、self.memory_streamをディスクに保存・読み込む2つの関数を追加したTimeWeightedVectorStoreRetrieverWithPersistenceクラスを作成しました。
memory_streamを保存・読み込みすることで、プログラムの実行が終了した後も状態を引き継げるはずです。
import pickle,os
from langchain.retrievers import TimeWeightedVectorStoreRetriever
class TimeWeightedVectorStoreRetrieverWithPersistence(TimeWeightedVectorStoreRetriever):
persistent_path: str
def save_memory_stream(self) -> None:
"""Save the memory stream to disk."""
if not os.path.exists(self.persistent_path):
os.makedirs(self.persistent_path)
memory_stream_path = os.path.join(self.persistent_path, "memory_stream.pkl")
with open(memory_stream_path, 'wb') as file:
pickle.dump(self.memory_stream, file)
self.vectorstore.save_local(self.persistent_path)
def load_memory_stream(self) -> None:
"""Load the memory stream."""
memory_stream_path = os.path.join(self.persistent_path, "memory_stream.pkl")
if os.path.exists(memory_stream_path):
with open(memory_stream_path, 'rb') as file:
self.memory_stream = pickle.load(file)
save_memory_stream関数:
self.memory_streamとFAISSインデックスをディスクに保存します。FAISSインデックスも保存している理由は、保存し忘れるとmemory_streamと不整合になるためです。
load_memory_stream関数:
ディスクからmemory_streamを読み込み、self.memory_streamに代入します。
使い方
まず、FAISS.load_localで既存のFAISSインデックスを読み込みます。
次に、persistent_pathを渡してretrieverを作成します。
その後に、load_memory_stream()を呼ぶことで、保存されていたmemory_stream.pklが読み込まれた状態になります。
vectorstore = FAISS.load_local(persistent_path, embeddings_model)
retriever = TimeWeightedVectorStoreRetrieverWithPersistence(vectorstore=vectorstore, persistent_path=persistent_path)
retriever.load_memory_stream()
retriever.save_memory_stream()を呼ぶと、retriever作成時に渡したpersistent_pathにFAISSインデックスとmemory_stream.pklが保存されます。
retriever.save_memory_stream()
まとめ
TimeWeightedVectorStoreRetrieverの状態の永続化方法として、self.memory_streamをディスクに保存・読み込む関数を作成しました。
ただ、他の実装をしっかり読み込んだわけではないので、なにか問題があるかも知れません。
TimeWeightedVectorStoreRetrieverは、チャットボットなどの記憶機構として非常に便利な気がしています。以前の記事で書いたものに組み込んだりしてみたいです。
今回の記事は、作成したクラスや目的などの情報をGPT-4に渡して、文章のベースを作ってもらいました。
この記事が気に入ったらサポートをしてみませんか?