Skip to content

如何创建用于并行执行的 Map-Reduce 分支

前提条件

本指南假设你熟悉以下内容:

Map-Reduce 操作对于高效的任务分解和并行处理至关重要。这种方法包括将一个任务分解为较小的子任务,并行处理每个子任务,并汇总所有已完成子任务的结果。

考虑以下示例:根据用户提出的一个通用主题,生成一个相关主题列表,为每个主题生成一个笑话,并从生成的笑话列表中选择最佳笑话。在这种设计模式中,第一个节点可能会生成一个对象列表(例如,相关主题),并且我们希望将另一个节点(例如,生成笑话)应用于所有这些对象(例如,主题)。然而,会出现两个主要挑战。

(1) 在我们规划图时,对象(例如,主题)的数量可能事先未知(这意味着边的数量可能未知);(2) 下游节点的输入状态应该不同(每个生成的对象对应一个状态)。

LangGraph 通过其 Send API 解决了这些挑战。通过使用条件边,Send 可以将不同的状态(例如,主题)分发到一个节点的多个实例(例如,笑话生成)。重要的是,发送的状态可以与核心图的状态不同,从而实现灵活和动态的工作流管理。

Screenshot 2024-07-12 at 9.45.40 AM.png

配置

首先,让我们安装所需的软件包并设置我们的 API 密钥。

%%capture --no-stderr
%pip install -U langchain-anthropic langgraph
import os
import getpass


def _set_env(name: str):
    if not os.getenv(name):
        os.environ[name] = getpass.getpass(f"{name}: ")


_set_env("ANTHROPIC_API_KEY")

为 LangGraph 开发设置 LangSmith

注册 LangSmith,以便快速发现问题并提升你的 LangGraph 项目的性能。LangSmith 允许你使用跟踪数据来调试、测试和监控使用 LangGraph 构建的大语言模型应用程序 —— 点击 此处 了解更多关于如何开始使用的信息。

定义图

在 LangChain 中使用 Pydantic

本笔记本使用 Pydantic v2 的 BaseModel,这需要 langchain-core >= 0.3。使用 langchain-core < 0.3 会因混用 Pydantic v1 和 v2 的 BaseModel 而导致错误。

import operator
from typing import Annotated
from typing_extensions import TypedDict

from langchain_anthropic import ChatAnthropic

from langgraph.types import Send
from langgraph.graph import END, StateGraph, START

from pydantic import BaseModel, Field

# Model and prompts
# Define model and prompts we will use
subjects_prompt = """Generate a comma separated list of between 2 and 5 examples related to: {topic}."""
joke_prompt = """Generate a joke about {subject}"""
best_joke_prompt = """Below are a bunch of jokes about {topic}. Select the best one! Return the ID of the best one.

{jokes}"""


class Subjects(BaseModel):
    subjects: list[str]


class Joke(BaseModel):
    joke: str


class BestJoke(BaseModel):
    id: int = Field(description="Index of the best joke, starting with 0", ge=0)


model = ChatAnthropic(model="claude-3-5-sonnet-20240620")

# Graph components: define the components that will make up the graph


# This will be the overall state of the main graph.
# It will contain a topic (which we expect the user to provide)
# and then will generate a list of subjects, and then a joke for
# each subject
class OverallState(TypedDict):
    topic: str
    subjects: list
    # Notice here we use the operator.add
    # This is because we want combine all the jokes we generate
    # from individual nodes back into one list - this is essentially
    # the "reduce" part
    jokes: Annotated[list, operator.add]
    best_selected_joke: str


# This will be the state of the node that we will "map" all
# subjects to in order to generate a joke
class JokeState(TypedDict):
    subject: str


# This is the function we will use to generate the subjects of the jokes
def generate_topics(state: OverallState):
    prompt = subjects_prompt.format(topic=state["topic"])
    response = model.with_structured_output(Subjects).invoke(prompt)
    return {"subjects": response.subjects}


# Here we generate a joke, given a subject
def generate_joke(state: JokeState):
    prompt = joke_prompt.format(subject=state["subject"])
    response = model.with_structured_output(Joke).invoke(prompt)
    return {"jokes": [response.joke]}


# Here we define the logic to map out over the generated subjects
# We will use this as an edge in the graph
def continue_to_jokes(state: OverallState):
    # We will return a list of `Send` objects
    # Each `Send` object consists of the name of a node in the graph
    # as well as the state to send to that node
    return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]


# Here we will judge the best joke
def best_joke(state: OverallState):
    jokes = "\n\n".join(state["jokes"])
    prompt = best_joke_prompt.format(topic=state["topic"], jokes=jokes)
    response = model.with_structured_output(BestJoke).invoke(prompt)
    return {"best_selected_joke": state["jokes"][response.id]}


# Construct the graph: here we put everything together to construct our graph
graph = StateGraph(OverallState)
graph.add_node("generate_topics", generate_topics)
graph.add_node("generate_joke", generate_joke)
graph.add_node("best_joke", best_joke)
graph.add_edge(START, "generate_topics")
graph.add_conditional_edges("generate_topics", continue_to_jokes, ["generate_joke"])
graph.add_edge("generate_joke", "best_joke")
graph.add_edge("best_joke", END)
app = graph.compile()
from IPython.display import Image

Image(app.get_graph().draw_mermaid_png())

使用图表

# Call the graph: here we call it to generate a list of jokes
for s in app.stream({"topic": "animals"}):
    print(s)
{'generate_topics': {'subjects': ['Lions', 'Elephants', 'Penguins', 'Dolphins']}}
{'generate_joke': {'jokes': ["Why don't elephants use computers? They're afraid of the mouse!"]}}
{'generate_joke': {'jokes': ["Why don't dolphins use smartphones? Because they're afraid of phishing!"]}}
{'generate_joke': {'jokes': ["Why don't you see penguins in Britain? Because they're afraid of Wales!"]}}
{'generate_joke': {'jokes': ["Why don't lions like fast food? Because they can't catch it!"]}}
{'best_joke': {'best_selected_joke': "Why don't dolphins use smartphones? Because they're afraid of phishing!"}}

Comments