基于 RAG 的 AI 应用从理论到工业级实战
引言:当 ChatGPT 开始 “胡说八道”,我们如何给 AI 注入精准记忆?
凌晨三点,你正在用 AI 助手写项目文档。你问它:“我们项目的数据库表结构是怎样的?”AI 自信地生成了一段看似专业的回答,但你仔细一看 —— 表名错了,字段类型错了,连主键都搞错了。更可怕的是,它用极其肯定的语气说:“根据我们的系统设计……”
这就是大模型最致命的 “幻觉” 问题。它不知道 “不知道”,只会基于训练数据中的统计规律生成看似合理的答案。但今天,我要告诉你一个革命性的解决方案:RAG(检索增强生成)—— 给大模型装上一个 “外挂大脑”,让它能实时查阅你的专属知识库,给出精准、可靠、有依据的回答。
本文将从零开始,带你构建一个完整的工业级 RAG 系统。无论你是 AI 新手还是经验丰富的开发者,都能在这里找到从理论到实践的完整路径。
第一部分:RAG 为何是 AI 应用的 “下一件大事”?
1.1 大模型的先天缺陷与 RAG 的诞生
让我们先理解问题的本质。当前的大语言模型(LLM)存在三个核心限制:
1. 知识固化问题

python


# 大模型的“知识截止日期”困境
class TraditionalLLM:
def __init__(self, training_cutoff_date="2023-01"):
self
.knowledge_cutoff =
training_cutoff_date
# 模型不知道这之后发生的任何事情
def answer_question(self, question):
# 比如问:"2024年Python的最新版本是多少?"
# 如果训练数据只到2023年,它会回答3.11
# 但正确答案是3.12(2024年10月发布)
return "基于我的训练数据,最新版本是Python 3.11"
2. 幻觉问题

python


# 大模型如何“编造”答案
def generate_response(prompt):
# 模型没有访问真实数据的能力
# 只能基于语言模式生成“听起来合理”的答案
if "公司财务数据" in prompt:
# 可能生成:2023年营收1.2亿元,增长15%
# 但实际数据可能是:营收9500万元,下降5%
return generate_plausible_financial_statement()
if "技术架构" in prompt:
# 可能混合不同项目的架构细节
return mix_architectures_from_training_data()
3. 缺乏私有数据访问

python


# 你的私有数据对通用大模型是“黑盒”
company_knowledge_base
= {
"内部API文档": "...",
"客户数据库": "...",
"项目代码库": "...",
"会议纪要": "..."
}
# 通用大模型无法访问这些
chatgpt
.answer("我们上周评审会的结论是什么?")
# 输出: "我无法访问您公司的内部会议记录"
1.2 RAG 的核心思想:像人类一样 “查资料再回答”
RAG 的工作流程模拟了人类专家的思考过程:

plaintext


传统大模型回答:
用户问题 → 模型记忆 → 生成答案(可能错误)
RAG系统回答:
用户问题 → 检索相关文档 → 结合文档上下文 → 生成准确答案
神经科学类比:
• 大模型本身:像大脑的前额叶皮层,负责语言生成和推理
• 向量数据库:像大脑的海马体,负责记忆检索
• 文档存储:像大脑的皮层,负责长期记忆存储
1.3 为什么现在爆发?技术栈的成熟
RAG 概念并不新,但直到最近才成为主流,因为:
1. 大模型能力突破:GPT-4、Claude-3 等模型的理解能力足够强
2. 向量数据库成熟:Milvus、Pinecone、Weaviate 等专用数据库出现
3. 嵌入模型优化:Sentence-BERT、OpenAI Embeddings 等质量大幅提升
4. 开源生态完善:LangChain、LlamaIndex 等框架降低了开发门槛
第二部分:RAG 系统架构深度解析
2.1 整体架构:从数据到答案的完整流水线

plaintext


┌─────────────────────────────────────────────────────────┐
│ 工业级RAG系统架构 │
├─────────────┬─────────────┬─────────────┬───────────────┤
│ 数据层 │ 索引层 │ 检索层 │ 生成层 │
├─────────────┼─────────────┼─────────────┼───────────────┤
│ • 文档存储 │ • 文本分块 │ • 向量检索 │ • 提示工程 │
│ • 对象存储 │ • 向量化 │ • 混合检索 │ • 上下文压缩 │
│ • 数据库 │ • 元数据 │ • 重排序 │ • 流式输出 │
│ • API接口 │ • 索引构建 │ • 相关性 │ • 引用溯源 │
└─────────────┴─────────────┴─────────────┴───────────────┘
2.2 核心组件详解
组件 1:文档处理管道(Document Processing Pipeline)

python


class DocumentProcessor:
"""工业级文档处理管道"""
def __init__(self, chunk_strategy="semantic", min_chunk_size=200, max_chunk_size=1000):
self
.chunk_strategy =
chunk_strategy
self
.min_chunk_size =
min_chunk_size
self
.max_chunk_size =
max_chunk_size
def process_document(self, document_path, document_type):
"""
处理各种类型的文档
支持: PDF, Word, Excel, PPT, Markdown, HTML, 纯文本
"""
# 1. 提取原始文本
raw_text
= self._extract_text(document_path, document_type)
# 2. 清理和规范化
cleaned_text
= self._clean_text(raw_text)
# 3. 智能分块(关键步骤)
if self.chunk_strategy == "semantic":
chunks
= self._semantic_chunking(cleaned_text)
elif self.chunk_strategy == "recursive":
chunks
= self._recursive_chunking(cleaned_text)
elif self.chunk_strategy == "fixed":
chunks
= self._fixed_size_chunking(cleaned_text)
# 4. 提取元数据
metadata
= self._extract_metadata(
document_path
, cleaned_text,
document_type
)
# 5. 质量检查
valid_chunks
= self._quality_filter(chunks)
return {
"chunks": valid_chunks,
"metadata": metadata,
"original_path": document_path,
"processing_time": time.time() -
start_time
}
def _semantic_chunking(self, text):
"""
基于语义的分块策略
保持语义完整性的同时控制块大小
"""
chunks
= []
# 使用句子分割
sentences
= self._split_into_sentences(text)
current_chunk
= []
current_length
= 0
for sentence in sentences:
sentence_length
= len(sentence)
# 如果当前块太小,继续添加
if current_length + sentence_length < self.min_chunk_size:
current_chunk
.append(sentence)
current_length
+=
sentence_length
continue
# 如果添加后不会超过最大限制
if current_length + sentence_length <= self.max_chunk_size:
current_chunk
.append(sentence)
current_length
+=
sentence_length
else:
# 保存当前块
if current_chunk:
chunk_text
= " ".join(current_chunk)
chunks
.append(chunk_text)
# 开始新块
current_chunk
= [sentence]
current_length
=
sentence_length
# 添加最后一个块
if current_chunk:
chunk_text
= " ".join(current_chunk)
chunks
.append(chunk_text)
return
chunks
def _extract_metadata(self, file_path, text, doc_type):
"""提取丰富的元数据"""
metadata
= {
"source": file_path,
"type": doc_type,
"file_size": os.path.getsize(file_path),
"processed_date": datetime.now().isoformat(),
"total_chars": len(text),
"estimated_tokens": len(text) // 4, # 粗略估计
}
# 提取文档特定元数据
if doc_type == "pdf":
metadata
.update(self._extract_pdf_metadata(file_path))
elif doc_type == "docx":
metadata
.update(self._extract_docx_metadata(file_path))
# 从内容提取关键信息
content_metadata
= self._analyze_content(text)
metadata
.update(content_metadata)
return
metadata
def _analyze_content(self, text):
"""分析内容特征"""
# 简单的关键词提取
words
= text.lower().split()
word_freq
= Counter(words)
top_keywords
= word_freq.most_common(10)
# 检测文档类型
doc_category
= self._classify_document(text)
# 估计阅读时间(按200字/分钟)
reading_time_minutes
= len(text.split()) // 200
return {
"top_keywords": [kw[0] for kw in top_keywords],
"category": doc_category,
"estimated_reading_minutes": max(1, reading_time_minutes),
"has_code_blocks": "```" in text,
"has_tables": "|" in text and "-" in text,
}
组件 2:向量化引擎(Embedding Engine)

python


class EmbeddingEngine:
"""多模型向量化引擎"""
def __init__(self, model_name="text-embedding-3-small", cache_enabled=True):
"""
支持多种嵌入模型:
- OpenAI: text-embedding-3-small/large
- 本地: sentence-transformers/all-MiniLM-L6-v2
- 国产: BGE, m3e
"""
self
.model_name =
model_name
self
.cache_enabled =
cache_enabled
self
.cache = {} if cache_enabled else None
self
.dimension = self._get_model_dimension(model_name)
# 初始化模型
self
.model = self._load_model(model_name)
def _load_model(self, model_name):
"""加载嵌入模型"""
if model_name.startswith("text-embedding"):
# OpenAI模型
return OpenAIModelWrapper(model_name)
elif "sentence-transformers" in model_name:
# 本地Sentence Transformers模型
from sentence_transformers import
SentenceTransformer
return SentenceTransformer(model_name)
elif model_name.startswith("BGE"):
# 智源BGE模型
return BGEModelWrapper(model_name)
else:
raise ValueError(f"不支持的模型: {model_name}")
def embed(self, texts, batch_size=32, **kwargs):
"""批量生成嵌入向量"""
# 检查缓存
if self.cache_enabled:
cached_results
= []
uncached_texts
= []
uncached_indices
= []
for i, text in enumerate(texts):
cache_key
= self._get_cache_key(text, kwargs)
if cache_key in self.cache:
cached_results
.append((i, self.cache[cache_key]))
else:
uncached_texts
.append(text)
uncached_indices
.append(i)
# 如果有缓存命中
if cached_results:
print(f"缓存命中: {len(cached_results)}/{len(texts)}")
else:
uncached_texts
=
texts
uncached_indices
= list(range(len(texts)))
# 处理未缓存的文本
if uncached_texts:
embeddings
= self._batch_embed(uncached_texts, batch_size, **kwargs)
# 更新缓存
if self.cache_enabled:
for idx, text, embedding in zip(uncached_indices, uncached_texts, embeddings):
cache_key
= self._get_cache_key(text, kwargs)
self
.cache[cache_key] =
embedding
# 合并结果
if cached_results:
all_embeddings
= [None] * len(texts)
for i, emb in cached_results:
all_embeddings
[i] =
emb
for i, emb in zip(uncached_indices, embeddings):
all_embeddings
[i] =
emb
return np.array(all_embeddings)
else:
return np.array(embeddings)
else:
# 全部来自缓存
all_embeddings
= [emb for _, emb in sorted(cached_results)]
return np.array(all_embeddings)
def _batch_embed(self, texts, batch_size, **kwargs):
"""批量处理嵌入"""
all_embeddings
= []
for i in range(0, len(texts), batch_size):
batch
= texts[i:i+batch_size]
# 进度显示
progress
= (i + len(batch)) / len(texts) * 100
print(f"嵌入处理: {progress:.1f}%", end="\r")
# 根据模型类型调用
if isinstance(self.model, OpenAIModelWrapper):
batch_embeddings
= self.model.embed(batch, **kwargs)
else:
# 本地模型
batch_embeddings
= self.model.encode(
batch
,
show_progress_bar
=False,
**
kwargs
)
all_embeddings
.extend(batch_embeddings)
print(f"嵌入完成: {len(texts)}个文本")
return
all_embeddings
def get_similarity(self, query_embedding, document_embeddings, metric="cosine"):
"""计算相似度"""
if metric == "cosine":
# 余弦相似度
query_norm
= np.linalg.norm(query_embedding)
doc_norms
= np.linalg.norm(document_embeddings, axis=1)
# 避免除零
if query_norm == 0:
return np.zeros(len(document_embeddings))
similarities
= np.dot(document_embeddings, query_embedding)
similarities
= similarities / (doc_norms * query_norm)
# 处理可能的数值误差
similarities
= np.clip(similarities, -1.0, 1.0)
return
similarities
elif metric == "euclidean":
# 欧氏距离(转换为相似度)
distances
= np.linalg.norm(document_embeddings - query_embedding, axis=1)
# 将距离转换为相似度(距离越小,相似度越高)
max_distance
= np.max(distances)
if max_distance > 0:
similarities
= 1 - (distances / max_distance)
else:
similarities
= np.ones_like(distances)
return
similarities
elif metric == "dot":
# 点积相似度
return np.dot(document_embeddings, query_embedding)
else:
raise ValueError(f"不支持的相似度度量: {metric}")
def optimize_for_domain(self, domain_texts, epochs=3):
"""
领域自适应优化
在特定领域文本上微调嵌入模型
"""
if not hasattr(self.model, "fit"):
print("当前模型不支持微调")
return
print(f"开始领域自适应优化,使用{len(domain_texts)}个样本")
# 准备训练数据
# 这里可以使用对比学习、三元组损失等方法
train_data
= self._prepare_training_data(domain_texts)
# 微调模型
self
.model.fit(
train_objectives
=train_data,
epochs
=epochs,
warmup_steps
=100,
show_progress_bar
=True
)
print("领域自适应优化完成")
组件 3:智能检索器(Hybrid Retriever)

python


class HybridRetriever:
"""混合检索器:结合向量检索和关键词检索"""
def __init__(self, vector_store, keyword_store=None):
self
.vector_store = vector_store # 向量数据库
self
.keyword_store = keyword_store # 关键词索引(如Elasticsearch)
self
.reranker = None # 重排序模型
def retrieve(self, query, top_k=10, alpha=0.7, use_rerank=True):
"""
混合检索
alpha: 向量检索权重 (1-alpha): 关键词检索权重
"""
# 1. 并行执行两种检索
vector_results
= self._vector_retrieve(query, top_k * 2)
keyword_results
= self._keyword_retrieve(query, top_k * 2)
# 2. 分数归一化
vector_scores
= self._normalize_scores(
[r["score"] for r in vector_results]
)
keyword_scores
= self._normalize_scores(
[r["score"] for r in keyword_results]
)
# 3. 融合结果
all_docs
= {}
# 处理向量检索结果
for i, result in enumerate(vector_results):
doc_id
= result["id"]
if doc_id not in all_docs:
all_docs
[doc_id] = {
"content": result["content"],
"metadata": result["metadata"],
"vector_score": vector_scores[i],
"keyword_score": 0.0
}
else:
all_docs
[doc_id]["vector_score"] = vector_scores[i]
# 处理关键词检索结果
for i, result in enumerate(keyword_results):
doc_id
= result["id"]
if doc_id not in all_docs:
all_docs
[doc_id] = {
"content": result["content"],
"metadata": result["metadata"],
"vector_score": 0.0,
"keyword_score": keyword_scores[i]
}
else:
all_docs
[doc_id]["keyword_score"] = keyword_scores[i]
# 4. 计算综合分数
for doc_id in all_docs:
doc
= all_docs[doc_id]
combined_score
= (
alpha
* doc["vector_score"] +
(1 - alpha) * doc["keyword_score"]
)
doc
["combined_score"] =
combined_score
# 5. 按综合分数排序
sorted_docs
= sorted(
all_docs
.items(),
key
=lambda x: x[1]["combined_score"],
reverse
=True
)
# 6. 重排序(可选)
if use_rerank and self.reranker:
reranked_docs
= self._rerank(query, sorted_docs[:top_k*3])
final_results
= reranked_docs[:top_k]
else:
final_results
= [
{"id": doc_id, **doc_info}
for doc_id, doc_info in sorted_docs[:top_k]
]
# 7. 添加检索元数据
for result in final_results:
result
["retrieval_metadata"] = {
"vector_score": result.get("vector_score", 0),
"keyword_score": result.get("keyword_score", 0),
"combined_score": result.get("combined_score", 0),
"retrieval_method": "hybrid",
"alpha":
alpha
}
return
final_results
def _vector_retrieve(self, query, top_k):
"""向量检索"""
# 生成查询向量
query_embedding
= self.embedding_engine.embed([query])[0]
# 在向量数据库中搜索
results
= self.vector_store.search(
query_embedding
,
top_k
=top_k,
filter_conditions
=None # 可以添加元数据过滤
)
return
results
def _keyword_retrieve(self, query, top_k):
"""关键词检索"""
if not self.keyword_store:
# 如果没有关键词索引,返回空结果
return []
# 提取查询关键词
keywords
= self._extract_keywords(query)
# 在关键词索引中搜索
results
= self.keyword_store.search(
query
=keywords,
top_k
=top_k,
fields
=["content", "title", "keywords"]
)
return
results
def _rerank(self, query, candidates):
"""使用重排序模型优化结果"""
if not self.reranker:
# 初始化重排序模型
self
.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# 准备重排序数据
pairs
= [(query, cand["content"]) for cand in candidates]
# 计算相关性分数
scores
= self.reranker.predict(pairs)
# 更新分数并重新排序
for i, cand in enumerate(candidates):
cand
["rerank_score"] = float(scores[i])
# 可以调整最终分数,比如结合原始分数和重排序分数
cand
["final_score"] = (
0.3 * cand.get("combined_score", 0) +
0.7 * cand["rerank_score"]
)
# 按最终分数排序
reranked
= sorted(
candidates
,
key
=lambda x: x["final_score"],
reverse
=True
)
return
reranked
def _extract_keywords(self, text):
"""从文本提取关键词"""
# 使用TF-IDF或TextRank等算法
# 这里使用简化的版本
# 移除停用词
words
= text.lower().split()
stop_words
= set(["the", "a", "an", "and", "or", "but", "in", "on", "at"])
filtered_words
= [w for w in words if w not in stop_words and len(w) > 2]
# 统计词频
word_freq
= Counter(filtered_words)
# 返回最高频的词
top_keywords
= [word for word, _ in word_freq.most_common(5)]
return
top_keywords
第三部分:从零搭建工业级 RAG 系统
3.1 环境准备与架构选型

yaml


# docker-compose.yml - 完整的RAG基础设施
version: '3.8'
services:
# 向量数据库 - 存储文档向量
milvus:
image: milvusdb/milvus:
v2.3.3
container_name: rag-
milvus
environment:
ETCD_ENDPOINTS: etcd:2379
MINIO_ADDRESS: minio:9000
volumes:
- ./volumes/milvus:
/var/lib/milvus
ports:
- "19530:19530"
- "9091:9091"
depends_on:
-
etcd
-
minio
networks:
- rag-
network
# 对象存储 - 存储原始文档和索引
minio:
image: minio/minio:RELEASE.2023-03-20T20-16-
18Z
container_name: rag-
minio
environment:
MINIO_ROOT_USER:
minioadmin
MINIO_ROOT_PASSWORD:
minioadmin
volumes:
- ./volumes/minio:
/data
ports:
- "9000:9000"
- "9001:9001"
command: server /data --console-address ":
9001"
networks:
- rag-
network
# 关键词检索(可选) - 用于混合检索
elasticsearch:
image: elasticsearch:
8.11.0
container_name: rag-
elasticsearch
environment:
- discovery.type=single-
node
-
xpack.security.enabled=false
- "ES_JAVA_OPTS=-Xms512m -Xmx512m"
volumes:
- ./volumes/elasticsearch:
/usr/share/elasticsearch/data
ports:
- "9200:9200"
networks:
- rag-
network
# 关系数据库 - 存储元数据和系统状态
postgres:
image: postgres:15-
alpine
container_name: rag-
postgres
environment:
POSTGRES_DB:
rag_system
POSTGRES_USER:
rag_user
POSTGRES_PASSWORD:
rag_password
volumes:
- ./volumes/postgres:
/var/lib/postgresql/data
- ./init-scripts:/docker-entrypoint-
initdb.d
ports:
- "5432:5432"
networks:
- rag-
network
# 缓存层 - 提高检索性能
redis:
image: redis:7-
alpine
container_name: rag-
redis
ports:
- "6379:6379"
volumes:
- ./volumes/redis:
/data
networks:
- rag-
network
# API服务 - 提供REST接口
rag-api:
build:
./api
container_name: rag-
api
environment:
-
MILVUS_HOST=milvus
-
MILVUS_PORT=19530
- MINIO_ENDPOINT=http://minio:9000
- POSTGRES_DSN=postgresql://rag_user:rag_password@postgres:
5432/rag_system
- REDIS_URL=redis://redis:
6379/0
ports:
- "8000:8000"
depends_on:
-
milvus
-
minio
-
postgres
-
redis
volumes:
- ./api:
/app
- ./data:
/data
networks:
- rag-
network
# 监控系统 - 监控系统状态
grafana:
image: grafana/grafana:
10.0.0
container_name: rag-
grafana
ports:
- "3000:3000"
volumes:
- ./volumes/grafana:
/var/lib/grafana
- ./monitoring/dashboards:
/etc/grafana/provisioning/dashboards
environment:
-
GF_SECURITY_ADMIN_PASSWORD=admin
networks:
- rag-
network
prometheus:
image: prom/prometheus:
v2.45.0
container_name: rag-
prometheus
ports:
- "9090:9090"
volumes:
- ./monitoring/prometheus.yml:
/etc/prometheus/prometheus.yml
- ./volumes/prometheus:
/prometheus
networks:
- rag-
network
networks:
rag-network:
driver:
bridge
volumes:
milvus:
minio:
elasticsearch:
postgres:
redis:
grafana:
prometheus:
3.2 核心系统实现

python


# rag_system.py - 完整的RAG系统实现
import
asyncio
import
hashlib
import
json
import
time
from datetime import
datetime
from typing import Dict, List, Optional,
Any
from dataclasses import dataclass,
asdict
from enum import
Enum
import numpy as
np
from pymilvus import Collection,
connections
from minio import
Minio
from redis import
Redis
import
psycopg2
from psycopg2.extras import
RealDictCursor
class DocumentStatus(Enum):
"""文档处理状态"""
PENDING
= "pending"
PROCESSING
= "processing"
INDEXED
= "indexed"
FAILED
= "failed"
@dataclass
class Document:
"""文档元数据"""
id: str
filename
: str
filepath
: str
file_type
: str
file_size
: int
status
:
DocumentStatus
chunks_count
: int = 0
vector_count
: int = 0
metadata
: Dict[str, Any] = None
created_at
: datetime = None
updated_at
: datetime = None
processing_time
: float = 0.0
def __post_init__(self):
if self.metadata is None:
self
.metadata = {}
if self.created_at is None:
self
.created_at = datetime.now()
if self.updated_at is None:
self
.updated_at = self.
created_at
class RAGSystem:
"""完整的RAG系统"""
def __init__(self, config: Dict[str, Any]):
self
.config =
config
self
.document_processor = DocumentProcessor()
self
.embedding_engine = EmbeddingEngine(
model_name
=config.get("embedding_model", "text-embedding-3-small")
)
# 初始化存储连接
self
._init_storage()
# 初始化集合
self
._init_collections()
# 统计信息
self
.stats = {
"documents_processed": 0,
"chunks_indexed": 0,
"queries_processed": 0,
"avg_response_time": 0.0
}
def _init_storage(self):
"""初始化所有存储连接"""
# 连接Milvus
connections
.connect(
alias
="default",
host
=self.config["milvus_host"],
port
=self.config["milvus_port"]
)
# 连接MinIO
self
.minio_client = Minio(
endpoint
=self.config["minio_endpoint"],
access_key
=self.config["minio_access_key"],
secret_key
=self.config["minio_secret_key"],
secure
=False
)
# 连接Redis
self
.redis_client = Redis(
host
=self.config["redis_host"],
port
=self.config["redis_port"],
db
=0,
decode_responses
=True
)
# 连接PostgreSQL
self
.pg_conn = psycopg2.connect(
host
=self.config["postgres_host"],
port
=self.config["postgres_port"],
database
=self.config["postgres_db"],
user
=self.config["postgres_user"],
password
=self.config["postgres_password"]
)
# 创建数据库表
self
._create_tables()
def _create_tables(self):
"""创建必要的数据库表"""
with self.pg_conn.cursor() as cursor:
# 文档表
cursor
.execute(
"""
CREATE TABLE IF NOT EXISTS documents (
id VARCHAR(64) PRIMARY KEY,
filename VARCHAR(512) NOT NULL,
filepath VARCHAR(1024) NOT NULL,
file_type VARCHAR(32),
file_size INTEGER,
status VARCHAR(32) DEFAULT 'pending',
chunks_count INTEGER DEFAULT 0,
vector_count INTEGER DEFAULT 0,
metadata JSONB,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
processing_time FLOAT DEFAULT 0.0
)
"""
)
# 查询历史表
cursor
.execute(
"""
CREATE TABLE IF NOT EXISTS query_history (
id SERIAL PRIMARY KEY,
query_text TEXT NOT NULL,
response_text TEXT,
retrieved_documents JSONB,
metadata JSONB,
response_time FLOAT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
# 系统指标表
cursor
.execute(
"""
CREATE TABLE IF NOT EXISTS system_metrics (
id SERIAL PRIMARY KEY,
metric_name VARCHAR(128) NOT NULL,
metric_value FLOAT NOT NULL,
labels JSONB,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
self
.pg_conn.commit()
def _init_collections(self):
"""初始化Milvus集合"""
# 检查集合是否存在
from pymilvus import
utility
collection_name
= self.config.get("collection_name", "rag_documents")
if not utility.has_collection(collection_name):
# 创建集合
from pymilvus import FieldSchema, CollectionSchema,
DataType
fields
= [
FieldSchema
(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=64),
FieldSchema
(name="document_id", dtype=DataType.VARCHAR, max_length=64),
FieldSchema
(name="chunk_index", dtype=DataType.INT64),
FieldSchema
(name="content", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema
(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.embedding_engine.dimension),
FieldSchema
(name="metadata", dtype=DataType.JSON),
FieldSchema
(name="created_at", dtype=DataType.INT64),
]
schema
= CollectionSchema(
fields
=fields,
description
="RAG文档向量存储",
enable_dynamic_field
=True
)
self
.collection = Collection(
name
=collection_name,
schema
=schema,
using
="default",
shards_num
=2
)
# 创建索引
index_params
= {
"index_type": "IVF_FLAT",
"metric_type": "COSINE",
"params": {"nlist": 1024}
}
self
.collection.create_index(
field_name
="vector",
index_params
=index_params,
index_name
="vector_idx"
)
print(f"创建集合: {collection_name}")
else:
self
.collection = Collection(collection_name)
print(f"加载现有集合: {collection_name}")
# 加载集合到内存
self
.collection.load()
async def ingest_document(self, file_path: str, metadata: Dict = None) -> Dict:
"""
摄取文档的完整流程
"""
start_time
= time.time()
try:
# 1. 生成文档ID
doc_id
= self._generate_document_id(file_path)
# 2. 保存到数据库(初始状态)
document
= Document(
id=doc_id,
filename
=os.path.basename(file_path),
filepath
=file_path,
file_type
=self._detect_file_type(file_path),
file_size
=os.path.getsize(file_path),
status
=DocumentStatus.PROCESSING,
metadata
=metadata or {}
)
self
._save_document(document)
# 3. 上传到MinIO(原始文件备份)
minio_path
= f"documents/{doc_id}/{os.path.basename(file_path)}"
self
.minio_client.fput_object(
bucket_name
="rag-documents",
object_name
=minio_path,
file_path
=file_path,
metadata
=
metadata
)
# 4. 处理文档
processing_result
= self.document_processor.process_document(
file_path
, document.
file_type
)
# 5. 生成向量
chunks
= processing_result["chunks"]
chunk_vectors
= self.embedding_engine.embed(chunks)
# 6. 存储到Milvus
entities
= self._prepare_entities(
doc_id
, chunks, chunk_vectors, processing_result["metadata"]
)
insert_result
= self.collection.insert(entities)
self
.collection.flush()
# 7. 更新文档状态
document
.status = DocumentStatus.
INDEXED
document
.chunks_count = len(chunks)
document
.vector_count = len(entities)
document
.processing_time = time.time() -
start_time
document
.updated_at = datetime.now()
self
._update_document(document)
# 8. 更新统计信息
self
.stats["documents_processed"] += 1
self
.stats["chunks_indexed"] += len(chunks)
return {
"success": True,
"document_id": doc_id,
"chunks_count": len(chunks),
"processing_time": document.processing_time,
"minio_path":
minio_path
}
except Exception as e:
# 处理失败
error_msg
= str(e)
print(f"文档摄取失败: {error_msg}")
# 更新文档状态为失败
if 'document' in locals():
document
.status = DocumentStatus.
FAILED
document
.metadata["error"] =
error_msg
self
._update_document(document)
return {
"success": False,
"error": error_msg,
"document_id": doc_id if 'doc_id' in locals() else None
}
def _prepare_entities(self, doc_id, chunks, vectors, metadata):
"""准备插入到Milvus的实体"""
entities
= []
current_time
= int(time.time())
for i, (chunk, vector) in enumerate(zip(chunks, vectors)):
# 生成块ID
chunk_id
= f"{doc_id}_{i:04d}"
entity
= {
"id": chunk_id,
"document_id": doc_id,
"chunk_index": i,
"content": chunk,
"vector": vector.tolist(),
"metadata": {
"document_id": doc_id,
"chunk_index": i,
"chunk_length": len(chunk),
"word_count": len(chunk.split()),
**
metadata
},
"created_at":
current_time
}
entities
.append(entity)
return
entities
async def query(self, question: str, top_k: int = 5, use_rerank: bool = True) -> Dict:
"""
查询RAG系统
"""
start_time
= time.time()
try:
# 1. 生成查询向量
query_vector
= self.embedding_engine.embed([question])[0]
# 2. 向量检索
search_params
= {
"metric_type": "COSINE",
"params": {"nprobe": 16}
}
search_results
= self.collection.search(
data
=[query_vector],
anns_field
="vector",
param
=search_params,
limit
=top_k * 2, # 多取一些用于重排序
output_fields
=["content", "metadata", "document_id"]
)
# 3. 解析检索结果
retrieved_docs
= []
for hits in search_results:
for hit in hits:
doc
= {
"id": hit.id,
"content": hit.entity.get("content"),
"metadata": hit.entity.get("metadata", {}),
"score": hit.score,
"distance": hit.distance,
"document_id": hit.entity.get("document_id")
}
retrieved_docs
.append(doc)
# 4. 重排序(可选)
if use_rerank and len(retrieved_docs) > 1:
retrieved_docs
= self._rerank_documents(question, retrieved_docs)
# 5. 构建上下文
context
= self._build_context(retrieved_docs[:top_k])
# 6. 调用LLM生成答案
answer
= await self._generate_answer(question, context)
# 7. 记录查询历史
response_time
= time.time() -
start_time
self
._log_query(question, answer, retrieved_docs, response_time)
# 8. 更新统计
self
.stats["queries_processed"] += 1
self
.stats["avg_response_time"] = (
(self.stats["avg_response_time"] * (self.stats["queries_processed"] - 1) + response_time) /
self
.stats["queries_processed"]
)
return {
"success": True,
"question": question,
"answer": answer["response"],
"context": context,
"retrieved_documents": retrieved_docs[:top_k],
"response_time": response_time,
"citations": answer.get("citations", []),
"confidence": answer.get("confidence", 0.0)
}
except Exception as e:
error_msg
= str(e)
print(f"查询失败: {error_msg}")
return {
"success": False,
"error": error_msg,
"question": question,
"response_time": time.time() -
start_time
}
async def _generate_answer(self, question: str, context: str) -> Dict:
"""使用LLM生成答案"""
# 这里可以使用OpenAI、Claude、本地模型等
# 示例使用OpenAI
import
openai
prompt
=
f"""基于以下上下文信息,回答用户的问题。
如果上下文信息不足以回答问题,请如实说明你不知道。
上下文信息:
{context}
用户问题:
{question}
请按照以下格式回答:
1. 首先给出直接答案
2. 然后引用上下文中的相关部分(使用【引用】标记)
3. 最后说明答案的置信度
答案:"""
try:
response
= await openai.ChatCompletion.acreate(
model
="gpt-4-turbo-preview",
messages
=[
{"role": "system", "content": "你是一个专业的助手,基于提供的上下文回答问题。"},
{"role": "user", "content": prompt}
],
temperature
=0.3,
max_tokens
=1000
)
answer_text
= response.choices[0].message.
content
# 解析答案,提取引用和置信度
citations
= self._extract_citations(answer_text)
confidence
= self._estimate_confidence(answer_text, question, context)
return {
"response": answer_text,
"citations": citations,
"confidence": confidence,
"model": "gpt-4-turbo-preview",
"tokens_used": response.usage.
total_tokens
}
except Exception as e:
# 如果调用失败,使用备用方案
return {
"response": f"基于上下文信息:{context[:500]}...\n\n问题:{question}\n\n答案:相关信息如上所示。",
"citations": [],
"confidence": 0.5,
"model": "fallback",
"error": str(e)
}
def _build_context(self, documents: List[Dict]) -> str:
"""构建LLM上下文"""
context_parts
= []
for i, doc in enumerate(documents, 1):
content
= doc["content"]
metadata
= doc.get("metadata", {})
source
= metadata.get("source", "未知来源")
context_part
= f"[文档{i} - 来源: {source}]\n{content}\n"
# 添加相关性分数(可选)
if "score" in doc:
context_part
+= f"[相关性分数: {doc['score']:.3f}]\n"
context_parts
.append(context_part)
# 添加分隔符
context
= "\n" + "="*50 + "\n".join(context_parts) + "="*50
return
context
def _rerank_documents(self, question: str, documents: List[Dict]) -> List[Dict]:
"""重排序检索结果"""
# 这里可以使用专门的reranker模型
# 示例使用简单的基于内容的重新排序
for doc in documents:
# 计算内容相似度(简单版本)
content
= doc["content"]
# 检查问题关键词是否在内容中
question_words
= set(question.lower().split())
content_words
= set(content.lower().split())
keyword_overlap
= len(question_words & content_words) / len(question_words)
# 更新分数(结合向量相似度和关键词重叠)
vector_score
= doc.get("score", 0)
doc
["rerank_score"] = 0.7 * vector_score + 0.3 *
keyword_overlap
# 按重排序分数排序
documents
.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
return
documents
def _generate_document_id(self, file_path: str) -> str:
"""生成文档ID"""
# 基于文件内容和路径生成唯一ID
file_hash
= hashlib.md5()
with open(file_path, 'rb') as f:
# 读取文件前1MB用于哈希
chunk
= f.read(1024 * 1024)
file_hash
.update(chunk)
# 添加文件路径和大小
file_hash
.update(file_path.encode())
file_hash
.update(str(os.path.getsize(file_path)).encode())
return file_hash.hexdigest()[:16]
def _save_document(self, document: Document):
"""保存文档到数据库"""
with self.pg_conn.cursor() as cursor:
cursor
.execute(
"""
INSERT INTO documents
(id, filename, filepath, file_type, file_size, status, metadata, chunks_count, vector_count, created_at, updated_at, processing_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
, (
document
.id,
document
.filename,
document
.filepath,
document
.file_type,
document
.file_size,
document
.status.value,
json
.dumps(document.metadata),
document
.chunks_count,
document
.vector_count,
document
.created_at,
document
.updated_at,
document
.
processing_time
))
self
.pg_conn.commit()
def _update_document(self, document: Document):
"""更新文档状态"""
with self.pg_conn.cursor() as cursor:
cursor
.execute(
"""
UPDATE documents
SET status = %s,
chunks_count = %s,
vector_count = %s,
metadata = %s,
updated_at = %s,
processing_time = %s
WHERE id = %s
"""
, (
document
.status.value,
document
.chunks_count,
document
.vector_count,
json
.dumps(document.metadata),
document
.updated_at,
document
.processing_time,
document
.id
))
self
.pg_conn.commit()
def _log_query(self, question: str, answer: Dict, documents: List[Dict], response_time: float):
"""记录查询历史"""
with self.pg_conn.cursor() as cursor:
cursor
.execute(
"""
INSERT INTO query_history
(query_text, response_text, retrieved_documents, metadata, response_time)
VALUES (%s, %s, %s, %s, %s)
"""
, (
question
,
answer
.get("response", ""),
json
.dumps(documents[:5]), # 只保存前5个文档
json
.dumps({
"model": answer.get("model"),
"confidence": answer.get("confidence"),
"citations_count": len(answer.get("citations", [])),
"tokens_used": answer.get("tokens_used", 0)
}),
response_time
))
self
.pg_conn.commit()
def get_system_stats(self) -> Dict:
"""获取系统统计信息"""
with self.pg_conn.cursor(cursor_factory=RealDictCursor) as cursor:
# 文档统计
cursor
.execute(
"""
SELECT
COUNT(*) as total_documents,
SUM(chunks_count) as total_chunks,
AVG(processing_time) as avg_processing_time,
status,
COUNT(*) as count_by_status
FROM documents
GROUP BY status
"""
)
doc_stats
= cursor.fetchall()
# 查询统计
cursor
.execute(
"""
SELECT
COUNT(*) as total_queries,
AVG(response_time) as avg_response_time,
DATE(created_at) as query_date,
COUNT(*) as daily_queries
FROM query_history
WHERE created_at >= NOW() - INTERVAL '7 days'
GROUP BY DATE(created_at)
ORDER BY query_date DESC
"""
)
query_stats
= cursor.fetchall()
return {
"document_statistics": doc_stats,
"query_statistics": query_stats,
"system_stats": self.stats,
"collection_info": {
"name": self.collection.name,
"num_entities": self.collection.num_entities,
"is_loaded": self.collection.
is_loaded
},
"timestamp": datetime.now().isoformat()
}
async def batch_ingest(self, directory_path: str, file_patterns: List[str] = None):
"""批量摄取目录中的文档"""
if file_patterns is None:
file_patterns
= ["*.pdf", "*.docx", "*.txt", "*.md", "*.html"]
# 收集文件
file_paths
= []
for pattern in file_patterns:
for file_path in Path(directory_path).rglob(pattern):
file_paths
.append(str(file_path))
print(f"找到 {len(file_paths)} 个文件待处理")
results
= {
"total": len(file_paths),
"success": 0,
"failed": 0,
"details": []
}
# 使用信号量控制并发数
semaphore
= asyncio.Semaphore(5) # 同时处理5个文件
async def process_file(file_path):
async with semaphore:
try:
result
= await self.ingest_document(file_path)
return
result
except Exception as e:
return {"success": False, "error": str(e), "file_path": file_path}
# 创建任务
tasks
= [process_file(fp) for fp in file_paths]
# 并发执行
for i, task in enumerate(asyncio.as_completed(tasks)):
result
= await
task
results
["details"].append(result)
if result["success"]:
results
["success"] += 1
else:
results
["failed"] += 1
# 进度显示
progress
= (i + 1) / len(file_paths) * 100
print(f"处理进度: {progress:.1f}% ({i+1}/{len(file_paths)})")
print(f"批量处理完成: {results['success']} 成功, {results['failed']} 失败")
return
results
def cleanup(self):
"""清理资源"""
if hasattr(self, 'collection'):
self
.collection.release()
if hasattr(self, 'pg_conn'):
self
.pg_conn.close()
print("RAG系统已关闭")
3.3 Web API 接口

python


# api/main.py - FastAPI Web接口
from fastapi import FastAPI, HTTPException, UploadFile, File, Query,
BackgroundTasks
from fastapi.middleware.cors import
CORSMiddleware
from fastapi.responses import JSONResponse,
StreamingResponse
from pydantic import
BaseModel
from typing import List,
Optional
import
asyncio
import
json
import
uuid
from rag_system import
RAGSystem
app
= FastAPI(
title
="RAG系统API",
description
="基于检索增强生成的智能问答系统",
version
="1.0.0"
)
# CORS配置
app
.add_middleware(
CORSMiddleware
,
allow_origins
=["*"],
allow_credentials
=True,
allow_methods
=["*"],
allow_headers
=["*"],
)
# 全局RAG系统实例
rag_system
= None
class QueryRequest(BaseModel):
question
: str
top_k
: int = 5
use_rerank
: bool = True
stream
: bool = False
class QueryResponse(BaseModel):
success
: bool
question
: str
answer
: str
response_time
: float
confidence
: float
retrieved_count
: int
citations
: List[str] = []
class DocumentUploadResponse(BaseModel):
success
: bool
document_id
: Optional[str] = None
message
: str
processing_time
: Optional[float] = None
class SystemStatsResponse(BaseModel):
documents_total
: int
chunks_total
: int
queries_total
: int
avg_response_time
: float
system_status
: str
@app.on_event("startup")
async def startup_event():
"""启动时初始化RAG系统"""
global
rag_system
try:
# 从环境变量或配置文件中加载配置
config
= {
"milvus_host": "localhost",
"milvus_port": "19530",
"minio_endpoint": "localhost:9000",
"minio_access_key": "minioadmin",
"minio_secret_key": "minioadmin",
"redis_host": "localhost",
"redis_port": 6379,
"postgres_host": "localhost",
"postgres_port": 5432,
"postgres_db": "rag_system",
"postgres_user": "rag_user",
"postgres_password": "rag_password",
"embedding_model": "text-embedding-3-small",
"collection_name": "rag_documents"
}
rag_system
= RAGSystem(config)
print("RAG系统初始化完成")
except Exception as e:
print(f"RAG系统初始化失败: {e}")
raise
@app.on_event("shutdown")
async def shutdown_event():
"""关闭时清理资源"""
if rag_system:
rag_system
.cleanup()
print("RAG系统已关闭")
@app.get("/")
async def root():
"""根路径,返回系统信息"""
return {
"service": "RAG System API",
"version": "1.0.0",
"status": "running",
"endpoints": {
"query": "/api/query",
"upload": "/api/upload",
"batch_upload": "/api/batch-upload",
"stats": "/api/stats",
"health": "/api/health"
}
}
@app.post("/api/query", response_model=QueryResponse)
async def query_rag(request: QueryRequest):
"""
查询RAG系统
Args:
question: 用户问题
top_k: 返回的文档数量
use_rerank: 是否使用重排序
stream: 是否流式返回
Returns:
包含答案和相关信息的响应
"""
if not rag_system:
raise HTTPException(status_code=503, detail="RAG系统未就绪")
if not request.question.strip():
raise HTTPException(status_code=400, detail="问题不能为空")
try:
if request.stream:
# 流式响应
return StreamingResponse(
stream_query
(request.question, request.top_k, request.use_rerank),
media_type
="text/event-stream"
)
else:
# 普通响应
result
= await rag_system.query(
question
=request.question,
top_k
=request.top_k,
use_rerank
=request.
use_rerank
)
if not result["success"]:
raise HTTPException(status_code=500, detail=result.get("error", "查询失败"))
return QueryResponse(
success
=True,
question
=result["question"],
answer
=result["answer"],
response_time
=result["response_time"],
confidence
=result.get("confidence", 0.0),
retrieved_count
=len(result.get("retrieved_documents", [])),
citations
=result.get("citations", [])
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
async def stream_query(question: str, top_k: int, use_rerank: bool):
"""流式查询生成器"""
try:
# 先检索文档
result
= await rag_system.query(
question
=question,
top_k
=top_k,
use_rerank
=
use_rerank
)
if not result["success"]:
yield f"data: {json.dumps({'error': result.get('error')})}\n\n"
return
# 流式生成答案
context
= result.get("context", "")
retrieved_docs
= result.get("retrieved_documents", [])
# 发送检索结果
yield f"data: {json.dumps({'type': 'retrieval_complete', 'count': len(retrieved_docs)})}\n\n"
# 模拟流式生成(实际应使用支持流式的LLM)
answer_parts
= result["answer"].split()
for i in range(0, len(answer_parts), 3):
chunk
= " ".join(answer_parts[i:i+3])
yield f"data: {json.dumps({'type': 'text_chunk', 'content': chunk})}\n\n"
await asyncio.sleep(0.05) # 模拟生成延迟
# 发送完成信号
yield f"data: {json.dumps({'type': 'complete', 'response_time': result['response_time']})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
@app.post("/api/upload", response_model=DocumentUploadResponse)
async def upload_document(
file: UploadFile = File(...),
metadata
: Optional[str] = Query(None)
):
"""
上传并处理单个文档
Args:
file: 上传的文件
metadata: 可选的元数据(JSON字符串)
Returns:
上传处理结果
"""
if not rag_system:
raise HTTPException(status_code=503, detail="RAG系统未就绪")
# 检查文件类型
allowed_types
= [
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/plain",
"text/markdown",
"text/html"
]
if file.content_type not in allowed_types:
raise HTTPException(
status_code
=400,
detail
=f"不支持的文件类型: {file.content_type}"
)
try:
# 保存上传文件到临时目录
temp_dir
= "/tmp/rag_uploads"
os
.makedirs(temp_dir, exist_ok=True)
file_path
= os.path.join(temp_dir, f"{uuid.uuid4()}_{file.filename}")
with open(file_path, "wb") as f:
content
= await file.read()
f
.write(content)
# 解析元数据
metadata_dict
= {}
if metadata:
try:
metadata_dict
= json.loads(metadata)
except json.JSONDecodeError:
metadata_dict
= {"raw_metadata": metadata}
# 处理文档
result
= await rag_system.ingest_document(file_path, metadata_dict)
# 清理临时文件
try:
os
.remove(file_path)
except:
pass
if result["success"]:
return DocumentUploadResponse(
success
=True,
document_id
=result["document_id"],
message
="文档处理成功",
processing_time
=result["processing_time"]
)
else:
return DocumentUploadResponse(
success
=False,
message
=f"文档处理失败: {result.get('error')}"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/batch-upload")
async def batch_upload(
background_tasks
: BackgroundTasks,
directory_path
: Optional[str] = None,
zip_file
: Optional[UploadFile] = File(None)
):
"""
批量上传文档
Args:
directory_path: 本地目录路径(开发环境)
zip_file: 压缩文件(生产环境)
Returns:
批量任务ID
"""
if not rag_system:
raise HTTPException(status_code=503, detail="RAG系统未就绪")
# 生成任务ID
task_id
= str(uuid.uuid4())
# 在后台处理批量任务
background_tasks
.add_task(
process_batch_upload
,
task_id
,
directory_path
,
zip_file
)
return {
"task_id": task_id,
"message": "批量处理任务已开始",
"status_url": f"/api/batch-status/{task_id}"
}
async def process_batch_upload(task_id: str, directory_path: str, zip_file: UploadFile):
"""处理批量上传的后台任务"""
# 这里实现批量处理逻辑
# 可以使用Redis或数据库存储任务状态
try:
# 如果是ZIP文件,先解压
if zip_file:
# 解压逻辑...
pass
# 批量处理
results
= await rag_system.batch_ingest(directory_path)
# 保存任务结果
# ...
except Exception as e:
# 保存错误信息
# ...
pass
@app.get("/api/batch-status/{task_id}")
async def get_batch_status(task_id: str):
"""获取批量任务状态"""
# 从Redis或数据库获取任务状态
# ...
return {
"task_id": task_id,
"status": "processing",
"progress": 65.5,
"processed": 130,
"total": 200,
"success": 125,
"failed": 5
}
@app.get("/api/stats", response_model=SystemStatsResponse)
async def get_system_stats():
"""获取系统统计信息"""
if not rag_system:
raise HTTPException(status_code=503, detail="RAG系统未就绪")
try:
stats
= rag_system.get_system_stats()
# 提取关键指标
doc_stats
= stats.get("document_statistics", [])
query_stats
= stats.get("query_statistics", [])
total_docs
= sum(item["total_documents"] for item in doc_stats) if doc_stats else 0
total_chunks
= sum(item["total_chunks"] for item in doc_stats) if doc_stats else 0
total_queries
= sum(item["total_queries"] for item in query_stats) if query_stats else 0
avg_response_time
= rag_system.stats["avg_response_time"]
return SystemStatsResponse(
documents_total
=total_docs,
chunks_total
=total_chunks,
queries_total
=total_queries,
avg_response_time
=avg_response_time,
system_status
="healthy" if rag_system.collection.is_loaded else "degraded"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/health")
async def health_check():
"""健康检查端点"""
checks
= {
"rag_system_initialized": rag_system is not None,
"milvus_connected": False,
"postgres_connected": False,
"minio_connected": False,
"redis_connected": False
}
if rag_system:
try:
# 检查Milvus连接
from pymilvus import
utility
checks
["milvus_connected"] = utility.get_server_version() is not None
# 检查PostgreSQL连接
with rag_system.pg_conn.cursor() as cursor:
cursor
.execute("SELECT 1")
checks
["postgres_connected"] = cursor.fetchone()[0] == 1
# 检查MinIO连接
try:
rag_system
.minio_client.list_buckets()
checks
["minio_connected"] = True
except:
checks
["minio_connected"] = False
# 检查Redis连接
try:
rag_system
.redis_client.ping()
checks
["redis_connected"] = True
except:
checks
["redis_connected"] = False
except Exception as e:
pass
all_healthy
= all(checks.values())
return {
"status": "healthy" if all_healthy else "unhealthy",
"checks": checks,
"timestamp": datetime.now().isoformat()
}
@app.get("/api/documents")
async def list_documents(
page
: int = 1,
page_size
: int = 20,
status
: Optional[str] = None
):
"""列出所有文档"""
if not rag_system:
raise HTTPException(status_code=503, detail="RAG系统未就绪")
try:
with rag_system.pg_conn.cursor(cursor_factory=RealDictCursor) as cursor:
# 构建查询
query
= "SELECT * FROM documents"
params
= []
if status:
query
+= " WHERE status = %s"
params
.append(status)
# 添加分页
offset
= (page - 1) *
page_size
query
+= " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params
.extend([page_size, offset])
cursor
.execute(query, params)
documents
= cursor.fetchall()
# 获取总数
count_query
= "SELECT COUNT(*) as total FROM documents"
if status:
count_query
+= " WHERE status = %s"
cursor
.execute(count_query, params[:1] if status else [])
total
= cursor.fetchone()["total"]
return {
"documents": documents,
"pagination": {
"page": page,
"page_size": page_size,
"total": total,
"total_pages": (total + page_size - 1) //
page_size
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/api/documents/{document_id}")
async def delete_document(document_id: str):
"""删除文档及其所有块"""
if not rag_system:
raise HTTPException(status_code=503, detail="RAG系统未就绪")
try:
# 1. 从Milvus删除相关块
rag_system
.collection.delete(f"document_id == '{document_id}'")
rag_system
.collection.flush()
# 2. 从数据库删除文档记录
with rag_system.pg_conn.cursor() as cursor:
cursor
.execute("DELETE FROM documents WHERE id = %s", (document_id,))
rag_system
.pg_conn.commit()
return {
"success": True,
"message": f"文档 {document_id} 已删除",
"deleted_at": datetime.now().isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# SSE端点用于实时更新
@app.get("/api/events")
async def event_stream():
"""服务器发送事件端点,用于实时更新"""
async def event_generator():
# 这里可以实现实时通知逻辑
# 例如:新文档处理完成、系统状态变化等
while True:
# 模拟发送心跳
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': datetime.now().isoformat()})}\n\n"
# 检查是否有新事件
# ...
await asyncio.sleep(30) # 30秒发送一次心跳
return StreamingResponse(
event_generator
(),
media_type
="text/event-stream"
)
第四部分:性能优化与生产部署
4.1 性能优化策略

python


# optimization.py - RAG系统性能优化
import
asyncio
from concurrent.futures import
ThreadPoolExecutor
from functools import
lru_cache
import numpy as
np
from typing import List,
Tuple
class RAGOptimizer:
"""RAG系统性能优化器"""
def __init__(self, rag_system, config: Dict):
self
.rag_system =
rag_system
self
.config =
config
# 线程池用于CPU密集型操作
self
.thread_pool = ThreadPoolExecutor(
max_workers
=config.get("max_workers", 4)
)
# 缓存配置
self
.enable_caching = config.get("enable_caching", True)
self
.cache_ttl = config.get("cache_ttl", 3600) # 1小时
# 性能监控
self
.metrics = {
"query_latency": [],
"embedding_latency": [],
"retrieval_latency": [],
"generation_latency": []
}
async def optimized_query(self, question: str, **kwargs) -> Dict:
"""优化后的查询流程"""
start_time
= time.time()
# 1. 查询缓存
if self.enable_caching:
cached_result
= self._get_from_cache(question, kwargs)
if cached_result:
print(f"缓存命中: {question[:50]}...")
cached_result
["from_cache"] = True
return
cached_result
# 2. 并行处理
query_vector_task
= asyncio.create_task(
self
._async_embed([question])
)
# 3. 检索优化
search_params
= self._optimize_search_params(question, kwargs)
# 等待向量生成完成
query_vector
= (await query_vector_task)[0]
# 4. 智能检索
retrieval_start
= time.time()
# 使用混合检索
if self.config.get("use_hybrid_retrieval", True):
results
= await self._hybrid_retrieve(
query_vector
, question, **
kwargs
)
else:
results
= await self._vector_retrieve(
query_vector
, **
kwargs
)
retrieval_time
= time.time() -
retrieval_start
self
.metrics["retrieval_latency"].append(retrieval_time)
# 5. 动态上下文构建
context
= self._dynamic_context_building(results, question)
# 6. 流式生成(如果启用)
if kwargs.get("stream", False):
answer
= await self._stream_generate(question, context)
else:
answer
= await self._generate_answer(question, context)
# 7. 更新缓存
if self.enable_caching:
result
= {
"question": question,
"answer": answer,
"context": context,
"retrieved_docs": results,
"response_time": time.time() -
start_time
}
self
._set_to_cache(question, kwargs, result)
# 8. 记录性能指标
total_time
= time.time() -
start_time
self
.metrics["query_latency"].append(total_time)
return {
"question": question,
"answer": answer,
"context": context,
"retrieved_docs": results,
"response_time": total_time,
"performance_metrics": {
"retrieval_time": retrieval_time,
"generation_time": total_time - retrieval_time,
"from_cache": False
}
}
async def _async_embed(self, texts: List[str]) -> List[np.ndarray]:
"""异步生成嵌入向量"""
# 将CPU密集型操作放到线程池
loop
= asyncio.get_event_loop()
start_time
= time.time()
embeddings
= await loop.run_in_executor(
self
.thread_pool,
lambda: self.rag_system.embedding_engine.embed(texts)
)
embed_time
= time.time() -
start_time
self
.metrics["embedding_latency"].append(embed_time)
return
embeddings
def _optimize_search_params(self, question: str, kwargs: Dict) -> Dict:
"""根据问题动态优化搜索参数"""
# 基础参数
params
= {
"metric_type": "COSINE",
"params": {"nprobe": 16}
}
# 根据问题长度调整
question_length
= len(question)
if question_length < 20:
# 短问题,可能更具体,减少搜索范围
params
["params"]["nprobe"] = 8
params
["top_k"] = kwargs.get("top_k", 5) * 2
elif question_length > 100:
# 长问题,可能更复杂,扩大搜索范围
params
["params"]["nprobe"] = 32
params
["top_k"] = kwargs.get("top_k", 5) * 3
else:
# 中等长度问题
params
["params"]["nprobe"] = 16
params
["top_k"] = kwargs.get("top_k", 5) * 2
# 根据问题类型调整
question_lower
= question.lower()
if any(word in question_lower for word in ["how", "why", "explain"]):
# 解释性问题,需要更多上下文
params
["top_k"] = min(params["top_k"] * 1.5, 20)
elif any(word in question_lower for word in ["what", "when", "where", "who"]):
# 事实性问题,可以更精确
params
["params"]["nprobe"] = max(8, params["params"]["nprobe"] // 2)
return
params
async def _hybrid_retrieve(self, query_vector: np.ndarray,
question
: str, **kwargs) -> List[Dict]:
"""混合检索优化"""
top_k
= kwargs.get("top_k", 5)
# 并行执行向量检索和关键词检索
vector_task
= asyncio.create_task(
self
._vector_retrieve(query_vector, **kwargs)
)
keyword_task
= asyncio.create_task(
self
._keyword_retrieve(question, **kwargs)
)
vector_results
, keyword_results = await asyncio.gather(
vector_task
,
keyword_task
)
# 结果融合
fused_results
= self._fuse_results(
vector_results
, keyword_results,
alpha
=kwargs.get("alpha", 0.7)
)
# 重排序
if kwargs.get("use_rerank", True) and len(fused_results) > 1:
fused_results
= await self._rerank_results(
question
,
fused_results
)
return fused_results[:top_k]
def _fuse_results(self, vector_results: List[Dict],
keyword_results
: List[Dict],
alpha
: float = 0.7) -> List[Dict]:
"""融合向量检索和关键词检索结果"""
all_results
= {}
# 处理向量结果
for i, result in enumerate(vector_results):
doc_id
= result.get("id")
if doc_id not in all_results:
all_results
[doc_id] = {
**result,
"vector_score": result.get("score", 0),
"keyword_score": 0,
"combined_score": 0
}
# 处理关键词结果
for i, result in enumerate(keyword_results):
doc_id
= result.get("id")
if doc_id in all_results:
all_results
[doc_id]["keyword_score"] = result.get("score", 0)
else:
all_results
[doc_id] = {
**result,
"vector_score": 0,
"keyword_score": result.get("score", 0),
"combined_score": 0
}
# 计算综合分数
for doc_id, result in all_results.items():
result
["combined_score"] = (
alpha
* result["vector_score"] +
(1 - alpha) * result["keyword_score"]
)
# 按综合分数排序
sorted_results
= sorted(
all_results
.values(),
key
=lambda x: x["combined_score"],
reverse
=True
)
return
sorted_results
@lru_cache(maxsize=1000)
def _get_from_cache(self, question: str, params_hash: str) -> Optional[Dict]:
"""从缓存获取结果"""
if not self.enable_caching:
return None
cache_key
= f"query:{hashlib.md5(question.encode()).hexdigest()}:{params_hash}"
try:
cached_data
= self.rag_system.redis_client.get(cache_key)
if cached_data:
return json.loads(cached_data)
except:
pass
return None
def _set_to_cache(self, question: str, params: Dict, result: Dict):
"""设置缓存"""
if not self.enable_caching:
return
# 生成参数哈希
params_str
= json.dumps(params, sort_keys=True)
params_hash
= hashlib.md5(params_str.encode()).hexdigest()
cache_key
= f"query:{hashlib.md5(question.encode()).hexdigest()}:{params_hash}"
try:
# 设置缓存,带TTL
self
.rag_system.redis_client.setex(
cache_key
,
self
.cache_ttl,
json
.dumps(result)
)
except:
pass
def get_performance_report(self) -> Dict:
"""获取性能报告"""
def percentile(data, p):
if not data:
return 0
return np.percentile(data, p)
return {
"query_latency": {
"p50": percentile(self.metrics["query_latency"], 50),
"p95": percentile(self.metrics["query_latency"], 95),
"p99": percentile(self.metrics["query_latency"], 99),
"avg": np.mean(self.metrics["query_latency"]) if self.metrics["query_latency"] else 0,
"count": len(self.metrics["query_latency"])
},
"embedding_latency": {
"avg": np.mean(self.metrics["embedding_latency"]) if self.metrics["embedding_latency"] else 0,
"count": len(self.metrics["embedding_latency"])
},
"retrieval_latency": {
"avg": np.mean(self.metrics["retrieval_latency"]) if self.metrics["retrieval_latency"] else 0,
"count": len(self.metrics["retrieval_latency"])
},
"generation_latency": {
"avg": np.mean(self.metrics["generation_latency"]) if self.metrics["generation_latency"] else 0,
"count": len(self.metrics["generation_latency"])
},
"cache_metrics": {
"hit_rate": self._calculate_cache_hit_rate(),
"size": self._get_cache_size()
}
}
def _calculate_cache_hit_rate(self) -> float:
"""计算缓存命中率"""
if not self.enable_caching:
return 0.0
try:
# 这里需要实现缓存统计逻辑
# 可以使用Redis的INFO命令或自定义计数器
hits
= self.rag_system.redis_client.get("cache:hits") or 0
misses
= self.rag_system.redis_client.get("cache:misses") or 0
hits
= int(hits)
misses
= int(misses)
total
= hits +
misses
return hits / total if total > 0 else 0.0
except:
return 0.0
4.2 生产部署配置

yaml


# kubernetes/deployment.yaml - Kubernetes部署配置
apiVersion:
apps/v1
kind:
Deployment
metadata:
name: rag-
system
namespace: ai-
production
labels:
app: rag-
system
component:
api
spec:
replicas: 3
selector:
matchLabels:
app: rag-
system
component:
api
strategy:
type:
RollingUpdate
rollingUpdate:
maxSurge: 1
maxUnavailable: 0
template:
metadata:
labels:
app: rag-
system
component:
api
spec:
containers:
- name: rag-
api
image: rag-system:
1.0.0
imagePullPolicy:
IfNotPresent
ports:
- containerPort: 8000
name:
http
env:
- name:
MILVUS_HOST
valueFrom:
configMapKeyRef:
name: rag-
config
key:
milvus_host
- name:
MILVUS_PORT
value: "19530"
- name:
MINIO_ENDPOINT
valueFrom:
configMapKeyRef:
name: rag-
config
key:
minio_endpoint
- name:
REDIS_HOST
valueFrom:
configMapKeyRef:
name: rag-
config
key:
redis_host
- name:
POSTGRES_HOST
valueFrom:
configMapKeyRef:
name: rag-
config
key:
postgres_host
- name:
OPENAI_API_KEY
valueFrom:
secretKeyRef:
name: rag-
secrets
key:
openai_api_key
resources:
requests:
memory: "2Gi"
cpu: "1000m"
limits:
memory: "4Gi"
cpu: "2000m"
livenessProbe:
httpGet:
path:
/api/health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
timeoutSeconds: 5
failureThreshold: 3
readinessProbe:
httpGet:
path:
/api/health
port: 8000
initialDelaySeconds: 5
periodSeconds: 5
timeoutSeconds: 3
volumeMounts:
- name: config-
volume
mountPath:
/app/config
- name: data-
volume
mountPath:
/data
volumes:
- name: config-
volume
configMap:
name: rag-
config
- name: data-
volume
persistentVolumeClaim:
claimName: rag-data-
pvc
nodeSelector:
node-type: cpu-
optimized
tolerations:
- key: "dedicated"
operator: "Equal"
value: "ai-workload"
effect: "NoSchedule"
---
# 服务配置
apiVersion:
v1
kind:
Service
metadata:
name: rag-
service
namespace: ai-
production
spec:
selector:
app: rag-
system
component:
api
ports:
- port: 80
targetPort: 8000
name:
http
type:
ClusterIP
---
# 水平自动扩缩容
apiVersion:
autoscaling/v2
kind:
HorizontalPodAutoscaler
metadata:
name: rag-
hpa
namespace: ai-
production
spec:
scaleTargetRef:
apiVersion:
apps/v1
kind:
Deployment
name: rag-
system
minReplicas: 2
maxReplicas: 10
metrics:
- type:
Resource
resource:
name:
cpu
target:
type:
Utilization
averageUtilization: 70
- type:
Resource
resource:
name:
memory
target:
type:
Utilization
averageUtilization: 80
behavior:
scaleDown:
stabilizationWindowSeconds: 300
policies:
- type:
Percent
value: 10
periodSeconds: 60
scaleUp:
stabilizationWindowSeconds: 60
policies:
- type:
Percent
value: 100
periodSeconds: 60
第五部分:未来展望与演进方向
5.1 RAG 技术的演进趋势
基于当前技术发展,RAG 系统正在向以下方向演进:
1. 多模态 RAG:支持图像、音频、视频的检索增强
2. 自主 RAG:系统自动发现知识缺口并主动学习
3. 联邦 RAG:在保护隐私的前提下,跨组织共享知识
4. 实时 RAG:毫秒级的知识更新和检索
5.2 构建 RAG 系统的核心能力矩阵

mermaid


graph
TD
A
[RAG系统建设者] --> B[技术能力]
A
--> C[业务能力]
A
--> D[工程能力]
B
--> B1[向量数据库精通]
B
--> B2[嵌入模型调优]
B
--> B3[检索算法设计]
B
--> B4[LLM集成]
C
--> C1[领域知识理解]
C
--> C2[用户需求分析]
C
--> C3[评估指标设计]
C
--> C4[业务价值量化]
D
--> D1[系统架构设计]
D
--> D2[性能优化]
D
--> D3[监控运维]
D
--> D4[安全合规]
5.3 立即开始你的 RAG 之旅
第一周:基础搭建
1. 使用 Docker Compose 部署 Milvus + MinIO
2. 实现简单的文档上传和向量化
3. 完成基础检索功能
第一个月:系统完善
1. 添加混合检索和重排序
2. 实现完整的 Web API
3. 添加基础监控和日志
第三个月:生产就绪
1. 性能优化和缓存策略
2. 实现多租户支持
3. 建立完整的 CI/CD 流水线
长期目标:智能演进
1. 实现自主知识发现
2. 构建多模态能力
3. 探索联邦学习应用
结语:从信息检索到知识增强的范式转移
RAG 不仅仅是一种技术架构,它代表了一种全新的 AI 应用范式。在这个范式中:
1. 大模型不再是 “全知全能的神”,而是专业的 “推理引擎”
2. 向量数据库不再是 “冷冰冰的存储”,而是智能的 “记忆系统”
3. 开发者不再是 “调参工程师”,而是 **“知识架构师”**
最成功的 AI 应用,不是拥有最大参数量的模型,而是最懂得如何组织、检索、应用知识的系统。
你现在有两个选择:继续让 AI “凭空想象” 答案,或者开始为它构建一个 “真实可靠的外挂大脑”。选择很明确,但构建的过程需要耐心、智慧和持续的迭代。
知识增强的时代已经到来,而你已经掌握了构建它的钥匙。
更多推荐

所有评论(0)