import os, json, gc, datetime from pathlib import Path from typing import List, Tuple # 性能与日志 os.environ["OMP_NUM_THREADS"] = "2" os.environ["ORT_LOG_SEVERITY_LEVEL"] = "3" import gradio as gr import onnxruntime_genai as og from huggingface_hub import snapshot_download # ================= 基础配置(中文) ================= 模型库 = "microsoft/Phi-4-mini-instruct-onnx" 子目录 = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4" 缓存目录 = "./model_cache" 模型本地目录 = "./phi4_model" 历史目录 = "./chat_history" 系统提示默认 = "你是一个友好的中文 AI 助手,请清晰、简洁地回答。" 结束标记 = "<|end|>" 上下文窗口 = 4096 # 估计值 默认回复长度 = 300 # ================ 工具函数 ================ def 确保目录(): os.makedirs(历史目录, exist_ok=True) def 下载模型() -> str: print("🔄 正在下载/准备模型...") snapshot_download( repo_id=模型库, allow_patterns=f"{子目录}/*", cache_dir=缓存目录, local_dir=模型本地目录, ) 模型路径 = os.path.join(模型本地目录, 子目录) if not os.path.exists(模型路径): raise RuntimeError(f"❌ 模型路径不存在: {模型路径}") print(f"✅ 模型就绪: {模型路径}") return 模型路径 class 历史管理: @staticmethod def 路径(会话="默认会话"): return os.path.join(历史目录, f"{会话}.json") @staticmethod def 保存(历史: List[List[str]], 会话="默认会话", 元数据: dict = None): data = {"history": 历史, "meta": 元数据 or {}, "time": datetime.datetime.now().isoformat()} with open(历史管理.路径(会话), "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False) @staticmethod def 加载(会话="默认会话") -> Tuple[List[List[str]], dict]: p = 历史管理.路径(会话) if os.path.exists(p): with open(p, "r", encoding="utf-8") as f: js = json.load(f) return js.get("history", []), js.get("meta", {}) return [], {} @staticmethod def 列表() -> List[str]: return sorted([p.stem for p in Path(历史目录).glob("*.json")], reverse=True) # ORT GenAI 兼容 def 设定长度参数(params: "og.GeneratorParams", 最大长度: int) -> bool: if hasattr(params, "set_length_options"): params.set_length_options(max_length=int(最大长度)) return True return False def 设定搜索参数(params: "og.GeneratorParams", 采样: bool, 温度: float, top_p: float, top_k: int, 重复惩罚: float, 最大长度_if_needed: int | None): kwargs = dict( do_sample=bool(采样), temperature=float(温度), top_p=float(top_p), top_k=int(top_k), repetition_penalty=float(重复惩罚), ) if 最大长度_if_needed is not None: kwargs["max_length"] = int(最大长度_if_needed) try: params.set_search_options(**kwargs) except TypeError: kwargs.pop("top_k", None) kwargs.pop("repetition_penalty", None) params.set_search_options(**kwargs) def 构建模板(系统: str, 历史: List[List[str]], 用户消息: str) -> str: parts = [f"<|system|>\n{系统}<|end|>\n"] for u, a in 历史: if u: parts.append(f"<|user|>\n{u}<|end|>\n") if a: parts.append(f"<|assistant|>\n{a}<|end|>\n") parts.append(f"<|user|>\n{用户消息}<|end|>\n<|assistant|>\n") return "".join(parts) def 按窗口裁剪(input_ids: list, 新token上限: int, 上下文上限: int) -> list: 允许输入 = max(256, 上下文上限 - 新token上限) if len(input_ids) > 允许输入: input_ids = input_ids[-允许输入:] return input_ids # ================ 模型初始化 ================ 确保目录() print("🚀 初始化中...") try: 模型路径 = 下载模型() try: import onnxruntime as ort ort.set_default_logger_severity(3) except Exception: pass 模型 = og.Model(模型路径) 分词器 = og.Tokenizer(模型) # 预热(减少首轮延迟) 预热ID = 分词器.encode("你好") 预热参数 = og.GeneratorParams(模型) 预热参数.input_ids = 预热ID 设定搜索参数(预热参数, False, 0.0, 1.0, 1, 1.0, len(预热ID) + 8) 预热生成器 = og.Generator(模型, 预热参数) 预热生成器.compute_logits() 预热生成器.generate_next_token() del 预热生成器, 预热参数 gc.collect() print("✅ 模型加载成功!") except Exception as e: print(f"❌ 模型加载失败: {e}") 模型 = None 分词器 = None # ================ 生成(流式) ================ def 流式回复(用户消息: str, 历史: List[List[str]], 回复长度: int, 温度: float, 记忆轮数: int, 上下文tokens: int): if not 模型 or not 分词器: yield "❌ 模型未加载,请稍后重试", 0 return try: if 记忆轮数 > 0 and len(历史) > 记忆轮数: 历史 = 历史[-记忆轮数:] 提示 = 构建模板(系统提示默认, 历史, 用户消息) 输入ID = 分词器.encode(提示) 输入ID = 按窗口裁剪(输入ID, 回复长度, min(上下文tokens, 上下文窗口)) params = og.GeneratorParams(模型) params.input_ids = 输入ID 总长度 = len(输入ID) + int(回复长度) 已设长度 = 设定长度参数(params, 总长度) 设定搜索参数(params, 采样=(温度 > 0), 温度=温度, top_p=0.9, top_k=40, 重复惩罚=1.05, 最大长度_if_needed=(None if 已设长度 else 总长度)) 生成器 = og.Generator(模型, params) 流 = 分词器.create_stream() 回复 = "" t = 0 while not 生成器.is_done(): 生成器.compute_logits() 生成器.generate_next_token() 新 = 生成器.get_next_tokens()[0] 片段 = 流.decode(新) if not 片段: continue 回复 += 片段 t += 1 if 结束标记 in 回复: 回复 = 回复.split(结束标记)[0].rstrip() yield 回复, t break if t % 6 == 0: yield 回复, t else: yield 回复.strip(), t del 生成器, params gc.collect() except Exception as e: yield f"❌ 生成错误: {str(e)}", 0 # ================== UI(全功能 + 自适应) ================== css = """ /* 页面:顶部标题 + 选项卡 +(聊天页:聊天+输入区网格) */ html, body { height: 100%; } .gradio-container { max-width: 1100px !important; margin: 0 auto; } /* 顶部标题条 */ #app_hdr { background: linear-gradient(135deg, #6d28d9 0%, #ec4899 100%); color: #fff; padding: 14px 16px; border-radius: 12px; box-shadow: 0 4px 20px rgba(109,40,217,.25); } #app_hdr h1 { margin: 0; font-size: 20px; } #app_hdr p { margin: 4px 0 0; font-size: 12px; opacity: .95; } /* 聊天页布局:自适应高度,输入区固定在底部,不会被遮挡 */ #chat_layout { height: calc(100dvh - 160px); /* 留出标题和tabs空间 */ display: grid; grid-template-rows: 1fr auto; gap: 8px; } #chat_scroll { min-height: 0; overflow: auto; background: #fff; border-radius: 10px; box-shadow: 0 2px 10px rgba(0,0,0,.06); padding: 6px; } #input_bar { background: #fff; border-radius: 10px; box-shadow: 0 2px 10px rgba(0,0,0,.06); padding: 8px; } /* 小屏优化 */ @media (max-width: 640px) { #app_hdr h1 { font-size: 18px; } #app_hdr p { font-size: 11px; } #chat_layout { height: calc(100dvh - 150px); } } """ # 初始历史 初始历史, 初始元数据 = 历史管理.加载("默认会话") with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: 会话ID = gr.State("默认会话") 系统提示状态 = gr.State(系统提示默认) # 顶部 with gr.Column(elem_id="app_hdr"): gr.Markdown("### 💬 Phi-4 中文助手") gr.Markdown("流式生成 · 自动保存 · 多会话 · 上下文可调 · 继续/停止 · 导入导出") with gr.Tabs(): # ========= 聊天 Tab ========= with gr.Tab("💬 聊天"): with gr.Column(elem_id="chat_layout"): with gr.Column(elem_id="chat_scroll"): 聊天框 = gr.Chatbot( value=初始历史, type="tuples", # 与 [[user, assistant], ...] 兼容 show_copy_button=True, height="100%" # 由外层容器控制实际高度 ) with gr.Column(elem_id="input_bar"): with gr.Row(): 消息 = gr.Textbox( placeholder="输入你的消息…(Enter 发送,Shift+Enter 换行)", scale=8, lines=1, max_lines=4, container=False ) 发送 = gr.Button("发送", variant="primary", scale=1) with gr.Row(): 清空 = gr.Button("🗑️ 清空", size="sm") 撤销 = gr.Button("↩️ 撤销", size="sm") 重试 = gr.Button("🔄 重试", size="sm") 继续 = gr.Button("⏭️ 继续", size="sm") 停止 = gr.Button("⏹️ 停止", size="sm") token计数 = gr.Markdown("Tokens: 0") # ========= 会话与设置 Tab ========= with gr.Tab("⚙️ 会话与设置"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### 🎯 预设") 预设精准 = gr.Button("📏 精准", size="sm") 预设平衡 = gr.Button("📘 平衡", size="sm") 预设创意 = gr.Button("🎨 创意", size="sm") gr.Markdown("#### 💬 系统提示词") 系统提示框 = gr.Textbox( label="系统提示词(影响风格)", value=系统提示默认, lines=3 ) with gr.Column(scale=2): gr.Markdown("#### 🔧 生成参数") with gr.Row(): 最大生成长度 = gr.Slider(50, 1024, value=默认回复长度, step=10, label="📝 最大生成长度 (tokens)") 温度 = gr.Slider(0.0, 1.2, value=0.7, step=0.1, label="🌡️ 温度") with gr.Row(): top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="🎲 Top-p") top_k = gr.Slider(1, 100, value=40, step=1, label="🔝 Top-k") with gr.Row(): 记忆轮数 = gr.Slider(1, 12, value=6, step=1, label="🧠 记忆轮数(保留最近N轮)") 上下文限制 = gr.Slider(512, 上下文窗口, value=上下文窗口, step=64, label=f"📚 上下文上限 (≤{上下文窗口})") gr.Markdown("#### 💾 会话管理") with gr.Row(): 会话列表 = gr.Dropdown( label="会话列表", choices=(历史管理.列表() or ["默认会话"]), value="默认会话", interactive=True ) 加载 = gr.Button("📂 加载", size="sm") 保存 = gr.Button("💾 保存", size="sm") 新建 = gr.Button("➕ 新建", size="sm") with gr.Row(): 导入文件 = gr.File(label="导入JSON", file_types=[".json"]) 导入按钮 = gr.Button("⬆️ 导入", size="sm") 导出按钮 = gr.Button("⬇️ 导出当前会话", size="sm") 导出文件 = gr.File(label="导出文件", interactive=False) # ========= 逻辑 ========= # 预设 def 用预设(模式): if 模式 == "精准": return 200, 0.2, 0.85, 20 if 模式 == "创意": return 500, 0.9, 0.95, 60 return 300, 0.7, 0.9, 40 # 平衡 预设精准.click(lambda: 用预设("精准"), outputs=[最大生成长度, 温度, top_p, top_k]) 预设平衡.click(lambda: 用预设("平衡"), outputs=[最大生成长度, 温度, top_p, top_k]) 预设创意.click(lambda: 用预设("创意"), outputs=[最大生成长度, 温度, top_p, top_k]) # 系统提示词 def 更新系统提示(s): return s.strip() if s.strip() else 系统提示默认 系统提示框.change(更新系统提示, 系统提示框, 系统提示状态) # 基础交互 def 用户提交(msg, hist): msg = (msg or "").strip() if not msg: return "", hist return "", hist + [[msg, None]] def 机器人应答(hist, sys_prompt_state, max_len, temp, tp, tk, keep_rounds, ctx_limit, sid): # 将系统提示更新为当前设置 global 系统提示默认 系统提示默认 = sys_prompt_state if not hist or hist[-1][1] is not None: return hist, gr.update(value="Tokens: 0") 用户消息 = hist[-1][0] hist[-1][1] = "" latest = "" for latest, t in 流式回复( 用户消息=用户消息, 历史=hist[:-1], 回复长度=int(max_len), 温度=float(temp), 记忆轮数=int(keep_rounds), 上下文tokens=int(ctx_limit), ): hist[-1][1] = latest yield hist, gr.update(value=f"Tokens: {t}") 历史管理.保存(hist, sid, {"system_prompt": sys_prompt_state}) # 发送 提交_evt = 消息.submit( 用户提交, [消息, 聊天框], [消息, 聊天框], queue=False ).then( 机器人应答, [聊天框, 系统提示状态, 最大生成长度, 温度, top_p, top_k, 记忆轮数, 上下文限制, 会话ID], [聊天框, token计数] ) 点击_evt = 发送.click( 用户提交, [消息, 聊天框], [消息, 聊天框], queue=False ).then( 机器人应答, [聊天框, 系统提示状态, 最大生成长度, 温度, top_p, top_k, 记忆轮数, 上下文限制, 会话ID], [聊天框, token计数] ) # 继续 def 继续输出(hist): if not hist: return hist return hist + [["请从上句继续输出。", None]] 继续.click(继续输出, 聊天框, 聊天框).then( 机器人应答, [聊天框, 系统提示状态, 最大生成长度, 温度, top_p, top_k, 记忆轮数, 上下文限制, 会话ID], [聊天框, token计数] ) # 停止(取消队列中事件) 停止.click(fn=None, inputs=None, outputs=None, cancels=[提交_evt, 点击_evt]) # 清空/撤销/重试 清空.click(lambda: [], None, 聊天框) 撤销.click(lambda h: h[:-1] if h else h, 聊天框, 聊天框) def 重试一轮(h): if not h: return h return h[:-1] + [[h[-1][0], None]] 重试.click(重试一轮, 聊天框, 聊天框).then( 机器人应答, [聊天框, 系统提示状态, 最大生成长度, 温度, top_p, top_k, 记忆轮数, 上下文限制, 会话ID], [聊天框, token计数] ) # 会话管理 def 保存当前(hist, sid, sys_prompt): 历史管理.保存(hist, sid, {"system_prompt": sys_prompt}) return gr.update(choices=历史管理.列表()) 保存.click(保存当前, [聊天框, 会话ID, 系统提示状态], 会话列表) def 加载会话(sid): h, meta = 历史管理.加载(sid) sp = meta.get("system_prompt", 系统提示默认) return h, sid, sp 加载.click(加载会话, 会话列表, [聊天框, 会话ID, 系统提示框]) def 新建会话(): sid = f"会话_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" 历史管理.保存([], sid, {"system_prompt": 系统提示默认}) return [], sid, gr.update(choices=历史管理.列表(), value=sid), 系统提示默认, gr.update(value=系统提示默认) 新建.click(新建会话, outputs=[聊天框, 会话ID, 会话列表, 系统提示状态, 系统提示框]) # 导入 / 导出(最小修复:按钮与函数避免同名) def 导出处理(hist, sid): 历史管理.保存(hist, sid, {"system_prompt": 系统提示状态.value}) return 历史管理.路径(sid) def 导入处理(file, sid): if file is None: return gr.update(), gr.update() try: with open(file.name, "r", encoding="utf-8") as f: js = json.load(f) h = js.get("history", []) meta = js.get("meta", {}) sp = meta.get("system_prompt", 系统提示默认) 历史管理.保存(h, sid, {"system_prompt": sp}) return h, sp except Exception as e: return gr.update(), gr.update(value=f"导入失败: {e}") 导出按钮.click(导出处理, [聊天框, 会话ID], 导出文件) 导入按钮.click(导入处理, [导入文件, 会话ID], [聊天框, 系统提示框]) # 自动保存 聊天框.change(lambda h, sid, sp: 历史管理.保存(h, sid, {"system_prompt": sp}) if h else None, [聊天框, 会话ID, 系统提示状态], None) # 首次加载:保证默认会话存在 + 列表更新 def 初始化(): if "默认会话" not in 历史管理.列表(): 历史管理.保存(初始历史 or [], "默认会话", {"system_prompt": 系统提示默认}) return gr.update(choices=历史管理.列表(), value="默认会话") demo.load(初始化, outputs=会话列表) if __name__ == "__main__": if 模型: print("🎉 启动服务...") demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=7860, share=False) else: print("❌ 无法启动")