見出し画像

AzureからGPTを使ってみる | LangGraphを使う(条件分岐:01)


以前の記事で、簡単に直線でノードをつなげるグラフを作成しましたが、

今回は、条件によって分岐するグラフを作ってみたいと思います。
ポイントは"add_conditional_edges"を使って、定義した関数で次に進むノードを決定するところになります。

条件分岐のグラフを作成する

こんなグラフを作ってみます。

#langgraph=0.2.35 
from typing import Literal
from typing_extensions import TypedDict
from langgraph.graph import StateGraph

import random

#Stateを宣言
class State(TypedDict):
    value: str

#Nodeを定義    
def start_node(state: State):
    return {"value": "1"}

def node2(state: State):
    return {"value": "2"}

def node3(state: State):
    return {"value": "3"}

#分岐条件を関数として定義します。
#nが1 or 0で次に進むNodeを決定します。
def routing(state: State) -> Literal["node2", "node3"]:
    n = random.randint(0,1)
    if n==0:
        return "node2"
    if n==1:
        return "node3"

#Graphの作成
graph_builder = StateGraph(State)

graph_builder.add_node("start_node", start_node)
graph_builder.add_node("node2", node2)
graph_builder.add_node("node3", node3)
graph_builder.add_node("end_node", lambda state: {"value": state["value"]})

graph_builder.set_entry_point("start_node")

#edgeでNodeをつなげます。
graph_builder.add_conditional_edges(
    "start_node",
    routing,#条件分岐の関数
)

graph_builder.add_edge("node2", "end_node")
graph_builder.add_edge("node3", "end_node")

graph_builder.set_finish_point("end_node")

# Graphをコンパイル
graph = graph_builder.compile()

# Graphの実行
graph.invoke(
    {"value": ""}, 
    debug=True
    )

valueが、2をとっていればnode2、3をとっていればnode3を通過していることがわかります。

LLMを使用した条件分岐Graph


Userからの質問に対して、参照テキストを使用して回答するか分岐するGraphを作成してみたいと思います。

grade_node: retrieverを使用するかどうかの判断をするnode。
retrieverを使用する場合は"RETREVE"、使用しないで素のGPTとして回答する場合は、"CHAT"とだけ回答します。
grade_nodeに続くconditional_edgeで、grade_nodeの応答に応じて次に進むnodeを決定します。
chat: retrieverを使わずに、llm.invokeでuserの質問に回答するnode。
retrieve_node: retrieverを使用してuserに回答するnode。
response: 上流のStateをただreturnするnode。

retrieverの使い方は、以下の記事を参考にしてください。

LLMとRetrieverを定義

今回は、retrieverにジョジョの奇妙な冒険についてのテキストを使用します。ジョジョについての質問には、retrieverを使用するようにします。

#langchain-chroma=0.1.4
#langchain-openai=0.1.15
#langchain-text-splitters=0.2.2  
#langgraph=0.2.35 
#langchain=0.2.16   

from langchain_chroma import Chroma

from langchain_openai import AzureOpenAIEmbeddings
from langchain_openai import AzureChatOpenAI
from langchain_text_splitters import MarkdownHeaderTextSplitter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

from dotenv import load_dotenv
import os
import json

# OpenAI APIキーの設定
dotenv_path = ".env"
load_dotenv(dotenv_path)

OPENAI_API_BASE = os.getenv('OPENAI_API_BASE')
OPENAI_API_VERSION = os.getenv('OPENAI_API_VERSION')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

os.environ["AZURE_OPENAI_API_KEY"] = OPENAI_API_KEY
os.environ["AZURE_OPENAI_ENDPOINT"] = OPENAI_API_BASE

# Load example document
with open("jojo.md") as f:
    state_of_the_union = f.read()

# Markdown見出しで文章を分割します    
headers_to_split_on = [
    ("#", "Header 1"),
    ("##", "Header 2"),
    ("###", "Header 3"),
    ("####", "Header 4"),
]    

markdown_splitter = MarkdownHeaderTextSplitter(
    headers_to_split_on=headers_to_split_on,
    strip_headers = False #コンテンツに分割されるヘッダーを含めるかどうか
    )

md_header_splits = markdown_splitter.split_text(state_of_the_union)

#embedding
embed = AzureOpenAIEmbeddings(
    model = 'text-embedding-3-small',
)
documents = md_header_splits

# VectorStoreの準備
vectorstore = Chroma.from_documents(
    documents,
    embedding=embed,
    #collection_name="example_collection",
    #persist_directory="./chroma_langchain_db",  # Where to save data locally, remove if not neccesary
)   

# Retrieverの準備
retriever = vectorstore.as_retriever()

llm = AzureChatOpenAI(
    openai_api_version = OPENAI_API_VERSION,
    azure_deployment = "gpt-4o-mini",
    verbose=True
                      )

nodeとedgeを定義

from typing_extensions import TypedDict, Optional
from langchain_core.pydantic_v1 import BaseModel, Field

class State(TypedDict):
    message_type: Optional[str] = None
    message: Optional[str] = None

# Data model
class GradeDocuments(BaseModel):
    """ユーザーの質問に対して、参照テキストの必要性をチェックするためのバイナリスコア。"""

    binary_score: str = Field(
        description="ユーザーの質問に対して参照テキストの必要性を評価する。 応答は、'RETRIEVE' or 'CHAT'"
    )
    
# LLM with function call

structured_llm_grader = llm.with_structured_output(GradeDocuments)

def grade_node(State):
    #ユーザーの質問に対して参照テキストを使用するかを決定します。
    system = """あなたはユーザーの質問に対する評価者です。
    
    厳密なテストである必要はありません。目標は誤った回答を排除することです。
    ユーザーの質問に正確に答えられるかどうかを答えてください。
    「ジョジョの奇妙な冒険」に関する質問には、retrieverを使用して回答してください。
    
    retrieverを使用して回答する場合は"RETRIEVE"、retrieverを使用しない場合は"CHAT"と回答してください。
    """
    grade_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "User question: {question}"),
        ]
    )

    retrieval_grader = grade_prompt | structured_llm_grader
    
    if State["message"]:
        print("####grade_node####")
        print("Question: ")
        print(State["message"])
        print("garade_nodeの判断:")
        print(retrieval_grader.invoke({"question":State["message"] }))
        print("#####")

        return {
            "message_type": retrieval_grader.invoke({"question":State["message"] }).binary_score,
            "message": State["message"]
        }
    else:
        return {"message": "No user input"}
        
        
def chat(State):
    if State["message"]:
        return {"message": llm.invoke(State["message"])}
    return {"message": "No user input provided"}        
        

def retrieve_node(State):
    print("####retrieve_node####")
    print(State)
    last_message = State["message"]
    print(last_message)
    print("####")
    # プロンプトテンプレートの準備
    rag_prompt = """
    提供されたコンテキストのみを使用して、この質問に答えてください。

    {question}

    Context:
    {context}
    """

    prompt = ChatPromptTemplate.from_messages([("human", rag_prompt)]) 
    rag_chain = {"context": retriever, "question": RunnablePassthrough()} | prompt | llm
    #print("####retrieverの回答####")
    #print(rag_chain.invoke(last_message))
    #print("####")
    return {"message": rag_chain.invoke(last_message)}
def response(State):
    print("####response####")
    return State

graphを構築

from langgraph.graph import StateGraph, END, START
from typing import Annotated, Literal

def routing(state: State) -> Literal["chat", "retrieve_node"]:
    score = state["message_type"]
    if score =="CHAT":
        return "chat"
    if score =="RETRIEVE":
        return "retrieve_node"

graph_builder = StateGraph(State)

# ノードの追加
graph_builder.add_node("grade_node", grade_node)
graph_builder.add_node("chat", chat)
graph_builder.add_node("retrieve_node", retrieve_node)
graph_builder.add_node("response", response)

# エッジの追加
graph_builder.add_edge("chat", "response")
graph_builder.add_edge("retrieve_node", "response")
# 条件分岐
graph_builder.add_conditional_edges(
    "grade_node", 
    routing,
    )

# 開始位置、終了位置の指定
graph_builder.set_entry_point("grade_node")
graph_builder.set_finish_point("response")

# グラフ構築
graph = graph_builder.compile()

それでは、聞いてみましょう。

response = graph.invoke({"message": "DIOとはどんな人物ですか?"})

#print(response)
print(response["message"].content)

####grade_node####
Question: DIOとはどんな人物ですか?
garade_nodeの判断: binary_score='RETRIEVE'
#####
####retrieve_node####
{'message_type': 'RETRIEVE', 'message': 'DIOとはどんな人物ですか?'} DIOとはどんな人物ですか?
####
####response####
DIO(ディオ)は、かつてジョナサン・ジョースターと戦った吸血鬼であり、豪華客船の爆発で死亡したと思われていたが、実際にはジョナサンの遺体の首から下を奪い、海底で100年もの間眠り続けていた。復活後、彼はスタンド能力「ザ・ワールド」を覚醒させ、世界征服を目指して配下のスタンド使いを増やしていった。DIOはその絶大なカリスマ性と恐怖によって人々を支配し、ジョースター家との因縁を持つ重要な敵キャラクターである。最終的に、空条承太郎との壮絶な戦闘の末に敗北し、その影響力は死後も物語に影響を与え続ける。

grade_nodeが、ちゃんとretrieverを使用する判断をしています。
回答も、参照テキストに沿った回答になっています。

次に、ジョジョとは関係のないことを聞いてみます。

response = graph.invoke({"message": "日本の観光名所を教えてください。"})

print(response["message"].content)

####grade_node####
Question: 日本の観光名所を教えてください。
garade_nodeの判断: binary_score='CHAT'
#####
####response####
もちろんです!日本には多くの魅力的な観光名所があります。いくつかご紹介しますね。
1. **東京**
- **東京タワー**:シンボル的な展望台で、夜景が美しいです。
- **浅草寺**:東京で最も古い寺院で、雷門や仲見世通りも楽しめます。
- **渋谷交差点**:世界的に有名な交差点で、観光名所の一つです。
2. **京都**
- **金閣寺**:美しい金色の寺院で、庭園も素晴らしいです。
- **清水寺**:歴史的な寺院で、舞台からの景色が絶景です。
- **伏見稲荷大社**:千本鳥居が有名な神社です。
3. **大阪**
- **大阪城**:歴史的な城で、周囲の公園も美しいです。
- **道頓堀**:グルメやショッピングが楽しめるエリアで、ネオンが華やかです。
4. **広島**
- **原爆ドーム**:歴史的な遺構で、平和記念公園が隣接しています。

... - **美ら海水族館**:巨大な水槽があり、海洋生物を間近で見ることができます。
- **首里城**:琉球王国の歴史を感じられる場所です。
これ以外にも日本全国にはたくさんの観光名所がありますので、訪れる地域によって楽しみ方が異なります。どこに行くか、ぜひ計画してみてください!

ちゃんとretrieverを使用しないで回答する判断をしています。
しかし、いつから渋谷交差点が、日本の観光名所になったのでしょうか?

いいなと思ったら応援しよう!