AI Agent 工作流引擎:DAG 编排、动态路由与容错设计
·
系列目录:本文是「AI 应用开发进阶实战」系列的第 4 篇。前面我们构建了 RAG、MCP 工具链和知识图谱,本篇进入 Agent 的核心——如何设计一个可靠的、可观测的工作流引擎。
一、为什么 Agent 需要工作流引擎?
1.1 从简单链到复杂工作流
最简单的 Agent:
user_input → LLM → output
加了 RAG 后:
user_input → retrieve → LLM → output (2步)
加了工具调用后:
user_input → LLM → tool_call → tool_result → LLM → output (循环)
加了多 Agent 后:
user_input → planner → [worker1, worker2, worker3] → aggregator → output
随着复杂度增长,直接写 if/else + while 循环的代码会迅速失控。工作流引擎提供:
| 能力 | 无引擎(裸写) | 有引擎 |
|---|---|---|
| 可视化 | 代码即文档,难以理解 | DAG 图一目了然 |
| 错误处理 | 到处 try/except,容易遗漏 | 声明式重试、降级、回退 |
| 并行执行 | asyncio.gather 自己管理 | 引擎自动拓扑排序并行 |
| 状态持久化 | 手动存 Redis/DB | 引擎内置检查点 |
| 可观测性 | 自己加日志 | 每个节点自动追踪 |
二、工作流引擎核心设计
2.1 DAG 工作流的数据结构
# workflow_engine.py
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set
from enum import Enum
import asyncio
import time
import json
class NodeStatus(Enum):
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
SKIPPED = "skipped"
@dataclass
class WorkflowNode:
"""工作流中的一个节点"""
name: str
handler: Callable # async function(context) -> Any
inputs: List[str] = field(default_factory=list) # 依赖节点名
retry_count: int = 0 # 失败重试次数
retry_delay: float = 1.0 # 重试间隔(秒)
timeout: float = 60.0 # 超时时间(秒)
condition: Optional[Callable] = None # 条件执行:fn(context)->bool
on_failure: str = "fail" # "fail" | "skip" | "continue"
@dataclass
class Workflow:
"""DAG 工作流定义"""
name: str
nodes: Dict[str, WorkflowNode]
edges: List[tuple] # [(from_node, to_node), ...]
def validate(self) -> bool:
"""校验 DAG 合法性:无环、所有依赖存在"""
# 1. 检查所有边引用的节点都存在
all_nodes = set(self.nodes.keys())
for src, dst in self.edges:
if src not in all_nodes:
raise ValueError(f"Edge source '{src}' not found in nodes")
if dst not in all_nodes:
raise ValueError(f"Edge target '{dst}' not found in nodes")
# 2. 检查无环(拓扑排序)
in_degree = {name: 0 for name in all_nodes}
adj = {name: [] for name in all_nodes}
for src, dst in self.edges:
adj[src].append(dst)
in_degree[dst] += 1
queue = [n for n, d in in_degree.items() if d == 0]
visited = 0
while queue:
node = queue.pop(0)
visited += 1
for neighbor in adj[node]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
if visited != len(all_nodes):
raise ValueError("Workflow contains a cycle!")
return True
@dataclass
class WorkflowContext:
"""工作流执行上下文——在节点间传递数据"""
inputs: Dict[str, Any] = field(default_factory=dict) # 用户输入
node_outputs: Dict[str, Any] = field(default_factory=dict) # 各节点输出
metadata: Dict[str, Any] = field(default_factory=dict)
def get(self, key: str, default=None):
"""优先从节点输出获取,其次从输入获取"""
return self.node_outputs.get(key) or self.inputs.get(key, default)
2.2 工作流执行引擎
class WorkflowEngine:
"""DAG 工作流执行引擎"""
def __init__(self):
self.executions: Dict[str, dict] = {} # 存储执行记录
async def execute(self, workflow: Workflow, context: WorkflowContext) -> dict:
"""执行工作流"""
workflow.validate()
exec_id = f"{workflow.name}_{int(time.time())}"
self.executions[exec_id] = {
"workflow": workflow.name,
"start_time": time.time(),
"node_statuses": {},
"results": {}
}
# 构建依赖图
in_degree = {name: 0 for name in workflow.nodes}
dependents = {name: [] for name in workflow.nodes} # 谁依赖我
for src, dst in workflow.edges:
in_degree[dst] += 1
dependents[src].append(dst)
print(f"\n{'='*60}")
print(f"Workflow: {workflow.name}")
print(f"Nodes: {len(workflow.nodes)}, Edges: {len(workflow.edges)}")
print(f"{'='*60}\n")
# 并发执行:当节点的所有依赖完成,即可执行
ready_queue = asyncio.Queue()
completed_count = 0
total_nodes = len(workflow.nodes)
# 入度为 0 的节点先入队
for name, degree in in_degree.items():
if degree == 0:
await ready_queue.put(name)
# 并发 worker
async def worker():
nonlocal completed_count
while completed_count < total_nodes:
try:
node_name = await asyncio.wait_for(
ready_queue.get(), timeout=1.0
)
except asyncio.TimeoutError:
continue
node = workflow.nodes[node_name]
# 检查条件执行
if node.condition and not node.condition(context):
print(f" [{node_name}] SKIPPED (condition=False)")
self.executions[exec_id]["node_statuses"][node_name] = NodeStatus.SKIPPED
else:
# 执行节点
result = await self._execute_node(
node, context, exec_id
)
context.node_outputs[node_name] = result
completed_count += 1
# 解锁依赖此节点的下游节点
for dependent in dependents[node_name]:
in_degree[dependent] -= 1
if in_degree[dependent] == 0:
await ready_queue.put(dependent)
# 启动并发 workers(最多 10 个并发)
workers = [
asyncio.create_task(worker())
for _ in range(min(10, total_nodes))
]
await asyncio.gather(*workers)
exec_record = self.executions[exec_id]
exec_record["end_time"] = time.time()
exec_record["duration"] = exec_record["end_time"] - exec_record["start_time"]
print(f"\nWorkflow completed in {exec_record['duration']:.1f}s")
return context.node_outputs
async def _execute_node(
self,
node: WorkflowNode,
context: WorkflowContext,
exec_id: str
) -> Any:
"""执行单个节点,带重试逻辑"""
last_error = None
for attempt in range(node.retry_count + 1):
try:
print(f" [{node.name}] Running... (attempt {attempt+1}/{node.retry_count+1})")
self.executions[exec_id]["node_statuses"][node.name] = NodeStatus.RUNNING
# 执行节点处理器(带超时)
result = await asyncio.wait_for(
node.handler(context),
timeout=node.timeout
)
self.executions[exec_id]["node_statuses"][node.name] = NodeStatus.SUCCESS
self.executions[exec_id]["results"][node.name] = {
"status": "success",
"attempts": attempt + 1
}
print(f" [{node.name}] SUCCESS")
return result
except asyncio.TimeoutError:
last_error = f"Timeout after {node.timeout}s"
print(f" [{node.name}] TIMEOUT")
except Exception as e:
last_error = str(e)
print(f" [{node.name}] ERROR: {e}")
if attempt < node.retry_count:
await asyncio.sleep(node.retry_delay)
# 所有重试失败后的处理
self.executions[exec_id]["node_statuses"][node.name] = NodeStatus.FAILED
self.executions[exec_id]["results"][node.name] = {
"status": "failed",
"error": last_error,
"attempts": node.retry_count + 1
}
if node.on_failure == "skip":
print(f" [{node.name}] FAILED → SKIPPING (on_failure=skip)")
return None
elif node.on_failure == "continue":
print(f" [{node.name}] FAILED → CONTINUING (on_failure=continue)")
return {"error": last_error}
else: # "fail"
raise RuntimeError(f"Node '{node.name}' failed: {last_error}")
2.3 构建示例:文档处理工作流
# example_workflow.py
from workflow_engine import Workflow, WorkflowNode, WorkflowContext, WorkflowEngine
# === 定义节点处理器 ===
async def load_documents(ctx: WorkflowContext) -> dict:
"""从输入路径加载文档"""
path = ctx.inputs.get("doc_path", "./docs")
import os, glob
files = glob.glob(os.path.join(path, "**/*.md"), recursive=True)
documents = []
for f in files[:20]: # 限制数量
with open(f, "r", encoding="utf-8") as fp:
documents.append({
"path": f,
"content": fp.read(),
"size": os.path.getsize(f)
})
print(f" Loaded {len(documents)} documents")
return {"documents": documents, "count": len(documents)}
async def chunk_documents(ctx: WorkflowContext) -> dict:
"""分块"""
docs = ctx.node_outputs["load_docs"]["documents"]
chunks = []
for doc in docs:
content = doc["content"]
# 简单按段落分块
paragraphs = content.split("\n\n")
for i, para in enumerate(paragraphs):
if len(para.strip()) > 50:
chunks.append({
"source": doc["path"],
"chunk_id": f"{doc['path']}#{i}",
"content": para.strip()[:1000]
})
print(f" Split into {len(chunks)} chunks")
return {"chunks": chunks, "count": len(chunks)}
async def generate_embeddings(ctx: WorkflowContext) -> dict:
"""生成向量嵌入"""
chunks = ctx.node_outputs["chunk_docs"]["chunks"]
# 模拟 embedding(实际中用 OpenAI API)
embeddings = []
for chunk in chunks:
embeddings.append({
"chunk_id": chunk["chunk_id"],
"embedding": [0.1] * 128, # 模拟向量
"content": chunk["content"][:100]
})
print(f" Generated {len(embeddings)} embeddings")
return {"embeddings": embeddings, "count": len(embeddings)}
async def extract_entities(ctx: WorkflowContext) -> dict:
"""提取实体(与 chunk 并行)"""
docs = ctx.node_outputs["load_docs"]["documents"]
# 模拟实体提取(实际中用 LLM)
entities = {}
for doc in docs:
# 简单关键词提取
for word in ["AI", "LLM", "Agent", "RAG"]:
count = doc["content"].count(word)
if count > 0:
entities[word] = entities.get(word, 0) + count
print(f" Extracted {len(entities)} entity types")
return {"entities": entities}
async def merge_and_index(ctx: WorkflowContext) -> dict:
"""合并 chunk 和 entity 结果,构建索引"""
embeddings = ctx.node_outputs["gen_embeddings"]["embeddings"]
entities = ctx.node_outputs["extract_entities"]["entities"]
# 合并为统一索引
index = {
"total_chunks": len(embeddings),
"total_entities": len(entities),
"top_entities": sorted(entities.items(), key=lambda x: x[1], reverse=True)[:5]
}
print(f" Built index: {index}")
return index
async def quality_check(ctx: WorkflowContext) -> dict:
"""质量检查——条件执行"""
index = ctx.node_outputs["merge_index"]
passed = index["total_chunks"] >= 5 # 至少 5 个块
result = {
"passed": passed,
"reason": "OK" if passed else f"Only {index['total_chunks']} chunks (< 5)",
"stats": index
}
status = "PASS" if passed else "FAIL"
print(f" Quality Check: {status}")
return result
async def notify_result(ctx: WorkflowContext) -> dict:
"""通知结果"""
qc = ctx.node_outputs["quality_check"]
if qc["passed"]:
msg = f"Workflow succeeded. Index built with {qc['stats']['total_chunks']} chunks."
else:
msg = f"Workflow completed but quality check failed: {qc['reason']}"
print(f" Notification: {msg}")
return {"message": msg, "sent": True}
# === 定义工作流 ===
def build_document_workflow():
"""构建文档处理 DAG 工作流"""
# define_node 辅助函数
def node(name, handler, deps=None, **kwargs):
return name, WorkflowNode(
name=name,
handler=handler,
inputs=deps or [],
retry_count=kwargs.get("retry", 1),
timeout=kwargs.get("timeout", 60),
condition=kwargs.get("condition"),
on_failure=kwargs.get("on_failure", "fail")
)
nodes = dict([
node("load_docs", load_documents),
node("chunk_docs", chunk_documents, deps=["load_docs"]),
node("gen_embeddings", generate_embeddings, deps=["chunk_docs"], timeout=120),
node("extract_entities", extract_entities, deps=["load_docs"]),
node("merge_index", merge_and_index, deps=["gen_embeddings", "extract_entities"]),
node("quality_check", quality_check, deps=["merge_index"], on_failure="skip"),
node("notify", notify_result, deps=["quality_check"], retry=2),
])
edges = [
("load_docs", "chunk_docs"),
("load_docs", "extract_entities"),
("chunk_docs", "gen_embeddings"),
("gen_embeddings", "merge_index"),
("extract_entities", "merge_index"),
("merge_index", "quality_check"),
("quality_check", "notify"),
]
return Workflow(name="document_processing", nodes=nodes, edges=edges)
# === 运行 ===
async def main():
engine = WorkflowEngine()
workflow = build_document_workflow()
context = WorkflowContext(inputs={
"doc_path": "./sample_docs",
"user_id": "user_123"
})
try:
results = await engine.execute(workflow, context)
print("\nFinal Results:")
for node_name, output in results.items():
print(f" {node_name}: {json.dumps(output, indent=2, ensure_ascii=False)[:200]}")
except Exception as e:
print(f"\nWorkflow failed: {e}")
if __name__ == "__main__":
asyncio.run(main())
执行流程可视化:
load_docs
/ \
/ \
chunk_docs extract_entities ← 并行执行!
| |
gen_embeddings |
\ /
\ /
merge_index
|
quality_check
|
notify
三、动态路由:条件分支
3.1 条件节点
# 条件节点示例:根据文档类型分流
async def classify_document(ctx: WorkflowContext) -> str:
"""分类文档类型,返回路由标签"""
doc = ctx.inputs.get("document", {})
content = doc.get("content", "")
if "合同" in content or "协议" in content:
return "contract"
elif "代码" in content or "function" in content:
return "code"
else:
return "general"
# 条件函数:只有分类为 "contract" 时才执行
def is_contract(ctx: WorkflowContext) -> bool:
return ctx.node_outputs.get("classify_doc") == "contract"
def is_code(ctx: WorkflowContext) -> bool:
return ctx.node_outputs.get("classify_doc") == "code"
# 工作流中的条件节点
contract_parser = WorkflowNode(
name="parse_contract",
handler=parse_contract_terms,
condition=is_contract # 仅当 classify_doc 返回 "contract"
)
code_analyzer = WorkflowNode(
name="analyze_code",
handler=analyze_code_structure,
condition=is_code
)
3.2 LLM 驱动的动态路由
from openai import OpenAI
async def llm_router(ctx: WorkflowContext) -> str:
"""让 LLM 决定下一步走哪个分支"""
client = OpenAI(api_key=ctx.inputs.get("api_key"))
user_query = ctx.inputs.get("query")
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{
"role": "system",
"content": """分析用户意图,返回以下路由标签之一:
- SEARCH: 需要检索知识库
- CODE: 需要生成/分析代码
- CHART: 需要数据可视化
- CHAT: 简单对话即可回答
- COMPLEX: 需要多步推理
只回复一个标签。"""
}, {
"role": "user",
"content": user_query
}],
temperature=0
)
route = response.choices[0].message.content.strip()
print(f" LLM Router: {user_query[:50]}... → {route}")
return route
# 动态路由版本
class DynamicWorkflow(Workflow):
"""支持动态路由的工作流"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dynamic_edges: Dict[str, Callable] = {} # node → routing_fn
def add_dynamic_route(self, from_node: str, routing_fn: Callable):
"""添加动态路由:fn(ctx) → target_node_name"""
self.dynamic_edges[from_node] = routing_fn
async def execute(self, context: WorkflowContext) -> dict:
"""执行时动态解析路由"""
results = {}
current = "start"
while current:
node = self.nodes.get(current)
if not node:
break
result = await self._execute_node(node, context, "")
results[current] = result
context.node_outputs[current] = result
# 动态路由
if current in self.dynamic_edges:
route_fn = self.dynamic_edges[current]
current = route_fn(context)
elif current in self.edges:
current = self.edges[current] # 固定边
else:
current = None
return results
四、状态机模式:复杂交互流程
对于需要多轮交互、状态转换的工作流(如审批流程),DAG 不适用——用有限状态机。
from enum import Enum, auto
class ApprovalState(Enum):
DRAFT = auto()
SUBMITTED = auto()
REVIEWING = auto()
APPROVED = auto()
REJECTED = auto()
REVISION_NEEDED = auto()
class StateMachineWorkflow:
"""基于状态机的审批工作流"""
def __init__(self):
self.transitions = {
ApprovalState.DRAFT: {
"submit": ApprovalState.SUBMITTED,
},
ApprovalState.SUBMITTED: {
"start_review": ApprovalState.REVIEWING,
"auto_approve": ApprovalState.APPROVED, # 自动通过
},
ApprovalState.REVIEWING: {
"approve": ApprovalState.APPROVED,
"reject": ApprovalState.REJECTED,
"request_revision": ApprovalState.REVISION_NEEDED,
},
ApprovalState.REVISION_NEEDED: {
"resubmit": ApprovalState.SUBMITTED,
},
}
self.handlers: Dict[tuple, Callable] = {} # (from, to) → handler
def on_transition(self, from_state, to_state, handler):
"""注册状态转换处理器"""
self.handlers[(from_state, to_state)] = handler
async def run(self, context: WorkflowContext) -> dict:
"""执行状态机"""
state = ApprovalState.DRAFT
history = []
while state not in (ApprovalState.APPROVED, ApprovalState.REJECTED):
# LLM 决定下一步操作
action = await self._decide_action(state, context)
if action not in self.transitions.get(state, {}):
raise ValueError(f"Invalid transition: {state} → {action}")
next_state = self.transitions[state][action]
# 执行转换处理器
handler = self.handlers.get((state, next_state))
if handler:
result = await handler(context)
context.node_outputs[f"{state.name}_{next_state.name}"] = result
history.append({"from": state.name, "to": next_state.name, "action": action})
state = next_state
return {
"final_state": state.name,
"history": history
}
async def _decide_action(self, state: ApprovalState, ctx: WorkflowContext) -> str:
"""让 LLM 决定当前状态下的操作"""
# 实现略
pass
五、可观测性
class ObservableWorkflowEngine(WorkflowEngine):
"""带完整可观测性的工作流引擎"""
def __init__(self):
super().__init__()
self.traces = []
async def execute(self, workflow, context):
"""执行并记录完整追踪"""
trace_id = f"trace_{int(time.time()*1000)}"
trace = {
"trace_id": trace_id,
"workflow": workflow.name,
"start_time": time.time(),
"spans": []
}
# Wrap 每个节点记录 span
original_execute = self._execute_node
async def traced_execute(node, ctx, exec_id):
span_start = time.time()
try:
result = await original_execute(node, ctx, exec_id)
span = {
"node": node.name,
"status": "success",
"duration": time.time() - span_start,
}
except Exception as e:
span = {
"node": node.name,
"status": "failed",
"error": str(e),
"duration": time.time() - span_start,
}
raise
finally:
trace["spans"].append(span)
return result
self._execute_node = traced_execute
try:
result = await super().execute(workflow, context)
trace["status"] = "success"
return result
except Exception:
trace["status"] = "failed"
raise
finally:
trace["end_time"] = time.time()
trace["total_duration"] = trace["end_time"] - trace["start_time"]
self.traces.append(trace)
# 打印追踪报告
self._print_trace_report(trace)
def _print_trace_report(self, trace: dict):
"""打印漂亮的可视化追踪"""
print(f"\n{'='*60}")
print(f"Trace: {trace['trace_id']}")
print(f"Workflow: {trace['workflow']} | Status: {trace['status']}")
print(f"Duration: {trace['total_duration']:.2f}s")
print(f"{'='*60}")
# 按耗时排序
spans = sorted(trace["spans"], key=lambda s: s["duration"], reverse=True)
for span in spans:
bar_len = int(span["duration"] / trace["total_duration"] * 30)
bar = "█" * bar_len
status_icon = "✓" if span["status"] == "success" else "✗"
print(f" {status_icon} {span['node']:<20s} {span['duration']:.1f}s {bar}")
六、总结
工作流引擎是 Agent 从"能跑"到"可靠"的关键:
DAG 编排 → 声明式定义任务依赖,自动并行
动态路由 → LLM 决策流程分支,灵活应对变化
重试机制 → 自动处理瞬时故障
超时控制 → 防止节点无限等待
条件执行 → 跳过不需要的分支
状态持久化 → 从失败节点恢复,不重做已完成工作
可观测性 → 追踪每个节点的耗时和状态
下一篇——系列最终篇:多 Agent 协作——任务分解、通信协议与并行编排。将一个复杂任务自动拆解给多个 Agent 并行执行。
本文完整代码已开源。下一篇:多 Agent 协作(最终篇,即将发布)
更多推荐


所有评论(0)