"""
rag_graph.py — 完整文件(run_rag_stream 使用 checkpoint_blobs 读取历史)
"""

import re
import json
import pickle
import logging
from typing import TypedDict, List, Optional, AsyncGenerator, Sequence

import psycopg
from langchain_core.messages import (
    HumanMessage,
    AIMessage,
    SystemMessage,
    BaseMessage,
)
from langchain_ollama import ChatOllama
from langgraph.graph import StateGraph, END, MessagesState
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver

from app.config import settings
from app.services.embedding_service import get_vector_store

logger = logging.getLogger(__name__)


# ════════════════════════════════════════════════════════════
#  State
# ════════════════════════════════════════════════════════════

class RAGState(MessagesState):
    question: str
    collection_name: str
    top_k: int
    context: str
    sources: List[dict]
    answer: str
    thinking: Optional[str]


# ════════════════════════════════════════════════════════════
#  节点
# ════════════════════════════════════════════════════════════

async def retrieve_node(state: RAGState) -> dict:
    question = state["question"]
    logger.info(f"🔎 检索: {question[:50]}...")

    vector_store = get_vector_store(state["collection_name"])
    docs = await vector_store.asimilarity_search(question, k=state["top_k"])

    sources, context_parts = [], []
    for i, doc in enumerate(docs, 1):
        source = doc.metadata.get("source", "未知来源")
        page = doc.metadata.get("page", "")
        page_info = f" (第{page}页)" if page else ""
        context_parts.append(
            f"【参考片段 {i}】来源: {source}{page_info}\n{doc.page_content}"
        )
        sources.append({
            "source": source,
            "page": page,
            "content_preview": (
                doc.page_content[:150] + "..."
                if len(doc.page_content) > 150
                else doc.page_content
            ),
        })

    context = "\n\n".join(context_parts) if context_parts else "未找到相关文档内容。"
    logger.info(f"📚 检索到 {len(docs)} 个片段")
    return {"context": context, "sources": sources}


async def generate_node(state: RAGState) -> dict:
    llm = ChatOllama(
        base_url=settings.OLLAMA_BASE_URL,
        model=settings.OLLAMA_MODEL,
        temperature=0.1,
    )

    system_msg = SystemMessage(content=f"""你是一个专业的知识库助手,拥有完整的多轮对话记忆。
职责:基于知识库检索内容回答,结合历史上下文理解追问,没有相关信息则明确说明,使用中文。

当前检索到的参考资料:
{state['context']}""")

    history: Sequence[BaseMessage] = state.get("messages", [])
    current_user_msg = HumanMessage(content=state["question"])
    all_messages = [system_msg] + list(history) + [current_user_msg]

    logger.info(f"🤖 生成回答,历史消息数: {len(history)}")
    response = await llm.ainvoke(all_messages)

    full_response = response.content
    thinking, answer = None, full_response
    think_match = re.search(r"<think>(.*?)</think>", full_response, re.DOTALL)
    if think_match:
        thinking = think_match.group(1).strip()
        answer = full_response[think_match.end():].strip()

    new_messages = list(history) + [current_user_msg, AIMessage(content=answer)]
    return {"messages": new_messages, "answer": answer, "thinking": thinking}


# ════════════════════════════════════════════════════════════
#  Graph 构建
# ════════════════════════════════════════════════════════════

def _sync_conn_str() -> str:
    return settings.PGVECTOR_CONNECTION_STRING.replace(
        "postgresql+psycopg://", "postgresql://"
    )


async def build_rag_graph_with_saver():
    conn = await psycopg.AsyncConnection.connect(_sync_conn_str())
    saver = AsyncPostgresSaver(conn)
    await saver.setup()

    workflow = StateGraph(RAGState)
    workflow.add_node("retrieve", retrieve_node)
    workflow.add_node("generate", generate_node)
    workflow.set_entry_point("retrieve")
    workflow.add_edge("retrieve", "generate")
    workflow.add_edge("generate", END)

    graph = workflow.compile(checkpointer=saver)
    logger.info("✅ RAG Graph (with PostgresSaver) 构建完成")
    return graph


# ════════════════════════════════════════════════════════════
#  从 checkpoint_blobs 反序列化历史消息
# ════════════════════════════════════════════════════════════

def _deserialize_blob(blob_data: bytes, blob_type: str) -> Optional[List[BaseMessage]]:
    """
    根据 LangGraph 写入 checkpoint_blobs 的序列化格式进行反序列化。
    LangGraph 默认使用 msgpack + pickle fallback 序列化消息列表。
    支持三种格式:
      - msgpack  (LangGraph ≥0.2 默认)
      - json     (部分版本)
      - pickle   (fallback)
    """
    if blob_data is None:
        return None

    # ── 尝试 msgpack ─────────────────────────────────────
    if blob_type in ("msgpack", "bytes"):
        try:
            import msgpack
            from langchain_core.messages import messages_from_dict

            raw = msgpack.unpackb(blob_data, raw=False, strict_map_key=False)

            # LangGraph 将 messages 序列化为 list[dict]
            if isinstance(raw, list):
                try:
                    return messages_from_dict(raw)
                except Exception:
                    pass

            # 也可能嵌套在字典结构中
            if isinstance(raw, dict):
                for key in ("messages", "value", "data"):
                    v = raw.get(key)
                    if isinstance(v, list):
                        try:
                            return messages_from_dict(v)
                        except Exception:
                            pass
        except Exception as e:
            logger.debug(f"msgpack 反序列化失败: {e}")

    # ── 尝试 JSON ─────────────────────────────────────────
    if blob_type in ("json", "text"):
        try:
            from langchain_core.messages import messages_from_dict

            raw = json.loads(blob_data.decode("utf-8") if isinstance(blob_data, bytes) else blob_data)
            if isinstance(raw, list):
                return messages_from_dict(raw)
            if isinstance(raw, dict):
                for key in ("messages", "value", "data"):
                    v = raw.get(key)
                    if isinstance(v, list):
                        try:
                            return messages_from_dict(v)
                        except Exception:
                            pass
        except Exception as e:
            logger.debug(f"JSON 反序列化失败: {e}")

    # ── 兜底:pickle ──────────────────────────────────────
    try:
        obj = pickle.loads(blob_data)
        if isinstance(obj, list) and all(isinstance(m, BaseMessage) for m in obj):
            return obj
    except Exception as e:
        logger.debug(f"pickle 反序列化失败: {e}")

    return None


async def _load_history_from_checkpoint_blobs(
        session_id: str,
) -> List[BaseMessage]:
    """
    直接查询 checkpoint_blobs 表,取最新 checkpoint 中
    channel='messages' 的 blob,反序列化后返回历史消息列表。

    checkpoint_blobs 表结构(LangGraph 标准):
      thread_id  | checkpoint_ns | channel  | version | type | blob
      -----------|---------------|----------|---------|------|------
      <uuid>     | ''            | messages | 1       | ...  | <bytes>
    """
    conn_str = _sync_conn_str()

    try:
        conn = await psycopg.AsyncConnection.connect(conn_str)
        async with conn:
            # 取该 session 最新版本的 messages channel blob
            row = await conn.execute(
                """
                SELECT cb.type, cb.blob
                FROM checkpoint_blobs cb
                         JOIN checkpoints c
                              ON c.thread_id = cb.thread_id
                                  AND c.checkpoint_ns = cb.checkpoint_ns
                -- 取最新 checkpoint 版本中记录的 messages channel 版本号
                WHERE cb.thread_id = %s
                  AND cb.channel = 'messages'
                ORDER BY cb.version::int DESC   -- 取最新版本
                LIMIT  1
                """,
                (session_id,),
            )
            record = await row.fetchone()

        if not record:
            logger.info(f"📭 会话 {session_id[:8]}... 暂无 checkpoint_blobs 记录")
            return []

        blob_type: str = record[0]  # 'msgpack' | 'json' | 'bytes' 等
        blob_data: bytes = record[1]

        logger.info(
            f"📦 读取 checkpoint_blobs: session={session_id[:8]}..., "
            f"type={blob_type}, size={len(blob_data) if blob_data else 0} bytes"
        )

        messages = _deserialize_blob(blob_data, blob_type)
        if messages is None:
            logger.warning("⚠️  反序列化失败,以空历史继续")
            return []

        logger.info(f"✅ 反序列化成功,历史消息数: {len(messages)}")
        return messages

    except Exception as e:
        logger.error(f"❌ 从 checkpoint_blobs 读取历史失败: {e}", exc_info=True)
        return []


# ════════════════════════════════════════════════════════════
#  对外接口
# ════════════════════════════════════════════════════════════

async def run_rag(
        question: str,
        session_id: str,
        collection_name: str = "default",
        top_k: int = 5,
) -> dict:
    """非流式 RAG(含多轮记忆)"""
    graph = await build_rag_graph_with_saver()
    config = {"configurable": {"thread_id": session_id}}

    initial_state = RAGState(
        question=question,
        collection_name=collection_name,
        top_k=top_k,
        context="",
        sources=[],
        answer="",
        thinking=None,
        messages=[],
    )

    result = await graph.ainvoke(initial_state, config=config)
    return {
        "answer": result["answer"],
        "thinking": result["thinking"],
        "sources": result["sources"],
        "session_id": session_id,
    }


async def run_rag_stream(
        question: str,
        session_id: str,
        collection_name: str = "default",
        top_k: int = 5,
) -> AsyncGenerator[str, None]:
    """
    流式 RAG(SSE)- 多轮记忆版本。
    history 直接从 checkpoint_blobs 表读取并反序列化,
    无需走 saver.aget() 全量 checkpoint 加载。
    """
    # ── 1. 向量检索 ──────────────────────────────────────────
    vector_store = get_vector_store(collection_name)
    docs = await vector_store.asimilarity_search(question, k=top_k)

    sources, context_parts = [], []
    for i, doc in enumerate(docs, 1):
        source = doc.metadata.get("source", "未知来源")
        page = doc.metadata.get("page", "")
        page_info = f" (第{page}页)" if page else ""
        context_parts.append(
            f"【参考片段 {i}】来源: {source}{page_info}\n{doc.page_content}"
        )
        sources.append({
            "source": source,
            "page": page,
            "content_preview": (
                doc.page_content[:150] + "..."
                if len(doc.page_content) > 150
                else doc.page_content
            ),
        })

    context = "\n\n".join(context_parts) if context_parts else "未找到相关文档内容。"

    # ── 2. 从 checkpoint_blobs 读取历史(核心修改)──────────
    history: List[BaseMessage] = await _load_history_from_checkpoint_blobs(session_id)
    logger.info(f"📜 会话 {session_id[:8]}... 历史消息数: {len(history)}")

    # ── 3. 构建完整消息链 ────────────────────────────────────
    system_msg = SystemMessage(content=f"""你是一个专业的知识库助手,拥有完整的多轮对话记忆。
职责:基于知识库检索内容回答,结合历史上下文理解追问,没有相关信息则明确说明,使用中文。

当前检索到的参考资料:
{context}""")

    current_user_msg = HumanMessage(content=question)
    all_messages = [system_msg] + list(history) + [current_user_msg]

    # ── 4. 流式生成 ──────────────────────────────────────────
    llm = ChatOllama(
        base_url=settings.OLLAMA_BASE_URL,
        model=settings.OLLAMA_MODEL,
        temperature=0.1,
    )

    yield f"data: {json.dumps({'type': 'sources', 'data': sources}, ensure_ascii=False)}\n\n"

    full_answer = ""
    async for chunk in llm.astream(all_messages):
        if chunk.content:
            full_answer += chunk.content
            yield f"data: {json.dumps({'type': 'token', 'data': chunk.content}, ensure_ascii=False)}\n\n"

    # ── 5. 提取思考过程 ──────────────────────────────────────
    thinking = None
    clean_answer = full_answer
    think_match = re.search(r"<think>(.*?)</think>", full_answer, re.DOTALL)
    if think_match:
        thinking = think_match.group(1).strip()
        clean_answer = full_answer[think_match.end():].strip()

    # ── 6. 持久化本轮对话到 PostgresSaver ───────────────────
    #    沿用标准 saver.aput() 写入,确保 checkpoint_blobs 正确更新
    try:
        conn_str = _sync_conn_str()
        conn = await psycopg.AsyncConnection.connect(conn_str)
        saver = AsyncPostgresSaver(conn)
        await saver.setup()

        new_messages = list(history) + [
            current_user_msg,
            AIMessage(content=clean_answer),
        ]
        config = {"configurable": {"thread_id": session_id, "checkpoint_ns": ""}}

        # 读取现有 checkpoint 元数据(避免覆盖版本号)
        existing = await saver.aget(config)
        checkpoint_id = existing["id"] if existing else session_id

        new_messages_dict_list = []
        for new_message in new_messages:
            if new_message.type == "human":
                new_messages_dict_list.append({"type": "human", "data": {"content": new_message.content}})
            elif new_message.type == "ai":
                new_messages_dict_list.append({"type": "ai", "data": {"content": new_message.content}})

        new_checkpoint = {
            "v": 1,
            "ts": "",
            "id": checkpoint_id,
            "channel_values": {"messages": new_messages_dict_list},
            "channel_versions": existing.get("channel_versions", {}) if existing else {},
            "versions_seen": existing.get("versions_seen", {}) if existing else {},
            "pending_sends": [],
        }
        await saver.adelete_thread(session_id)
        await saver.aput(config, new_checkpoint, {}, {"messages": 1})
        await conn.commit()
        await conn.close()
        logger.info(
            f"💾 会话 {session_id[:8]}... 持久化完成,"
            f"共 {len(new_messages)} 条消息"
        )
    except Exception as e:
        logger.error(f"❌ 持久化失败(不影响本次回答): {e}", exc_info=True)

    yield f"data: {json.dumps({'type': 'done', 'thinking': thinking}, ensure_ascii=False)}\n\n"

from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, DeclarativeBase
from sqlalchemy import text
from app.config import settings
import logging
import psycopg

logger = logging.getLogger(__name__)


class Base(DeclarativeBase):
    pass


engine = create_async_engine(
    settings.PGVECTOR_CONNECTION_STRING,
    echo=False,
    pool_size=10,
    max_overflow=20,
)

AsyncSessionLocal = sessionmaker(
    engine, class_=AsyncSession, expire_on_commit=False
)


async def get_db():
    async with AsyncSessionLocal() as session:
        try:
            yield session
        finally:
            await session.close()


async def init_db():
    conn = await psycopg.AsyncConnection.connect(settings.DATABASE_URL, autocommit=True)
    saver = AsyncPostgresSaver(conn)
    await saver.setup()

    """初始化数据库:pgvector 扩展 + 业务表"""
    async with engine.begin() as conn:
        # 启用 pgvector 扩展
        await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))

        # 文档元数据表
        await conn.execute(text("""
            CREATE TABLE IF NOT EXISTS document_metadata (
                id          SERIAL PRIMARY KEY,
                filename    VARCHAR(255) NOT NULL,
                file_type   VARCHAR(50)  NOT NULL,
                file_size   BIGINT,
                chunk_count INTEGER DEFAULT 0,
                status      VARCHAR(50)  DEFAULT 'processing',
                created_at  TIMESTAMPTZ  DEFAULT NOW(),
                updated_at  TIMESTAMPTZ  DEFAULT NOW()
            )
        """))

        # 🆕 会话表
        await conn.execute(text("""
            CREATE TABLE IF NOT EXISTS chat_sessions (
                id              VARCHAR(64)  PRIMARY KEY,
                title           VARCHAR(255) NOT NULL DEFAULT '新对话',
                collection_name VARCHAR(100) NOT NULL DEFAULT 'default',
                message_count   INTEGER      DEFAULT 0,
                created_at      TIMESTAMPTZ  DEFAULT NOW(),
                updated_at      TIMESTAMPTZ  DEFAULT NOW()
            )
        """))

        # 🆕 消息记录表(用于前端展示,与 PostgresSaver checkpoint 分开)
        await conn.execute(text("""
            CREATE TABLE IF NOT EXISTS chat_messages (
                id          SERIAL       PRIMARY KEY,
                session_id  VARCHAR(64)  NOT NULL REFERENCES chat_sessions(id) ON DELETE CASCADE,
                role        VARCHAR(20)  NOT NULL,  -- 'user' | 'assistant'
                content     TEXT         NOT NULL,
                thinking    TEXT,
                sources     JSONB,
                created_at  TIMESTAMPTZ  DEFAULT NOW()
            )
        """))

        await conn.execute(text("""
            CREATE INDEX IF NOT EXISTS idx_chat_messages_session_id
            ON chat_messages(session_id)
        """))

        logger.info("✅ 数据库初始化完成")
Logo

欢迎加入 MCP 技术社区!与志同道合者携手前行,一同解锁 MCP 技术的无限可能!

更多推荐