RAG Day2 LCEL
RAG Day2 はLCELでのRAG実装についてです。
LCELとはLangchainのLLMアプリケーション開発を簡素化するためのLangChainフレームワークの一部で、チェーンを簡単に組むことができます。
LCELには以下のような特徴があります。
チェーンのシンプルな表現
ストリーミングのサポート
非同期のサポート
バッチのサポート
RetrievalQAをそのまま使うとチェーンが組みにくくて使いにくいので今後は基本的にLCELを使っていくことになります。
早速実装 DAY1でやったRAGの実装をLCELで実装していきます。
ソースコードはこちらです。
下準備
Day1と共通している箇所です。
ドキュメントのロード ~ retrieverの作成・プロンプト・LLMインスタンスの作成などは共通しているので、説明は省略します。
from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import MarkdownHeaderTextSplitter,RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
#Document load
loader = DirectoryLoader("../datasets/company_documents_dataset_1/", glob="**/*.txt",recursive=True)
raw_docs = loader.load()
# Document split
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
]
markdown_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on,
return_each_line=False,
strip_headers = False
)
docs = []
for raw_doc in raw_docs:
source = raw_doc.metadata["source"]
spilited_docs = markdown_splitter.split_text(raw_doc.page_content)
for doc in spilited_docs:
doc.metadata["source"] = source#metadataにsourceを加える
docs = docs + spilited_docs
markdown_splited_docs = docs
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 300,chunk_overlap=50)
docs = text_splitter.split_documents(docs)
# Embd
vectorstore = Chroma.from_documents(persist_directory="./vecstore/index", documents=docs, embedding=OpenAIEmbeddings())
#llm
llm = ChatOpenAI(model_name="gpt-3.5-turbo",temperature=0)
# retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
# prompt
prompt = PromptTemplate(
input_variables=["context","question"],
template="""以下の参考用のテキストの一部を参照して、Questionに回答してください。もし参考用のテキストの中に回答に役立つ情報が含まれていなければ、分からない、と答えてください。
{context}
Question: {question}
Answer: """
)
LCELでチェーンを組む
ここからが前回と違うところです。
まずは今回のLCELでのコードと前回のコードを比較するために
#chain
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain_core.output_parsers import StrOutputParser
rag_chain = (
{
"question":RunnablePassthrough(),#invokeで指定したtextが入る。
"context":retriever
}
|prompt
|llm
|StrOutputParser()
)
question = "社長の名前は?"
rag_chain.invoke(question)
'漆黒 花太郎'
コード解説です!
LCELではチェーンを "| "で表すことができます
"| "の上から下に実行されると考えればいいので、直感的でわかりやすいです。
実行は簡単で、rag_chain.invoke(question)で実行可能です。
このコードで一番よくわからないのは、RunnablePassthrough()だと思います。
RunnablePassthrough()を言葉で説明する前に図で説明するとこのようになります。
言葉で説明すると、rag_chain.invoke(question)で渡したquestion引数の値をそのまま、promptに渡しているのです。
なぜ、このような処理が必要なのか見ていきます。
まず、prompt引数をもう一度見てみましょう
prompt = PromptTemplate(
input_variables=["context","question"],
template="""以下の参考用のテキストの一部を参照して、Questionに回答してください。もし参考用のテキストの中に回答に役立つ情報が含まれていなければ、分からない、と答えてください。
{context}
Question: {question}
Answer: """
)
input_variables=["context","question"]からcontextとquestionの二つの引数を受け取ってるのがわかりますよね。
これはRAGが検索結果であるcontextをもとにquestionを回答するpromptなので当然だと言えます。
もし、このコードが下記のようになってるとどうなるでしょうか?
#このコードはinputエラーが出ます。
rag_chain = (
{
"context":retriever
}
|prompt
|llm
|StrOutputParser()
)
question = "社長の名前は?"
rag_chain.invoke(question)
LCELでは上から順番に実行していくので、prompt作成時に"question"キーの値がないので、エラーが出ます。
questionキーにはinvokeの引数の値をそのまま渡したい
contextキーにはinvokeの引数で検索した結果を渡したい
ここで、RunnablePassthrough()の出番です。
RunnablePassthrough()では、処理を飛ばして、次のchainにそのまま引数を渡すことができるのです。
また、実はこのコードは実はとあるコードが省略して書かれています。
省略せずに書くと次のようになります
#chain
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain_core.output_parsers import StrOutputParser
rag_chain = (
RunnableParallel({
"question":RunnablePassthrough(),#invokeで指定したtextが入る。
"context":retriever
})
|prompt
|llm
|StrOutputParser()
)
question = "社長の名前は?"
rag_chain.invoke(question)
省略せずに書くと、辞書型変数が、RunnableParallelの引数として渡されてるのがわかりますね。
RunnableParallelとは何をしているかというと、名前の通りParallelに値を渡すことができる関数です。
つまり、rag_chain.invoke(question)のquestion引数の値を、
辞書型変数の"question"キーと"context"に渡しているということになります。
LCELのRunnable系の関数についてはこちらが詳しく書かれていたので参考にしてください。
チェーンを可視化する
rag_chain.get_graph().print_ascii()
+---------------------------------+
| Parallel<question,context>Input |
+---------------------------------+
** **
*** ***
** **
+-------------+ +----------------------+
| Passthrough | | VectorStoreRetriever |
+-------------+ +----------------------+
** **
*** ***
** **
+----------------------------------+
| Parallel<question,context>Output |
+----------------------------------+
*
*
*
+----------------+
| PromptTemplate |
+----------------+
*
*
*
+------------+
| ChatOpenAI |
+------------+
*
*
*
+-----------------+
| StrOutputParser |
+-----------------+
*
*
*
+-----------------------+
| StrOutputParserOutput |
+-----------------------+
最後に一つLCELの便利な関数を覚えて、今日のnoteは終了です。
この.print_ascii()関数は簡単にチェーンを可視化することができます。
| Parallel<question,context>Input |
+---------------------------------+
** **
*** ***
** **
+-------------+ +----------------------+
| Passthrough | | VectorStoreRetriever |
+-------------+ +----------------------+
** **
*** ***
** **
+----------------------------------+
| Parallel<question,context>Output |
+----------------------------------+
引数がPassthroughされてParallel<question,context>Outputとなっているのがわかりますね。
このようにLCELにはチェーンに関する非常に便利な関数が用意されているようです。
終わりに
今日はここまで。次回からはやっとRAGの色々な手法をやっていきたいと思います。
この記事が気に入ったらサポートをしてみませんか?