aladdin1995's picture
Update app.py
cfb21b4 verified
raw
history blame
9.55 kB
# app.py
# Gradio UI for PromptEnhancerV2
import os
from threading import Thread
from transformers import TextIteratorStreamer, AutoTokenizer
import time
import logging
import re
import torch
import gradio as gr
import spaces
# 尝试导入 qwen_vl_utils,若失败则提供降级实现(返回空的图像/视频输入)
try:
from qwen_vl_utils import process_vision_info
except Exception:
def process_vision_info(messages):
return None, None
def replace_single_quotes(text):
pattern = r"\B'([^']*)'\B"
replaced_text = re.sub(pattern, r'"\1"', text)
replaced_text = replaced_text.replace("’", "”").replace("‘", "“")
return replaced_text
class PromptEnhancerV2:
@spaces.GPU
def __init__(self, models_root_path, device_map="auto", torch_dtype="bfloat16"):#auto
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
if not logging.getLogger(__name__).handlers:
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
# dtype 兼容处理
if torch_dtype == "bfloat16":
dtype = torch.bfloat16
elif torch_dtype == "float16":
dtype = torch.float16
else:
dtype = torch.float32
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
models_root_path,
torch_dtype=dtype,
device_map=device_map,
)
self.processor = AutoProcessor.from_pretrained(models_root_path)
# @torch.inference_mode()
@spaces.GPU
def predict(
self,
prompt_cot,
sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
temperature=0.1,
top_p=1.0,
max_new_tokens=2048,
device="cuda",
):
org_prompt_cot = prompt_cot
try:
user_prompt_format = sys_prompt + "\n" + org_prompt_cot
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt_format},
],
}
]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
# 注意:原始代码固定 do_sample=False,top_k=5, top_p=0.9,这里保持一致
generated_ids = self.model.generate(
**inputs,
max_new_tokens=2048, # 与原始代码保持一致(未使用 max_new_tokens 参数)
temperature=float(temperature),
do_sample=False,
top_k=5,
top_p=0.9
)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
output_res = output_text[0]
assert output_res.count("think>") == 2
prompt_cot = output_res.split("think>")[-1]
if prompt_cot.startswith("\n"):
prompt_cot = prompt_cot[1:]
prompt_cot = replace_single_quotes(prompt_cot)
except Exception as e:
prompt_cot = org_prompt_cot
print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
return prompt_cot
# -------------------------
# Gradio app helpers
# -------------------------
DEFAULT_MODEL_PATH = os.environ.get("MODEL_OUTPUT_PATH", "PromptEnhancer/PromptEnhancer-32B")
def ensure_enhancer(state, model_path, device_map, torch_dtype):
"""
state: dict or None
Returns: (state_dict)
"""
need_reload = False
if state is None or not isinstance(state, dict):
need_reload = True
else:
prev_path = state.get("model_path")
prev_map = state.get("device_map")
prev_dtype = state.get("torch_dtype")
if prev_path != model_path or prev_map != device_map or prev_dtype != torch_dtype:
need_reload = True
if need_reload:
enhancer = PromptEnhancerV2(model_path, device_map=device_map, torch_dtype=torch_dtype)
return {"enhancer": enhancer, "model_path": model_path, "device_map": device_map, "torch_dtype": torch_dtype}
return state
def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
model_path, device_map, torch_dtype, state):
if not prompt or not str(prompt).strip():
return "", "请先输入提示词。", state
t0 = time.time()
state = ensure_enhancer(state, model_path, device_map, torch_dtype)
enhancer = state["enhancer"]
try:
out = enhancer.predict(
prompt_cot=prompt,
sys_prompt=sys_prompt,
temperature=temperature,
max_new_tokens=max_new_tokens,
device=device
)
dt = time.time() - t0
return out, f"耗时:{dt:.2f}s", state
except Exception as e:
return "", f"推理失败:{e}", state
# 示例数据
test_list_zh = [
"第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。",
"韩系插画风女生头像,粉紫色短发+透明感腮红,侧光渲染。",
"点彩派,盛夏海滨,两位渔夫正在搬运木箱,三艘帆船停在岸边,对角线构图。",
"一幅由梵高绘制的梦境麦田,旋转的蓝色星云与燃烧的向日葵相纠缠。",
]
test_list_en = [
"Create a painting depicting a 30-year-old white female white-collar worker on a business trip by plane.",
"Depicted in the anime style of Studio Ghibli, a girl stands quietly at the deck with a gentle smile.",
"Blue background, a lone girl gazes into the distant sea; her expression is sorrowful.",
"A blend of expressionist and vintage styles, drawing a building with colorful walls.",
"Paint a winter scene with crystalline ice hangings from an Antarctic research station.",
]
with gr.Blocks(title="Prompt Enhancer_V2") as demo:
gr.Markdown("## 提示词重写器")
with gr.Row():
with gr.Column(scale=2):
model_path = gr.Textbox(
label="模型路径(本地或HF地址)",
value=DEFAULT_MODEL_PATH,
placeholder="/apdcephfs_jn3/share_302243908/aladdinwang/model_weight/cot_taurus_v6_50/global_step0",
)
device_map = gr.Dropdown(
choices=["cuda", "cpu"],
value="cuda",
label="device_map(模型加载映射)"
)
torch_dtype = gr.Dropdown(
choices=["bfloat16", "float16", "float32"],
value="bfloat16",
label="torch_dtype"
)
with gr.Column(scale=3):
sys_prompt = gr.Textbox(
label="系统提示词(默认无需修改)",
value="请根据用户的输入,生成思考过程的思维链并改写提示词:",
lines=3
)
with gr.Row():
temperature = gr.Slider(0, 1, value=0.1, step=0.05, label="Temperature")
max_new_tokens = gr.Slider(16, 4096, value=2048, step=16, label="Max New Tokens(原代码未使用该参数)")
device = gr.Dropdown(choices=["cuda", "cpu"], value="cuda", label="推理device")
state = gr.State(value=None)
with gr.Tab("推理"):
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(label="输入提示词", lines=6, placeholder="在此粘贴要改写的提示词...")
run_btn = gr.Button("生成重写", variant="primary")
gr.Examples(
examples=test_list_zh + test_list_en,
inputs=prompt,
label="示例"
)
with gr.Column(scale=3):
out_text = gr.Textbox(label="重写结果", lines=10)
out_info = gr.Markdown("准备就绪。")
# run_btn.click(
# stream_single,
# inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
# model_path, device_map, torch_dtype, state],
# outputs=[out_text, out_info, state]
# )
run_btn.click(
run_single,
inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
model_path, device_map, torch_dtype, state],
outputs=[out_text, out_info, state]
)
gr.Markdown(
"提示:如有任何问题可email联系:[email protected]"
)
# 为避免多并发导致显存爆,限制并发
# demo.queue(concurrency_count=1, max_size=10)
if __name__ == "__main__":
# demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True)
demo.launch(ssr_mode=False, show_error=True, share=True)