Local-LLM+LongLLMLingua, RAG series 3/n
RAGシリーズ3回目。
今回はLlamaindexでLongLLMLinguaを用いたRAGです。
通常のRAGではpromptが長くなりがちで、計算コストが嵩む(ChatGPTなどを使用する場合は費用が嵩む)、性能が低下する(ex.:“Lost in the middle”)、などの課題が生じます。それらの課題を、適切にpromptを圧縮するLLMLinguaとRerankingなどを組み合わせることで克服する手法です。
0. 環境
1. text読込
今回はLongLLMLinguaの論文を用いました。
from llama_index.node_parser import SimpleNodeParser
from llama_index import SimpleDirectoryReader
path = r".\nlp\LongLLMLingua"
documents = SimpleDirectoryReader(path).load_data()
node_parser = SimpleNodeParser.from_defaults(chunk_size=256)
base_nodes = node_parser.get_nodes_from_documents(documents)
2. LLMとembeddingのモデル指定
今回はLLMにZephyr-7B-βを4bit量子化し、embeddingにはBAAI/bge-small-en-v1.5を使用しました。
import torch
from transformers import BitsAndBytesConfig
from llama_index.llms import HuggingFaceLLM
from llama_index import ServiceContext
from llama_index.embeddings import resolve_embed_model
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
llm = HuggingFaceLLM(
model_name="HuggingFaceH4/zephyr-7b-beta",
tokenizer_name="HuggingFaceH4/zephyr-7b-beta",
context_window=2048,
max_new_tokens=512,
model_kwargs={"quantization_config": quantization_config},
generate_kwargs={"temperature": 0.1, "top_k": 50, "top_p": 0.95},
device_map="auto",
)
embed_model = resolve_embed_model("local:BAAI/bge-small-en-v1.5")
service_context = ServiceContext.from_defaults(
llm=llm, embed_model=embed_model, chunk_size=256
)
3. index, retrieverの設定
from llama_index import VectorStoreIndex, ServiceContext
base_index = VectorStoreIndex(base_nodes, service_context=service_context)
base_retriever = base_index.as_retriever(similarity_top_k=5)
4. LongLLMLinguaポストプロセスの設定
LlamaindexのLongLLMLinguaPostprocessorでLongLLMLinguaを実装することができます。
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import CompactAndRefine
from llama_index.indices.postprocessor import LongLLMLinguaPostprocessor
node_postprocessor = LongLLMLinguaPostprocessor(
instruction_str="Given the context, please answer the final question",
target_token=300,
rank_method="longllmlingua",
additional_compress_kwargs={
"condition_compare": True,
"condition_in_question": "after",
"context_budget": "+100",
"reorder_context": "sort",
"dynamic_context_compression_ratio": 0.3,
},
)
5. 方法1. Step-by-Step
質問は下記のLongLLMLinguaの利点について聞いてみます。
question = "What are the advantages of LongLLMLingua?"
retrieverの設定
retrieved_nodes = base_retriever.retrieve(question)
synthesizer = CompactAndRefine()
from llama_index.indices.query.schema import QueryBundle
new_retrieved_nodes = node_postprocessor.postprocess_nodes(
retrieved_nodes, query_bundle=QueryBundle(query_str=question)
)
圧縮前後の比較
original_contexts = "\n\n".join([n.get_content() for n in retrieved_nodes])
compressed_contexts = "\n\n".join([n.get_content() for n in new_retrieved_nodes])
original_tokens = node_postprocessor._llm_lingua.get_token_length(original_contexts)
compressed_tokens = node_postprocessor._llm_lingua.get_token_length(compressed_contexts)
print(compressed_contexts)
print()
print("Original Tokens:", original_tokens)
print("Compressed Tokens:", compressed_tokens)
print("Comressed Ratio:", f"{original_tokens/(compressed_tokens + 1e-5):.2f}x")
今回は引用する文章が短いので効果は限定的ですが、それでも1/8程度に圧縮されています。
LongLLMLinguaで圧縮されたpromptで質疑
response = synthesizer.synthesize(question, new_retrieved_nodes)
print(str(response))
翻訳
1/8に圧縮したpromptでも適切な回答が得られています。
計算コストで言えば1/60程度になるため、申し分ない結果のように感じます。引用する文章が長ければ長いほど、計算コストの恩恵も大きく、また回答精度も通常のRAGに対して優位性があるようなので、非常に優秀な手法のように思います。
6. 方法2. End-to-End
上記同様
retriever_query_engine = RetrieverQueryEngine.from_args(
base_retriever, node_postprocessors=[node_postprocessor]
)
response = retriever_query_engine.query(question)
print(str(response))
先ほどと同じ回答が得られました。
7. 比較例:通常のRAG
query_engine_base = RetrieverQueryEngine.from_args(
base_retriever, service_context=service_context
)
base_response = query_engine_base.query(question)
print(str(base_response))
やはりLongLLMLinguaに対して計算時間は非常にかかったものの、優秀な回答です。
7. 参考
https://github.com/microsoft/LLMLingua/blob/main/examples/RAGLlamaIndex.ipynb