|
|
import torch |
|
|
import faiss |
|
|
import numpy as np |
|
|
from llama_index.core.node_parser import SentenceSplitter |
|
|
import re |
|
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
import json |
|
|
import shutil |
|
|
from typing import Optional |
|
|
from openai import OpenAI |
|
|
import gradio as gr |
|
|
import os |
|
|
import fitz |
|
|
import chardet |
|
|
import traceback |
|
|
from config import Config |
|
|
|
|
|
|
|
|
KB_BASE_DIR = Config.kb_base_dir |
|
|
os.makedirs(KB_BASE_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
DEFAULT_KB = Config.default_kb |
|
|
DEFAULT_KB_DIR = os.path.join(KB_BASE_DIR, DEFAULT_KB) |
|
|
os.makedirs(DEFAULT_KB_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
OUTPUT_DIR = Config.output_dir |
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
client = OpenAI( |
|
|
api_key=Config.llm_api_key, |
|
|
base_url=Config.llm_base_url |
|
|
) |
|
|
|
|
|
class DeepSeekClient: |
|
|
def generate_answer(self, system_prompt, user_prompt, model=Config.llm_model): |
|
|
response = client.chat.completions.create( |
|
|
model=Config.llm_model, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
], |
|
|
stream=False |
|
|
) |
|
|
return response.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
def get_knowledge_bases() -> List[str]: |
|
|
"""获取所有知识库名称""" |
|
|
try: |
|
|
if not os.path.exists(KB_BASE_DIR): |
|
|
os.makedirs(KB_BASE_DIR, exist_ok=True) |
|
|
|
|
|
kb_dirs = [d for d in os.listdir(KB_BASE_DIR) |
|
|
if os.path.isdir(os.path.join(KB_BASE_DIR, d))] |
|
|
|
|
|
|
|
|
if DEFAULT_KB not in kb_dirs: |
|
|
os.makedirs(os.path.join(KB_BASE_DIR, DEFAULT_KB), exist_ok=True) |
|
|
kb_dirs.append(DEFAULT_KB) |
|
|
|
|
|
return sorted(kb_dirs) |
|
|
except Exception as e: |
|
|
print(f"获取知识库列表失败: {str(e)}") |
|
|
return [DEFAULT_KB] |
|
|
|
|
|
|
|
|
def create_knowledge_base(kb_name: str) -> str: |
|
|
"""创建新的知识库""" |
|
|
try: |
|
|
if not kb_name or not kb_name.strip(): |
|
|
return "错误:知识库名称不能为空" |
|
|
|
|
|
|
|
|
kb_name = re.sub(r'[^\w\u4e00-\u9fff]', '_', kb_name.strip()) |
|
|
|
|
|
kb_path = os.path.join(KB_BASE_DIR, kb_name) |
|
|
if os.path.exists(kb_path): |
|
|
return f"知识库 '{kb_name}' 已存在" |
|
|
|
|
|
os.makedirs(kb_path, exist_ok=True) |
|
|
return f"知识库 '{kb_name}' 创建成功" |
|
|
except Exception as e: |
|
|
return f"创建知识库失败: {str(e)}" |
|
|
|
|
|
|
|
|
def delete_knowledge_base(kb_name: str) -> str: |
|
|
"""删除指定的知识库""" |
|
|
try: |
|
|
if kb_name == DEFAULT_KB: |
|
|
return f"无法删除默认知识库 '{DEFAULT_KB}'" |
|
|
|
|
|
kb_path = os.path.join(KB_BASE_DIR, kb_name) |
|
|
if not os.path.exists(kb_path): |
|
|
return f"知识库 '{kb_name}' 不存在" |
|
|
|
|
|
shutil.rmtree(kb_path) |
|
|
return f"知识库 '{kb_name}' 已删除" |
|
|
except Exception as e: |
|
|
return f"删除知识库失败: {str(e)}" |
|
|
|
|
|
|
|
|
def get_kb_files(kb_name: str) -> List[str]: |
|
|
"""获取指定知识库中的文件列表""" |
|
|
try: |
|
|
kb_path = os.path.join(KB_BASE_DIR, kb_name) |
|
|
if not os.path.exists(kb_path): |
|
|
return [] |
|
|
|
|
|
|
|
|
files = [f for f in os.listdir(kb_path) |
|
|
if os.path.isfile(os.path.join(kb_path, f)) and |
|
|
not f.endswith(('.index', '.json'))] |
|
|
|
|
|
return sorted(files) |
|
|
except Exception as e: |
|
|
print(f"获取知识库文件列表失败: {str(e)}") |
|
|
return [] |
|
|
|
|
|
|
|
|
def semantic_chunk(text: str, chunk_size=800, chunk_overlap=20) -> List[dict]: |
|
|
class EnhancedSentenceSplitter(SentenceSplitter): |
|
|
def __init__(self, *args, **kwargs): |
|
|
custom_seps = [";", "!", "?", "\n"] |
|
|
separators = [kwargs.get("separator", "。")] + custom_seps |
|
|
kwargs["separator"] = '|'.join(map(re.escape, separators)) |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
def _split_text(self, text: str, **kwargs) -> List[str]: |
|
|
splits = re.split(f'({self.separator})', text) |
|
|
chunks = [] |
|
|
current_chunk = [] |
|
|
for part in splits: |
|
|
part = part.strip() |
|
|
if not part: |
|
|
continue |
|
|
if re.fullmatch(self.separator, part): |
|
|
if current_chunk: |
|
|
chunks.append("".join(current_chunk)) |
|
|
current_chunk = [] |
|
|
else: |
|
|
current_chunk.append(part) |
|
|
if current_chunk: |
|
|
chunks.append("".join(current_chunk)) |
|
|
return [chunk.strip() for chunk in chunks if chunk.strip()] |
|
|
|
|
|
text_splitter = EnhancedSentenceSplitter( |
|
|
separator="。", |
|
|
chunk_size=chunk_size, |
|
|
chunk_overlap=chunk_overlap, |
|
|
paragraph_separator="\n\n" |
|
|
) |
|
|
|
|
|
paragraphs = [] |
|
|
current_para = [] |
|
|
current_len = 0 |
|
|
|
|
|
for para in text.split("\n\n"): |
|
|
para = para.strip() |
|
|
para_len = len(para) |
|
|
if para_len == 0: |
|
|
continue |
|
|
if current_len + para_len <= chunk_size: |
|
|
current_para.append(para) |
|
|
current_len += para_len |
|
|
else: |
|
|
if current_para: |
|
|
paragraphs.append("\n".join(current_para)) |
|
|
current_para = [para] |
|
|
current_len = para_len |
|
|
|
|
|
if current_para: |
|
|
paragraphs.append("\n".join(current_para)) |
|
|
|
|
|
chunk_data_list = [] |
|
|
chunk_id = 0 |
|
|
for para in paragraphs: |
|
|
chunks = text_splitter.split_text(para) |
|
|
for chunk in chunks: |
|
|
if len(chunk) < 20: |
|
|
continue |
|
|
chunk_data_list.append({ |
|
|
"id": f'chunk{chunk_id}', |
|
|
"chunk": chunk, |
|
|
"method": "semantic_chunk" |
|
|
}) |
|
|
chunk_id += 1 |
|
|
return chunk_data_list |
|
|
|
|
|
|
|
|
def build_faiss_index(vector_file, index_path, metadata_path): |
|
|
try: |
|
|
with open(vector_file, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
if not data: |
|
|
raise ValueError("向量数据为空,请检查输入文件。") |
|
|
|
|
|
|
|
|
valid_data = [] |
|
|
for item in data: |
|
|
if 'vector' in item and item['vector']: |
|
|
valid_data.append(item) |
|
|
else: |
|
|
print(f"警告: 跳过没有向量的数据项 ID: {item.get('id', '未知')}") |
|
|
|
|
|
if not valid_data: |
|
|
raise ValueError("没有找到任何有效的向量数据。") |
|
|
|
|
|
|
|
|
vectors = [item['vector'] for item in valid_data] |
|
|
vectors = np.array(vectors, dtype=np.float32) |
|
|
|
|
|
if vectors.size == 0: |
|
|
raise ValueError("向量数组为空,转换失败。") |
|
|
|
|
|
|
|
|
dim = vectors.shape[1] |
|
|
n_vectors = vectors.shape[0] |
|
|
print(f"构建索引: {n_vectors} 个向量,每个向量维度: {dim}") |
|
|
|
|
|
|
|
|
max_nlist = n_vectors // 39 |
|
|
nlist = min(max_nlist, 128) if max_nlist >= 1 else 1 |
|
|
|
|
|
if nlist >= 1 and n_vectors >= nlist * 39: |
|
|
print(f"使用 IndexIVFFlat 索引,nlist={nlist}") |
|
|
quantizer = faiss.IndexFlatIP(dim) |
|
|
index = faiss.IndexIVFFlat(quantizer, dim, nlist) |
|
|
if not index.is_trained: |
|
|
index.train(vectors) |
|
|
index.add(vectors) |
|
|
else: |
|
|
print(f"使用 IndexFlatIP 索引") |
|
|
index = faiss.IndexFlatIP(dim) |
|
|
index.add(vectors) |
|
|
|
|
|
faiss.write_index(index, index_path) |
|
|
print(f"成功写入索引到 {index_path}") |
|
|
|
|
|
|
|
|
metadata = [{'id': item['id'], 'chunk': item['chunk'], 'method': item['method']} for item in valid_data] |
|
|
with open(metadata_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(metadata, f, ensure_ascii=False, indent=4) |
|
|
print(f"成功写入元数据到 {metadata_path}") |
|
|
|
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"构建索引失败: {str(e)}") |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
|
|
|
def vectorize_file(data_list, output_file_path, field_name="chunk"): |
|
|
"""向量化文件内容,处理长度限制并确保输入有效""" |
|
|
if not data_list: |
|
|
print("警告: 没有数据需要向量化") |
|
|
with open(output_file_path, 'w', encoding='utf-8') as outfile: |
|
|
json.dump([], outfile, ensure_ascii=False, indent=4) |
|
|
return |
|
|
|
|
|
|
|
|
valid_data = [] |
|
|
valid_texts = [] |
|
|
|
|
|
for data in data_list: |
|
|
text = data.get(field_name, "") |
|
|
|
|
|
if text and 1 <= len(text) <= 8000: |
|
|
valid_data.append(data) |
|
|
valid_texts.append(text) |
|
|
else: |
|
|
|
|
|
if len(text) > 8000: |
|
|
truncated_text = text[:8000] |
|
|
print(f"警告: 文本过长,已截断至8000字符。原始长度: {len(text)}") |
|
|
data[field_name] = truncated_text |
|
|
valid_data.append(data) |
|
|
valid_texts.append(truncated_text) |
|
|
else: |
|
|
print(f"警告: 跳过空文本或长度为0的文本") |
|
|
|
|
|
if not valid_texts: |
|
|
print("错误: 所有文本都无效,无法进行向量化") |
|
|
with open(output_file_path, 'w', encoding='utf-8') as outfile: |
|
|
json.dump([], outfile, ensure_ascii=False, indent=4) |
|
|
return |
|
|
|
|
|
|
|
|
vectors = vectorize_query(valid_texts) |
|
|
|
|
|
|
|
|
if vectors.size == 0 or len(vectors) != len(valid_data): |
|
|
print(f"错误: 向量化失败或向量数量({len(vectors) if vectors.size > 0 else 0})与数据条目({len(valid_data)})不匹配") |
|
|
|
|
|
with open(output_file_path, 'w', encoding='utf-8') as outfile: |
|
|
json.dump(valid_data, outfile, ensure_ascii=False, indent=4) |
|
|
return |
|
|
|
|
|
|
|
|
for data, vector in zip(valid_data, vectors): |
|
|
data['vector'] = vector.tolist() |
|
|
|
|
|
|
|
|
with open(output_file_path, 'w', encoding='utf-8') as outfile: |
|
|
json.dump(valid_data, outfile, ensure_ascii=False, indent=4) |
|
|
|
|
|
print(f"成功向量化 {len(valid_data)} 条数据并保存到 {output_file_path}") |
|
|
|
|
|
|
|
|
|
|
|
def vectorize_query(query, model_name=Config.model_name, batch_size=Config.batch_size) -> np.ndarray: |
|
|
"""向量化文本查询,返回嵌入向量,改进错误处理和批处理""" |
|
|
embedding_client = OpenAI( |
|
|
api_key=Config.api_key, |
|
|
base_url=Config.base_url |
|
|
) |
|
|
|
|
|
if not query: |
|
|
print("警告: 传入向量化的查询为空") |
|
|
return np.array([]) |
|
|
|
|
|
if isinstance(query, str): |
|
|
query = [query] |
|
|
|
|
|
|
|
|
valid_queries = [] |
|
|
for q in query: |
|
|
if not q or not isinstance(q, str): |
|
|
print(f"警告: 跳过无效查询: {type(q)}") |
|
|
continue |
|
|
|
|
|
|
|
|
clean_q = clean_text(q) |
|
|
if not clean_q: |
|
|
print("警告: 清理后的查询文本为空") |
|
|
continue |
|
|
|
|
|
|
|
|
if len(clean_q) > 8000: |
|
|
print(f"警告: 查询文本过长 ({len(clean_q)} 字符),截断至 8000 字符") |
|
|
clean_q = clean_q[:8000] |
|
|
|
|
|
valid_queries.append(clean_q) |
|
|
|
|
|
if not valid_queries: |
|
|
print("错误: 所有查询都无效,无法进行向量化") |
|
|
return np.array([]) |
|
|
|
|
|
|
|
|
all_vectors = [] |
|
|
for i in range(0, len(valid_queries), batch_size): |
|
|
batch = valid_queries[i:i + batch_size] |
|
|
try: |
|
|
|
|
|
print(f"正在向量化批次 {i//batch_size + 1}/{(len(valid_queries)-1)//batch_size + 1}, " |
|
|
f"包含 {len(batch)} 个文本,第一个文本长度: {len(batch[0][:50])}...") |
|
|
|
|
|
completion = embedding_client.embeddings.create( |
|
|
model=model_name, |
|
|
input=batch, |
|
|
dimensions=Config.dimensions, |
|
|
encoding_format="float" |
|
|
) |
|
|
vectors = [embedding.embedding for embedding in completion.data] |
|
|
all_vectors.extend(vectors) |
|
|
print(f"批次 {i//batch_size + 1} 向量化成功,获得 {len(vectors)} 个向量") |
|
|
except Exception as e: |
|
|
print(f"向量化批次 {i//batch_size + 1} 失败:{str(e)}") |
|
|
print(f"问题批次中的第一个文本: {batch[0][:100]}...") |
|
|
traceback.print_exc() |
|
|
|
|
|
if i == 0: |
|
|
return np.array([]) |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
if not all_vectors: |
|
|
print("错误: 向量化过程没有产生任何向量") |
|
|
return np.array([]) |
|
|
|
|
|
return np.array(all_vectors) |
|
|
|
|
|
|
|
|
def vector_search(query, index_path, metadata_path, limit): |
|
|
"""基本向量搜索函数""" |
|
|
query_vector = vectorize_query(query) |
|
|
if query_vector.size == 0: |
|
|
return [] |
|
|
|
|
|
query_vector = np.array(query_vector, dtype=np.float32).reshape(1, -1) |
|
|
|
|
|
index = faiss.read_index(index_path) |
|
|
try: |
|
|
with open(metadata_path, 'r', encoding='utf-8') as f: |
|
|
metadata = json.load(f) |
|
|
except UnicodeDecodeError: |
|
|
print(f"警告:{metadata_path} 包含非法字符,使用 UTF-8 忽略错误重新加载") |
|
|
with open(metadata_path, 'rb') as f: |
|
|
content = f.read().decode('utf-8', errors='ignore') |
|
|
metadata = json.loads(content) |
|
|
|
|
|
D, I = index.search(query_vector, limit) |
|
|
results = [metadata[i] for i in I[0] if i < len(metadata)] |
|
|
return results |
|
|
|
|
|
def clean_text(text): |
|
|
"""清理文本中的非法字符,控制文本长度""" |
|
|
if not text: |
|
|
return "" |
|
|
|
|
|
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text) |
|
|
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
|
|
|
return text.strip() |
|
|
|
|
|
|
|
|
def extract_text_from_pdf(pdf_path): |
|
|
try: |
|
|
doc = fitz.open(pdf_path) |
|
|
text = "" |
|
|
for page in doc: |
|
|
page_text = page.get_text() |
|
|
|
|
|
text += page_text.encode('utf-8', errors='ignore').decode('utf-8') |
|
|
if not text.strip(): |
|
|
print(f"警告:PDF文件 {pdf_path} 提取内容为空") |
|
|
return text |
|
|
except Exception as e: |
|
|
print(f"PDF文本提取失败:{str(e)}") |
|
|
return "" |
|
|
|
|
|
|
|
|
def process_single_file(file_path: str) -> str: |
|
|
try: |
|
|
if file_path.lower().endswith('.pdf'): |
|
|
text = extract_text_from_pdf(file_path) |
|
|
if not text: |
|
|
return f"PDF文件 {file_path} 内容为空或无法提取" |
|
|
else: |
|
|
with open(file_path, "rb") as f: |
|
|
content = f.read() |
|
|
result = chardet.detect(content) |
|
|
detected_encoding = result['encoding'] |
|
|
confidence = result['confidence'] |
|
|
|
|
|
|
|
|
if detected_encoding and confidence > 0.7: |
|
|
try: |
|
|
text = content.decode(detected_encoding) |
|
|
print(f"文件 {file_path} 使用检测到的编码 {detected_encoding} 解码成功") |
|
|
except UnicodeDecodeError: |
|
|
text = content.decode('utf-8', errors='ignore') |
|
|
print(f"文件 {file_path} 使用 {detected_encoding} 解码失败,强制使用 UTF-8 忽略非法字符") |
|
|
else: |
|
|
|
|
|
encodings = ['utf-8', 'gbk', 'gb18030', 'gb2312', 'latin-1', 'utf-16', 'cp936', 'big5'] |
|
|
text = None |
|
|
for encoding in encodings: |
|
|
try: |
|
|
text = content.decode(encoding) |
|
|
print(f"文件 {file_path} 使用 {encoding} 解码成功") |
|
|
break |
|
|
except UnicodeDecodeError: |
|
|
continue |
|
|
|
|
|
|
|
|
if text is None: |
|
|
text = content.decode('utf-8', errors='ignore') |
|
|
print(f"警告:文件 {file_path} 使用 UTF-8 忽略非法字符") |
|
|
|
|
|
|
|
|
text = clean_text(text) |
|
|
return text |
|
|
except Exception as e: |
|
|
print(f"处理文件 {file_path} 时出错: {str(e)}") |
|
|
traceback.print_exc() |
|
|
return f"处理文件 {file_path} 失败:{str(e)}" |
|
|
|
|
|
|
|
|
def process_and_index_files(file_objs: List, kb_name: str = DEFAULT_KB) -> str: |
|
|
"""处理并索引文件到指定的知识库""" |
|
|
|
|
|
kb_dir = os.path.join(KB_BASE_DIR, kb_name) |
|
|
os.makedirs(kb_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
semantic_chunk_output = os.path.join(OUTPUT_DIR, "semantic_chunk_output.json") |
|
|
semantic_chunk_vector = os.path.join(OUTPUT_DIR, "semantic_chunk_vector.json") |
|
|
|
|
|
|
|
|
semantic_chunk_index = os.path.join(kb_dir, "semantic_chunk.index") |
|
|
semantic_chunk_metadata = os.path.join(kb_dir, "semantic_chunk_metadata.json") |
|
|
|
|
|
all_chunks = [] |
|
|
error_messages = [] |
|
|
try: |
|
|
if not file_objs or len(file_objs) == 0: |
|
|
return "错误:没有选择任何文件" |
|
|
|
|
|
print(f"开始处理 {len(file_objs)} 个文件,目标知识库: {kb_name}...") |
|
|
with ThreadPoolExecutor(max_workers=4) as executor: |
|
|
future_to_file = {executor.submit(process_single_file, file_obj.name): file_obj for file_obj in file_objs} |
|
|
for future in as_completed(future_to_file): |
|
|
result = future.result() |
|
|
file_obj = future_to_file[future] |
|
|
file_name = file_obj.name |
|
|
|
|
|
if isinstance(result, str) and result.startswith("处理文件"): |
|
|
error_messages.append(result) |
|
|
print(result) |
|
|
continue |
|
|
|
|
|
|
|
|
if not result or not isinstance(result, str) or len(result.strip()) == 0: |
|
|
error_messages.append(f"文件 {file_name} 处理后内容为空") |
|
|
print(f"警告: 文件 {file_name} 处理后内容为空") |
|
|
continue |
|
|
|
|
|
print(f"对文件 {file_name} 进行语义分块...") |
|
|
chunks = semantic_chunk(result) |
|
|
|
|
|
if not chunks or len(chunks) == 0: |
|
|
error_messages.append(f"文件 {file_name} 无法生成任何分块") |
|
|
print(f"警告: 文件 {file_name} 无法生成任何分块") |
|
|
continue |
|
|
|
|
|
|
|
|
file_basename = os.path.basename(file_name) |
|
|
dest_file_path = os.path.join(kb_dir, file_basename) |
|
|
try: |
|
|
shutil.copy2(file_name, dest_file_path) |
|
|
print(f"已将文件 {file_basename} 复制到知识库 {kb_name}") |
|
|
except Exception as e: |
|
|
print(f"复制文件到知识库失败: {str(e)}") |
|
|
|
|
|
all_chunks.extend(chunks) |
|
|
print(f"文件 {file_name} 处理完成,生成 {len(chunks)} 个分块") |
|
|
|
|
|
if not all_chunks: |
|
|
return "所有文件处理失败或内容为空\n" + "\n".join(error_messages) |
|
|
|
|
|
|
|
|
valid_chunks = [] |
|
|
for chunk in all_chunks: |
|
|
|
|
|
clean_chunk_text = clean_text(chunk["chunk"]) |
|
|
|
|
|
|
|
|
if clean_chunk_text and 1 <= len(clean_chunk_text) <= 8000: |
|
|
chunk["chunk"] = clean_chunk_text |
|
|
valid_chunks.append(chunk) |
|
|
elif len(clean_chunk_text) > 8000: |
|
|
|
|
|
chunk["chunk"] = clean_chunk_text[:8000] |
|
|
valid_chunks.append(chunk) |
|
|
print(f"警告: 分块 {chunk['id']} 过长已被截断") |
|
|
else: |
|
|
print(f"警告: 跳过无效分块 {chunk['id']}") |
|
|
|
|
|
if not valid_chunks: |
|
|
return "所有生成的分块内容无效或为空\n" + "\n".join(error_messages) |
|
|
|
|
|
print(f"处理了 {len(all_chunks)} 个分块,有效分块数: {len(valid_chunks)}") |
|
|
|
|
|
|
|
|
with open(semantic_chunk_output, 'w', encoding='utf-8') as json_file: |
|
|
json.dump(valid_chunks, json_file, ensure_ascii=False, indent=4) |
|
|
print(f"语义分块完成: {semantic_chunk_output}") |
|
|
|
|
|
|
|
|
print(f"开始向量化 {len(valid_chunks)} 个分块...") |
|
|
vectorize_file(valid_chunks, semantic_chunk_vector) |
|
|
print(f"语义分块向量化完成: {semantic_chunk_vector}") |
|
|
|
|
|
|
|
|
try: |
|
|
with open(semantic_chunk_vector, 'r', encoding='utf-8') as f: |
|
|
vector_data = json.load(f) |
|
|
|
|
|
if not vector_data or len(vector_data) == 0: |
|
|
return f"向量化失败: 生成的向量文件为空\n" + "\n".join(error_messages) |
|
|
|
|
|
|
|
|
if 'vector' not in vector_data[0]: |
|
|
return f"向量化失败: 数据中缺少向量字段\n" + "\n".join(error_messages) |
|
|
|
|
|
print(f"成功生成 {len(vector_data)} 个向量") |
|
|
except Exception as e: |
|
|
return f"读取向量文件失败: {str(e)}\n" + "\n".join(error_messages) |
|
|
|
|
|
|
|
|
print(f"开始为知识库 {kb_name} 构建索引...") |
|
|
build_faiss_index(semantic_chunk_vector, semantic_chunk_index, semantic_chunk_metadata) |
|
|
print(f"知识库 {kb_name} 索引构建完成: {semantic_chunk_index}") |
|
|
|
|
|
status = f"知识库 {kb_name} 更新成功!共处理 {len(valid_chunks)} 个有效分块。\n" |
|
|
if error_messages: |
|
|
status += "以下文件处理过程中出现问题:\n" + "\n".join(error_messages) |
|
|
return status |
|
|
except Exception as e: |
|
|
error = f"知识库 {kb_name} 索引构建过程中出错:{str(e)}" |
|
|
print(error) |
|
|
traceback.print_exc() |
|
|
return error + "\n" + "\n".join(error_messages) |
|
|
|
|
|
|
|
|
def get_search_background(query: str, max_length: int = 1500) -> str: |
|
|
try: |
|
|
from retrievor import q_searching |
|
|
search_results = q_searching(query) |
|
|
cleaned_results = re.sub(r'\s+', ' ', search_results).strip() |
|
|
return cleaned_results[:max_length] |
|
|
except Exception as e: |
|
|
print(f"联网搜索失败:{str(e)}") |
|
|
return "" |
|
|
|
|
|
|
|
|
def generate_answer_from_deepseek(question: str, system_prompt: str = "你是一名专业医疗助手,请根据背景知识回答问题。", background_info: Optional[str] = None) -> str: |
|
|
deepseek_client = DeepSeekClient() |
|
|
user_prompt = f"问题:{question}" |
|
|
if background_info: |
|
|
user_prompt = f"背景知识:{background_info}\n\n{user_prompt}" |
|
|
try: |
|
|
answer = deepseek_client.generate_answer(system_prompt, user_prompt) |
|
|
return answer |
|
|
except Exception as e: |
|
|
return f"生成回答时出错:{str(e)}" |
|
|
|
|
|
|
|
|
class ReasoningRAG: |
|
|
""" |
|
|
多跳推理RAG系统,通过迭代式的检索和推理过程回答问题,支持流式响应 |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
index_path: str, |
|
|
metadata_path: str, |
|
|
max_hops: int = 3, |
|
|
initial_candidates: int = 5, |
|
|
refined_candidates: int = 3, |
|
|
reasoning_model: str = Config.llm_model, |
|
|
verbose: bool = False): |
|
|
""" |
|
|
初始化推理RAG系统 |
|
|
|
|
|
参数: |
|
|
index_path: FAISS索引的路径 |
|
|
metadata_path: 元数据JSON文件的路径 |
|
|
max_hops: 最大推理-检索跳数 |
|
|
initial_candidates: 初始检索候选数量 |
|
|
refined_candidates: 精炼检索候选数量 |
|
|
reasoning_model: 用于推理步骤的LLM模型 |
|
|
verbose: 是否打印详细日志 |
|
|
""" |
|
|
self.index_path = index_path |
|
|
self.metadata_path = metadata_path |
|
|
self.max_hops = max_hops |
|
|
self.initial_candidates = initial_candidates |
|
|
self.refined_candidates = refined_candidates |
|
|
self.reasoning_model = reasoning_model |
|
|
self.verbose = verbose |
|
|
|
|
|
|
|
|
self._load_resources() |
|
|
|
|
|
def _load_resources(self): |
|
|
"""加载FAISS索引和元数据""" |
|
|
if os.path.exists(self.index_path) and os.path.exists(self.metadata_path): |
|
|
self.index = faiss.read_index(self.index_path) |
|
|
try: |
|
|
with open(self.metadata_path, 'r', encoding='utf-8') as f: |
|
|
self.metadata = json.load(f) |
|
|
except UnicodeDecodeError: |
|
|
with open(self.metadata_path, 'rb') as f: |
|
|
content = f.read().decode('utf-8', errors='ignore') |
|
|
self.metadata = json.loads(content) |
|
|
else: |
|
|
raise FileNotFoundError(f"Index or metadata not found at {self.index_path} or {self.metadata_path}") |
|
|
|
|
|
def _vectorize_query(self, query: str) -> np.ndarray: |
|
|
"""将查询转换为向量""" |
|
|
return vectorize_query(query).reshape(1, -1) |
|
|
|
|
|
def _retrieve(self, query_vector: np.ndarray, limit: int) -> List[Dict[str, Any]]: |
|
|
"""使用向量相似性检索块""" |
|
|
if query_vector.size == 0: |
|
|
return [] |
|
|
|
|
|
D, I = self.index.search(query_vector, limit) |
|
|
results = [self.metadata[i] for i in I[0] if i < len(self.metadata)] |
|
|
return results |
|
|
|
|
|
def _generate_reasoning(self, |
|
|
query: str, |
|
|
retrieved_chunks: List[Dict[str, Any]], |
|
|
previous_queries: List[str] = None, |
|
|
hop_number: int = 0) -> Dict[str, Any]: |
|
|
""" |
|
|
为检索到的信息生成推理分析并识别信息缺口 |
|
|
|
|
|
返回包含以下字段的字典: |
|
|
- analysis: 对当前信息的推理分析 |
|
|
- missing_info: 已识别的缺失信息 |
|
|
- follow_up_queries: 填补信息缺口的后续查询列表 |
|
|
- is_sufficient: 表示信息是否足够的布尔值 |
|
|
""" |
|
|
if previous_queries is None: |
|
|
previous_queries = [] |
|
|
|
|
|
|
|
|
chunks_text = "\n\n".join([f"[Chunk {i+1}]: {chunk['chunk']}" |
|
|
for i, chunk in enumerate(retrieved_chunks)]) |
|
|
|
|
|
previous_queries_text = "\n".join([f"Q{i+1}: {q}" for i, q in enumerate(previous_queries)]) |
|
|
|
|
|
system_prompt = """ |
|
|
你是医疗信息检索的专家分析系统。 |
|
|
你的任务是分析检索到的信息块,识别缺失的内容,并提出有针对性的后续查询来填补信息缺口。 |
|
|
|
|
|
重点关注医疗领域知识,如: |
|
|
- 疾病诊断和症状 |
|
|
- 治疗方法和药物 |
|
|
- 医学研究和临床试验 |
|
|
- 患者护理和康复 |
|
|
- 医疗法规和伦理 |
|
|
""" |
|
|
|
|
|
user_prompt = f""" |
|
|
## 原始查询 |
|
|
{query} |
|
|
|
|
|
## 先前查询(如果有) |
|
|
{previous_queries_text if previous_queries else "无"} |
|
|
|
|
|
## 检索到的信息(跳数 {hop_number}) |
|
|
{chunks_text if chunks_text else "未检索到信息。"} |
|
|
|
|
|
## 你的任务 |
|
|
1. 分析已检索到的信息与原始查询的关系 |
|
|
2. 确定能够更完整回答查询的特定缺失信息 |
|
|
3. 提出1-3个针对性的后续查询,以检索缺失信息 |
|
|
4. 确定当前信息是否足够回答原始查询 |
|
|
|
|
|
以JSON格式回答,包含以下字段: |
|
|
- analysis: 对当前信息的详细分析 |
|
|
- missing_info: 特定缺失信息的列表 |
|
|
- follow_up_queries: 1-3个具体的后续查询 |
|
|
- is_sufficient: 表示信息是否足够的布尔值 |
|
|
""" |
|
|
|
|
|
try: |
|
|
response = client.chat.completions.create( |
|
|
model=Config.llm_model, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
], |
|
|
response_format={"type": "json_object"} |
|
|
) |
|
|
reasoning_text = response.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
try: |
|
|
reasoning = json.loads(reasoning_text) |
|
|
|
|
|
required_keys = ["analysis", "missing_info", "follow_up_queries", "is_sufficient"] |
|
|
for key in required_keys: |
|
|
if key not in reasoning: |
|
|
reasoning[key] = [] if key != "is_sufficient" else False |
|
|
return reasoning |
|
|
except json.JSONDecodeError: |
|
|
|
|
|
if self.verbose: |
|
|
print(f"无法从模型输出解析JSON: {reasoning_text[:100]}...") |
|
|
return { |
|
|
"analysis": "无法分析检索到的信息。", |
|
|
"missing_info": ["无法识别缺失信息"], |
|
|
"follow_up_queries": [], |
|
|
"is_sufficient": False |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
if self.verbose: |
|
|
print(f"推理生成错误: {e}") |
|
|
print(traceback.format_exc()) |
|
|
return { |
|
|
"analysis": "分析过程出错。", |
|
|
"missing_info": [], |
|
|
"follow_up_queries": [], |
|
|
"is_sufficient": False |
|
|
} |
|
|
|
|
|
def _synthesize_answer(self, |
|
|
query: str, |
|
|
all_chunks: List[Dict[str, Any]], |
|
|
reasoning_steps: List[Dict[str, Any]], |
|
|
use_table_format: bool = False) -> str: |
|
|
"""从所有检索到的块和推理步骤中合成最终答案""" |
|
|
|
|
|
unique_chunks = [] |
|
|
chunk_ids = set() |
|
|
for chunk in all_chunks: |
|
|
if chunk["id"] not in chunk_ids: |
|
|
unique_chunks.append(chunk) |
|
|
chunk_ids.add(chunk["id"]) |
|
|
|
|
|
|
|
|
chunks_text = "\n\n".join([f"[Chunk {i+1}]: {chunk['chunk']}" |
|
|
for i, chunk in enumerate(unique_chunks)]) |
|
|
|
|
|
|
|
|
reasoning_trace = "" |
|
|
for i, step in enumerate(reasoning_steps): |
|
|
reasoning_trace += f"\n\n推理步骤 {i+1}:\n" |
|
|
reasoning_trace += f"分析: {step['analysis']}\n" |
|
|
reasoning_trace += f"缺失信息: {', '.join(step['missing_info'])}\n" |
|
|
reasoning_trace += f"后续查询: {', '.join(step['follow_up_queries'])}" |
|
|
|
|
|
system_prompt = """ |
|
|
你是医疗领域的专家。基于检索到的信息块,为用户的查询合成一个全面的答案。 |
|
|
|
|
|
重点提供有关医疗和健康的准确、基于证据的信息,包括诊断、治疗、预防和医学研究等方面。 |
|
|
|
|
|
逻辑地组织你的答案,并在适当时引用块中的具体信息。如果信息不完整,请承认限制。 |
|
|
""" |
|
|
|
|
|
output_format_instruction = "" |
|
|
if use_table_format: |
|
|
output_format_instruction = """ |
|
|
请尽可能以Markdown表格格式组织你的回答。如果信息适合表格形式展示,请使用表格; |
|
|
如果不适合表格形式,可以先用文本介绍,然后再使用表格总结关键信息。 |
|
|
|
|
|
表格语法示例: |
|
|
| 标题1 | 标题2 | 标题3 | |
|
|
| ----- | ----- | ----- | |
|
|
| 内容1 | 内容2 | 内容3 | |
|
|
|
|
|
确保表格格式符合Markdown标准,以便正确渲染。 |
|
|
""" |
|
|
|
|
|
user_prompt = f""" |
|
|
## 原始查询 |
|
|
{query} |
|
|
|
|
|
## 检索到的信息块 |
|
|
{chunks_text} |
|
|
|
|
|
## 推理过程 |
|
|
{reasoning_trace} |
|
|
|
|
|
## 你的任务 |
|
|
使用提供的信息块为原始查询合成一个全面的答案。你的答案应该: |
|
|
|
|
|
1. 直接回应查询 |
|
|
2. 结构清晰,易于理解 |
|
|
3. 基于检索到的信息 |
|
|
4. 承认可用信息中的任何重大缺口 |
|
|
|
|
|
{output_format_instruction} |
|
|
|
|
|
以直接回应提出原始查询的用户的方式呈现你的答案。 |
|
|
""" |
|
|
|
|
|
try: |
|
|
response = client.chat.completions.create( |
|
|
model=Config.llm_model, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
] |
|
|
) |
|
|
return response.choices[0].message.content.strip() |
|
|
except Exception as e: |
|
|
if self.verbose: |
|
|
print(f"答案合成错误: {e}") |
|
|
print(traceback.format_exc()) |
|
|
return "由于出错,无法生成答案。" |
|
|
|
|
|
def stream_retrieve_and_answer(self, query: str, use_table_format: bool = False): |
|
|
""" |
|
|
执行多跳检索和回答生成的流式方法,逐步返回结果 |
|
|
|
|
|
这是一个生成器函数,会在处理的每个阶段产生中间结果 |
|
|
""" |
|
|
all_chunks = [] |
|
|
all_queries = [query] |
|
|
reasoning_steps = [] |
|
|
|
|
|
|
|
|
yield { |
|
|
"status": "正在将查询向量化...", |
|
|
"reasoning_display": "", |
|
|
"answer": None, |
|
|
"all_chunks": [], |
|
|
"reasoning_steps": [] |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
query_vector = self._vectorize_query(query) |
|
|
if query_vector.size == 0: |
|
|
yield { |
|
|
"status": "向量化失败", |
|
|
"reasoning_display": "由于嵌入错误,无法处理查询。", |
|
|
"answer": "由于嵌入错误,无法处理查询。", |
|
|
"all_chunks": [], |
|
|
"reasoning_steps": [] |
|
|
} |
|
|
return |
|
|
|
|
|
yield { |
|
|
"status": "正在执行初始检索...", |
|
|
"reasoning_display": "", |
|
|
"answer": None, |
|
|
"all_chunks": [], |
|
|
"reasoning_steps": [] |
|
|
} |
|
|
|
|
|
initial_chunks = self._retrieve(query_vector, self.initial_candidates) |
|
|
all_chunks.extend(initial_chunks) |
|
|
|
|
|
if not initial_chunks: |
|
|
yield { |
|
|
"status": "未找到相关信息", |
|
|
"reasoning_display": "未找到与您的查询相关的信息。", |
|
|
"answer": "未找到与您的查询相关的信息。", |
|
|
"all_chunks": [], |
|
|
"reasoning_steps": [] |
|
|
} |
|
|
return |
|
|
|
|
|
|
|
|
chunks_preview = "\n".join([f"- {chunk['chunk'][:100]}..." for chunk in initial_chunks[:2]]) |
|
|
yield { |
|
|
"status": f"找到 {len(initial_chunks)} 个相关信息块,正在生成初步分析...", |
|
|
"reasoning_display": f"### 检索到的初始信息\n{chunks_preview}\n\n### 正在分析...", |
|
|
"answer": None, |
|
|
"all_chunks": all_chunks, |
|
|
"reasoning_steps": [] |
|
|
} |
|
|
|
|
|
|
|
|
reasoning = self._generate_reasoning(query, initial_chunks, hop_number=0) |
|
|
reasoning_steps.append(reasoning) |
|
|
|
|
|
|
|
|
reasoning_display = "### 多跳推理过程\n" |
|
|
reasoning_display += f"**推理步骤 1**\n" |
|
|
reasoning_display += f"- 分析: {reasoning['analysis'][:200]}...\n" |
|
|
reasoning_display += f"- 缺失信息: {', '.join(reasoning['missing_info'])}\n" |
|
|
if reasoning['follow_up_queries']: |
|
|
reasoning_display += f"- 后续查询: {', '.join(reasoning['follow_up_queries'])}\n" |
|
|
reasoning_display += f"- 信息是否足够: {'是' if reasoning['is_sufficient'] else '否'}\n\n" |
|
|
|
|
|
yield { |
|
|
"status": "初步分析完成", |
|
|
"reasoning_display": reasoning_display, |
|
|
"answer": None, |
|
|
"all_chunks": all_chunks, |
|
|
"reasoning_steps": reasoning_steps |
|
|
} |
|
|
|
|
|
|
|
|
hop = 1 |
|
|
while (hop < self.max_hops and |
|
|
not reasoning["is_sufficient"] and |
|
|
reasoning["follow_up_queries"]): |
|
|
|
|
|
follow_up_status = f"执行跳数 {hop},正在处理 {len(reasoning['follow_up_queries'])} 个后续查询..." |
|
|
yield { |
|
|
"status": follow_up_status, |
|
|
"reasoning_display": reasoning_display + f"\n\n### {follow_up_status}", |
|
|
"answer": None, |
|
|
"all_chunks": all_chunks, |
|
|
"reasoning_steps": reasoning_steps |
|
|
} |
|
|
|
|
|
hop_chunks = [] |
|
|
|
|
|
|
|
|
for i, follow_up_query in enumerate(reasoning["follow_up_queries"]): |
|
|
all_queries.append(follow_up_query) |
|
|
|
|
|
query_status = f"处理后续查询 {i+1}/{len(reasoning['follow_up_queries'])}: {follow_up_query}" |
|
|
yield { |
|
|
"status": query_status, |
|
|
"reasoning_display": reasoning_display + f"\n\n### {query_status}", |
|
|
"answer": None, |
|
|
"all_chunks": all_chunks, |
|
|
"reasoning_steps": reasoning_steps |
|
|
} |
|
|
|
|
|
|
|
|
follow_up_vector = self._vectorize_query(follow_up_query) |
|
|
if follow_up_vector.size > 0: |
|
|
follow_up_chunks = self._retrieve(follow_up_vector, self.refined_candidates) |
|
|
hop_chunks.extend(follow_up_chunks) |
|
|
all_chunks.extend(follow_up_chunks) |
|
|
|
|
|
|
|
|
yield { |
|
|
"status": f"查询 '{follow_up_query}' 找到了 {len(follow_up_chunks)} 个相关块", |
|
|
"reasoning_display": reasoning_display + f"\n\n为查询 '{follow_up_query}' 找到了 {len(follow_up_chunks)} 个相关块", |
|
|
"answer": None, |
|
|
"all_chunks": all_chunks, |
|
|
"reasoning_steps": reasoning_steps |
|
|
} |
|
|
|
|
|
|
|
|
yield { |
|
|
"status": f"正在为跳数 {hop} 生成推理分析...", |
|
|
"reasoning_display": reasoning_display + f"\n\n### 正在为跳数 {hop} 生成推理分析...", |
|
|
"answer": None, |
|
|
"all_chunks": all_chunks, |
|
|
"reasoning_steps": reasoning_steps |
|
|
} |
|
|
|
|
|
reasoning = self._generate_reasoning( |
|
|
query, |
|
|
hop_chunks, |
|
|
previous_queries=all_queries[:-1], |
|
|
hop_number=hop |
|
|
) |
|
|
reasoning_steps.append(reasoning) |
|
|
|
|
|
|
|
|
reasoning_display += f"\n**推理步骤 {hop+1}**\n" |
|
|
reasoning_display += f"- 分析: {reasoning['analysis'][:200]}...\n" |
|
|
reasoning_display += f"- 缺失信息: {', '.join(reasoning['missing_info'])}\n" |
|
|
if reasoning['follow_up_queries']: |
|
|
reasoning_display += f"- 后续查询: {', '.join(reasoning['follow_up_queries'])}\n" |
|
|
reasoning_display += f"- 信息是否足够: {'是' if reasoning['is_sufficient'] else '否'}\n" |
|
|
|
|
|
yield { |
|
|
"status": f"跳数 {hop} 完成", |
|
|
"reasoning_display": reasoning_display, |
|
|
"answer": None, |
|
|
"all_chunks": all_chunks, |
|
|
"reasoning_steps": reasoning_steps |
|
|
} |
|
|
|
|
|
hop += 1 |
|
|
|
|
|
|
|
|
yield { |
|
|
"status": "正在合成最终答案...", |
|
|
"reasoning_display": reasoning_display + "\n\n### 正在合成最终答案...", |
|
|
"answer": "正在处理您的问题,请稍候...", |
|
|
"all_chunks": all_chunks, |
|
|
"reasoning_steps": reasoning_steps |
|
|
} |
|
|
|
|
|
answer = self._synthesize_answer(query, all_chunks, reasoning_steps, use_table_format) |
|
|
|
|
|
|
|
|
all_chunks_summary = "\n\n".join([f"**检索块 {i+1}**:\n{chunk['chunk']}" |
|
|
for i, chunk in enumerate(all_chunks[:10])]) |
|
|
|
|
|
if len(all_chunks) > 10: |
|
|
all_chunks_summary += f"\n\n...以及另外 {len(all_chunks) - 10} 个块(总计 {len(all_chunks)} 个)" |
|
|
|
|
|
enhanced_display = reasoning_display + "\n\n### 检索到的内容\n" + all_chunks_summary + "\n\n### 回答已生成" |
|
|
|
|
|
yield { |
|
|
"status": "回答已生成", |
|
|
"reasoning_display": enhanced_display, |
|
|
"answer": answer, |
|
|
"all_chunks": all_chunks, |
|
|
"reasoning_steps": reasoning_steps |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"处理过程中出错: {str(e)}" |
|
|
if self.verbose: |
|
|
print(error_msg) |
|
|
print(traceback.format_exc()) |
|
|
|
|
|
yield { |
|
|
"status": "处理出错", |
|
|
"reasoning_display": error_msg, |
|
|
"answer": f"处理您的问题时遇到错误: {str(e)}", |
|
|
"all_chunks": all_chunks, |
|
|
"reasoning_steps": reasoning_steps |
|
|
} |
|
|
|
|
|
def retrieve_and_answer(self, query: str, use_table_format: bool = False) -> Tuple[str, Dict[str, Any]]: |
|
|
""" |
|
|
执行多跳检索和回答生成的主要方法 |
|
|
|
|
|
返回: |
|
|
包含以下内容的元组: |
|
|
- 最终答案 |
|
|
- 包含推理步骤和所有检索到的块的调试字典 |
|
|
""" |
|
|
all_chunks = [] |
|
|
all_queries = [query] |
|
|
reasoning_steps = [] |
|
|
debug_info = {"reasoning_steps": [], "all_chunks": [], "all_queries": all_queries} |
|
|
|
|
|
|
|
|
query_vector = self._vectorize_query(query) |
|
|
if query_vector.size == 0: |
|
|
return "由于嵌入错误,无法处理查询。", debug_info |
|
|
|
|
|
initial_chunks = self._retrieve(query_vector, self.initial_candidates) |
|
|
all_chunks.extend(initial_chunks) |
|
|
debug_info["all_chunks"].extend(initial_chunks) |
|
|
|
|
|
if not initial_chunks: |
|
|
return "未找到与您的查询相关的信息。", debug_info |
|
|
|
|
|
|
|
|
reasoning = self._generate_reasoning(query, initial_chunks, hop_number=0) |
|
|
reasoning_steps.append(reasoning) |
|
|
debug_info["reasoning_steps"].append(reasoning) |
|
|
|
|
|
|
|
|
hop = 1 |
|
|
while (hop < self.max_hops and |
|
|
not reasoning["is_sufficient"] and |
|
|
reasoning["follow_up_queries"]): |
|
|
|
|
|
if self.verbose: |
|
|
print(f"开始跳数 {hop},有 {len(reasoning['follow_up_queries'])} 个后续查询") |
|
|
|
|
|
hop_chunks = [] |
|
|
|
|
|
|
|
|
for follow_up_query in reasoning["follow_up_queries"]: |
|
|
all_queries.append(follow_up_query) |
|
|
debug_info["all_queries"].append(follow_up_query) |
|
|
|
|
|
|
|
|
follow_up_vector = self._vectorize_query(follow_up_query) |
|
|
if follow_up_vector.size > 0: |
|
|
follow_up_chunks = self._retrieve(follow_up_vector, self.refined_candidates) |
|
|
hop_chunks.extend(follow_up_chunks) |
|
|
all_chunks.extend(follow_up_chunks) |
|
|
debug_info["all_chunks"].extend(follow_up_chunks) |
|
|
|
|
|
|
|
|
reasoning = self._generate_reasoning( |
|
|
query, |
|
|
hop_chunks, |
|
|
previous_queries=all_queries[:-1], |
|
|
hop_number=hop |
|
|
) |
|
|
reasoning_steps.append(reasoning) |
|
|
debug_info["reasoning_steps"].append(reasoning) |
|
|
|
|
|
hop += 1 |
|
|
|
|
|
|
|
|
answer = self._synthesize_answer(query, all_chunks, reasoning_steps, use_table_format) |
|
|
|
|
|
return answer, debug_info |
|
|
|
|
|
|
|
|
def get_kb_paths(kb_name: str) -> Dict[str, str]: |
|
|
"""获取指定知识库的索引文件路径""" |
|
|
kb_dir = os.path.join(KB_BASE_DIR, kb_name) |
|
|
return { |
|
|
"index_path": os.path.join(kb_dir, "semantic_chunk.index"), |
|
|
"metadata_path": os.path.join(kb_dir, "semantic_chunk_metadata.json") |
|
|
} |
|
|
|
|
|
def multi_hop_generate_answer(query: str, kb_name: str, use_table_format: bool = False, system_prompt: str = "你是一名医疗专家。") -> Tuple[str, Dict]: |
|
|
"""使用多跳推理RAG生成答案,基于指定知识库""" |
|
|
kb_paths = get_kb_paths(kb_name) |
|
|
|
|
|
reasoning_rag = ReasoningRAG( |
|
|
index_path=kb_paths["index_path"], |
|
|
metadata_path=kb_paths["metadata_path"], |
|
|
max_hops=3, |
|
|
initial_candidates=5, |
|
|
refined_candidates=3, |
|
|
reasoning_model=Config.llm_model, |
|
|
verbose=True |
|
|
) |
|
|
|
|
|
answer, debug_info = reasoning_rag.retrieve_and_answer(query, use_table_format) |
|
|
return answer, debug_info |
|
|
|
|
|
|
|
|
def simple_generate_answer(query: str, kb_name: str, use_table_format: bool = False) -> str: |
|
|
"""使用简单的向量检索生成答案,不使用多跳推理""" |
|
|
try: |
|
|
kb_paths = get_kb_paths(kb_name) |
|
|
|
|
|
|
|
|
search_results = vector_search(query, kb_paths["index_path"], kb_paths["metadata_path"], limit=5) |
|
|
|
|
|
if not search_results: |
|
|
return "未找到相关信息。" |
|
|
|
|
|
|
|
|
background_chunks = "\n\n".join([f"[相关信息 {i+1}]: {result['chunk']}" |
|
|
for i, result in enumerate(search_results)]) |
|
|
|
|
|
|
|
|
system_prompt = "你是一名医疗专家。基于提供的背景信息回答用户的问题。" |
|
|
|
|
|
if use_table_format: |
|
|
system_prompt += "请尽可能以Markdown表格的形式呈现结构化信息。" |
|
|
|
|
|
user_prompt = f""" |
|
|
问题:{query} |
|
|
|
|
|
背景信息: |
|
|
{background_chunks} |
|
|
|
|
|
请基于以上背景信息回答用户的问题。 |
|
|
""" |
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model=Config.llm_model, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
] |
|
|
) |
|
|
|
|
|
return response.choices[0].message.content.strip() |
|
|
|
|
|
except Exception as e: |
|
|
return f"生成答案时出错:{str(e)}" |
|
|
|
|
|
|
|
|
def ask_question_parallel(question: str, kb_name: str = DEFAULT_KB, use_search: bool = True, use_table_format: bool = False, multi_hop: bool = False) -> str: |
|
|
"""基于指定知识库回答问题""" |
|
|
try: |
|
|
kb_paths = get_kb_paths(kb_name) |
|
|
index_path = kb_paths["index_path"] |
|
|
metadata_path = kb_paths["metadata_path"] |
|
|
|
|
|
search_background = "" |
|
|
local_answer = "" |
|
|
debug_info = {} |
|
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=2) as executor: |
|
|
futures = {} |
|
|
|
|
|
if use_search: |
|
|
futures[executor.submit(get_search_background, question)] = "search" |
|
|
|
|
|
if os.path.exists(index_path): |
|
|
if multi_hop: |
|
|
|
|
|
futures[executor.submit(multi_hop_generate_answer, question, kb_name, use_table_format)] = "rag" |
|
|
else: |
|
|
|
|
|
futures[executor.submit(simple_generate_answer, question, kb_name, use_table_format)] = "simple" |
|
|
|
|
|
for future in as_completed(futures): |
|
|
result = future.result() |
|
|
if futures[future] == "search": |
|
|
search_background = result or "" |
|
|
elif futures[future] == "rag": |
|
|
local_answer, debug_info = result |
|
|
elif futures[future] == "simple": |
|
|
local_answer = result |
|
|
|
|
|
|
|
|
if search_background and local_answer: |
|
|
system_prompt = "你是一名医疗专家,请整合网络搜索和本地知识库提供全面的解答。" |
|
|
|
|
|
table_instruction = "" |
|
|
if use_table_format: |
|
|
table_instruction = """ |
|
|
请尽可能以Markdown表格的形式呈现你的回答,特别是对于症状、治疗方法、药物等结构化信息。 |
|
|
|
|
|
请确保你的表格遵循正确的Markdown语法: |
|
|
| 列标题1 | 列标题2 | 列标题3 | |
|
|
| ------- | ------- | ------- | |
|
|
| 数据1 | 数据2 | 数据3 | |
|
|
""" |
|
|
|
|
|
user_prompt = f""" |
|
|
问题:{question} |
|
|
|
|
|
网络搜索结果:{search_background} |
|
|
|
|
|
本地知识库分析:{local_answer} |
|
|
|
|
|
{table_instruction} |
|
|
|
|
|
请根据以上信息,提供一个综合的回答。 |
|
|
""" |
|
|
|
|
|
try: |
|
|
response = client.chat.completions.create( |
|
|
model="qwen-plus", |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
] |
|
|
) |
|
|
combined_answer = response.choices[0].message.content.strip() |
|
|
return combined_answer |
|
|
except Exception as e: |
|
|
|
|
|
return local_answer |
|
|
elif local_answer: |
|
|
return local_answer |
|
|
elif search_background: |
|
|
|
|
|
system_prompt = "你是一名医疗专家。" |
|
|
if use_table_format: |
|
|
system_prompt += "请尽可能以Markdown表格的形式呈现结构化信息。" |
|
|
return generate_answer_from_deepseek(question, system_prompt=system_prompt, background_info=f"[联网搜索结果]:{search_background}") |
|
|
else: |
|
|
return "未找到相关信息。" |
|
|
|
|
|
except Exception as e: |
|
|
return f"查询失败:{str(e)}" |
|
|
|
|
|
|
|
|
def process_question_with_reasoning(question: str, kb_name: str = DEFAULT_KB, use_search: bool = True, use_table_format: bool = False, multi_hop: bool = False, chat_history: List = None): |
|
|
"""增强版process_question,支持流式响应,实时显示检索和推理过程,支持多知识库和对话历史""" |
|
|
try: |
|
|
kb_paths = get_kb_paths(kb_name) |
|
|
index_path = kb_paths["index_path"] |
|
|
metadata_path = kb_paths["metadata_path"] |
|
|
|
|
|
|
|
|
if chat_history and len(chat_history) > 0: |
|
|
|
|
|
context = "之前的对话内容:\n" |
|
|
for user_msg, assistant_msg in chat_history[-3:]: |
|
|
context += f"用户:{user_msg}\n" |
|
|
context += f"助手:{assistant_msg}\n" |
|
|
context += f"\n当前问题:{question}" |
|
|
enhanced_question = f"基于以下对话历史,回答用户的当前问题。\n{context}" |
|
|
else: |
|
|
enhanced_question = question |
|
|
|
|
|
|
|
|
search_result = "联网搜索进行中..." if use_search else "未启用联网搜索" |
|
|
|
|
|
if multi_hop: |
|
|
reasoning_status = f"正在准备对知识库 '{kb_name}' 进行多跳推理检索..." |
|
|
search_display = f"### 联网搜索结果\n{search_result}\n\n### 推理状态\n{reasoning_status}" |
|
|
yield search_display, "正在启动多跳推理流程..." |
|
|
else: |
|
|
reasoning_status = f"正在准备对知识库 '{kb_name}' 进行向量检索..." |
|
|
search_display = f"### 联网搜索结果\n{search_result}\n\n### 检索状态\n{reasoning_status}" |
|
|
yield search_display, "正在启动简单检索流程..." |
|
|
|
|
|
|
|
|
search_future = None |
|
|
with ThreadPoolExecutor(max_workers=1) as executor: |
|
|
if use_search: |
|
|
search_future = executor.submit(get_search_background, question) |
|
|
|
|
|
|
|
|
if not (os.path.exists(index_path) and os.path.exists(metadata_path)): |
|
|
|
|
|
if search_future: |
|
|
|
|
|
search_result = "等待联网搜索结果..." |
|
|
search_display = f"### 联网搜索结果\n{search_result}\n\n### 检索状态\n知识库 '{kb_name}' 中未找到索引" |
|
|
yield search_display, "等待联网搜索结果..." |
|
|
|
|
|
search_result = search_future.result() or "未找到相关网络信息" |
|
|
system_prompt = "你是一名医疗专家。请考虑对话历史并回答用户的问题。" |
|
|
if use_table_format: |
|
|
system_prompt += "请尽可能以Markdown表格的形式呈现结构化信息。" |
|
|
answer = generate_answer_from_deepseek(enhanced_question, system_prompt=system_prompt, background_info=f"[联网搜索结果]:{search_result}") |
|
|
|
|
|
search_display = f"### 联网搜索结果\n{search_result}\n\n### 检索状态\n无法在知识库 '{kb_name}' 中进行本地检索(未找到索引)" |
|
|
yield search_display, answer |
|
|
else: |
|
|
yield f"知识库 '{kb_name}' 中未找到索引,且未启用联网搜索", "无法回答您的问题。请先上传文件到该知识库或启用联网搜索。" |
|
|
return |
|
|
|
|
|
|
|
|
current_answer = "正在分析您的问题..." |
|
|
|
|
|
if multi_hop: |
|
|
|
|
|
reasoning_rag = ReasoningRAG( |
|
|
index_path=index_path, |
|
|
metadata_path=metadata_path, |
|
|
max_hops=3, |
|
|
initial_candidates=5, |
|
|
refined_candidates=3, |
|
|
verbose=True |
|
|
) |
|
|
|
|
|
|
|
|
for step_result in reasoning_rag.stream_retrieve_and_answer(enhanced_question, use_table_format): |
|
|
|
|
|
status = step_result["status"] |
|
|
reasoning_display = step_result["reasoning_display"] |
|
|
|
|
|
|
|
|
if step_result["answer"]: |
|
|
current_answer = step_result["answer"] |
|
|
|
|
|
|
|
|
if search_future and search_future.done(): |
|
|
search_result = search_future.result() or "未找到相关网络信息" |
|
|
|
|
|
|
|
|
current_display = f"### 联网搜索结果\n{search_result}\n\n### 知识库: {kb_name}\n### 推理状态\n{status}\n\n{reasoning_display}" |
|
|
yield current_display, current_answer |
|
|
else: |
|
|
|
|
|
yield f"### 联网搜索结果\n{search_result}\n\n### 知识库: {kb_name}\n### 检索状态\n正在执行向量相似度搜索...", "正在检索相关信息..." |
|
|
|
|
|
|
|
|
try: |
|
|
search_results = vector_search(enhanced_question, index_path, metadata_path, limit=5) |
|
|
|
|
|
if not search_results: |
|
|
yield f"### 联网搜索结果\n{search_result}\n\n### 知识库: {kb_name}\n### 检索状态\n未找到相关信息", f"知识库 '{kb_name}' 中未找到相关信息。" |
|
|
current_answer = f"知识库 '{kb_name}' 中未找到相关信息。" |
|
|
else: |
|
|
|
|
|
chunks_detail = "\n\n".join([f"**相关信息 {i+1}**:\n{result['chunk']}" for i, result in enumerate(search_results[:5])]) |
|
|
chunks_preview = "\n".join([f"- {result['chunk'][:100]}..." for i, result in enumerate(search_results[:3])]) |
|
|
yield f"### 联网搜索结果\n{search_result}\n\n### 知识库: {kb_name}\n### 检索状态\n找到 {len(search_results)} 个相关信息块\n\n### 检索到的信息预览\n{chunks_preview}", "正在生成答案..." |
|
|
|
|
|
|
|
|
background_chunks = "\n\n".join([f"[相关信息 {i+1}]: {result['chunk']}" |
|
|
for i, result in enumerate(search_results)]) |
|
|
|
|
|
system_prompt = "你是一名医疗专家。基于提供的背景信息和对话历史回答用户的问题。" |
|
|
if use_table_format: |
|
|
system_prompt += "请尽可能以Markdown表格的形式呈现结构化信息。" |
|
|
|
|
|
user_prompt = f""" |
|
|
{enhanced_question} |
|
|
|
|
|
背景信息: |
|
|
{background_chunks} |
|
|
|
|
|
请基于以上背景信息和对话历史回答用户的问题。 |
|
|
""" |
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model=Config.llm_model, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
] |
|
|
) |
|
|
|
|
|
current_answer = response.choices[0].message.content.strip() |
|
|
yield f"### 联网搜索结果\n{search_result}\n\n### 知识库: {kb_name}\n### 检索状态\n检索完成,已生成答案\n\n### 检索到的内容\n{chunks_detail}", current_answer |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"检索过程中出错: {str(e)}" |
|
|
yield f"### 联网搜索结果\n{search_result}\n\n### 知识库: {kb_name}\n### 检索状态\n{error_msg}", f"检索过程中出错: {str(e)}" |
|
|
current_answer = f"检索过程中出错: {str(e)}" |
|
|
|
|
|
|
|
|
if search_future and search_future.done(): |
|
|
search_result = search_future.result() or "未找到相关网络信息" |
|
|
|
|
|
|
|
|
if search_result and current_answer and current_answer not in ["正在分析您的问题...", "本地知识库中未找到相关信息。"]: |
|
|
status_text = "正在合并联网搜索和知识库结果..." |
|
|
if multi_hop: |
|
|
yield f"### 联网搜索结果\n{search_result}\n\n### 知识库: {kb_name}\n### 推理状态\n{status_text}", current_answer |
|
|
else: |
|
|
yield f"### 联网搜索结果\n{search_result}\n\n### 知识库: {kb_name}\n### 检索状态\n{status_text}", current_answer |
|
|
|
|
|
|
|
|
system_prompt = "你是一名医疗专家,请整合网络搜索和本地知识库提供全面的解答。请考虑对话历史。" |
|
|
|
|
|
if use_table_format: |
|
|
system_prompt += "请尽可能以Markdown表格的形式呈现结构化信息。" |
|
|
|
|
|
user_prompt = f""" |
|
|
{enhanced_question} |
|
|
|
|
|
网络搜索结果:{search_result} |
|
|
|
|
|
本地知识库分析:{current_answer} |
|
|
|
|
|
请根据以上信息和对话历史,提供一个综合的回答。确保使用Markdown表格来呈现适合表格形式的信息。 |
|
|
""" |
|
|
|
|
|
try: |
|
|
response = client.chat.completions.create( |
|
|
model="qwen-plus", |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
] |
|
|
) |
|
|
combined_answer = response.choices[0].message.content.strip() |
|
|
|
|
|
final_status = "已整合联网和知识库结果" |
|
|
if multi_hop: |
|
|
final_display = f"### 联网搜索结果\n{search_result}\n\n### 知识库: {kb_name}\n### 本地知识库分析\n已完成多跳推理分析,检索到的内容已在上方显示\n\n### 综合分析\n{final_status}" |
|
|
else: |
|
|
|
|
|
chunks_info = "".join([part.split("### 检索到的内容\n")[-1] if "### 检索到的内容\n" in part else "" for part in search_display.split("### 联网搜索结果")]) |
|
|
if not chunks_info.strip(): |
|
|
chunks_info = "检索内容已在上方显示" |
|
|
final_display = f"### 联网搜索结果\n{search_result}\n\n### 知识库: {kb_name}\n### 本地知识库分析\n已完成向量检索分析\n\n### 检索到的内容\n{chunks_info}\n\n### 综合分析\n{final_status}" |
|
|
|
|
|
yield final_display, combined_answer |
|
|
except Exception as e: |
|
|
|
|
|
error_status = f"合并结果失败: {str(e)}" |
|
|
if multi_hop: |
|
|
final_display = f"### 联网搜索结果\n{search_result}\n\n### 知识库: {kb_name}\n### 本地知识库分析\n已完成多跳推理分析,检索到的内容已在上方显示\n\n### 综合分析\n{error_status}" |
|
|
else: |
|
|
|
|
|
chunks_info = "".join([part.split("### 检索到的内容\n")[-1] if "### 检索到的内容\n" in part else "" for part in search_display.split("### 联网搜索结果")]) |
|
|
if not chunks_info.strip(): |
|
|
chunks_info = "检索内容已在上方显示" |
|
|
final_display = f"### 联网搜索结果\n{search_result}\n\n### 知识库: {kb_name}\n### 本地知识库分析\n已完成向量检索分析\n\n### 检索到的内容\n{chunks_info}\n\n### 综合分析\n{error_status}" |
|
|
|
|
|
yield final_display, current_answer |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"处理失败:{str(e)}\n{traceback.format_exc()}" |
|
|
yield f"### 错误信息\n{error_msg}", f"处理您的问题时遇到错误:{str(e)}" |
|
|
|
|
|
|
|
|
def batch_upload_to_kb(file_objs: List, kb_name: str) -> str: |
|
|
"""批量上传文件到指定知识库并进行处理""" |
|
|
try: |
|
|
if not kb_name or not kb_name.strip(): |
|
|
return "错误:未指定知识库" |
|
|
|
|
|
|
|
|
kb_dir = os.path.join(KB_BASE_DIR, kb_name) |
|
|
if not os.path.exists(kb_dir): |
|
|
os.makedirs(kb_dir, exist_ok=True) |
|
|
|
|
|
if not file_objs or len(file_objs) == 0: |
|
|
return "错误:未选择任何文件" |
|
|
|
|
|
return process_and_index_files(file_objs, kb_name) |
|
|
except Exception as e: |
|
|
return f"上传文件到知识库失败: {str(e)}" |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.web-search-toggle .form { display: flex !important; align-items: center !important; } |
|
|
.web-search-toggle .form > label { order: 2 !important; margin-left: 10px !important; } |
|
|
.web-search-toggle .checkbox-wrap { order: 1 !important; background: #d4e4d4 !important; border-radius: 15px !important; padding: 2px !important; width: 50px !important; height: 28px !important; } |
|
|
.web-search-toggle .checkbox-wrap .checkbox-container { width: 24px !important; height: 24px !important; transition: all 0.3s ease !important; } |
|
|
.web-search-toggle input:checked + .checkbox-wrap { background: #2196F3 !important; } |
|
|
.web-search-toggle input:checked + .checkbox-wrap .checkbox-container { transform: translateX(22px) !important; } |
|
|
#search-results { max-height: 400px; overflow-y: auto; border: 1px solid #2196F3; border-radius: 5px; padding: 10px; background-color: #e7f0f9; } |
|
|
#question-input { border-color: #2196F3 !important; } |
|
|
#answer-output { background-color: #f0f7f0; border-color: #2196F3 !important; max-height: 400px; overflow-y: auto; } |
|
|
.submit-btn { background-color: #2196F3 !important; border: none !important; } |
|
|
.reasoning-steps { background-color: #f0f7f0; border: 1px dashed #2196F3; padding: 10px; margin-top: 10px; border-radius: 5px; } |
|
|
.loading-spinner { display: inline-block; width: 20px; height: 20px; border: 3px solid rgba(33, 150, 243, 0.3); border-radius: 50%; border-top-color: #2196F3; animation: spin 1s ease-in-out infinite; } |
|
|
@keyframes spin { to { transform: rotate(360deg); } } |
|
|
.stream-update { animation: fade 0.5s ease-in-out; } |
|
|
@keyframes fade { from { background-color: rgba(33, 150, 243, 0.1); } to { background-color: transparent; } } |
|
|
.status-box { padding: 10px; border-radius: 5px; margin-bottom: 10px; font-weight: bold; } |
|
|
.status-processing { background-color: #e3f2fd; color: #1565c0; border-left: 4px solid #2196F3; } |
|
|
.status-success { background-color: #e8f5e9; color: #2e7d32; border-left: 4px solid #4CAF50; } |
|
|
.status-error { background-color: #ffebee; color: #c62828; border-left: 4px solid #f44336; } |
|
|
.multi-hop-toggle .form { display: flex !important; align-items: center !important; } |
|
|
.multi-hop-toggle .form > label { order: 2 !important; margin-left: 10px !important; } |
|
|
.multi-hop-toggle .checkbox-wrap { order: 1 !important; background: #d4e4d4 !important; border-radius: 15px !important; padding: 2px !important; width: 50px !important; height: 28px !important; } |
|
|
.multi-hop-toggle .checkbox-wrap .checkbox-container { width: 24px !important; height: 24px !important; transition: all 0.3s ease !important; } |
|
|
.multi-hop-toggle input:checked + .checkbox-wrap { background: #4CAF50 !important; } |
|
|
.multi-hop-toggle input:checked + .checkbox-wrap .checkbox-container { transform: translateX(22px) !important; } |
|
|
.kb-management { border: 1px solid #2196F3; border-radius: 5px; padding: 15px; margin-bottom: 15px; background-color: #f0f7ff; } |
|
|
.kb-selector { margin-bottom: 10px; } |
|
|
/* 缩小文件上传区域高度 */ |
|
|
.compact-upload { |
|
|
margin-bottom: 10px; |
|
|
} |
|
|
|
|
|
.file-upload.compact { |
|
|
padding: 10px; /* 减小内边距 */ |
|
|
min-height: 120px; /* 减小最小高度 */ |
|
|
margin-bottom: 10px; |
|
|
} |
|
|
|
|
|
/* 优化知识库内容显示区域 */ |
|
|
.kb-files-list { |
|
|
height: 400px; |
|
|
overflow-y: auto; |
|
|
} |
|
|
|
|
|
/* 确保右侧列有足够空间 */ |
|
|
#kb-files-group { |
|
|
height: 100%; |
|
|
display: flex; |
|
|
flex-direction: column; |
|
|
} |
|
|
.kb-files-list { max-height: 250px; overflow-y: auto; border: 1px solid #ccc; border-radius: 5px; padding: 10px; margin-top: 10px; background-color: #f9f9f9; } |
|
|
#kb-management-container { |
|
|
max-width: 800px !important; |
|
|
margin: 0 !important; /* 移除自动边距,靠左对齐 */ |
|
|
margin-left: 20px !important; /* 添加左边距 */ |
|
|
} |
|
|
.container { |
|
|
max-width: 1200px !important; |
|
|
margin: 0 auto !important; |
|
|
} |
|
|
.file-upload { |
|
|
border: 2px dashed #2196F3; |
|
|
padding: 15px; |
|
|
border-radius: 10px; |
|
|
background-color: #f0f7ff; |
|
|
margin-bottom: 15px; |
|
|
} |
|
|
.tabs.tab-selected { |
|
|
background-color: #e3f2fd; |
|
|
border-bottom: 3px solid #2196F3; |
|
|
} |
|
|
.group { |
|
|
border: 1px solid #e0e0e0; |
|
|
border-radius: 8px; |
|
|
padding: 10px; |
|
|
margin-bottom: 15px; |
|
|
background-color: #fafafa; |
|
|
} |
|
|
|
|
|
/* 添加更多针对知识库管理页面的样式 */ |
|
|
#kb-controls, #kb-file-upload, #kb-files-group { |
|
|
width: 100% !important; |
|
|
max-width: 800px !important; |
|
|
margin-right: auto !important; |
|
|
} |
|
|
|
|
|
/* 修改Gradio默认的标签页样式以支持左对齐 */ |
|
|
.tabs > .tab-nav > button { |
|
|
flex: 0 1 auto !important; /* 修改为不自动扩展,只占用必要空间 */ |
|
|
} |
|
|
.tabs > .tabitem { |
|
|
padding-left: 0 !important; /* 移除左边距,使内容靠左 */ |
|
|
} |
|
|
/* 对于首页的顶部标题部分 */ |
|
|
#app-container h1, #app-container h2, #app-container h3, |
|
|
#app-container > .prose { |
|
|
text-align: left !important; |
|
|
padding-left: 20px !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
custom_theme = gr.themes.Soft( |
|
|
primary_hue="blue", |
|
|
secondary_hue="blue", |
|
|
neutral_hue="gray", |
|
|
text_size="lg", |
|
|
spacing_size="md", |
|
|
radius_size="md" |
|
|
) |
|
|
|
|
|
|
|
|
js_code = """ |
|
|
<script> |
|
|
document.addEventListener('DOMContentLoaded', function() { |
|
|
// 当页面加载完毕后,找到提交按钮,并为其添加点击事件 |
|
|
const observer = new MutationObserver(function(mutations) { |
|
|
// 找到提交按钮 |
|
|
const submitButton = document.querySelector('button[data-testid="submit"]'); |
|
|
if (submitButton) { |
|
|
submitButton.addEventListener('click', function() { |
|
|
// 找到检索标签页按钮并点击它 |
|
|
setTimeout(function() { |
|
|
const retrievalTab = document.querySelector('[data-testid="tab-button-retrieval-tab"]'); |
|
|
if (retrievalTab) retrievalTab.click(); |
|
|
}, 100); |
|
|
}); |
|
|
observer.disconnect(); // 一旦找到并设置事件,停止观察 |
|
|
} |
|
|
}); |
|
|
|
|
|
// 开始观察文档变化 |
|
|
observer.observe(document.body, { childList: true, subtree: true }); |
|
|
}); |
|
|
</script> |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title="医疗知识问答系统", theme=custom_theme, css=custom_css, elem_id="app-container") as demo: |
|
|
with gr.Column(elem_id="header-container"): |
|
|
gr.Markdown(""" |
|
|
# 🏥 医疗知识问答系统 |
|
|
**智能医疗助手,支持多知识库管理、多轮对话、普通语义检索和高级多跳推理** |
|
|
本系统支持创建多个知识库,上传TXT或PDF文件,通过语义向量检索或创新的多跳推理机制提供医疗信息查询服务。 |
|
|
""") |
|
|
|
|
|
|
|
|
gr.HTML(js_code, visible=False) |
|
|
|
|
|
|
|
|
chat_history_state = gr.State([]) |
|
|
|
|
|
|
|
|
with gr.Tabs() as tabs: |
|
|
|
|
|
with gr.TabItem("知识库管理"): |
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=1, min_width=400): |
|
|
gr.Markdown("### 📚 知识库管理与构建") |
|
|
|
|
|
with gr.Row(elem_id="kb-controls"): |
|
|
with gr.Column(scale=1): |
|
|
new_kb_name = gr.Textbox( |
|
|
label="新知识库名称", |
|
|
placeholder="输入新知识库名称", |
|
|
lines=1 |
|
|
) |
|
|
create_kb_btn = gr.Button("创建知识库", variant="primary", scale=1) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
current_kbs = get_knowledge_bases() |
|
|
kb_dropdown = gr.Dropdown( |
|
|
label="选择知识库", |
|
|
choices=current_kbs, |
|
|
value=DEFAULT_KB if DEFAULT_KB in current_kbs else (current_kbs[0] if current_kbs else None), |
|
|
elem_classes="kb-selector" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
refresh_kb_btn = gr.Button("刷新列表", size="sm", scale=1) |
|
|
delete_kb_btn = gr.Button("删除知识库", size="sm", variant="stop", scale=1) |
|
|
|
|
|
kb_status = gr.Textbox(label="知识库状态", interactive=False, placeholder="选择或创建知识库") |
|
|
|
|
|
with gr.Group(elem_id="kb-file-upload", elem_classes="compact-upload"): |
|
|
gr.Markdown("### 📄 上传文件到知识库") |
|
|
file_upload = gr.File( |
|
|
label="选择文件(支持多选TXT/PDF)", |
|
|
type="filepath", |
|
|
file_types=[".txt", ".pdf"], |
|
|
file_count="multiple", |
|
|
elem_classes="file-upload compact" |
|
|
) |
|
|
upload_status = gr.Textbox(label="上传状态", interactive=False, placeholder="上传后显示状态") |
|
|
|
|
|
kb_select_for_chat = gr.Dropdown( |
|
|
label="为对话选择知识库", |
|
|
choices=current_kbs, |
|
|
value=DEFAULT_KB if DEFAULT_KB in current_kbs else (current_kbs[0] if current_kbs else None), |
|
|
visible=False |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1, min_width=400): |
|
|
with gr.Group(elem_id="kb-files-group"): |
|
|
gr.Markdown("### 📋 知识库内容") |
|
|
kb_files_list = gr.Markdown( |
|
|
value="选择知识库查看文件...", |
|
|
elem_classes="kb-files-list" |
|
|
) |
|
|
|
|
|
|
|
|
kb_select_for_chat = gr.Dropdown( |
|
|
label="为对话选择知识库", |
|
|
choices=current_kbs, |
|
|
value=DEFAULT_KB if DEFAULT_KB in current_kbs else (current_kbs[0] if current_kbs else None), |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.TabItem("对话交互"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### ⚙️ 对话设置") |
|
|
|
|
|
kb_dropdown_chat = gr.Dropdown( |
|
|
label="选择知识库进行对话", |
|
|
choices=current_kbs, |
|
|
value=DEFAULT_KB if DEFAULT_KB in current_kbs else (current_kbs[0] if current_kbs else None), |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
web_search_toggle = gr.Checkbox( |
|
|
label="🌐 启用联网搜索", |
|
|
value=True, |
|
|
info="获取最新医疗动态", |
|
|
elem_classes="web-search-toggle" |
|
|
) |
|
|
table_format_toggle = gr.Checkbox( |
|
|
label="📊 表格格式输出", |
|
|
value=True, |
|
|
info="使用Markdown表格展示结构化回答", |
|
|
elem_classes="web-search-toggle" |
|
|
) |
|
|
|
|
|
multi_hop_toggle = gr.Checkbox( |
|
|
label="🔄 启用多跳推理", |
|
|
value=False, |
|
|
info="使用高级多跳推理机制(较慢但更全面)", |
|
|
elem_classes="multi-hop-toggle" |
|
|
) |
|
|
|
|
|
with gr.Accordion("显示检索进展", open=False): |
|
|
search_results_output = gr.Markdown( |
|
|
label="检索过程", |
|
|
elem_id="search-results", |
|
|
value="等待提交问题..." |
|
|
) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
gr.Markdown("### 💬 对话历史") |
|
|
chatbot = gr.Chatbot( |
|
|
elem_id="chatbot", |
|
|
label="对话历史", |
|
|
height=550 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
question_input = gr.Textbox( |
|
|
label="输入医疗健康相关问题", |
|
|
placeholder="例如:2型糖尿病的症状和治疗方法有哪些?", |
|
|
lines=2, |
|
|
elem_id="question-input" |
|
|
) |
|
|
|
|
|
with gr.Row(elem_classes="submit-row"): |
|
|
submit_btn = gr.Button("提交问题", variant="primary", elem_classes="submit-btn") |
|
|
clear_btn = gr.Button("清空输入", variant="secondary") |
|
|
clear_history_btn = gr.Button("清空对话历史", variant="secondary", elem_classes="clear-history-btn") |
|
|
|
|
|
|
|
|
status_box = gr.HTML( |
|
|
value='<div class="status-box status-processing">准备就绪,等待您的问题</div>', |
|
|
visible=True |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["2型糖尿病的症状和治疗方法有哪些?"], |
|
|
["高血压患者的日常饮食应该注意什么?"], |
|
|
["肺癌的早期症状和筛查方法是什么?"], |
|
|
["新冠肺炎后遗症有哪些表现?如何缓解?"], |
|
|
["儿童过敏性鼻炎的诊断标准和治疗方案有哪些?"] |
|
|
], |
|
|
inputs=question_input, |
|
|
label="示例问题(点击尝试)" |
|
|
) |
|
|
|
|
|
|
|
|
def create_kb_and_refresh(kb_name): |
|
|
result = create_knowledge_base(kb_name) |
|
|
kbs = get_knowledge_bases() |
|
|
|
|
|
return result, gr.update(choices=kbs, value=kb_name if "创建成功" in result else None), gr.update(choices=kbs, value=kb_name if "创建成功" in result else None) |
|
|
|
|
|
|
|
|
def refresh_kb_list(): |
|
|
kbs = get_knowledge_bases() |
|
|
|
|
|
return gr.update(choices=kbs, value=kbs[0] if kbs else None), gr.update(choices=kbs, value=kbs[0] if kbs else None) |
|
|
|
|
|
|
|
|
def delete_kb_and_refresh(kb_name): |
|
|
result = delete_knowledge_base(kb_name) |
|
|
kbs = get_knowledge_bases() |
|
|
|
|
|
return result, gr.update(choices=kbs, value=kbs[0] if kbs else None), gr.update(choices=kbs, value=kbs[0] if kbs else None) |
|
|
|
|
|
|
|
|
def update_kb_files_list(kb_name): |
|
|
if not kb_name: |
|
|
return "未选择知识库" |
|
|
|
|
|
files = get_kb_files(kb_name) |
|
|
kb_dir = os.path.join(KB_BASE_DIR, kb_name) |
|
|
has_index = os.path.exists(os.path.join(kb_dir, "semantic_chunk.index")) |
|
|
|
|
|
if not files: |
|
|
files_str = "知识库中暂无文件" |
|
|
else: |
|
|
files_str = "**文件列表:**\n\n" + "\n".join([f"- {file}" for file in files]) |
|
|
|
|
|
index_status = "\n\n**索引状态:** " + ("✅ 已建立索引" if has_index else "❌ 未建立索引") |
|
|
|
|
|
return f"### 知识库: {kb_name}\n\n{files_str}{index_status}" |
|
|
|
|
|
|
|
|
def sync_kb_to_chat(kb_name): |
|
|
return gr.update(value=kb_name) |
|
|
|
|
|
|
|
|
def sync_chat_to_kb(kb_name): |
|
|
return gr.update(value=kb_name), update_kb_files_list(kb_name) |
|
|
|
|
|
|
|
|
def process_upload_to_kb(files, kb_name): |
|
|
if not kb_name: |
|
|
return "错误:未选择知识库" |
|
|
|
|
|
result = batch_upload_to_kb(files, kb_name) |
|
|
|
|
|
files_list = update_kb_files_list(kb_name) |
|
|
return result, files_list |
|
|
|
|
|
|
|
|
def on_kb_change(kb_name): |
|
|
if not kb_name: |
|
|
return "未选择知识库", "选择知识库查看文件..." |
|
|
|
|
|
kb_dir = os.path.join(KB_BASE_DIR, kb_name) |
|
|
has_index = os.path.exists(os.path.join(kb_dir, "semantic_chunk.index")) |
|
|
status = f"已选择知识库: {kb_name}" + (" (已建立索引)" if has_index else " (未建立索引)") |
|
|
|
|
|
|
|
|
files_list = update_kb_files_list(kb_name) |
|
|
|
|
|
return status, files_list |
|
|
|
|
|
|
|
|
create_kb_btn.click( |
|
|
fn=create_kb_and_refresh, |
|
|
inputs=[new_kb_name], |
|
|
outputs=[kb_status, kb_dropdown, kb_dropdown_chat] |
|
|
).then( |
|
|
fn=lambda: "", |
|
|
inputs=[], |
|
|
outputs=[new_kb_name] |
|
|
) |
|
|
|
|
|
|
|
|
refresh_kb_btn.click( |
|
|
fn=refresh_kb_list, |
|
|
inputs=[], |
|
|
outputs=[kb_dropdown, kb_dropdown_chat] |
|
|
) |
|
|
|
|
|
|
|
|
delete_kb_btn.click( |
|
|
fn=delete_kb_and_refresh, |
|
|
inputs=[kb_dropdown], |
|
|
outputs=[kb_status, kb_dropdown, kb_dropdown_chat] |
|
|
).then( |
|
|
fn=update_kb_files_list, |
|
|
inputs=[kb_dropdown], |
|
|
outputs=[kb_files_list] |
|
|
) |
|
|
|
|
|
|
|
|
kb_dropdown.change( |
|
|
fn=on_kb_change, |
|
|
inputs=[kb_dropdown], |
|
|
outputs=[kb_status, kb_files_list] |
|
|
).then( |
|
|
fn=sync_kb_to_chat, |
|
|
inputs=[kb_dropdown], |
|
|
outputs=[kb_dropdown_chat] |
|
|
) |
|
|
|
|
|
|
|
|
kb_dropdown_chat.change( |
|
|
fn=sync_chat_to_kb, |
|
|
inputs=[kb_dropdown_chat], |
|
|
outputs=[kb_dropdown, kb_files_list] |
|
|
) |
|
|
|
|
|
|
|
|
file_upload.upload( |
|
|
fn=process_upload_to_kb, |
|
|
inputs=[file_upload, kb_dropdown], |
|
|
outputs=[upload_status, kb_files_list] |
|
|
) |
|
|
|
|
|
|
|
|
clear_btn.click( |
|
|
fn=lambda: "", |
|
|
inputs=[], |
|
|
outputs=[question_input] |
|
|
) |
|
|
|
|
|
|
|
|
def clear_history(): |
|
|
return [], [] |
|
|
|
|
|
clear_history_btn.click( |
|
|
fn=clear_history, |
|
|
inputs=[], |
|
|
outputs=[chatbot, chat_history_state] |
|
|
) |
|
|
|
|
|
|
|
|
def update_status(is_processing=True, is_error=False): |
|
|
if is_processing: |
|
|
return '<div class="status-box status-processing">正在处理您的问题...</div>' |
|
|
elif is_error: |
|
|
return '<div class="status-box status-error">处理过程中出现错误</div>' |
|
|
else: |
|
|
return '<div class="status-box status-success">回答已生成完毕</div>' |
|
|
|
|
|
|
|
|
def process_and_update_chat(question, kb_name, use_search, use_table_format, multi_hop, chat_history): |
|
|
if not question.strip(): |
|
|
return chat_history, update_status(False, True), "等待提交问题..." |
|
|
|
|
|
try: |
|
|
|
|
|
chat_history.append([question, "正在思考..."]) |
|
|
yield chat_history, update_status(True), f"开始处理您的问题,使用知识库: {kb_name}..." |
|
|
|
|
|
|
|
|
last_search_display = "" |
|
|
last_answer = "" |
|
|
|
|
|
|
|
|
for search_display, answer in process_question_with_reasoning(question, kb_name, use_search, use_table_format, multi_hop, chat_history[:-1]): |
|
|
|
|
|
last_search_display = search_display |
|
|
last_answer = answer |
|
|
|
|
|
|
|
|
if chat_history: |
|
|
chat_history[-1][1] = answer |
|
|
yield chat_history, update_status(True), search_display |
|
|
|
|
|
|
|
|
yield chat_history, update_status(False), last_search_display |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
error_msg = f"处理问题时出错: {str(e)}" |
|
|
if chat_history: |
|
|
chat_history[-1][1] = error_msg |
|
|
yield chat_history, update_status(False, True), f"### 错误\n{error_msg}" |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=process_and_update_chat, |
|
|
inputs=[question_input, kb_dropdown_chat, web_search_toggle, table_format_toggle, multi_hop_toggle, chat_history_state], |
|
|
outputs=[chatbot, status_box, search_results_output], |
|
|
queue=True |
|
|
).then( |
|
|
fn=lambda: "", |
|
|
inputs=[], |
|
|
outputs=[question_input] |
|
|
).then( |
|
|
fn=lambda h: h, |
|
|
inputs=[chatbot], |
|
|
outputs=[chat_history_state] |
|
|
) |
|
|
|
|
|
|
|
|
question_input.submit( |
|
|
fn=process_and_update_chat, |
|
|
inputs=[question_input, kb_dropdown_chat, web_search_toggle, table_format_toggle, multi_hop_toggle, chat_history_state], |
|
|
outputs=[chatbot, status_box, search_results_output], |
|
|
queue=True |
|
|
).then( |
|
|
fn=lambda: "", |
|
|
inputs=[], |
|
|
outputs=[question_input] |
|
|
).then( |
|
|
fn=lambda h: h, |
|
|
inputs=[chatbot], |
|
|
outputs=[chat_history_state] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |