LangGraph实现聊天机器人
·
"""
rag_graph.py — 完整文件(run_rag_stream 使用 checkpoint_blobs 读取历史)
"""
import re
import json
import pickle
import logging
from typing import TypedDict, List, Optional, AsyncGenerator, Sequence
import psycopg
from langchain_core.messages import (
HumanMessage,
AIMessage,
SystemMessage,
BaseMessage,
)
from langchain_ollama import ChatOllama
from langgraph.graph import StateGraph, END, MessagesState
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.config import settings
from app.services.embedding_service import get_vector_store
logger = logging.getLogger(__name__)
# ════════════════════════════════════════════════════════════
# State
# ════════════════════════════════════════════════════════════
class RAGState(MessagesState):
question: str
collection_name: str
top_k: int
context: str
sources: List[dict]
answer: str
thinking: Optional[str]
# ════════════════════════════════════════════════════════════
# 节点
# ════════════════════════════════════════════════════════════
async def retrieve_node(state: RAGState) -> dict:
question = state["question"]
logger.info(f"🔎 检索: {question[:50]}...")
vector_store = get_vector_store(state["collection_name"])
docs = await vector_store.asimilarity_search(question, k=state["top_k"])
sources, context_parts = [], []
for i, doc in enumerate(docs, 1):
source = doc.metadata.get("source", "未知来源")
page = doc.metadata.get("page", "")
page_info = f" (第{page}页)" if page else ""
context_parts.append(
f"【参考片段 {i}】来源: {source}{page_info}\n{doc.page_content}"
)
sources.append({
"source": source,
"page": page,
"content_preview": (
doc.page_content[:150] + "..."
if len(doc.page_content) > 150
else doc.page_content
),
})
context = "\n\n".join(context_parts) if context_parts else "未找到相关文档内容。"
logger.info(f"📚 检索到 {len(docs)} 个片段")
return {"context": context, "sources": sources}
async def generate_node(state: RAGState) -> dict:
llm = ChatOllama(
base_url=settings.OLLAMA_BASE_URL,
model=settings.OLLAMA_MODEL,
temperature=0.1,
)
system_msg = SystemMessage(content=f"""你是一个专业的知识库助手,拥有完整的多轮对话记忆。
职责:基于知识库检索内容回答,结合历史上下文理解追问,没有相关信息则明确说明,使用中文。
当前检索到的参考资料:
{state['context']}""")
history: Sequence[BaseMessage] = state.get("messages", [])
current_user_msg = HumanMessage(content=state["question"])
all_messages = [system_msg] + list(history) + [current_user_msg]
logger.info(f"🤖 生成回答,历史消息数: {len(history)}")
response = await llm.ainvoke(all_messages)
full_response = response.content
thinking, answer = None, full_response
think_match = re.search(r"<think>(.*?)</think>", full_response, re.DOTALL)
if think_match:
thinking = think_match.group(1).strip()
answer = full_response[think_match.end():].strip()
new_messages = list(history) + [current_user_msg, AIMessage(content=answer)]
return {"messages": new_messages, "answer": answer, "thinking": thinking}
# ════════════════════════════════════════════════════════════
# Graph 构建
# ════════════════════════════════════════════════════════════
def _sync_conn_str() -> str:
return settings.PGVECTOR_CONNECTION_STRING.replace(
"postgresql+psycopg://", "postgresql://"
)
async def build_rag_graph_with_saver():
conn = await psycopg.AsyncConnection.connect(_sync_conn_str())
saver = AsyncPostgresSaver(conn)
await saver.setup()
workflow = StateGraph(RAGState)
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("generate", generate_node)
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
graph = workflow.compile(checkpointer=saver)
logger.info("✅ RAG Graph (with PostgresSaver) 构建完成")
return graph
# ════════════════════════════════════════════════════════════
# 从 checkpoint_blobs 反序列化历史消息
# ════════════════════════════════════════════════════════════
def _deserialize_blob(blob_data: bytes, blob_type: str) -> Optional[List[BaseMessage]]:
"""
根据 LangGraph 写入 checkpoint_blobs 的序列化格式进行反序列化。
LangGraph 默认使用 msgpack + pickle fallback 序列化消息列表。
支持三种格式:
- msgpack (LangGraph ≥0.2 默认)
- json (部分版本)
- pickle (fallback)
"""
if blob_data is None:
return None
# ── 尝试 msgpack ─────────────────────────────────────
if blob_type in ("msgpack", "bytes"):
try:
import msgpack
from langchain_core.messages import messages_from_dict
raw = msgpack.unpackb(blob_data, raw=False, strict_map_key=False)
# LangGraph 将 messages 序列化为 list[dict]
if isinstance(raw, list):
try:
return messages_from_dict(raw)
except Exception:
pass
# 也可能嵌套在字典结构中
if isinstance(raw, dict):
for key in ("messages", "value", "data"):
v = raw.get(key)
if isinstance(v, list):
try:
return messages_from_dict(v)
except Exception:
pass
except Exception as e:
logger.debug(f"msgpack 反序列化失败: {e}")
# ── 尝试 JSON ─────────────────────────────────────────
if blob_type in ("json", "text"):
try:
from langchain_core.messages import messages_from_dict
raw = json.loads(blob_data.decode("utf-8") if isinstance(blob_data, bytes) else blob_data)
if isinstance(raw, list):
return messages_from_dict(raw)
if isinstance(raw, dict):
for key in ("messages", "value", "data"):
v = raw.get(key)
if isinstance(v, list):
try:
return messages_from_dict(v)
except Exception:
pass
except Exception as e:
logger.debug(f"JSON 反序列化失败: {e}")
# ── 兜底:pickle ──────────────────────────────────────
try:
obj = pickle.loads(blob_data)
if isinstance(obj, list) and all(isinstance(m, BaseMessage) for m in obj):
return obj
except Exception as e:
logger.debug(f"pickle 反序列化失败: {e}")
return None
async def _load_history_from_checkpoint_blobs(
session_id: str,
) -> List[BaseMessage]:
"""
直接查询 checkpoint_blobs 表,取最新 checkpoint 中
channel='messages' 的 blob,反序列化后返回历史消息列表。
checkpoint_blobs 表结构(LangGraph 标准):
thread_id | checkpoint_ns | channel | version | type | blob
-----------|---------------|----------|---------|------|------
<uuid> | '' | messages | 1 | ... | <bytes>
"""
conn_str = _sync_conn_str()
try:
conn = await psycopg.AsyncConnection.connect(conn_str)
async with conn:
# 取该 session 最新版本的 messages channel blob
row = await conn.execute(
"""
SELECT cb.type, cb.blob
FROM checkpoint_blobs cb
JOIN checkpoints c
ON c.thread_id = cb.thread_id
AND c.checkpoint_ns = cb.checkpoint_ns
-- 取最新 checkpoint 版本中记录的 messages channel 版本号
WHERE cb.thread_id = %s
AND cb.channel = 'messages'
ORDER BY cb.version::int DESC -- 取最新版本
LIMIT 1
""",
(session_id,),
)
record = await row.fetchone()
if not record:
logger.info(f"📭 会话 {session_id[:8]}... 暂无 checkpoint_blobs 记录")
return []
blob_type: str = record[0] # 'msgpack' | 'json' | 'bytes' 等
blob_data: bytes = record[1]
logger.info(
f"📦 读取 checkpoint_blobs: session={session_id[:8]}..., "
f"type={blob_type}, size={len(blob_data) if blob_data else 0} bytes"
)
messages = _deserialize_blob(blob_data, blob_type)
if messages is None:
logger.warning("⚠️ 反序列化失败,以空历史继续")
return []
logger.info(f"✅ 反序列化成功,历史消息数: {len(messages)}")
return messages
except Exception as e:
logger.error(f"❌ 从 checkpoint_blobs 读取历史失败: {e}", exc_info=True)
return []
# ════════════════════════════════════════════════════════════
# 对外接口
# ════════════════════════════════════════════════════════════
async def run_rag(
question: str,
session_id: str,
collection_name: str = "default",
top_k: int = 5,
) -> dict:
"""非流式 RAG(含多轮记忆)"""
graph = await build_rag_graph_with_saver()
config = {"configurable": {"thread_id": session_id}}
initial_state = RAGState(
question=question,
collection_name=collection_name,
top_k=top_k,
context="",
sources=[],
answer="",
thinking=None,
messages=[],
)
result = await graph.ainvoke(initial_state, config=config)
return {
"answer": result["answer"],
"thinking": result["thinking"],
"sources": result["sources"],
"session_id": session_id,
}
async def run_rag_stream(
question: str,
session_id: str,
collection_name: str = "default",
top_k: int = 5,
) -> AsyncGenerator[str, None]:
"""
流式 RAG(SSE)- 多轮记忆版本。
history 直接从 checkpoint_blobs 表读取并反序列化,
无需走 saver.aget() 全量 checkpoint 加载。
"""
# ── 1. 向量检索 ──────────────────────────────────────────
vector_store = get_vector_store(collection_name)
docs = await vector_store.asimilarity_search(question, k=top_k)
sources, context_parts = [], []
for i, doc in enumerate(docs, 1):
source = doc.metadata.get("source", "未知来源")
page = doc.metadata.get("page", "")
page_info = f" (第{page}页)" if page else ""
context_parts.append(
f"【参考片段 {i}】来源: {source}{page_info}\n{doc.page_content}"
)
sources.append({
"source": source,
"page": page,
"content_preview": (
doc.page_content[:150] + "..."
if len(doc.page_content) > 150
else doc.page_content
),
})
context = "\n\n".join(context_parts) if context_parts else "未找到相关文档内容。"
# ── 2. 从 checkpoint_blobs 读取历史(核心修改)──────────
history: List[BaseMessage] = await _load_history_from_checkpoint_blobs(session_id)
logger.info(f"📜 会话 {session_id[:8]}... 历史消息数: {len(history)}")
# ── 3. 构建完整消息链 ────────────────────────────────────
system_msg = SystemMessage(content=f"""你是一个专业的知识库助手,拥有完整的多轮对话记忆。
职责:基于知识库检索内容回答,结合历史上下文理解追问,没有相关信息则明确说明,使用中文。
当前检索到的参考资料:
{context}""")
current_user_msg = HumanMessage(content=question)
all_messages = [system_msg] + list(history) + [current_user_msg]
# ── 4. 流式生成 ──────────────────────────────────────────
llm = ChatOllama(
base_url=settings.OLLAMA_BASE_URL,
model=settings.OLLAMA_MODEL,
temperature=0.1,
)
yield f"data: {json.dumps({'type': 'sources', 'data': sources}, ensure_ascii=False)}\n\n"
full_answer = ""
async for chunk in llm.astream(all_messages):
if chunk.content:
full_answer += chunk.content
yield f"data: {json.dumps({'type': 'token', 'data': chunk.content}, ensure_ascii=False)}\n\n"
# ── 5. 提取思考过程 ──────────────────────────────────────
thinking = None
clean_answer = full_answer
think_match = re.search(r"<think>(.*?)</think>", full_answer, re.DOTALL)
if think_match:
thinking = think_match.group(1).strip()
clean_answer = full_answer[think_match.end():].strip()
# ── 6. 持久化本轮对话到 PostgresSaver ───────────────────
# 沿用标准 saver.aput() 写入,确保 checkpoint_blobs 正确更新
try:
conn_str = _sync_conn_str()
conn = await psycopg.AsyncConnection.connect(conn_str)
saver = AsyncPostgresSaver(conn)
await saver.setup()
new_messages = list(history) + [
current_user_msg,
AIMessage(content=clean_answer),
]
config = {"configurable": {"thread_id": session_id, "checkpoint_ns": ""}}
# 读取现有 checkpoint 元数据(避免覆盖版本号)
existing = await saver.aget(config)
checkpoint_id = existing["id"] if existing else session_id
new_messages_dict_list = []
for new_message in new_messages:
if new_message.type == "human":
new_messages_dict_list.append({"type": "human", "data": {"content": new_message.content}})
elif new_message.type == "ai":
new_messages_dict_list.append({"type": "ai", "data": {"content": new_message.content}})
new_checkpoint = {
"v": 1,
"ts": "",
"id": checkpoint_id,
"channel_values": {"messages": new_messages_dict_list},
"channel_versions": existing.get("channel_versions", {}) if existing else {},
"versions_seen": existing.get("versions_seen", {}) if existing else {},
"pending_sends": [],
}
await saver.adelete_thread(session_id)
await saver.aput(config, new_checkpoint, {}, {"messages": 1})
await conn.commit()
await conn.close()
logger.info(
f"💾 会话 {session_id[:8]}... 持久化完成,"
f"共 {len(new_messages)} 条消息"
)
except Exception as e:
logger.error(f"❌ 持久化失败(不影响本次回答): {e}", exc_info=True)
yield f"data: {json.dumps({'type': 'done', 'thinking': thinking}, ensure_ascii=False)}\n\n"
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, DeclarativeBase
from sqlalchemy import text
from app.config import settings
import logging
import psycopg
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
pass
engine = create_async_engine(
settings.PGVECTOR_CONNECTION_STRING,
echo=False,
pool_size=10,
max_overflow=20,
)
AsyncSessionLocal = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
async def get_db():
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
async def init_db():
conn = await psycopg.AsyncConnection.connect(settings.DATABASE_URL, autocommit=True)
saver = AsyncPostgresSaver(conn)
await saver.setup()
"""初始化数据库:pgvector 扩展 + 业务表"""
async with engine.begin() as conn:
# 启用 pgvector 扩展
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
# 文档元数据表
await conn.execute(text("""
CREATE TABLE IF NOT EXISTS document_metadata (
id SERIAL PRIMARY KEY,
filename VARCHAR(255) NOT NULL,
file_type VARCHAR(50) NOT NULL,
file_size BIGINT,
chunk_count INTEGER DEFAULT 0,
status VARCHAR(50) DEFAULT 'processing',
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)
"""))
# 🆕 会话表
await conn.execute(text("""
CREATE TABLE IF NOT EXISTS chat_sessions (
id VARCHAR(64) PRIMARY KEY,
title VARCHAR(255) NOT NULL DEFAULT '新对话',
collection_name VARCHAR(100) NOT NULL DEFAULT 'default',
message_count INTEGER DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)
"""))
# 🆕 消息记录表(用于前端展示,与 PostgresSaver checkpoint 分开)
await conn.execute(text("""
CREATE TABLE IF NOT EXISTS chat_messages (
id SERIAL PRIMARY KEY,
session_id VARCHAR(64) NOT NULL REFERENCES chat_sessions(id) ON DELETE CASCADE,
role VARCHAR(20) NOT NULL, -- 'user' | 'assistant'
content TEXT NOT NULL,
thinking TEXT,
sources JSONB,
created_at TIMESTAMPTZ DEFAULT NOW()
)
"""))
await conn.execute(text("""
CREATE INDEX IF NOT EXISTS idx_chat_messages_session_id
ON chat_messages(session_id)
"""))
logger.info("✅ 数据库初始化完成")
更多推荐


所有评论(0)