| import os |
| import uuid |
| import shutil |
| import asyncio |
| import threading |
| from datetime import datetime, timedelta |
| from functools import partial |
| from pathlib import Path |
| from typing import List, Dict, Any |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import uvicorn |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import FileResponse |
| from fastapi.staticfiles import StaticFiles |
|
|
| from MyModel import PollutionDifferenceModel |
| from my_Segmenter import Segmenter |
|
|
| |
| from fastapi.responses import StreamingResponse |
| import json |
| from collections import defaultdict |
|
|
| |
| batch_progress = defaultdict(dict) |
| progress_lock = threading.Lock() |
|
|
|
|
| |
| |
| |
| BASE_DIR = Path(".") |
| STATIC_DIR = BASE_DIR / "static" |
| STATIC_RESULTS_DIR = STATIC_DIR / "results" |
| RUNS_DIR = BASE_DIR / "runs" |
| AIR_STATION_DIR = BASE_DIR / "AirStationImage" |
| FRONTEND_DIR = BASE_DIR / "frontend" |
|
|
| for d in [STATIC_DIR, STATIC_RESULTS_DIR, RUNS_DIR, AIR_STATION_DIR]: |
| d.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| |
| |
| |
| segmenter = Segmenter(dataset="cityscapes", task="semantic", device="cpu") |
|
|
|
|
| |
| |
| |
| app = FastAPI(title="香港空气污染预测") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=False, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") |
| app.mount("/runs", StaticFiles(directory=str(RUNS_DIR)), name="runs") |
| app.mount("/AirStationImage", StaticFiles(directory=str(AIR_STATION_DIR)), name="AirStationImage") |
|
|
|
|
| |
| |
| |
| |
| @app.on_event("startup") |
| async def startup_event(): |
| |
| await cleanup_old_runs() |
| |
| try: |
| for pollutant in MODEL_PATHS.keys(): |
| load_pollution_model(pollutant) |
| print("✅ 所有污染预测模型预加载完成") |
| except Exception as e: |
| print(f"⚠️ 模型预加载失败: {e}") |
|
|
|
|
| |
| |
| |
| @app.get("/") |
| async def read_index(): |
| index_path = FRONTEND_DIR / "index.html" |
| if not index_path.exists(): |
| raise HTTPException(status_code=500, detail="frontend/index.html not found") |
| return FileResponse(str(index_path)) |
|
|
|
|
| |
| |
| |
| MODEL_PATHS = { |
| "CO": BASE_DIR / "models" / "best_CO_model_multiscale20251110.pth", |
| "NO2": BASE_DIR / "models" / "best_NO2_model_multiscale20251110.pth", |
| "PM25": BASE_DIR / "models" / "best_PM25_model_multiscale20251110.pth", |
| "PM10": BASE_DIR / "models" / "best_PM10_model_multiscale20251110.pth", |
| "O3": BASE_DIR / "models" / "best_O3_model_multiscale20251110.pth", |
| } |
|
|
| loaded_models: Dict[str, PollutionDifferenceModel] = {} |
| model_lock = threading.Lock() |
|
|
| |
| |
| |
| MAX_FILE_SIZE = 10 * 1024 * 1024 |
| ALLOWED_CONTENT_TYPES = {"image/jpeg", "image/png", "image/webp"} |
|
|
| |
| |
| |
| POLLUTANT_RANGES = { |
| "CO": (0, 50), |
| "NO2": (0, 500), |
| "PM25": (0, 999), |
| "PM10": (0, 999), |
| "O3": (0, 500), |
| } |
|
|
|
|
| |
| |
| |
| def create_request_dirs() -> Dict[str, Any]: |
| request_id = uuid.uuid4().hex |
| base_dir = RUNS_DIR / request_id |
| input_dir = base_dir / "input" |
| output_dir = base_dir / "output" |
| summary_dir = base_dir / "summary" |
|
|
| for d in [input_dir, output_dir, summary_dir]: |
| d.mkdir(parents=True, exist_ok=True) |
|
|
| return { |
| "request_id": request_id, |
| "base_dir": base_dir, |
| "input_dir": input_dir, |
| "output_dir": output_dir, |
| "summary_dir": summary_dir, |
| } |
|
|
|
|
| |
| |
| |
| async def batch_predict_task( |
| request_id: str, |
| pollutant: str, |
| ref_data: float, |
| ref_tensor: torch.Tensor, |
| model: PollutionDifferenceModel, |
| query_file_paths: list, |
| batch_input_dir: Path |
| ): |
| results = [] |
| failed = [] |
| total = len(query_file_paths) |
|
|
| |
| with progress_lock: |
| batch_progress[request_id] = { |
| "total": total, |
| "current": 0, |
| "results": [], |
| "failed": [], |
| "status": "processing" |
| } |
|
|
| with torch.no_grad(): |
| for idx, file_info in enumerate(query_file_paths): |
| safe_name = file_info["name"] |
| query_path = Path(file_info["path"]) |
|
|
| try: |
| |
| query_np = read_rgb_image(query_path) |
| query_tensor = preprocess_image(query_np) |
|
|
| |
| out = model(ref_tensor, query_tensor) |
| model_out = float(out.item()) |
| final_pred = ref_data + model_out |
|
|
| results.append({ |
| "filename": safe_name, |
| "status": "ok", |
| "pred_value": round(final_pred, 4), |
| "model_out": round(model_out, 4), |
| }) |
|
|
| except Exception as e: |
| error_msg = f"文件:{safe_name},错误:{str(e)}" |
| print(f"【批量预测失败】{error_msg}") |
|
|
| failed.append({ |
| "filename": safe_name, |
| "status": "error", |
| "message": str(e) |
| }) |
|
|
| |
| with progress_lock: |
| batch_progress[request_id]["current"] = idx + 1 |
| batch_progress[request_id]["results"] = results |
| batch_progress[request_id]["failed"] = failed |
|
|
| await asyncio.sleep(0.05) |
|
|
| |
| with progress_lock: |
| batch_progress[request_id]["status"] = "completed" |
| print(f"【批量任务完成】{request_id} | 成功:{len(results)} 张,失败:{len(failed)} 张") |
|
|
| async def save_upload_file(upload_file: UploadFile, save_path: Path) -> None: |
| """保存上传文件,同时校验大小与类型。""" |
| content = await upload_file.read() |
|
|
| if not content: |
| raise HTTPException(status_code=400, detail=f"上传文件为空: {upload_file.filename}") |
|
|
| if len(content) > MAX_FILE_SIZE: |
| raise HTTPException(status_code=413, detail=f"文件过大(最大 10MB): {upload_file.filename}") |
|
|
| if upload_file.content_type not in ALLOWED_CONTENT_TYPES: |
| raise HTTPException( |
| status_code=415, |
| detail=f"不支持的文件类型 '{upload_file.content_type}',仅支持 JPEG / PNG / WebP" |
| ) |
|
|
| save_path.write_bytes(content) |
|
|
|
|
| def preprocess_image(img_np: np.ndarray) -> torch.Tensor: |
| img = cv2.resize(img_np, (256, 256)) |
| img = img.astype(np.float32) / 255.0 |
| img = img.transpose(2, 0, 1) |
| return torch.from_numpy(img).unsqueeze(0) |
|
|
|
|
| def read_rgb_image(path: Path) -> np.ndarray: |
| img = cv2.imread(str(path)) |
| if img is None: |
| raise HTTPException(status_code=400, detail=f"无法读取图像: {path.name}") |
| return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
| |
| def load_pollution_model(pollutant: str) -> PollutionDifferenceModel: |
| """线程安全的模型加载(双重检查锁定)+ 非阻塞优化""" |
| if pollutant not in MODEL_PATHS: |
| raise HTTPException(status_code=400, detail=f"不支持的污染物类型: {pollutant}") |
|
|
| if pollutant in loaded_models: |
| return loaded_models[pollutant] |
|
|
| with model_lock: |
| if pollutant not in loaded_models: |
| model_path = MODEL_PATHS[pollutant] |
| if not model_path.exists(): |
| raise HTTPException(status_code=500, detail=f"模型文件不存在: {model_path}") |
|
|
| |
| checkpoint = torch.load( |
| str(model_path), |
| map_location="cpu", |
| weights_only=True |
| ) |
|
|
| model = PollutionDifferenceModel(num_classes=19, pollution_dims=1) |
|
|
| |
| if isinstance(checkpoint, dict) and "model" in checkpoint: |
| model.load_state_dict(checkpoint["model"]) |
| else: |
| model.load_state_dict(checkpoint) |
|
|
| model.eval() |
|
|
| |
| torch.set_grad_enabled(False) |
| loaded_models[pollutant] = model |
|
|
| return loaded_models[pollutant] |
|
|
|
|
| async def run_segmentation_async(input_dir: Path, output_dir: Path, summary_dir: Path) -> None: |
| """在线程池中异步执行语义分割,避免阻塞事件循环。""" |
| loop = asyncio.get_event_loop() |
| await loop.run_in_executor( |
| None, |
| partial( |
| segmenter.segment, |
| dir_input=str(input_dir), |
| dir_image_output=str(output_dir), |
| dir_summary_output=str(summary_dir) |
| ) |
| ) |
|
|
|
|
| def find_segmented_img(output_dir: Path, base_name: str) -> Path | None: |
| """确定性地查找分割结果图像(排序后取第一个)。""" |
| candidates = sorted([ |
| f for f in output_dir.iterdir() |
| if base_name in f.name and "colored_segmented" in f.name |
| ]) |
| return candidates[0] if candidates else None |
|
|
|
|
| def find_blend_img(output_dir: Path, base_name: str) -> Path | None: |
| """确定性地查找融合结果图像(排序后取第一个)。""" |
| candidates = sorted([ |
| f for f in output_dir.iterdir() |
| if base_name in f.name and "blend" in f.name |
| ]) |
| return candidates[0] if candidates else None |
|
|
|
|
| def copy_segmentation_outputs(output_dir: Path, request_id: str) -> Dict[str, str]: |
| ref_seg_path = find_segmented_img(output_dir, "ref") |
| query_seg_path = find_segmented_img(output_dir, "query") |
| ref_blend_path = find_blend_img(output_dir, "ref") |
| query_blend_path = find_blend_img(output_dir, "query") |
|
|
| if not ref_seg_path or not query_seg_path: |
| raise HTTPException(status_code=500, detail="找不到分割结果图像") |
|
|
| target_ref = STATIC_RESULTS_DIR / f"{request_id}_ref_seg.png" |
| target_query = STATIC_RESULTS_DIR / f"{request_id}_query_seg.png" |
| target_ref_blend = STATIC_RESULTS_DIR / f"{request_id}_ref_blend.png" |
| target_query_blend = STATIC_RESULTS_DIR / f"{request_id}_query_blend.png" |
|
|
| shutil.copy(ref_seg_path, target_ref) |
| shutil.copy(query_seg_path, target_query) |
|
|
| if ref_blend_path and ref_blend_path.exists(): |
| shutil.copy(ref_blend_path, target_ref_blend) |
| if query_blend_path and query_blend_path.exists(): |
| shutil.copy(query_blend_path, target_query_blend) |
|
|
| return { |
| "ref_seg": f"/static/results/{request_id}_ref_seg.png", |
| "query_seg": f"/static/results/{request_id}_query_seg.png", |
| "ref_blend": f"/static/results/{request_id}_ref_blend.png" if target_ref_blend.exists() else "", |
| "query_blend": f"/static/results/{request_id}_query_blend.png" if target_query_blend.exists() else "", |
| } |
|
|
|
|
| def infer_difference( |
| model: PollutionDifferenceModel, |
| ref_tensor: torch.Tensor, |
| query_tensor: torch.Tensor |
| ) -> float: |
| with torch.no_grad(): |
| out = model(ref_tensor, query_tensor) |
| return float(out.item()) |
|
|
|
|
| def validate_ref_data(pollutant: str, ref_data: float) -> None: |
| """服务端校验参考值合理范围。""" |
| if pollutant not in POLLUTANT_RANGES: |
| raise HTTPException(status_code=400, detail=f"不支持的污染物: {pollutant}") |
|
|
| lo, hi = POLLUTANT_RANGES[pollutant] |
| if ref_data < lo: |
| raise HTTPException(status_code=422, detail=f"{pollutant} 参考值不能为负数") |
| if ref_data > hi: |
| raise HTTPException( |
| status_code=422, |
| detail=f"{pollutant} 参考值 {ref_data} 超出合理范围(最大 {hi})" |
| ) |
|
|
|
|
| async def cleanup_old_runs(max_age_hours: int = 24) -> None: |
| """清理超过指定小时数的旧运行目录,释放磁盘空间。""" |
| cutoff = datetime.now() - timedelta(hours=max_age_hours) |
| if not RUNS_DIR.exists(): |
| return |
| for run_dir in RUNS_DIR.iterdir(): |
| if run_dir.is_dir(): |
| try: |
| mtime = datetime.fromtimestamp(run_dir.stat().st_mtime) |
| if mtime < cutoff: |
| shutil.rmtree(run_dir, ignore_errors=True) |
| except Exception: |
| pass |
|
|
|
|
| |
| |
| |
| @app.post("/predict") |
| async def predict( |
| pollutant: str = Form(...), |
| ref_data: float = Form(...), |
| ref_img: UploadFile = File(...), |
| query_img: UploadFile = File(...) |
| ): |
| try: |
| |
| validate_ref_data(pollutant, ref_data) |
|
|
| paths = create_request_dirs() |
| request_id = paths["request_id"] |
| input_dir = paths["input_dir"] |
| output_dir = paths["output_dir"] |
| summary_dir = paths["summary_dir"] |
|
|
| ref_path = input_dir / "ref.jpg" |
| query_path = input_dir / "query.jpg" |
|
|
| await save_upload_file(ref_img, ref_path) |
| await save_upload_file(query_img, query_path) |
|
|
| |
| await run_segmentation_async(input_dir, output_dir, summary_dir) |
|
|
| |
| seg_urls = copy_segmentation_outputs(output_dir, request_id) |
|
|
| |
| ref_tensor = preprocess_image(read_rgb_image(ref_path)) |
| query_tensor = preprocess_image(read_rgb_image(query_path)) |
|
|
| model = load_pollution_model(pollutant) |
| model_out = infer_difference(model, ref_tensor, query_tensor) |
| final_pred = ref_data - model_out |
|
|
| ratio_json_path = summary_dir / "pixel_ratios.json" |
| if not ratio_json_path.exists(): |
| raise HTTPException(status_code=500, detail="分割后未找到 pixel_ratios.json") |
|
|
| return { |
| "status": "ok", |
| "request_id": request_id, |
| "pollutant": pollutant, |
| "ref_data": ref_data, |
| "model_out": round(model_out, 4), |
| "pred_value": round(final_pred, 4), |
| "ref_seg": seg_urls["ref_seg"], |
| "query_seg": seg_urls["query_seg"], |
| "ref_blend": seg_urls["ref_blend"], |
| "query_blend": seg_urls["query_blend"], |
| "ratio_json": f"/runs/{request_id}/summary/pixel_ratios.json" |
| } |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"预测失败: {str(e)}") |
|
|
|
|
| |
| |
| |
| @app.post("/batch-predict") |
| async def batch_predict( |
| pollutant: str = Form(...), |
| ref_data: float = Form(...), |
| ref_img: UploadFile = File(...), |
| query_files: List[UploadFile] = File(...) |
| ): |
| try: |
| validate_ref_data(pollutant, ref_data) |
|
|
| if not query_files: |
| raise HTTPException(status_code=400, detail="未上传任何查询图像") |
|
|
| paths = create_request_dirs() |
| request_id = paths["request_id"] |
| batch_input_dir = paths["input_dir"] |
|
|
| |
| ref_path = batch_input_dir / "ref.jpg" |
| await save_upload_file(ref_img, ref_path) |
| ref_tensor = preprocess_image(read_rgb_image(ref_path)) |
| model = load_pollution_model(pollutant) |
|
|
| |
| query_file_paths = [] |
| for file in query_files: |
| safe_name = os.path.basename(file.filename) if file.filename else f"{uuid.uuid4().hex}.jpg" |
| query_path = batch_input_dir / safe_name |
| |
| await save_upload_file(file, query_path) |
| query_file_paths.append({ |
| "path": str(query_path), |
| "name": safe_name |
| }) |
|
|
| |
| asyncio.create_task( |
| batch_predict_task( |
| request_id=request_id, |
| pollutant=pollutant, |
| ref_data=ref_data, |
| ref_tensor=ref_tensor, |
| model=model, |
| query_file_paths=query_file_paths, |
| batch_input_dir=batch_input_dir |
| ) |
| ) |
|
|
| return { |
| "status": "processing", |
| "request_id": request_id, |
| "total_files": len(query_files) |
| } |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"批量预测启动失败: {str(e)}") |
|
|
|
|
| |
| |
| |
| @app.get("/progress/{request_id}") |
| async def get_batch_progress(request_id: str): |
| """SSE接口:前端监听此接口获取实时进度""" |
| async def event_generator(): |
| while True: |
| |
| progress = batch_progress.get(request_id, {}) |
| if not progress: |
| yield 'data: {"error": "任务不存在"}\n\n' |
| break |
|
|
| |
| progress_data = { |
| 'total': progress.get('total', 0), |
| 'current': progress.get('current', 0), |
| 'status': progress.get('status', 'processing'), |
| 'results': progress.get('results', []), |
| 'failed': progress.get('failed', []) |
| } |
| |
| yield f"data: {json.dumps(progress_data)}\n\n" |
|
|
| |
| if progress.get("status") in ["completed", "failed"]: |
| break |
|
|
| |
| await asyncio.sleep(0.1) |
|
|
| return StreamingResponse(event_generator(), media_type="text/event-stream") |
|
|
| |
| |
| |
| @app.get("/health") |
| async def health_check(): |
| return {"status": "ok"} |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=8000) |