Go实现DAG工作流引擎:AI Agent任务编排内核
·
一、为什么需要DAG工作流引擎
AI Agent的核心能力不是单次LLM调用,而是多步骤任务编排。一个典型的Agent请求链路:
用户意图识别 → 工具选择 → 并行调用(搜索+数据库查询) → 结果聚合 → LLM合成 → 输出
这个链条有明确的依赖关系:聚合必须等并行调用全部完成,合成依赖聚合结果。这正是有向无环图(DAG) 的天然建模场景。
| 问题 | 朴素方案 | DAG方案 |
|---|---|---|
| 依赖管理 | if-else硬编码 | 拓扑排序自动解析 |
| 并行执行 | 手动goroutine | 层级分组自动并行 |
| 错误传播 | 调用链断裂 | 节点级重试+降级 |
| 可观测性 | 无结构化日志 | 节点粒度trace |
下面用Go从零实现一个生产可用的DAG工作流引擎。
二、核心数据结构
package dag
import (
"context"
"fmt"
"sync"
"time"
)
// NodeFunc 节点执行函数
// input 包含该节点所有上游节点的输出,key为节点ID
// 返回的输出会被下游节点消费
type NodeFunc func(ctx context.Context, input map[string]interface{}) (interface{}, error)
// Node 工作流中的一个任务节点
type Node struct {
ID string // 唯一标识
DependsOn []string // 依赖的上游节点ID列表
Handler NodeFunc // 执行逻辑
Retries int // 失败重试次数,0表示不重试
Timeout time.Duration // 单节点超时,0表示不限制
}
// WorkflowState 工作流执行过程中的共享状态
type WorkflowState struct {
mu sync.RWMutex
results map[string]interface{} // nodeID -> output
errors map[string]error // nodeID -> error
statuses map[string]string // nodeID -> "pending"|"running"|"done"|"failed"
}
func newWorkflowState() *WorkflowState {
return &WorkflowState{
results: make(map[string]interface{}),
errors: make(map[string]error),
statuses: make(map[string]string),
}
}
func (ws *WorkflowState) setResult(nodeID string, result interface{}) {
ws.mu.Lock()
defer ws.mu.Unlock()
ws.results[nodeID] = result
ws.statuses[nodeID] = "done"
}
func (ws *WorkflowState) setError(nodeID string, err error) {
ws.mu.Lock()
defer ws.mu.Unlock()
ws.errors[nodeID] = err
ws.statuses[nodeID] = "failed"
}
func (ws *WorkflowState) getResult(nodeID string) (interface{}, bool) {
ws.mu.RLock()
defer ws.mu.RUnlock()
v, ok := ws.results[nodeID]
return v, ok
}
// collectInputs 收集某节点所有上游的输出
func (ws *WorkflowState) collectInputs(dependsOn []string) map[string]interface{} {
ws.mu.RLock()
defer ws.mu.RUnlock()
input := make(map[string]interface{}, len(dependsOn))
for _, dep := range dependsOn {
if v, ok := ws.results[dep]; ok {
input[dep] = v
}
}
return input
}
// DAG 有向无环图工作流引擎
type DAG struct {
nodes map[string]*Node // 节点映射
edges map[string][]string // 邻接表: nodeID -> 下游节点列表
mu sync.RWMutex
}
// NewDAG 创建空DAG
func NewDAG() *DAG {
return &DAG{
nodes: make(map[string]*Node),
edges: make(map[string][]string),
}
}
// AddNode 添加节点(会做ID唯一性检查)
func (d *DAG) AddNode(n *Node) error {
d.mu.Lock()
defer d.mu.Unlock()
if _, exists := d.nodes[n.ID]; exists {
return fmt.Errorf("节点 %s 已存在", n.ID)
}
d.nodes[n.ID] = n
return nil
}
// Build 构建邻接表(所有节点添加完毕后调用)
func (d *DAG) Build() error {
d.mu.Lock()
defer d.mu.Unlock()
d.edges = make(map[string][]string)
nodeSet := make(map[string]bool)
for id := range d.nodes {
nodeSet[id] = true
}
for id, node := range d.nodes {
for _, dep := range node.DependsOn {
if !nodeSet[dep] {
return fmt.Errorf("节点 %s 依赖的 %s 不存在", id, dep)
}
d.edges[dep] = append(d.edges[dep], id)
}
}
return nil
}
三、拓扑排序与层级并行分组
DAG并行执行的核心思路:同一层级的节点可以并发,层级之间必须串行。
用Kahn算法做拓扑排序的同时完成层级分组:
// topoLevels 拓扑排序并按层级分组,同时检测环
// 返回:按层级分组的节点ID列表
func (d *DAG) topoLevels() ([][]string, error) {
// 计算入度
indegree := make(map[string]int)
for id := range d.nodes {
indegree[id] = len(d.nodes[id].DependsOn)
}
// 第一层:入度为0的节点
var currentLevel []string
for id, deg := range indegree {
if deg == 0 {
currentLevel = append(currentLevel, id)
}
}
var levels [][]string
visited := 0
total := len(d.nodes)
for len(currentLevel) > 0 {
levels = append(levels, currentLevel)
var nextLevel []string
// 对本层每个节点,将其下游节点的入度减1
for _, nodeID := range currentLevel {
visited++
for _, downstream := range d.edges[nodeID] {
indegree[downstream]--
if indegree[downstream] == 0 {
nextLevel = append(nextLevel, downstream)
}
}
}
currentLevel = nextLevel
}
// 访问数 < 总节点数 = 存在环
if visited < total {
return nil, fmt.Errorf("DAG中存在循环依赖,已访问 %d/%d 个节点", visited, total)
}
return levels, nil
}
算法复杂度:O(V+E),其中V为节点数,E为边数。
四、执行引擎
// Execute 按拓扑层级并行执行所有节点
func (d *DAG) Execute(ctx context.Context) (*WorkflowState, error) {
// 1. 拓扑排序+层级分组
levels, err := d.topoLevels()
if err != nil {
return nil, fmt.Errorf("拓扑排序失败: %w", err)
}
state := newWorkflowState()
// 2. 逐层执行
for levelIdx, level := range levels {
select {
case <-ctx.Done():
return state, ctx.Err()
default:
}
var wg sync.WaitGroup
errCh := make(chan error, len(level))
// 同一层级内所有节点并发执行
for _, nodeID := range level {
wg.Add(1)
go func(nid string) {
defer wg.Done()
if err := d.executeNode(ctx, nid, state); err != nil {
errCh <- fmt.Errorf("[层级%d] 节点 %s 执行失败: %w", levelIdx, nid, err)
}
}(nodeID)
}
wg.Wait()
close(errCh)
// 收集本层错误
var levelErrs []error
for err := range errCh {
levelErrs = append(levelErrs, err)
}
if len(levelErrs) > 0 {
return state, fmt.Errorf("层级 %d 执行异常: %v", levelIdx, levelErrs)
}
}
return state, nil
}
// executeNode 执行单个节点(含重试和超时)
func (d *DAG) executeNode(ctx context.Context, nodeID string, state *WorkflowState) error {
node := d.nodes[nodeID]
// 收集上游输入
input := state.collectInputs(node.DependsOn)
// 带重试的执行
var lastErr error
maxAttempts := node.Retries + 1
for attempt := 0; attempt < maxAttempts; attempt++ {
if attempt > 0 {
// 指数退避:重试间隔 = 100ms * 2^(attempt-1)
backoff := time.Duration(100*(1<<(attempt-1))) * time.Millisecond
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff):
}
}
// 超时控制
execCtx := ctx
var cancel context.CancelFunc
if node.Timeout > 0 {
execCtx, cancel = context.WithTimeout(ctx, node.Timeout)
}
result, err := node.Handler(execCtx, input)
if cancel != nil {
cancel()
}
if err == nil {
state.setResult(nodeID, result)
return nil
}
lastErr = err
}
state.setError(nodeID, lastErr)
return fmt.Errorf("节点 %s 重试 %d 次后仍失败: %w", nodeID, node.Retries, lastErr)
}
// Result 获取最终执行结果(所有叶子节点的输出)
func (d *DAG) Result(state *WorkflowState) map[string]interface{} {
// 叶子节点 = 不在任何节点的dependsOn中出现
hasDownstream := make(map[string]bool)
for _, node := range d.nodes {
for _, dep := range node.DependsOn {
hasDownstream[dep] = true
}
}
result := make(map[string]interface{})
state.mu.RLock()
defer state.mu.RUnlock()
for id, res := range state.results {
if !hasDownstream[id] {
result[id] = res
}
}
return result
}
五、实战:AI Agent意图识别流水线
用一个完整示例展示引擎能力——构建一个AI Agent的意图识别→工具调用→结果合成流水线:
package main
import (
"context"
"fmt"
"strings"
"time"
"github.com/yourorg/workflow/dag"
)
func main() {
d := dag.NewDAG()
// 节点1:意图识别(入口节点)
d.AddNode(&dag.Node{
ID: "intent_classifier",
Handler: intentClassifier,
Retries: 2,
Timeout: 5 * time.Second,
})
// 节点2/3:并行搜索(依赖意图识别结果)
d.AddNode(&dag.Node{
ID: "web_search",
DependsOn: []string{"intent_classifier"},
Handler: webSearch,
Retries: 1,
Timeout: 10 * time.Second,
})
d.AddNode(&dag.Node{
ID: "db_query",
DependsOn: []string{"intent_classifier"},
Handler: databaseQuery,
Retries: 2,
Timeout: 8 * time.Second,
})
// 节点4:结果筛选(依赖搜索和查询结果)
d.AddNode(&dag.Node{
ID: "result_filter",
DependsOn: []string{"web_search", "db_query"},
Handler: resultFilter,
})
// 节点5:LLM合成最终答案(依赖筛选结果)
d.AddNode(&dag.Node{
ID: "llm_synthesize",
DependsOn: []string{"result_filter"},
Handler: llmSynthesize,
Timeout: 15 * time.Second,
})
// 构建拓扑
if err := d.Build(); err != nil {
panic(err)
}
// 执行
ctx := context.Background()
state, err := d.Execute(ctx)
if err != nil {
fmt.Printf("工作流执行失败: %v\n", err)
return
}
fmt.Println("=== 执行结果 ===")
for id, res := range d.Result(state) {
fmt.Printf("[%s]: %v\n", id, res)
}
}
// --- 节点处理函数 ---
func intentClassifier(ctx context.Context, input map[string]interface{}) (interface{}, error) {
// 模拟LLM意图分类
query, _ := input["origin"].(string)
time.Sleep(100 * time.Millisecond) // 模拟网络延迟
switch {
case strings.Contains(query, "价格") || strings.Contains(query, "多少钱"):
return map[string]string{"intent": "price_query", "entity": query}, nil
case strings.Contains(query, "教程") || strings.Contains(query, "怎么"):
return map[string]string{"intent": "howto", "entity": query}, nil
default:
return map[string]string{"intent": "general", "entity": query}, nil
}
}
func webSearch(ctx context.Context, input map[string]interface{}) (interface{}, error) {
intentInfo := input["intent_classifier"].(map[string]string)
time.Sleep(200 * time.Millisecond)
return []string{
fmt.Sprintf("搜索结果1: %s 相关文档", intentInfo["entity"]),
"搜索结果2: 社区讨论帖",
"搜索结果3: 官方文档",
}, nil
}
func databaseQuery(ctx context.Context, input map[string]interface{}) (interface{}, error) {
intentInfo := input["intent_classifier"].(map[string]string)
time.Sleep(150 * time.Millisecond)
return map[string]interface{}{
"count": 42,
"items": []string{"记录A", "记录B", "记录C"},
}, nil
}
func resultFilter(ctx context.Context, input map[string]interface{}) (interface{}, error) {
webResults := input["web_search"].([]string)
dbResults := input["db_query"].(map[string]interface{})
// 合并去重 + 相关性排序
merged := append(webResults,
fmt.Sprintf("数据库命中 %d 条", dbResults["count"]))
return merged, nil
}
func llmSynthesize(ctx context.Context, input map[string]interface{}) (interface{}, error) {
filtered := input["result_filter"].([]string)
time.Sleep(300 * time.Millisecond)
return fmt.Sprintf("基于 %d 条信息来源的综合答案...", len(filtered)), nil
}
运行输出:
=== 执行结果 ===
[llm_synthesize]: 基于 4 条信息来源的综合答案...
执行时间线:
0ms 50ms 100ms 150ms 200ms 250ms 300ms 350ms
| | | | | | | |
intent_classifier --+
web_search --------------------+
db_query ------------+
result_filter --+
llm_synthesize ------+
两个搜索节点并行执行,总耗时约为串行的一半。
六、工程化增强
6.1 上下文传播
在实际AI Agent场景中,需要在整个工作流中传播traceID、用户会话等元数据:
type contextKey string
const (
TraceIDKey contextKey = "trace_id"
SessionKey contextKey = "session_id"
)
// 在执行入口注入
ctx = context.WithValue(ctx, TraceIDKey, uuid.New().String())
6.2 熔断器
// 熔断器(简化版)
type CircuitBreaker struct {
failureCount int
threshold int
mu sync.Mutex
}
func (cb *CircuitBreaker) Call(fn func() error) error {
cb.mu.Lock()
if cb.failureCount >= cb.threshold {
cb.mu.Unlock()
return fmt.Errorf("熔断器已打开,拒绝请求")
}
cb.mu.Unlock()
err := fn()
if err != nil {
cb.mu.Lock()
cb.failureCount++
cb.mu.Unlock()
}
return err
}
6.3 执行追踪
type ExecutionTrace struct {
NodeID string `json:"node_id"`
StartTime time.Time `json:"start_time"`
Duration time.Duration `json:"duration"`
Attempt int `json:"attempt"`
Error string `json:"error,omitempty"`
}
七、总结
| 特性 | 实现方式 | 工程价值 |
|---|---|---|
| 依赖管理 | Kahn拓扑排序 | 声明式任务编排,消除硬编码 |
| 并行执行 | 层级分组+goroutine | 充分利用多核,减少端到端延迟 |
| 超时控制 | context.WithTimeout | 防止单个节点卡死整个流程 |
| 重试机制 | 指数退避 | 处理瞬时故障 |
| 环检测 | 拓扑排序visited计数 | 启动时即发现配置错误 |
更多推荐
所有评论(0)