William941008's picture
Update main.py
2f4dcae verified
import os
import uuid
import shutil
import asyncio
import threading
from datetime import datetime, timedelta
from functools import partial
from pathlib import Path
from typing import List, Dict, Any
import cv2
import numpy as np
import torch
import uvicorn
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from MyModel import PollutionDifferenceModel
from my_Segmenter import Segmenter
# 新增:SSE进度条依赖
from fastapi.responses import StreamingResponse
import json
from collections import defaultdict
# 全局进度存储(线程安全)
batch_progress = defaultdict(dict) # key: request_id, value: {total, current, results, failed}
progress_lock = threading.Lock()
# =========================
# 基础目录初始化
# =========================
BASE_DIR = Path(".")
STATIC_DIR = BASE_DIR / "static"
STATIC_RESULTS_DIR = STATIC_DIR / "results"
RUNS_DIR = BASE_DIR / "runs"
AIR_STATION_DIR = BASE_DIR / "AirStationImage"
FRONTEND_DIR = BASE_DIR / "frontend"
for d in [STATIC_DIR, STATIC_RESULTS_DIR, RUNS_DIR, AIR_STATION_DIR]:
d.mkdir(parents=True, exist_ok=True)
# =========================
# 初始化分割模型
# =========================
segmenter = Segmenter(dataset="cityscapes", task="semantic", device="cpu")
# =========================
# FastAPI 初始化
# =========================
app = FastAPI(title="香港空气污染预测")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
# 静态目录挂载
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
app.mount("/runs", StaticFiles(directory=str(RUNS_DIR)), name="runs")
app.mount("/AirStationImage", StaticFiles(directory=str(AIR_STATION_DIR)), name="AirStationImage")
# =========================
# 启动时清理旧的运行目录
# =========================
# 在 startup_event 中添加
@app.on_event("startup")
async def startup_event():
# 清理旧文件
await cleanup_old_runs()
# 预加载所有模型 ✅
try:
for pollutant in MODEL_PATHS.keys():
load_pollution_model(pollutant)
print("✅ 所有污染预测模型预加载完成")
except Exception as e:
print(f"⚠️ 模型预加载失败: {e}")
# =========================
# 首页
# =========================
@app.get("/")
async def read_index():
index_path = FRONTEND_DIR / "index.html"
if not index_path.exists():
raise HTTPException(status_code=500, detail="frontend/index.html not found")
return FileResponse(str(index_path))
# =========================
# 模型路径
# =========================
MODEL_PATHS = {
"CO": BASE_DIR / "models" / "best_CO_model_multiscale20251110.pth",
"NO2": BASE_DIR / "models" / "best_NO2_model_multiscale20251110.pth",
"PM25": BASE_DIR / "models" / "best_PM25_model_multiscale20251110.pth",
"PM10": BASE_DIR / "models" / "best_PM10_model_multiscale20251110.pth",
"O3": BASE_DIR / "models" / "best_O3_model_multiscale20251110.pth",
}
loaded_models: Dict[str, PollutionDifferenceModel] = {}
model_lock = threading.Lock()
# =========================
# 文件上传限制
# =========================
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
ALLOWED_CONTENT_TYPES = {"image/jpeg", "image/png", "image/webp"}
# =========================
# 污染物合理范围校验
# =========================
POLLUTANT_RANGES = {
"CO": (0, 50),
"NO2": (0, 500),
"PM25": (0, 999),
"PM10": (0, 999),
"O3": (0, 500),
}
# =========================
# 工具函数
# =========================
def create_request_dirs() -> Dict[str, Any]:
request_id = uuid.uuid4().hex # 统一使用 str,不再包装成 Path
base_dir = RUNS_DIR / request_id
input_dir = base_dir / "input"
output_dir = base_dir / "output"
summary_dir = base_dir / "summary"
for d in [input_dir, output_dir, summary_dir]:
d.mkdir(parents=True, exist_ok=True)
return {
"request_id": request_id, # str
"base_dir": base_dir,
"input_dir": input_dir,
"output_dir": output_dir,
"summary_dir": summary_dir,
}
# =========================
# 异步批量处理函数(修复 I/O 错误:读本地路径)
# =========================
async def batch_predict_task(
request_id: str,
pollutant: str,
ref_data: float,
ref_tensor: torch.Tensor,
model: PollutionDifferenceModel,
query_file_paths: list, # 改为接收路径列表
batch_input_dir: Path
):
results = []
failed = []
total = len(query_file_paths)
# 初始化进度
with progress_lock:
batch_progress[request_id] = {
"total": total,
"current": 0,
"results": [],
"failed": [],
"status": "processing"
}
with torch.no_grad():
for idx, file_info in enumerate(query_file_paths):
safe_name = file_info["name"]
query_path = Path(file_info["path"])
try:
# 🔥 直接读本地已保存的文件,不会有 I/O 错误
query_np = read_rgb_image(query_path)
query_tensor = preprocess_image(query_np)
# 模型推理
out = model(ref_tensor, query_tensor)
model_out = float(out.item())
final_pred = ref_data + model_out
results.append({
"filename": safe_name,
"status": "ok",
"pred_value": round(final_pred, 4),
"model_out": round(model_out, 4),
})
except Exception as e:
error_msg = f"文件:{safe_name},错误:{str(e)}"
print(f"【批量预测失败】{error_msg}")
failed.append({
"filename": safe_name,
"status": "error",
"message": str(e)
})
# 更新进度
with progress_lock:
batch_progress[request_id]["current"] = idx + 1
batch_progress[request_id]["results"] = results
batch_progress[request_id]["failed"] = failed
await asyncio.sleep(0.05)
# 标记完成
with progress_lock:
batch_progress[request_id]["status"] = "completed"
print(f"【批量任务完成】{request_id} | 成功:{len(results)} 张,失败:{len(failed)} 张")
async def save_upload_file(upload_file: UploadFile, save_path: Path) -> None:
"""保存上传文件,同时校验大小与类型。"""
content = await upload_file.read()
if not content:
raise HTTPException(status_code=400, detail=f"上传文件为空: {upload_file.filename}")
if len(content) > MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail=f"文件过大(最大 10MB): {upload_file.filename}")
if upload_file.content_type not in ALLOWED_CONTENT_TYPES:
raise HTTPException(
status_code=415,
detail=f"不支持的文件类型 '{upload_file.content_type}',仅支持 JPEG / PNG / WebP"
)
save_path.write_bytes(content)
def preprocess_image(img_np: np.ndarray) -> torch.Tensor:
img = cv2.resize(img_np, (256, 256))
img = img.astype(np.float32) / 255.0
img = img.transpose(2, 0, 1)
return torch.from_numpy(img).unsqueeze(0)
def read_rgb_image(path: Path) -> np.ndarray:
img = cv2.imread(str(path))
if img is None:
raise HTTPException(status_code=400, detail=f"无法读取图像: {path.name}")
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 替换原 load_pollution_model 函数
def load_pollution_model(pollutant: str) -> PollutionDifferenceModel:
"""线程安全的模型加载(双重检查锁定)+ 非阻塞优化"""
if pollutant not in MODEL_PATHS:
raise HTTPException(status_code=400, detail=f"不支持的污染物类型: {pollutant}")
if pollutant in loaded_models:
return loaded_models[pollutant]
with model_lock:
if pollutant not in loaded_models:
model_path = MODEL_PATHS[pollutant]
if not model_path.exists():
raise HTTPException(status_code=500, detail=f"模型文件不存在: {model_path}")
# ✅ 关键修复:weights_only=True 防止安全问题+加速加载
checkpoint = torch.load(
str(model_path),
map_location="cpu",
weights_only=True
)
model = PollutionDifferenceModel(num_classes=19, pollution_dims=1)
# ✅ 兼容模型加载
if isinstance(checkpoint, dict) and "model" in checkpoint:
model.load_state_dict(checkpoint["model"])
else:
model.load_state_dict(checkpoint)
model.eval()
# ✅ 优化推理速度:启用推理模式
torch.set_grad_enabled(False)
loaded_models[pollutant] = model
return loaded_models[pollutant]
async def run_segmentation_async(input_dir: Path, output_dir: Path, summary_dir: Path) -> None:
"""在线程池中异步执行语义分割,避免阻塞事件循环。"""
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
partial(
segmenter.segment,
dir_input=str(input_dir),
dir_image_output=str(output_dir),
dir_summary_output=str(summary_dir)
)
)
def find_segmented_img(output_dir: Path, base_name: str) -> Path | None:
"""确定性地查找分割结果图像(排序后取第一个)。"""
candidates = sorted([
f for f in output_dir.iterdir()
if base_name in f.name and "colored_segmented" in f.name
])
return candidates[0] if candidates else None
def find_blend_img(output_dir: Path, base_name: str) -> Path | None:
"""确定性地查找融合结果图像(排序后取第一个)。"""
candidates = sorted([
f for f in output_dir.iterdir()
if base_name in f.name and "blend" in f.name
])
return candidates[0] if candidates else None
def copy_segmentation_outputs(output_dir: Path, request_id: str) -> Dict[str, str]:
ref_seg_path = find_segmented_img(output_dir, "ref")
query_seg_path = find_segmented_img(output_dir, "query")
ref_blend_path = find_blend_img(output_dir, "ref")
query_blend_path = find_blend_img(output_dir, "query")
if not ref_seg_path or not query_seg_path:
raise HTTPException(status_code=500, detail="找不到分割结果图像")
target_ref = STATIC_RESULTS_DIR / f"{request_id}_ref_seg.png"
target_query = STATIC_RESULTS_DIR / f"{request_id}_query_seg.png"
target_ref_blend = STATIC_RESULTS_DIR / f"{request_id}_ref_blend.png"
target_query_blend = STATIC_RESULTS_DIR / f"{request_id}_query_blend.png"
shutil.copy(ref_seg_path, target_ref)
shutil.copy(query_seg_path, target_query)
if ref_blend_path and ref_blend_path.exists():
shutil.copy(ref_blend_path, target_ref_blend)
if query_blend_path and query_blend_path.exists():
shutil.copy(query_blend_path, target_query_blend)
return {
"ref_seg": f"/static/results/{request_id}_ref_seg.png",
"query_seg": f"/static/results/{request_id}_query_seg.png",
"ref_blend": f"/static/results/{request_id}_ref_blend.png" if target_ref_blend.exists() else "",
"query_blend": f"/static/results/{request_id}_query_blend.png" if target_query_blend.exists() else "",
}
def infer_difference(
model: PollutionDifferenceModel,
ref_tensor: torch.Tensor,
query_tensor: torch.Tensor
) -> float:
with torch.no_grad():
out = model(ref_tensor, query_tensor)
return float(out.item())
def validate_ref_data(pollutant: str, ref_data: float) -> None:
"""服务端校验参考值合理范围。"""
if pollutant not in POLLUTANT_RANGES:
raise HTTPException(status_code=400, detail=f"不支持的污染物: {pollutant}")
lo, hi = POLLUTANT_RANGES[pollutant]
if ref_data < lo:
raise HTTPException(status_code=422, detail=f"{pollutant} 参考值不能为负数")
if ref_data > hi:
raise HTTPException(
status_code=422,
detail=f"{pollutant} 参考值 {ref_data} 超出合理范围(最大 {hi})"
)
async def cleanup_old_runs(max_age_hours: int = 24) -> None:
"""清理超过指定小时数的旧运行目录,释放磁盘空间。"""
cutoff = datetime.now() - timedelta(hours=max_age_hours)
if not RUNS_DIR.exists():
return
for run_dir in RUNS_DIR.iterdir():
if run_dir.is_dir():
try:
mtime = datetime.fromtimestamp(run_dir.stat().st_mtime)
if mtime < cutoff:
shutil.rmtree(run_dir, ignore_errors=True)
except Exception:
pass # 跳过无法访问的目录
# =========================
# 单图预测
# =========================
@app.post("/predict")
async def predict(
pollutant: str = Form(...),
ref_data: float = Form(...),
ref_img: UploadFile = File(...),
query_img: UploadFile = File(...)
):
try:
# 服务端输入校验
validate_ref_data(pollutant, ref_data)
paths = create_request_dirs()
request_id = paths["request_id"] # 现在是纯 str
input_dir = paths["input_dir"]
output_dir = paths["output_dir"]
summary_dir = paths["summary_dir"]
ref_path = input_dir / "ref.jpg"
query_path = input_dir / "query.jpg"
await save_upload_file(ref_img, ref_path)
await save_upload_file(query_img, query_path)
# 异步语义分割(不阻塞事件循环)
await run_segmentation_async(input_dir, output_dir, summary_dir)
# 复制结果图到 static/results
seg_urls = copy_segmentation_outputs(output_dir, request_id)
# 读取图像并推理
ref_tensor = preprocess_image(read_rgb_image(ref_path))
query_tensor = preprocess_image(read_rgb_image(query_path))
model = load_pollution_model(pollutant)
model_out = infer_difference(model, ref_tensor, query_tensor)
final_pred = ref_data - model_out
ratio_json_path = summary_dir / "pixel_ratios.json"
if not ratio_json_path.exists():
raise HTTPException(status_code=500, detail="分割后未找到 pixel_ratios.json")
return {
"status": "ok",
"request_id": request_id,
"pollutant": pollutant,
"ref_data": ref_data,
"model_out": round(model_out, 4),
"pred_value": round(final_pred, 4),
"ref_seg": seg_urls["ref_seg"],
"query_seg": seg_urls["query_seg"],
"ref_blend": seg_urls["ref_blend"],
"query_blend": seg_urls["query_blend"],
"ratio_json": f"/runs/{request_id}/summary/pixel_ratios.json"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"预测失败: {str(e)}")
# =========================
# 批量预测(修复 I/O 错误:先保存所有文件)
# =========================
@app.post("/batch-predict")
async def batch_predict(
pollutant: str = Form(...),
ref_data: float = Form(...),
ref_img: UploadFile = File(...),
query_files: List[UploadFile] = File(...)
):
try:
validate_ref_data(pollutant, ref_data)
if not query_files:
raise HTTPException(status_code=400, detail="未上传任何查询图像")
paths = create_request_dirs()
request_id = paths["request_id"]
batch_input_dir = paths["input_dir"]
# 1. 先保存参考图
ref_path = batch_input_dir / "ref.jpg"
await save_upload_file(ref_img, ref_path)
ref_tensor = preprocess_image(read_rgb_image(ref_path))
model = load_pollution_model(pollutant)
# 🔥 核心修复:接口返回前,先把所有查询图片保存到本地
query_file_paths = []
for file in query_files:
safe_name = os.path.basename(file.filename) if file.filename else f"{uuid.uuid4().hex}.jpg"
query_path = batch_input_dir / safe_name
# 提前保存文件
await save_upload_file(file, query_path)
query_file_paths.append({
"path": str(query_path), # 只传路径,不传文件对象
"name": safe_name
})
# 2. 启动后台任务(只传路径,不传 UploadFile)
asyncio.create_task(
batch_predict_task(
request_id=request_id,
pollutant=pollutant,
ref_data=ref_data,
ref_tensor=ref_tensor,
model=model,
query_file_paths=query_file_paths, # 传路径列表
batch_input_dir=batch_input_dir
)
)
return {
"status": "processing",
"request_id": request_id,
"total_files": len(query_files)
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"批量预测启动失败: {str(e)}")
# =========================
# 批量预测进度推送(SSE)【修复语法错误版】
# =========================
@app.get("/progress/{request_id}")
async def get_batch_progress(request_id: str):
"""SSE接口:前端监听此接口获取实时进度"""
async def event_generator():
while True:
# 获取进度
progress = batch_progress.get(request_id, {})
if not progress:
yield 'data: {"error": "任务不存在"}\n\n'
break
# 【修复】把json提出来,避免f-string换行语法错误
progress_data = {
'total': progress.get('total', 0),
'current': progress.get('current', 0),
'status': progress.get('status', 'processing'),
'results': progress.get('results', []),
'failed': progress.get('failed', [])
}
# 【关键修复】一行写完,不换行!
yield f"data: {json.dumps(progress_data)}\n\n"
# 任务完成/失败,停止推送
if progress.get("status") in ["completed", "failed"]:
break
# 每100ms推送1次
await asyncio.sleep(0.1)
return StreamingResponse(event_generator(), media_type="text/event-stream")
# =========================
# 健康检查
# =========================
@app.get("/health")
async def health_check():
return {"status": "ok"}
# =========================
# 启动
# =========================
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)