
AzureからGPTを使ってみる | LangGraphを使う(条件分岐:01)
以前の記事で、簡単に直線でノードをつなげるグラフを作成しましたが、
今回は、条件によって分岐するグラフを作ってみたいと思います。
ポイントは"add_conditional_edges"を使って、定義した関数で次に進むノードを決定するところになります。
条件分岐のグラフを作成する

こんなグラフを作ってみます。
#langgraph=0.2.35
from typing import Literal
from typing_extensions import TypedDict
from langgraph.graph import StateGraph
import random
#Stateを宣言
class State(TypedDict):
value: str
#Nodeを定義
def start_node(state: State):
return {"value": "1"}
def node2(state: State):
return {"value": "2"}
def node3(state: State):
return {"value": "3"}
#分岐条件を関数として定義します。
#nが1 or 0で次に進むNodeを決定します。
def routing(state: State) -> Literal["node2", "node3"]:
n = random.randint(0,1)
if n==0:
return "node2"
if n==1:
return "node3"
#Graphの作成
graph_builder = StateGraph(State)
graph_builder.add_node("start_node", start_node)
graph_builder.add_node("node2", node2)
graph_builder.add_node("node3", node3)
graph_builder.add_node("end_node", lambda state: {"value": state["value"]})
graph_builder.set_entry_point("start_node")
#edgeでNodeをつなげます。
graph_builder.add_conditional_edges(
"start_node",
routing,#条件分岐の関数
)
graph_builder.add_edge("node2", "end_node")
graph_builder.add_edge("node3", "end_node")
graph_builder.set_finish_point("end_node")
# Graphをコンパイル
graph = graph_builder.compile()
# Graphの実行
graph.invoke(
{"value": ""},
debug=True
)
valueが、2をとっていればnode2、3をとっていればnode3を通過していることがわかります。
LLMを使用した条件分岐Graph
Userからの質問に対して、参照テキストを使用して回答するか分岐するGraphを作成してみたいと思います。

grade_node: retrieverを使用するかどうかの判断をするnode。
retrieverを使用する場合は"RETREVE"、使用しないで素のGPTとして回答する場合は、"CHAT"とだけ回答します。
grade_nodeに続くconditional_edgeで、grade_nodeの応答に応じて次に進むnodeを決定します。
chat: retrieverを使わずに、llm.invokeでuserの質問に回答するnode。
retrieve_node: retrieverを使用してuserに回答するnode。
response: 上流のStateをただreturnするnode。
retrieverの使い方は、以下の記事を参考にしてください。
LLMとRetrieverを定義
今回は、retrieverにジョジョの奇妙な冒険についてのテキストを使用します。ジョジョについての質問には、retrieverを使用するようにします。
#langchain-chroma=0.1.4
#langchain-openai=0.1.15
#langchain-text-splitters=0.2.2
#langgraph=0.2.35
#langchain=0.2.16
from langchain_chroma import Chroma
from langchain_openai import AzureOpenAIEmbeddings
from langchain_openai import AzureChatOpenAI
from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from dotenv import load_dotenv
import os
import json
# OpenAI APIキーの設定
dotenv_path = ".env"
load_dotenv(dotenv_path)
OPENAI_API_BASE = os.getenv('OPENAI_API_BASE')
OPENAI_API_VERSION = os.getenv('OPENAI_API_VERSION')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
os.environ["AZURE_OPENAI_API_KEY"] = OPENAI_API_KEY
os.environ["AZURE_OPENAI_ENDPOINT"] = OPENAI_API_BASE
# Load example document
with open("jojo.md") as f:
state_of_the_union = f.read()
# Markdown見出しで文章を分割します
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
("####", "Header 4"),
]
markdown_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on,
strip_headers = False #コンテンツに分割されるヘッダーを含めるかどうか
)
md_header_splits = markdown_splitter.split_text(state_of_the_union)
#embedding
embed = AzureOpenAIEmbeddings(
model = 'text-embedding-3-small',
)
documents = md_header_splits
# VectorStoreの準備
vectorstore = Chroma.from_documents(
documents,
embedding=embed,
#collection_name="example_collection",
#persist_directory="./chroma_langchain_db", # Where to save data locally, remove if not neccesary
)
# Retrieverの準備
retriever = vectorstore.as_retriever()
llm = AzureChatOpenAI(
openai_api_version = OPENAI_API_VERSION,
azure_deployment = "gpt-4o-mini",
verbose=True
)
nodeとedgeを定義
from typing_extensions import TypedDict, Optional
from langchain_core.pydantic_v1 import BaseModel, Field
class State(TypedDict):
message_type: Optional[str] = None
message: Optional[str] = None
# Data model
class GradeDocuments(BaseModel):
"""ユーザーの質問に対して、参照テキストの必要性をチェックするためのバイナリスコア。"""
binary_score: str = Field(
description="ユーザーの質問に対して参照テキストの必要性を評価する。 応答は、'RETRIEVE' or 'CHAT'"
)
# LLM with function call
structured_llm_grader = llm.with_structured_output(GradeDocuments)
def grade_node(State):
#ユーザーの質問に対して参照テキストを使用するかを決定します。
system = """あなたはユーザーの質問に対する評価者です。
厳密なテストである必要はありません。目標は誤った回答を排除することです。
ユーザーの質問に正確に答えられるかどうかを答えてください。
「ジョジョの奇妙な冒険」に関する質問には、retrieverを使用して回答してください。
retrieverを使用して回答する場合は"RETRIEVE"、retrieverを使用しない場合は"CHAT"と回答してください。
"""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "User question: {question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
if State["message"]:
print("####grade_node####")
print("Question: ")
print(State["message"])
print("garade_nodeの判断:")
print(retrieval_grader.invoke({"question":State["message"] }))
print("#####")
return {
"message_type": retrieval_grader.invoke({"question":State["message"] }).binary_score,
"message": State["message"]
}
else:
return {"message": "No user input"}
def chat(State):
if State["message"]:
return {"message": llm.invoke(State["message"])}
return {"message": "No user input provided"}
def retrieve_node(State):
print("####retrieve_node####")
print(State)
last_message = State["message"]
print(last_message)
print("####")
# プロンプトテンプレートの準備
rag_prompt = """
提供されたコンテキストのみを使用して、この質問に答えてください。
{question}
Context:
{context}
"""
prompt = ChatPromptTemplate.from_messages([("human", rag_prompt)])
rag_chain = {"context": retriever, "question": RunnablePassthrough()} | prompt | llm
#print("####retrieverの回答####")
#print(rag_chain.invoke(last_message))
#print("####")
return {"message": rag_chain.invoke(last_message)}
def response(State):
print("####response####")
return State
graphを構築
from langgraph.graph import StateGraph, END, START
from typing import Annotated, Literal
def routing(state: State) -> Literal["chat", "retrieve_node"]:
score = state["message_type"]
if score =="CHAT":
return "chat"
if score =="RETRIEVE":
return "retrieve_node"
graph_builder = StateGraph(State)
# ノードの追加
graph_builder.add_node("grade_node", grade_node)
graph_builder.add_node("chat", chat)
graph_builder.add_node("retrieve_node", retrieve_node)
graph_builder.add_node("response", response)
# エッジの追加
graph_builder.add_edge("chat", "response")
graph_builder.add_edge("retrieve_node", "response")
# 条件分岐
graph_builder.add_conditional_edges(
"grade_node",
routing,
)
# 開始位置、終了位置の指定
graph_builder.set_entry_point("grade_node")
graph_builder.set_finish_point("response")
# グラフ構築
graph = graph_builder.compile()
それでは、聞いてみましょう。
response = graph.invoke({"message": "DIOとはどんな人物ですか?"})
#print(response)
print(response["message"].content)
####grade_node####
Question: DIOとはどんな人物ですか?
garade_nodeの判断: binary_score='RETRIEVE'
#####
####retrieve_node####
{'message_type': 'RETRIEVE', 'message': 'DIOとはどんな人物ですか?'} DIOとはどんな人物ですか?
####
####response####
DIO(ディオ)は、かつてジョナサン・ジョースターと戦った吸血鬼であり、豪華客船の爆発で死亡したと思われていたが、実際にはジョナサンの遺体の首から下を奪い、海底で100年もの間眠り続けていた。復活後、彼はスタンド能力「ザ・ワールド」を覚醒させ、世界征服を目指して配下のスタンド使いを増やしていった。DIOはその絶大なカリスマ性と恐怖によって人々を支配し、ジョースター家との因縁を持つ重要な敵キャラクターである。最終的に、空条承太郎との壮絶な戦闘の末に敗北し、その影響力は死後も物語に影響を与え続ける。
grade_nodeが、ちゃんとretrieverを使用する判断をしています。
回答も、参照テキストに沿った回答になっています。
次に、ジョジョとは関係のないことを聞いてみます。
response = graph.invoke({"message": "日本の観光名所を教えてください。"})
print(response["message"].content)
####grade_node####
Question: 日本の観光名所を教えてください。
garade_nodeの判断: binary_score='CHAT'
#####
####response####
もちろんです!日本には多くの魅力的な観光名所があります。いくつかご紹介しますね。
1. **東京**
- **東京タワー**:シンボル的な展望台で、夜景が美しいです。
- **浅草寺**:東京で最も古い寺院で、雷門や仲見世通りも楽しめます。
- **渋谷交差点**:世界的に有名な交差点で、観光名所の一つです。
2. **京都**
- **金閣寺**:美しい金色の寺院で、庭園も素晴らしいです。
- **清水寺**:歴史的な寺院で、舞台からの景色が絶景です。
- **伏見稲荷大社**:千本鳥居が有名な神社です。
3. **大阪**
- **大阪城**:歴史的な城で、周囲の公園も美しいです。
- **道頓堀**:グルメやショッピングが楽しめるエリアで、ネオンが華やかです。
4. **広島**
- **原爆ドーム**:歴史的な遺構で、平和記念公園が隣接しています。
... - **美ら海水族館**:巨大な水槽があり、海洋生物を間近で見ることができます。
- **首里城**:琉球王国の歴史を感じられる場所です。
これ以外にも日本全国にはたくさんの観光名所がありますので、訪れる地域によって楽しみ方が異なります。どこに行くか、ぜひ計画してみてください!
ちゃんとretrieverを使用しないで回答する判断をしています。
しかし、いつから渋谷交差点が、日本の観光名所になったのでしょうか?