一、为什么需要DAG工作流引擎

AI Agent的核心能力不是单次LLM调用,而是多步骤任务编排。一个典型的Agent请求链路:

用户意图识别 → 工具选择 → 并行调用(搜索+数据库查询) → 结果聚合 → LLM合成 → 输出

这个链条有明确的依赖关系:聚合必须等并行调用全部完成,合成依赖聚合结果。这正是有向无环图(DAG) 的天然建模场景。

问题 朴素方案 DAG方案
依赖管理 if-else硬编码 拓扑排序自动解析
并行执行 手动goroutine 层级分组自动并行
错误传播 调用链断裂 节点级重试+降级
可观测性 无结构化日志 节点粒度trace

下面用Go从零实现一个生产可用的DAG工作流引擎。


二、核心数据结构

package dag

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// NodeFunc 节点执行函数
// input 包含该节点所有上游节点的输出,key为节点ID
// 返回的输出会被下游节点消费
type NodeFunc func(ctx context.Context, input map[string]interface{}) (interface{}, error)

// Node 工作流中的一个任务节点
type Node struct {
    ID        string            // 唯一标识
    DependsOn []string          // 依赖的上游节点ID列表
    Handler   NodeFunc          // 执行逻辑
    Retries   int               // 失败重试次数,0表示不重试
    Timeout   time.Duration     // 单节点超时,0表示不限制
}

// WorkflowState 工作流执行过程中的共享状态
type WorkflowState struct {
    mu       sync.RWMutex
    results  map[string]interface{} // nodeID -> output
    errors   map[string]error       // nodeID -> error
    statuses map[string]string      // nodeID -> "pending"|"running"|"done"|"failed"
}

func newWorkflowState() *WorkflowState {
    return &WorkflowState{
        results:  make(map[string]interface{}),
        errors:   make(map[string]error),
        statuses: make(map[string]string),
    }
}

func (ws *WorkflowState) setResult(nodeID string, result interface{}) {
    ws.mu.Lock()
    defer ws.mu.Unlock()
    ws.results[nodeID] = result
    ws.statuses[nodeID] = "done"
}

func (ws *WorkflowState) setError(nodeID string, err error) {
    ws.mu.Lock()
    defer ws.mu.Unlock()
    ws.errors[nodeID] = err
    ws.statuses[nodeID] = "failed"
}

func (ws *WorkflowState) getResult(nodeID string) (interface{}, bool) {
    ws.mu.RLock()
    defer ws.mu.RUnlock()
    v, ok := ws.results[nodeID]
    return v, ok
}

// collectInputs 收集某节点所有上游的输出
func (ws *WorkflowState) collectInputs(dependsOn []string) map[string]interface{} {
    ws.mu.RLock()
    defer ws.mu.RUnlock()
    input := make(map[string]interface{}, len(dependsOn))
    for _, dep := range dependsOn {
        if v, ok := ws.results[dep]; ok {
            input[dep] = v
        }
    }
    return input
}

// DAG 有向无环图工作流引擎
type DAG struct {
    nodes map[string]*Node     // 节点映射
    edges map[string][]string  // 邻接表: nodeID -> 下游节点列表
    mu    sync.RWMutex
}

// NewDAG 创建空DAG
func NewDAG() *DAG {
    return &DAG{
        nodes: make(map[string]*Node),
        edges: make(map[string][]string),
    }
}

// AddNode 添加节点(会做ID唯一性检查)
func (d *DAG) AddNode(n *Node) error {
    d.mu.Lock()
    defer d.mu.Unlock()
    if _, exists := d.nodes[n.ID]; exists {
        return fmt.Errorf("节点 %s 已存在", n.ID)
    }
    d.nodes[n.ID] = n
    return nil
}

// Build 构建邻接表(所有节点添加完毕后调用)
func (d *DAG) Build() error {
    d.mu.Lock()
    defer d.mu.Unlock()
    d.edges = make(map[string][]string)
    nodeSet := make(map[string]bool)
    for id := range d.nodes {
        nodeSet[id] = true
    }
    for id, node := range d.nodes {
        for _, dep := range node.DependsOn {
            if !nodeSet[dep] {
                return fmt.Errorf("节点 %s 依赖的 %s 不存在", id, dep)
            }
            d.edges[dep] = append(d.edges[dep], id)
        }
    }
    return nil
}

三、拓扑排序与层级并行分组

DAG并行执行的核心思路:同一层级的节点可以并发,层级之间必须串行

用Kahn算法做拓扑排序的同时完成层级分组:

// topoLevels 拓扑排序并按层级分组,同时检测环
// 返回:按层级分组的节点ID列表
func (d *DAG) topoLevels() ([][]string, error) {
    // 计算入度
    indegree := make(map[string]int)
    for id := range d.nodes {
        indegree[id] = len(d.nodes[id].DependsOn)
    }

    // 第一层:入度为0的节点
    var currentLevel []string
    for id, deg := range indegree {
        if deg == 0 {
            currentLevel = append(currentLevel, id)
        }
    }

    var levels [][]string
    visited := 0
    total := len(d.nodes)

    for len(currentLevel) > 0 {
        levels = append(levels, currentLevel)
        var nextLevel []string

        // 对本层每个节点,将其下游节点的入度减1
        for _, nodeID := range currentLevel {
            visited++
            for _, downstream := range d.edges[nodeID] {
                indegree[downstream]--
                if indegree[downstream] == 0 {
                    nextLevel = append(nextLevel, downstream)
                }
            }
        }
        currentLevel = nextLevel
    }

    // 访问数 < 总节点数 = 存在环
    if visited < total {
        return nil, fmt.Errorf("DAG中存在循环依赖,已访问 %d/%d 个节点", visited, total)
    }

    return levels, nil
}

算法复杂度:O(V+E),其中V为节点数,E为边数。


四、执行引擎

// Execute 按拓扑层级并行执行所有节点
func (d *DAG) Execute(ctx context.Context) (*WorkflowState, error) {
    // 1. 拓扑排序+层级分组
    levels, err := d.topoLevels()
    if err != nil {
        return nil, fmt.Errorf("拓扑排序失败: %w", err)
    }

    state := newWorkflowState()

    // 2. 逐层执行
    for levelIdx, level := range levels {
        select {
        case <-ctx.Done():
            return state, ctx.Err()
        default:
        }

        var wg sync.WaitGroup
        errCh := make(chan error, len(level))

        // 同一层级内所有节点并发执行
        for _, nodeID := range level {
            wg.Add(1)
            go func(nid string) {
                defer wg.Done()
                if err := d.executeNode(ctx, nid, state); err != nil {
                    errCh <- fmt.Errorf("[层级%d] 节点 %s 执行失败: %w", levelIdx, nid, err)
                }
            }(nodeID)
        }

        wg.Wait()
        close(errCh)

        // 收集本层错误
        var levelErrs []error
        for err := range errCh {
            levelErrs = append(levelErrs, err)
        }
        if len(levelErrs) > 0 {
            return state, fmt.Errorf("层级 %d 执行异常: %v", levelIdx, levelErrs)
        }
    }

    return state, nil
}

// executeNode 执行单个节点(含重试和超时)
func (d *DAG) executeNode(ctx context.Context, nodeID string, state *WorkflowState) error {
    node := d.nodes[nodeID]

    // 收集上游输入
    input := state.collectInputs(node.DependsOn)

    // 带重试的执行
    var lastErr error
    maxAttempts := node.Retries + 1
    for attempt := 0; attempt < maxAttempts; attempt++ {
        if attempt > 0 {
            // 指数退避:重试间隔 = 100ms * 2^(attempt-1)
            backoff := time.Duration(100*(1<<(attempt-1))) * time.Millisecond
            select {
            case <-ctx.Done():
                return ctx.Err()
            case <-time.After(backoff):
            }
        }

        // 超时控制
        execCtx := ctx
        var cancel context.CancelFunc
        if node.Timeout > 0 {
            execCtx, cancel = context.WithTimeout(ctx, node.Timeout)
        }

        result, err := node.Handler(execCtx, input)
        if cancel != nil {
            cancel()
        }

        if err == nil {
            state.setResult(nodeID, result)
            return nil
        }

        lastErr = err
    }

    state.setError(nodeID, lastErr)
    return fmt.Errorf("节点 %s 重试 %d 次后仍失败: %w", nodeID, node.Retries, lastErr)
}

// Result 获取最终执行结果(所有叶子节点的输出)
func (d *DAG) Result(state *WorkflowState) map[string]interface{} {
    // 叶子节点 = 不在任何节点的dependsOn中出现
    hasDownstream := make(map[string]bool)
    for _, node := range d.nodes {
        for _, dep := range node.DependsOn {
            hasDownstream[dep] = true
        }
    }

    result := make(map[string]interface{})
    state.mu.RLock()
    defer state.mu.RUnlock()
    for id, res := range state.results {
        if !hasDownstream[id] {
            result[id] = res
        }
    }
    return result
}

五、实战:AI Agent意图识别流水线

用一个完整示例展示引擎能力——构建一个AI Agent的意图识别→工具调用→结果合成流水线:

package main

import (
    "context"
    "fmt"
    "strings"
    "time"

    "github.com/yourorg/workflow/dag"
)

func main() {
    d := dag.NewDAG()

    // 节点1:意图识别(入口节点)
    d.AddNode(&dag.Node{
        ID:      "intent_classifier",
        Handler: intentClassifier,
        Retries: 2,
        Timeout: 5 * time.Second,
    })

    // 节点2/3:并行搜索(依赖意图识别结果)
    d.AddNode(&dag.Node{
        ID:        "web_search",
        DependsOn: []string{"intent_classifier"},
        Handler:   webSearch,
        Retries:   1,
        Timeout:   10 * time.Second,
    })
    d.AddNode(&dag.Node{
        ID:        "db_query",
        DependsOn: []string{"intent_classifier"},
        Handler:   databaseQuery,
        Retries:   2,
        Timeout:   8 * time.Second,
    })

    // 节点4:结果筛选(依赖搜索和查询结果)
    d.AddNode(&dag.Node{
        ID:        "result_filter",
        DependsOn: []string{"web_search", "db_query"},
        Handler:   resultFilter,
    })

    // 节点5:LLM合成最终答案(依赖筛选结果)
    d.AddNode(&dag.Node{
        ID:        "llm_synthesize",
        DependsOn: []string{"result_filter"},
        Handler:   llmSynthesize,
        Timeout:   15 * time.Second,
    })

    // 构建拓扑
    if err := d.Build(); err != nil {
        panic(err)
    }

    // 执行
    ctx := context.Background()
    state, err := d.Execute(ctx)
    if err != nil {
        fmt.Printf("工作流执行失败: %v\n", err)
        return
    }

    fmt.Println("=== 执行结果 ===")
    for id, res := range d.Result(state) {
        fmt.Printf("[%s]: %v\n", id, res)
    }
}

// --- 节点处理函数 ---

func intentClassifier(ctx context.Context, input map[string]interface{}) (interface{}, error) {
    // 模拟LLM意图分类
    query, _ := input["origin"].(string)
    time.Sleep(100 * time.Millisecond) // 模拟网络延迟

    switch {
    case strings.Contains(query, "价格") || strings.Contains(query, "多少钱"):
        return map[string]string{"intent": "price_query", "entity": query}, nil
    case strings.Contains(query, "教程") || strings.Contains(query, "怎么"):
        return map[string]string{"intent": "howto", "entity": query}, nil
    default:
        return map[string]string{"intent": "general", "entity": query}, nil
    }
}

func webSearch(ctx context.Context, input map[string]interface{}) (interface{}, error) {
    intentInfo := input["intent_classifier"].(map[string]string)
    time.Sleep(200 * time.Millisecond)
    return []string{
        fmt.Sprintf("搜索结果1: %s 相关文档", intentInfo["entity"]),
        "搜索结果2: 社区讨论帖",
        "搜索结果3: 官方文档",
    }, nil
}

func databaseQuery(ctx context.Context, input map[string]interface{}) (interface{}, error) {
    intentInfo := input["intent_classifier"].(map[string]string)
    time.Sleep(150 * time.Millisecond)
    return map[string]interface{}{
        "count": 42,
        "items": []string{"记录A", "记录B", "记录C"},
    }, nil
}

func resultFilter(ctx context.Context, input map[string]interface{}) (interface{}, error) {
    webResults := input["web_search"].([]string)
    dbResults := input["db_query"].(map[string]interface{})
    // 合并去重 + 相关性排序
    merged := append(webResults,
        fmt.Sprintf("数据库命中 %d 条", dbResults["count"]))
    return merged, nil
}

func llmSynthesize(ctx context.Context, input map[string]interface{}) (interface{}, error) {
    filtered := input["result_filter"].([]string)
    time.Sleep(300 * time.Millisecond)
    return fmt.Sprintf("基于 %d 条信息来源的综合答案...", len(filtered)), nil
}

运行输出:

=== 执行结果 ===
[llm_synthesize]: 基于 4 条信息来源的综合答案...

执行时间线

0ms       50ms      100ms     150ms     200ms     250ms     300ms     350ms
|         |         |         |         |         |         |         |
intent_classifier --+
                    web_search --------------------+
                    db_query ------------+
                                         result_filter --+
                                                         llm_synthesize ------+

两个搜索节点并行执行,总耗时约为串行的一半。


六、工程化增强

6.1 上下文传播

在实际AI Agent场景中,需要在整个工作流中传播traceID、用户会话等元数据:

type contextKey string

const (
    TraceIDKey contextKey = "trace_id"
    SessionKey contextKey = "session_id"
)

// 在执行入口注入
ctx = context.WithValue(ctx, TraceIDKey, uuid.New().String())

6.2 熔断器

// 熔断器(简化版)
type CircuitBreaker struct {
    failureCount int
    threshold    int
    mu           sync.Mutex
}

func (cb *CircuitBreaker) Call(fn func() error) error {
    cb.mu.Lock()
    if cb.failureCount >= cb.threshold {
        cb.mu.Unlock()
        return fmt.Errorf("熔断器已打开,拒绝请求")
    }
    cb.mu.Unlock()

    err := fn()
    if err != nil {
        cb.mu.Lock()
        cb.failureCount++
        cb.mu.Unlock()
    }
    return err
}

6.3 执行追踪

type ExecutionTrace struct {
    NodeID    string        `json:"node_id"`
    StartTime time.Time     `json:"start_time"`
    Duration  time.Duration `json:"duration"`
    Attempt   int           `json:"attempt"`
    Error     string        `json:"error,omitempty"`
}

七、总结

特性 实现方式 工程价值
依赖管理 Kahn拓扑排序 声明式任务编排,消除硬编码
并行执行 层级分组+goroutine 充分利用多核,减少端到端延迟
超时控制 context.WithTimeout 防止单个节点卡死整个流程
重试机制 指数退避 处理瞬时故障
环检测 拓扑排序visited计数 启动时即发现配置错误
Logo

欢迎加入 MCP 技术社区!与志同道合者携手前行,一同解锁 MCP 技术的无限可能!

更多推荐