如何添加节点重试策略¶
在许多用例中,你可能希望为节点设置自定义重试策略,例如,当你调用 API、查询数据库或调用大语言模型(LLM)等情况时。
环境设置¶
首先,让我们安装所需的包并设置 API 密钥
import getpass
import os
def _set_env(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"{var}: ")
_set_env("ANTHROPIC_API_KEY")
为 LangGraph 开发设置 LangSmith
注册 LangSmith 以快速发现问题并提升你的 LangGraph 项目的性能。LangSmith 允许你使用跟踪数据来调试、测试和监控使用 LangGraph 构建的大语言模型应用程序 — 点击 此处 了解更多关于如何开始使用的信息。
为了配置重试策略,你必须将 retry
参数传递给 add_node。retry
参数接受一个名为 RetryPolicy
的元组对象。下面我们使用默认参数实例化一个 RetryPolicy
对象:
RetryPolicy(initial_interval=0.5, backoff_factor=2.0, max_interval=128.0, max_attempts=3, jitter=True, retry_on=<function default_retry_on at 0x78b964b89940>)
默认情况下,retry_on
参数使用 default_retry_on
函数,该函数会对除以下异常之外的任何异常进行重试:
ValueError
TypeError
ArithmeticError
ImportError
LookupError
NameError
SyntaxError
RuntimeError
ReferenceError
StopIteration
StopAsyncIteration
OSError
此外,对于来自 requests
和 httpx
等流行 HTTP 请求库的异常,它仅在状态码为 5xx 时进行重试。
向节点传递重试策略¶
最后,我们可以在调用 add_node 函数时传递 RetryPolicy
对象。在下面的示例中,我们为每个节点分别传递了两种不同的重试策略:
import operator
import sqlite3
from typing import Annotated, Sequence
from typing_extensions import TypedDict
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import BaseMessage
from langgraph.graph import END, StateGraph, START
from langchain_community.utilities import SQLDatabase
from langchain_core.messages import AIMessage
db = SQLDatabase.from_uri("sqlite:///:memory:")
model = ChatAnthropic(model_name="claude-2.1")
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
def query_database(state):
query_result = db.run("SELECT * FROM Artist LIMIT 10;")
return {"messages": [AIMessage(content=query_result)]}
def call_model(state):
response = model.invoke(state["messages"])
return {"messages": [response]}
# Define a new graph
builder = StateGraph(AgentState)
builder.add_node(
"query_database",
query_database,
retry=RetryPolicy(retry_on=sqlite3.OperationalError),
)
builder.add_node("model", call_model, retry=RetryPolicy(max_attempts=5))
builder.add_edge(START, "model")
builder.add_edge("model", "query_database")
builder.add_edge("query_database", END)
graph = builder.compile()
API Reference: BaseMessage | SQLDatabase | AIMessage