William941008 commited on
Commit
fcfed3d
·
verified ·
1 Parent(s): 2fb298b

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +571 -577
  2. my_Segmenter.py +1189 -0
main.py CHANGED
@@ -1,578 +1,572 @@
1
- import os
2
- import uuid
3
- import shutil
4
- import asyncio
5
- import threading
6
- from datetime import datetime, timedelta
7
- from functools import partial
8
- from pathlib import Path
9
- from typing import List, Dict, Any
10
-
11
- # 固定 zensvi 服务器崩溃问题
12
- import os
13
- os.environ["QT_QPA_PLATFORM"] = "offscreen"
14
- import matplotlib
15
- matplotlib.use("Agg")
16
-
17
- import cv2
18
- import numpy as np
19
- import torch
20
- import uvicorn
21
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
22
- from fastapi.middleware.cors import CORSMiddleware
23
- from fastapi.responses import FileResponse
24
- from fastapi.staticfiles import StaticFiles
25
-
26
- from MyModel import PollutionDifferenceModel
27
- from zensvi.cv import Segmenter
28
-
29
- # 新增:SSE进度条依赖
30
- from fastapi.responses import StreamingResponse
31
- import json
32
- from collections import defaultdict
33
-
34
- # 全局进度存储(线程安全)
35
- batch_progress = defaultdict(dict) # key: request_id, value: {total, current, results, failed}
36
- progress_lock = threading.Lock()
37
-
38
-
39
- # =========================
40
- # 基础目录初始化
41
- # =========================
42
- BASE_DIR = Path(".")
43
- STATIC_DIR = BASE_DIR / "static"
44
- STATIC_RESULTS_DIR = STATIC_DIR / "results"
45
- RUNS_DIR = BASE_DIR / "runs"
46
- AIR_STATION_DIR = BASE_DIR / "AirStationImage"
47
- FRONTEND_DIR = BASE_DIR / "frontend"
48
-
49
- for d in [STATIC_DIR, STATIC_RESULTS_DIR, RUNS_DIR, AIR_STATION_DIR]:
50
- d.mkdir(parents=True, exist_ok=True)
51
-
52
-
53
- # =========================
54
- # 初始化分割模型
55
- # =========================
56
- segmenter = Segmenter(dataset="cityscapes", task="semantic", device="cpu")
57
-
58
-
59
- # =========================
60
- # FastAPI 初始化
61
- # =========================
62
- app = FastAPI(title="香港空气污染预测")
63
-
64
- app.add_middleware(
65
- CORSMiddleware,
66
- allow_origins=["*"],
67
- allow_credentials=False,
68
- allow_methods=["*"],
69
- allow_headers=["*"],
70
- )
71
-
72
- # 静态目录挂载
73
- app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
74
- app.mount("/runs", StaticFiles(directory=str(RUNS_DIR)), name="runs")
75
- app.mount("/AirStationImage", StaticFiles(directory=str(AIR_STATION_DIR)), name="AirStationImage")
76
-
77
-
78
- # =========================
79
- # 启动时清理旧的运行目录
80
- # =========================
81
- # 在 startup_event 中添加
82
- @app.on_event("startup")
83
- async def startup_event():
84
- # 清理旧文件
85
- await cleanup_old_runs()
86
- # 预加载所有模型
87
- try:
88
- for pollutant in MODEL_PATHS.keys():
89
- load_pollution_model(pollutant)
90
- print("✅ 所有污染预测模型预加载完成")
91
- except Exception as e:
92
- print(f"⚠️ 模型预加载失败: {e}")
93
-
94
-
95
- # =========================
96
- # 首页
97
- # =========================
98
- @app.get("/")
99
- async def read_index():
100
- index_path = FRONTEND_DIR / "index.html"
101
- if not index_path.exists():
102
- raise HTTPException(status_code=500, detail="frontend/index.html not found")
103
- return FileResponse(str(index_path))
104
-
105
-
106
- # =========================
107
- # 模型路径
108
- # =========================
109
- MODEL_PATHS = {
110
- "CO": BASE_DIR / "models" / "CO.pth",
111
- "NO2": BASE_DIR / "models" / "NO2.pth",
112
- "PM25": BASE_DIR / "models" / "PM25.pth",
113
- "PM10": BASE_DIR / "models" / "PM10.pth",
114
- "O3": BASE_DIR / "models" / "O3.pth",
115
- }
116
-
117
- loaded_models: Dict[str, PollutionDifferenceModel] = {}
118
- model_lock = threading.Lock()
119
-
120
- # =========================
121
- # 文件上传限制
122
- # =========================
123
- MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
124
- ALLOWED_CONTENT_TYPES = {"image/jpeg", "image/png", "image/webp"}
125
-
126
- # =========================
127
- # 污染物合理范围校验
128
- # =========================
129
- POLLUTANT_RANGES = {
130
- "CO": (0, 50),
131
- "NO2": (0, 500),
132
- "PM25": (0, 999),
133
- "PM10": (0, 999),
134
- "O3": (0, 500),
135
- }
136
-
137
-
138
- # =========================
139
- # 工具函数
140
- # =========================
141
- def create_request_dirs() -> Dict[str, Any]:
142
- request_id = uuid.uuid4().hex # 统一使用 str,不再包装成 Path
143
- base_dir = RUNS_DIR / request_id
144
- input_dir = base_dir / "input"
145
- output_dir = base_dir / "output"
146
- summary_dir = base_dir / "summary"
147
-
148
- for d in [input_dir, output_dir, summary_dir]:
149
- d.mkdir(parents=True, exist_ok=True)
150
-
151
- return {
152
- "request_id": request_id, # str
153
- "base_dir": base_dir,
154
- "input_dir": input_dir,
155
- "output_dir": output_dir,
156
- "summary_dir": summary_dir,
157
- }
158
-
159
-
160
- # =========================
161
- # 异步批量处理函数(修复 I/O 错误:读本地路径)
162
- # =========================
163
- async def batch_predict_task(
164
- request_id: str,
165
- pollutant: str,
166
- ref_data: float,
167
- ref_tensor: torch.Tensor,
168
- model: PollutionDifferenceModel,
169
- query_file_paths: list, # 改为接收路径列表
170
- batch_input_dir: Path
171
- ):
172
- results = []
173
- failed = []
174
- total = len(query_file_paths)
175
-
176
- # 初始化进度
177
- with progress_lock:
178
- batch_progress[request_id] = {
179
- "total": total,
180
- "current": 0,
181
- "results": [],
182
- "failed": [],
183
- "status": "processing"
184
- }
185
-
186
- with torch.no_grad():
187
- for idx, file_info in enumerate(query_file_paths):
188
- safe_name = file_info["name"]
189
- query_path = Path(file_info["path"])
190
-
191
- try:
192
- # 🔥 直接读本地已保存的文件,不会有 I/O 错误
193
- query_np = read_rgb_image(query_path)
194
- query_tensor = preprocess_image(query_np)
195
-
196
- # 模型推理
197
- out = model(ref_tensor, query_tensor)
198
- model_out = float(out.item())
199
- final_pred = ref_data + model_out
200
-
201
- results.append({
202
- "filename": safe_name,
203
- "status": "ok",
204
- "pred_value": round(final_pred, 4),
205
- "model_out": round(model_out, 4),
206
- })
207
-
208
- except Exception as e:
209
- error_msg = f"文件:{safe_name},错误:{str(e)}"
210
- print(f"【批量预测失败】{error_msg}")
211
-
212
- failed.append({
213
- "filename": safe_name,
214
- "status": "error",
215
- "message": str(e)
216
- })
217
-
218
- # 更新进度
219
- with progress_lock:
220
- batch_progress[request_id]["current"] = idx + 1
221
- batch_progress[request_id]["results"] = results
222
- batch_progress[request_id]["failed"] = failed
223
-
224
- await asyncio.sleep(0.05)
225
-
226
- # 标记完成
227
- with progress_lock:
228
- batch_progress[request_id]["status"] = "completed"
229
- print(f"【批量任务完成】{request_id} | 成功:{len(results)} 张,失败:{len(failed)} 张")
230
-
231
- async def save_upload_file(upload_file: UploadFile, save_path: Path) -> None:
232
- """保存上传文件,同时校验大小与类型。"""
233
- content = await upload_file.read()
234
-
235
- if not content:
236
- raise HTTPException(status_code=400, detail=f"上传文件为空: {upload_file.filename}")
237
-
238
- if len(content) > MAX_FILE_SIZE:
239
- raise HTTPException(status_code=413, detail=f"文件过大(最大 10MB): {upload_file.filename}")
240
-
241
- if upload_file.content_type not in ALLOWED_CONTENT_TYPES:
242
- raise HTTPException(
243
- status_code=415,
244
- detail=f"不支持的文件类型 '{upload_file.content_type}',仅支持 JPEG / PNG / WebP"
245
- )
246
-
247
- save_path.write_bytes(content)
248
-
249
-
250
- def preprocess_image(img_np: np.ndarray) -> torch.Tensor:
251
- img = cv2.resize(img_np, (256, 256))
252
- img = img.astype(np.float32) / 255.0
253
- img = img.transpose(2, 0, 1)
254
- return torch.from_numpy(img).unsqueeze(0)
255
-
256
-
257
- def read_rgb_image(path: Path) -> np.ndarray:
258
- img = cv2.imread(str(path))
259
- if img is None:
260
- raise HTTPException(status_code=400, detail=f"无法读取图像: {path.name}")
261
- return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
262
-
263
-
264
- # 替换原 load_pollution_model 函数
265
- def load_pollution_model(pollutant: str) -> PollutionDifferenceModel:
266
- """线程安全的模型加载(双重检查锁定)+ 非阻塞优化"""
267
- if pollutant not in MODEL_PATHS:
268
- raise HTTPException(status_code=400, detail=f"不支持的污染物类型: {pollutant}")
269
-
270
- if pollutant in loaded_models:
271
- return loaded_models[pollutant]
272
-
273
- with model_lock:
274
- if pollutant not in loaded_models:
275
- model_path = MODEL_PATHS[pollutant]
276
- if not model_path.exists():
277
- raise HTTPException(status_code=500, detail=f"模型文件不存在: {model_path}")
278
-
279
- # ✅ 关键修复:weights_only=True 防止安全问题+加速加载
280
- checkpoint = torch.load(
281
- str(model_path),
282
- map_location="cpu",
283
- weights_only=True
284
- )
285
-
286
- model = PollutionDifferenceModel(num_classes=19, pollution_dims=1)
287
-
288
- # ✅ 兼容模型加载
289
- if isinstance(checkpoint, dict) and "model" in checkpoint:
290
- model.load_state_dict(checkpoint["model"])
291
- else:
292
- model.load_state_dict(checkpoint)
293
-
294
- model.eval()
295
-
296
- # ✅ 优化推理速度:启用推理模式
297
- torch.set_grad_enabled(False)
298
- loaded_models[pollutant] = model
299
-
300
- return loaded_models[pollutant]
301
-
302
-
303
- async def run_segmentation_async(input_dir: Path, output_dir: Path, summary_dir: Path) -> None:
304
- """在线程池中异步执行语义分割,避免阻塞事件循环。"""
305
- loop = asyncio.get_event_loop()
306
- await loop.run_in_executor(
307
- None,
308
- partial(
309
- segmenter.segment,
310
- dir_input=str(input_dir),
311
- dir_image_output=str(output_dir),
312
- dir_summary_output=str(summary_dir)
313
- )
314
- )
315
-
316
-
317
- def find_segmented_img(output_dir: Path, base_name: str) -> Path | None:
318
- """确定性地查找分割结果图像(排序后取第一个)。"""
319
- candidates = sorted([
320
- f for f in output_dir.iterdir()
321
- if base_name in f.name and "colored_segmented" in f.name
322
- ])
323
- return candidates[0] if candidates else None
324
-
325
-
326
- def find_blend_img(output_dir: Path, base_name: str) -> Path | None:
327
- """确定性地查找融合结果图像(排序后取第一个)。"""
328
- candidates = sorted([
329
- f for f in output_dir.iterdir()
330
- if base_name in f.name and "blend" in f.name
331
- ])
332
- return candidates[0] if candidates else None
333
-
334
-
335
- def copy_segmentation_outputs(output_dir: Path, request_id: str) -> Dict[str, str]:
336
- ref_seg_path = find_segmented_img(output_dir, "ref")
337
- query_seg_path = find_segmented_img(output_dir, "query")
338
- ref_blend_path = find_blend_img(output_dir, "ref")
339
- query_blend_path = find_blend_img(output_dir, "query")
340
-
341
- if not ref_seg_path or not query_seg_path:
342
- raise HTTPException(status_code=500, detail="找不到分割结果图像")
343
-
344
- target_ref = STATIC_RESULTS_DIR / f"{request_id}_ref_seg.png"
345
- target_query = STATIC_RESULTS_DIR / f"{request_id}_query_seg.png"
346
- target_ref_blend = STATIC_RESULTS_DIR / f"{request_id}_ref_blend.png"
347
- target_query_blend = STATIC_RESULTS_DIR / f"{request_id}_query_blend.png"
348
-
349
- shutil.copy(ref_seg_path, target_ref)
350
- shutil.copy(query_seg_path, target_query)
351
-
352
- if ref_blend_path and ref_blend_path.exists():
353
- shutil.copy(ref_blend_path, target_ref_blend)
354
- if query_blend_path and query_blend_path.exists():
355
- shutil.copy(query_blend_path, target_query_blend)
356
-
357
- return {
358
- "ref_seg": f"/static/results/{request_id}_ref_seg.png",
359
- "query_seg": f"/static/results/{request_id}_query_seg.png",
360
- "ref_blend": f"/static/results/{request_id}_ref_blend.png" if target_ref_blend.exists() else "",
361
- "query_blend": f"/static/results/{request_id}_query_blend.png" if target_query_blend.exists() else "",
362
- }
363
-
364
-
365
- def infer_difference(
366
- model: PollutionDifferenceModel,
367
- ref_tensor: torch.Tensor,
368
- query_tensor: torch.Tensor
369
- ) -> float:
370
- with torch.no_grad():
371
- out = model(ref_tensor, query_tensor)
372
- return float(out.item())
373
-
374
-
375
- def validate_ref_data(pollutant: str, ref_data: float) -> None:
376
- """服务端校验参考值合理范围。"""
377
- if pollutant not in POLLUTANT_RANGES:
378
- raise HTTPException(status_code=400, detail=f"不支持的污染物: {pollutant}")
379
-
380
- lo, hi = POLLUTANT_RANGES[pollutant]
381
- if ref_data < lo:
382
- raise HTTPException(status_code=422, detail=f"{pollutant} 参考值不能为负数")
383
- if ref_data > hi:
384
- raise HTTPException(
385
- status_code=422,
386
- detail=f"{pollutant} 参考值 {ref_data} 超出合理范围(最大 {hi})"
387
- )
388
-
389
-
390
- async def cleanup_old_runs(max_age_hours: int = 24) -> None:
391
- """清理超过指定小时数的旧运行目录,释放磁盘空间。"""
392
- cutoff = datetime.now() - timedelta(hours=max_age_hours)
393
- if not RUNS_DIR.exists():
394
- return
395
- for run_dir in RUNS_DIR.iterdir():
396
- if run_dir.is_dir():
397
- try:
398
- mtime = datetime.fromtimestamp(run_dir.stat().st_mtime)
399
- if mtime < cutoff:
400
- shutil.rmtree(run_dir, ignore_errors=True)
401
- except Exception:
402
- pass # 跳过无法访问的目录
403
-
404
-
405
- # =========================
406
- # 单图预测
407
- # =========================
408
- @app.post("/predict")
409
- async def predict(
410
- pollutant: str = Form(...),
411
- ref_data: float = Form(...),
412
- ref_img: UploadFile = File(...),
413
- query_img: UploadFile = File(...)
414
- ):
415
- try:
416
- # 服务端输入校验
417
- validate_ref_data(pollutant, ref_data)
418
-
419
- paths = create_request_dirs()
420
- request_id = paths["request_id"] # 现在是纯 str
421
- input_dir = paths["input_dir"]
422
- output_dir = paths["output_dir"]
423
- summary_dir = paths["summary_dir"]
424
-
425
- ref_path = input_dir / "ref.jpg"
426
- query_path = input_dir / "query.jpg"
427
-
428
- await save_upload_file(ref_img, ref_path)
429
- await save_upload_file(query_img, query_path)
430
-
431
- # 异步语义分割(不阻塞事件循环)
432
- await run_segmentation_async(input_dir, output_dir, summary_dir)
433
-
434
- # 复制结果图到 static/results
435
- seg_urls = copy_segmentation_outputs(output_dir, request_id)
436
-
437
- # 读取图像并推理
438
- ref_tensor = preprocess_image(read_rgb_image(ref_path))
439
- query_tensor = preprocess_image(read_rgb_image(query_path))
440
-
441
- model = load_pollution_model(pollutant)
442
- model_out = infer_difference(model, ref_tensor, query_tensor)
443
- final_pred = ref_data + model_out
444
-
445
- ratio_json_path = summary_dir / "pixel_ratios.json"
446
- if not ratio_json_path.exists():
447
- raise HTTPException(status_code=500, detail="分割后未找到 pixel_ratios.json")
448
-
449
- return {
450
- "status": "ok",
451
- "request_id": request_id,
452
- "pollutant": pollutant,
453
- "ref_data": ref_data,
454
- "model_out": round(model_out, 4),
455
- "pred_value": round(final_pred, 4),
456
- "ref_seg": seg_urls["ref_seg"],
457
- "query_seg": seg_urls["query_seg"],
458
- "ref_blend": seg_urls["ref_blend"],
459
- "query_blend": seg_urls["query_blend"],
460
- "ratio_json": f"/runs/{request_id}/summary/pixel_ratios.json"
461
- }
462
-
463
- except HTTPException:
464
- raise
465
- except Exception as e:
466
- raise HTTPException(status_code=500, detail=f"预测失败: {str(e)}")
467
-
468
-
469
- # =========================
470
- # 批量预测(修复 I/O 错误:先保存所有文件)
471
- # =========================
472
- @app.post("/batch-predict")
473
- async def batch_predict(
474
- pollutant: str = Form(...),
475
- ref_data: float = Form(...),
476
- ref_img: UploadFile = File(...),
477
- query_files: List[UploadFile] = File(...)
478
- ):
479
- try:
480
- validate_ref_data(pollutant, ref_data)
481
-
482
- if not query_files:
483
- raise HTTPException(status_code=400, detail="未上传任何查询像")
484
-
485
- paths = create_request_dirs()
486
- request_id = paths["request_id"]
487
- batch_input_dir = paths["input_dir"]
488
-
489
- # 1. 先保存参考图
490
- ref_path = batch_input_dir / "ref.jpg"
491
- await save_upload_file(ref_img, ref_path)
492
- ref_tensor = preprocess_image(read_rgb_image(ref_path))
493
- model = load_pollution_model(pollutant)
494
-
495
- # 🔥 核心修复:接口返回前,先把所有查询图片保存到本地
496
- query_file_paths = []
497
- for file in query_files:
498
- safe_name = os.path.basename(file.filename) if file.filename else f"{uuid.uuid4().hex}.jpg"
499
- query_path = batch_input_dir / safe_name
500
- # 提前保存文件
501
- await save_upload_file(file, query_path)
502
- query_file_paths.append({
503
- "path": str(query_path), # 只传路径,不传文件对象
504
- "name": safe_name
505
- })
506
-
507
- # 2. 启动后台任务(只传路径,不传 UploadFile)
508
- asyncio.create_task(
509
- batch_predict_task(
510
- request_id=request_id,
511
- pollutant=pollutant,
512
- ref_data=ref_data,
513
- ref_tensor=ref_tensor,
514
- model=model,
515
- query_file_paths=query_file_paths, # 传路径列表
516
- batch_input_dir=batch_input_dir
517
- )
518
- )
519
-
520
- return {
521
- "status": "processing",
522
- "request_id": request_id,
523
- "total_files": len(query_files)
524
- }
525
-
526
- except HTTPException:
527
- raise
528
- except Exception as e:
529
- raise HTTPException(status_code=500, detail=f"批量预测启动失败: {str(e)}")
530
-
531
-
532
- # =========================
533
- # 批量预测进度推送(SSE)【修复语法错误版】
534
- # =========================
535
- @app.get("/progress/{request_id}")
536
- async def get_batch_progress(request_id: str):
537
- """SSE接口:前端监听此接口获取实时进度"""
538
- async def event_generator():
539
- while True:
540
- # 获取进度
541
- progress = batch_progress.get(request_id, {})
542
- if not progress:
543
- yield 'data: {"error": "任务不存在"}\n\n'
544
- break
545
-
546
- # 【修复】把json提出来,避免f-string换行语法错误
547
- progress_data = {
548
- 'total': progress.get('total', 0),
549
- 'current': progress.get('current', 0),
550
- 'status': progress.get('status', 'processing'),
551
- 'results': progress.get('results', []),
552
- 'failed': progress.get('failed', [])
553
- }
554
- # 【关键修复】一行写完,不换行!
555
- yield f"data: {json.dumps(progress_data)}\n\n"
556
-
557
- # 任务完成/失败,停止推送
558
- if progress.get("status") in ["completed", "failed"]:
559
- break
560
-
561
- # 每100ms推送1次
562
- await asyncio.sleep(0.1)
563
-
564
- return StreamingResponse(event_generator(), media_type="text/event-stream")
565
-
566
- # =========================
567
- # 健康检查
568
- # =========================
569
- @app.get("/health")
570
- async def health_check():
571
- return {"status": "ok"}
572
-
573
-
574
- # =========================
575
- # 启动
576
- # =========================
577
- if __name__ == "__main__":
578
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import os
2
+ import uuid
3
+ import shutil
4
+ import asyncio
5
+ import threading
6
+ from datetime import datetime, timedelta
7
+ from functools import partial
8
+ from pathlib import Path
9
+ from typing import List, Dict, Any
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import torch
14
+ import uvicorn
15
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from fastapi.responses import FileResponse
18
+ from fastapi.staticfiles import StaticFiles
19
+
20
+ from MyModel import PollutionDifferenceModel
21
+ from my_Segmenter import Segmenter
22
+
23
+ # 新增:SSE进度条依赖
24
+ from fastapi.responses import StreamingResponse
25
+ import json
26
+ from collections import defaultdict
27
+
28
+ # 全局进度存储(线程安全)
29
+ batch_progress = defaultdict(dict) # key: request_id, value: {total, current, results, failed}
30
+ progress_lock = threading.Lock()
31
+
32
+
33
+ # =========================
34
+ # 基础目录初始化
35
+ # =========================
36
+ BASE_DIR = Path(".")
37
+ STATIC_DIR = BASE_DIR / "static"
38
+ STATIC_RESULTS_DIR = STATIC_DIR / "results"
39
+ RUNS_DIR = BASE_DIR / "runs"
40
+ AIR_STATION_DIR = BASE_DIR / "AirStationImage"
41
+ FRONTEND_DIR = BASE_DIR / "frontend"
42
+
43
+ for d in [STATIC_DIR, STATIC_RESULTS_DIR, RUNS_DIR, AIR_STATION_DIR]:
44
+ d.mkdir(parents=True, exist_ok=True)
45
+
46
+
47
+ # =========================
48
+ # 初始化分割模型
49
+ # =========================
50
+ segmenter = Segmenter(dataset="cityscapes", task="semantic", device="cpu")
51
+
52
+
53
+ # =========================
54
+ # FastAPI 初始化
55
+ # =========================
56
+ app = FastAPI(title="香港空气污染预测")
57
+
58
+ app.add_middleware(
59
+ CORSMiddleware,
60
+ allow_origins=["*"],
61
+ allow_credentials=False,
62
+ allow_methods=["*"],
63
+ allow_headers=["*"],
64
+ )
65
+
66
+ # 静态目录挂载
67
+ app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
68
+ app.mount("/runs", StaticFiles(directory=str(RUNS_DIR)), name="runs")
69
+ app.mount("/AirStationImage", StaticFiles(directory=str(AIR_STATION_DIR)), name="AirStationImage")
70
+
71
+
72
+ # =========================
73
+ # 启动时清理旧的运行目录
74
+ # =========================
75
+ # startup_event 中添加
76
+ @app.on_event("startup")
77
+ async def startup_event():
78
+ # 清理旧文件
79
+ await cleanup_old_runs()
80
+ # 预加载所有模型 ✅
81
+ try:
82
+ for pollutant in MODEL_PATHS.keys():
83
+ load_pollution_model(pollutant)
84
+ print("✅ 所有污染预测模型预加载完成")
85
+ except Exception as e:
86
+ print(f"⚠️ 模型预加载失败: {e}")
87
+
88
+
89
+ # =========================
90
+ # 首页
91
+ # =========================
92
+ @app.get("/")
93
+ async def read_index():
94
+ index_path = FRONTEND_DIR / "index.html"
95
+ if not index_path.exists():
96
+ raise HTTPException(status_code=500, detail="frontend/index.html not found")
97
+ return FileResponse(str(index_path))
98
+
99
+
100
+ # =========================
101
+ # 模型路径
102
+ # =========================
103
+ MODEL_PATHS = {
104
+ "CO": BASE_DIR / "models" / "CO.pth",
105
+ "NO2": BASE_DIR / "models" / "NO2.pth",
106
+ "PM25": BASE_DIR / "models" / "PM25.pth",
107
+ "PM10": BASE_DIR / "models" / "PM10.pth",
108
+ "O3": BASE_DIR / "models" / "O3.pth",
109
+ }
110
+
111
+ loaded_models: Dict[str, PollutionDifferenceModel] = {}
112
+ model_lock = threading.Lock()
113
+
114
+ # =========================
115
+ # 文件上传限制
116
+ # =========================
117
+ MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
118
+ ALLOWED_CONTENT_TYPES = {"image/jpeg", "image/png", "image/webp"}
119
+
120
+ # =========================
121
+ # 污染物合理范围校验
122
+ # =========================
123
+ POLLUTANT_RANGES = {
124
+ "CO": (0, 50),
125
+ "NO2": (0, 500),
126
+ "PM25": (0, 999),
127
+ "PM10": (0, 999),
128
+ "O3": (0, 500),
129
+ }
130
+
131
+
132
+ # =========================
133
+ # 工具函数
134
+ # =========================
135
+ def create_request_dirs() -> Dict[str, Any]:
136
+ request_id = uuid.uuid4().hex # 统一使用 str,不再包装成 Path
137
+ base_dir = RUNS_DIR / request_id
138
+ input_dir = base_dir / "input"
139
+ output_dir = base_dir / "output"
140
+ summary_dir = base_dir / "summary"
141
+
142
+ for d in [input_dir, output_dir, summary_dir]:
143
+ d.mkdir(parents=True, exist_ok=True)
144
+
145
+ return {
146
+ "request_id": request_id, # str
147
+ "base_dir": base_dir,
148
+ "input_dir": input_dir,
149
+ "output_dir": output_dir,
150
+ "summary_dir": summary_dir,
151
+ }
152
+
153
+
154
+ # =========================
155
+ # 异步批量处理函数(修复 I/O 错误:读本地路径)
156
+ # =========================
157
+ async def batch_predict_task(
158
+ request_id: str,
159
+ pollutant: str,
160
+ ref_data: float,
161
+ ref_tensor: torch.Tensor,
162
+ model: PollutionDifferenceModel,
163
+ query_file_paths: list, # 改为接收路径列表
164
+ batch_input_dir: Path
165
+ ):
166
+ results = []
167
+ failed = []
168
+ total = len(query_file_paths)
169
+
170
+ # 初始化进度
171
+ with progress_lock:
172
+ batch_progress[request_id] = {
173
+ "total": total,
174
+ "current": 0,
175
+ "results": [],
176
+ "failed": [],
177
+ "status": "processing"
178
+ }
179
+
180
+ with torch.no_grad():
181
+ for idx, file_info in enumerate(query_file_paths):
182
+ safe_name = file_info["name"]
183
+ query_path = Path(file_info["path"])
184
+
185
+ try:
186
+ # 🔥 直接读本地已保存的文件,不会有 I/O 错误
187
+ query_np = read_rgb_image(query_path)
188
+ query_tensor = preprocess_image(query_np)
189
+
190
+ # 模型推理
191
+ out = model(ref_tensor, query_tensor)
192
+ model_out = float(out.item())
193
+ final_pred = ref_data + model_out
194
+
195
+ results.append({
196
+ "filename": safe_name,
197
+ "status": "ok",
198
+ "pred_value": round(final_pred, 4),
199
+ "model_out": round(model_out, 4),
200
+ })
201
+
202
+ except Exception as e:
203
+ error_msg = f"文件:{safe_name},错误:{str(e)}"
204
+ print(f"【批量预测失败】{error_msg}")
205
+
206
+ failed.append({
207
+ "filename": safe_name,
208
+ "status": "error",
209
+ "message": str(e)
210
+ })
211
+
212
+ # 更新进度
213
+ with progress_lock:
214
+ batch_progress[request_id]["current"] = idx + 1
215
+ batch_progress[request_id]["results"] = results
216
+ batch_progress[request_id]["failed"] = failed
217
+
218
+ await asyncio.sleep(0.05)
219
+
220
+ # 标记完成
221
+ with progress_lock:
222
+ batch_progress[request_id]["status"] = "completed"
223
+ print(f"【批量任务完成】{request_id} | 成功:{len(results)} 张,失败:{len(failed)} 张")
224
+
225
+ async def save_upload_file(upload_file: UploadFile, save_path: Path) -> None:
226
+ """保存上传文件,同时校验大小与类型。"""
227
+ content = await upload_file.read()
228
+
229
+ if not content:
230
+ raise HTTPException(status_code=400, detail=f"上传文件为空: {upload_file.filename}")
231
+
232
+ if len(content) > MAX_FILE_SIZE:
233
+ raise HTTPException(status_code=413, detail=f"文件过大(最大 10MB): {upload_file.filename}")
234
+
235
+ if upload_file.content_type not in ALLOWED_CONTENT_TYPES:
236
+ raise HTTPException(
237
+ status_code=415,
238
+ detail=f"不支持的文件类型 '{upload_file.content_type}',仅支持 JPEG / PNG / WebP"
239
+ )
240
+
241
+ save_path.write_bytes(content)
242
+
243
+
244
+ def preprocess_image(img_np: np.ndarray) -> torch.Tensor:
245
+ img = cv2.resize(img_np, (256, 256))
246
+ img = img.astype(np.float32) / 255.0
247
+ img = img.transpose(2, 0, 1)
248
+ return torch.from_numpy(img).unsqueeze(0)
249
+
250
+
251
+ def read_rgb_image(path: Path) -> np.ndarray:
252
+ img = cv2.imread(str(path))
253
+ if img is None:
254
+ raise HTTPException(status_code=400, detail=f"无法读取图像: {path.name}")
255
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
256
+
257
+
258
+ # 替换原 load_pollution_model 函数
259
+ def load_pollution_model(pollutant: str) -> PollutionDifferenceModel:
260
+ """线程安全的模型加载(双重检查锁定)+ 非阻塞优化"""
261
+ if pollutant not in MODEL_PATHS:
262
+ raise HTTPException(status_code=400, detail=f"不支持的污染物类型: {pollutant}")
263
+
264
+ if pollutant in loaded_models:
265
+ return loaded_models[pollutant]
266
+
267
+ with model_lock:
268
+ if pollutant not in loaded_models:
269
+ model_path = MODEL_PATHS[pollutant]
270
+ if not model_path.exists():
271
+ raise HTTPException(status_code=500, detail=f"模型文件不存在: {model_path}")
272
+
273
+ # ✅ 关键修复:weights_only=True 防止安全问题+加速加载
274
+ checkpoint = torch.load(
275
+ str(model_path),
276
+ map_location="cpu",
277
+ weights_only=True
278
+ )
279
+
280
+ model = PollutionDifferenceModel(num_classes=19, pollution_dims=1)
281
+
282
+ # ✅ 兼容模型加载
283
+ if isinstance(checkpoint, dict) and "model" in checkpoint:
284
+ model.load_state_dict(checkpoint["model"])
285
+ else:
286
+ model.load_state_dict(checkpoint)
287
+
288
+ model.eval()
289
+
290
+ # ✅ 优化推理速度:启用推理模式
291
+ torch.set_grad_enabled(False)
292
+ loaded_models[pollutant] = model
293
+
294
+ return loaded_models[pollutant]
295
+
296
+
297
+ async def run_segmentation_async(input_dir: Path, output_dir: Path, summary_dir: Path) -> None:
298
+ """在线程池中异步执行语义分割,避免阻塞事件循环。"""
299
+ loop = asyncio.get_event_loop()
300
+ await loop.run_in_executor(
301
+ None,
302
+ partial(
303
+ segmenter.segment,
304
+ dir_input=str(input_dir),
305
+ dir_image_output=str(output_dir),
306
+ dir_summary_output=str(summary_dir)
307
+ )
308
+ )
309
+
310
+
311
+ def find_segmented_img(output_dir: Path, base_name: str) -> Path | None:
312
+ """确定性地查找分割结果图像(排序后取第一个)。"""
313
+ candidates = sorted([
314
+ f for f in output_dir.iterdir()
315
+ if base_name in f.name and "colored_segmented" in f.name
316
+ ])
317
+ return candidates[0] if candidates else None
318
+
319
+
320
+ def find_blend_img(output_dir: Path, base_name: str) -> Path | None:
321
+ """确定性地查找融合结果图像(排序后取第一个)。"""
322
+ candidates = sorted([
323
+ f for f in output_dir.iterdir()
324
+ if base_name in f.name and "blend" in f.name
325
+ ])
326
+ return candidates[0] if candidates else None
327
+
328
+
329
+ def copy_segmentation_outputs(output_dir: Path, request_id: str) -> Dict[str, str]:
330
+ ref_seg_path = find_segmented_img(output_dir, "ref")
331
+ query_seg_path = find_segmented_img(output_dir, "query")
332
+ ref_blend_path = find_blend_img(output_dir, "ref")
333
+ query_blend_path = find_blend_img(output_dir, "query")
334
+
335
+ if not ref_seg_path or not query_seg_path:
336
+ raise HTTPException(status_code=500, detail="找不到分割结果图像")
337
+
338
+ target_ref = STATIC_RESULTS_DIR / f"{request_id}_ref_seg.png"
339
+ target_query = STATIC_RESULTS_DIR / f"{request_id}_query_seg.png"
340
+ target_ref_blend = STATIC_RESULTS_DIR / f"{request_id}_ref_blend.png"
341
+ target_query_blend = STATIC_RESULTS_DIR / f"{request_id}_query_blend.png"
342
+
343
+ shutil.copy(ref_seg_path, target_ref)
344
+ shutil.copy(query_seg_path, target_query)
345
+
346
+ if ref_blend_path and ref_blend_path.exists():
347
+ shutil.copy(ref_blend_path, target_ref_blend)
348
+ if query_blend_path and query_blend_path.exists():
349
+ shutil.copy(query_blend_path, target_query_blend)
350
+
351
+ return {
352
+ "ref_seg": f"/static/results/{request_id}_ref_seg.png",
353
+ "query_seg": f"/static/results/{request_id}_query_seg.png",
354
+ "ref_blend": f"/static/results/{request_id}_ref_blend.png" if target_ref_blend.exists() else "",
355
+ "query_blend": f"/static/results/{request_id}_query_blend.png" if target_query_blend.exists() else "",
356
+ }
357
+
358
+
359
+ def infer_difference(
360
+ model: PollutionDifferenceModel,
361
+ ref_tensor: torch.Tensor,
362
+ query_tensor: torch.Tensor
363
+ ) -> float:
364
+ with torch.no_grad():
365
+ out = model(ref_tensor, query_tensor)
366
+ return float(out.item())
367
+
368
+
369
+ def validate_ref_data(pollutant: str, ref_data: float) -> None:
370
+ """服务端校验参考值合理范围。"""
371
+ if pollutant not in POLLUTANT_RANGES:
372
+ raise HTTPException(status_code=400, detail=f"不支持的污染物: {pollutant}")
373
+
374
+ lo, hi = POLLUTANT_RANGES[pollutant]
375
+ if ref_data < lo:
376
+ raise HTTPException(status_code=422, detail=f"{pollutant} 参考值不能为负数")
377
+ if ref_data > hi:
378
+ raise HTTPException(
379
+ status_code=422,
380
+ detail=f"{pollutant} 参考值 {ref_data} 超出合理范围(最大 {hi})"
381
+ )
382
+
383
+
384
+ async def cleanup_old_runs(max_age_hours: int = 24) -> None:
385
+ """清理超过指定小时数的旧运行目录,释放磁盘空间。"""
386
+ cutoff = datetime.now() - timedelta(hours=max_age_hours)
387
+ if not RUNS_DIR.exists():
388
+ return
389
+ for run_dir in RUNS_DIR.iterdir():
390
+ if run_dir.is_dir():
391
+ try:
392
+ mtime = datetime.fromtimestamp(run_dir.stat().st_mtime)
393
+ if mtime < cutoff:
394
+ shutil.rmtree(run_dir, ignore_errors=True)
395
+ except Exception:
396
+ pass # 跳过无法访问的目录
397
+
398
+
399
+ # =========================
400
+ # 单图预测
401
+ # =========================
402
+ @app.post("/predict")
403
+ async def predict(
404
+ pollutant: str = Form(...),
405
+ ref_data: float = Form(...),
406
+ ref_img: UploadFile = File(...),
407
+ query_img: UploadFile = File(...)
408
+ ):
409
+ try:
410
+ # 服务端输入校验
411
+ validate_ref_data(pollutant, ref_data)
412
+
413
+ paths = create_request_dirs()
414
+ request_id = paths["request_id"] # 现在是纯 str
415
+ input_dir = paths["input_dir"]
416
+ output_dir = paths["output_dir"]
417
+ summary_dir = paths["summary_dir"]
418
+
419
+ ref_path = input_dir / "ref.jpg"
420
+ query_path = input_dir / "query.jpg"
421
+
422
+ await save_upload_file(ref_img, ref_path)
423
+ await save_upload_file(query_img, query_path)
424
+
425
+ # 异步语义分割(不阻塞事件循环)
426
+ await run_segmentation_async(input_dir, output_dir, summary_dir)
427
+
428
+ # 复制结果图到 static/results
429
+ seg_urls = copy_segmentation_outputs(output_dir, request_id)
430
+
431
+ # 读取图像并推理
432
+ ref_tensor = preprocess_image(read_rgb_image(ref_path))
433
+ query_tensor = preprocess_image(read_rgb_image(query_path))
434
+
435
+ model = load_pollution_model(pollutant)
436
+ model_out = infer_difference(model, ref_tensor, query_tensor)
437
+ final_pred = ref_data + model_out
438
+
439
+ ratio_json_path = summary_dir / "pixel_ratios.json"
440
+ if not ratio_json_path.exists():
441
+ raise HTTPException(status_code=500, detail="分割后未找到 pixel_ratios.json")
442
+
443
+ return {
444
+ "status": "ok",
445
+ "request_id": request_id,
446
+ "pollutant": pollutant,
447
+ "ref_data": ref_data,
448
+ "model_out": round(model_out, 4),
449
+ "pred_value": round(final_pred, 4),
450
+ "ref_seg": seg_urls["ref_seg"],
451
+ "query_seg": seg_urls["query_seg"],
452
+ "ref_blend": seg_urls["ref_blend"],
453
+ "query_blend": seg_urls["query_blend"],
454
+ "ratio_json": f"/runs/{request_id}/summary/pixel_ratios.json"
455
+ }
456
+
457
+ except HTTPException:
458
+ raise
459
+ except Exception as e:
460
+ raise HTTPException(status_code=500, detail=f"预测失败: {str(e)}")
461
+
462
+
463
+ # =========================
464
+ # 批量预测(修复 I/O 错误:先保存所有文件)
465
+ # =========================
466
+ @app.post("/batch-predict")
467
+ async def batch_predict(
468
+ pollutant: str = Form(...),
469
+ ref_data: float = Form(...),
470
+ ref_img: UploadFile = File(...),
471
+ query_files: List[UploadFile] = File(...)
472
+ ):
473
+ try:
474
+ validate_ref_data(pollutant, ref_data)
475
+
476
+ if not query_files:
477
+ raise HTTPException(status_code=400, detail="未上传任何查询图像")
478
+
479
+ paths = create_request_dirs()
480
+ request_id = paths["request_id"]
481
+ batch_input_dir = paths["input_dir"]
482
+
483
+ # 1. 先保存参考
484
+ ref_path = batch_input_dir / "ref.jpg"
485
+ await save_upload_file(ref_img, ref_path)
486
+ ref_tensor = preprocess_image(read_rgb_image(ref_path))
487
+ model = load_pollution_model(pollutant)
488
+
489
+ # 🔥 核心修复:接口返回前,把所有查询图片保存到本地
490
+ query_file_paths = []
491
+ for file in query_files:
492
+ safe_name = os.path.basename(file.filename) if file.filename else f"{uuid.uuid4().hex}.jpg"
493
+ query_path = batch_input_dir / safe_name
494
+ # 提前保存文件
495
+ await save_upload_file(file, query_path)
496
+ query_file_paths.append({
497
+ "path": str(query_path), # 只传路径,不传文件对象
498
+ "name": safe_name
499
+ })
500
+
501
+ # 2. 启动后台任务(只传路径,不传 UploadFile)
502
+ asyncio.create_task(
503
+ batch_predict_task(
504
+ request_id=request_id,
505
+ pollutant=pollutant,
506
+ ref_data=ref_data,
507
+ ref_tensor=ref_tensor,
508
+ model=model,
509
+ query_file_paths=query_file_paths, # 传路径列表
510
+ batch_input_dir=batch_input_dir
511
+ )
512
+ )
513
+
514
+ return {
515
+ "status": "processing",
516
+ "request_id": request_id,
517
+ "total_files": len(query_files)
518
+ }
519
+
520
+ except HTTPException:
521
+ raise
522
+ except Exception as e:
523
+ raise HTTPException(status_code=500, detail=f"批量预测启动失败: {str(e)}")
524
+
525
+
526
+ # =========================
527
+ # 批量预测进度推送(SSE)【修复语法错误版】
528
+ # =========================
529
+ @app.get("/progress/{request_id}")
530
+ async def get_batch_progress(request_id: str):
531
+ """SSE接口:前端监听此接口获取实时进度"""
532
+ async def event_generator():
533
+ while True:
534
+ # 获取进度
535
+ progress = batch_progress.get(request_id, {})
536
+ if not progress:
537
+ yield 'data: {"error": "任务不存在"}\n\n'
538
+ break
539
+
540
+ # 【修复】把json提出来,避免f-string换行语法错误
541
+ progress_data = {
542
+ 'total': progress.get('total', 0),
543
+ 'current': progress.get('current', 0),
544
+ 'status': progress.get('status', 'processing'),
545
+ 'results': progress.get('results', []),
546
+ 'failed': progress.get('failed', [])
547
+ }
548
+ # 【关键修复】一行写完,不换行!
549
+ yield f"data: {json.dumps(progress_data)}\n\n"
550
+
551
+ # 任务完成/失败,停止推送
552
+ if progress.get("status") in ["completed", "failed"]:
553
+ break
554
+
555
+ # 每100ms推送1次
556
+ await asyncio.sleep(0.1)
557
+
558
+ return StreamingResponse(event_generator(), media_type="text/event-stream")
559
+
560
+ # =========================
561
+ # 健康检查
562
+ # =========================
563
+ @app.get("/health")
564
+ async def health_check():
565
+ return {"status": "ok"}
566
+
567
+
568
+ # =========================
569
+ # 启动
570
+ # =========================
571
+ if __name__ == "__main__":
 
 
 
 
 
 
572
  uvicorn.run(app, host="0.0.0.0", port=8000)
my_Segmenter.py ADDED
@@ -0,0 +1,1189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import shutil
4
+ from collections import defaultdict, namedtuple
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ from math import ceil
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional, Tuple, Union
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from tqdm.contrib.concurrent import thread_map
16
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
17
+
18
+ from zensvi.utils.log import verbosity_tqdm
19
+
20
+ # a label and all meta information
21
+ _Label = namedtuple(
22
+ "_Label",
23
+ [
24
+ "name", # The identifier of this label, e.g. 'car', 'person', ... .
25
+ # We use them to uniquely name a class
26
+ "id", # An integer ID that is associated with this label.
27
+ # The IDs are used to represent the label in ground truth images
28
+ # An ID of -1 means that this label does not have an ID and thus
29
+ # is ignored when creating ground truth images (e.g. license plate).
30
+ # Do not modify these IDs, since exactly these IDs are expected by the
31
+ # evaluation server.
32
+ "trainId", # Feel free to modify these IDs as suitable for your method. Then create
33
+ # ground truth images with train IDs, using the tools provided in the
34
+ # 'preparation' folder. However, make sure to validate or submit results
35
+ # to our evaluation server using the regular IDs above!
36
+ # For trainIds, multiple labels might have the same ID. Then, these labels
37
+ # are mapped to the same class in the ground truth images. For the inverse
38
+ # mapping, we use the label that is defined first in the list below.
39
+ # For example, mapping all void-type classes to the same ID in training,
40
+ # might make sense for some approaches.
41
+ # Max value is 255!
42
+ "category", # The name of the category that this label belongs to
43
+ "categoryId", # The ID of this category. Used to create ground truth images
44
+ # on category level.
45
+ "hasInstances", # Whether this label distinguishes between single instances or not
46
+ "ignoreInEval", # Whether pixels having this class as ground truth label are ignored
47
+ # during evaluations or not
48
+ "color", # The color of this label
49
+ ],
50
+ )
51
+
52
+
53
+ def _create_cityscapes_label_colormap() -> List[_Label]:
54
+ """Creates a label colormap used in CITYSCAPES segmentation benchmark.
55
+
56
+ Args:
57
+
58
+ Returns:
59
+ : A colormap for visualizing segmentation results.
60
+
61
+ """
62
+ labels = [
63
+ # name id trainId category catId hasInstances ignoreInEval color
64
+ _Label("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)),
65
+ _Label("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)),
66
+ _Label("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)),
67
+ _Label("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)),
68
+ _Label("static", 4, 255, "void", 0, False, True, (0, 0, 0)),
69
+ _Label("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)),
70
+ _Label("ground", 6, 255, "void", 0, False, True, (81, 0, 81)),
71
+ _Label("road", 7, 0, "flat", 1, False, False, (128, 64, 128)),
72
+ _Label("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)),
73
+ _Label("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)),
74
+ _Label("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)),
75
+ _Label("building", 11, 2, "construction", 2, False, False, (70, 70, 70)),
76
+ _Label("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)),
77
+ _Label("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)),
78
+ _Label("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)),
79
+ _Label("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)),
80
+ _Label("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)),
81
+ _Label("pole", 17, 5, "object", 3, False, False, (153, 153, 153)),
82
+ _Label("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)),
83
+ _Label("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)),
84
+ _Label("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)),
85
+ _Label("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)),
86
+ _Label("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)),
87
+ _Label("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)),
88
+ _Label("person", 24, 11, "human", 6, True, False, (220, 20, 60)),
89
+ _Label("rider", 25, 12, "human", 6, True, False, (255, 0, 0)),
90
+ _Label("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)),
91
+ _Label("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)),
92
+ _Label("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)),
93
+ _Label("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)),
94
+ _Label("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)),
95
+ _Label("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)),
96
+ _Label("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)),
97
+ _Label("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)),
98
+ _Label("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)),
99
+ ]
100
+ return labels
101
+
102
+
103
+ def _create_mapillary_vistas_label_colormap() -> List[_Label]:
104
+ """Creates a label colormap used in Mapillary Vistas segmentation benchmark.
105
+
106
+ Args:
107
+
108
+ Returns:
109
+ : A list of labels for visualizing segmentation results.
110
+
111
+ """
112
+ labels = [
113
+ _Label("Bird", 0, 0, "animal", 0, True, False, (165, 42, 42)),
114
+ _Label("Ground Animal", 1, 1, "animal", 0, True, False, (0, 192, 0)),
115
+ _Label("Curb", 2, 2, "construction", 1, False, False, (196, 196, 196)),
116
+ _Label("Fence", 3, 3, "construction", 1, False, False, (190, 153, 153)),
117
+ _Label("Guard Rail", 4, 4, "construction", 1, False, False, (180, 165, 180)),
118
+ _Label("Barrier", 5, 5, "construction", 1, False, False, (102, 102, 156)),
119
+ _Label("Wall", 6, 6, "construction", 1, False, False, (102, 102, 156)),
120
+ _Label("Bike Lane", 7, 7, "flat", 2, False, False, (128, 64, 255)),
121
+ _Label("Crosswalk - Plain", 8, 8, "flat", 2, False, False, (140, 140, 200)),
122
+ _Label("Curb Cut", 9, 9, "flat", 2, False, False, (170, 170, 170)),
123
+ _Label("Parking", 10, 10, "flat", 2, False, False, (250, 170, 160)),
124
+ _Label("Pedestrian Area", 11, 11, "flat", 2, False, False, (96, 96, 96)),
125
+ _Label("Rail Track", 12, 12, "flat", 2, False, False, (230, 150, 140)),
126
+ _Label("Road", 13, 13, "flat", 2, False, False, (128, 64, 128)),
127
+ _Label("Service Lane", 14, 14, "flat", 2, False, False, (110, 110, 110)),
128
+ _Label("Sidewalk", 15, 15, "flat", 2, False, False, (244, 35, 232)),
129
+ _Label("Bridge", 16, 16, "construction", 1, False, False, (150, 100, 100)),
130
+ _Label("Building", 17, 17, "construction", 1, False, False, (70, 70, 70)),
131
+ _Label("Tunnel", 18, 18, "construction", 1, False, False, (150, 120, 90)),
132
+ _Label("Person", 19, 19, "human", 3, True, False, (220, 20, 60)),
133
+ _Label("Bicyclist", 20, 20, "human", 3, True, False, (255, 0, 0)),
134
+ _Label("Motorcyclist", 21, 21, "human", 3, True, False, (255, 0, 0)),
135
+ _Label("Other Rider", 22, 22, "human", 3, True, False, (255, 0, 0)),
136
+ _Label(
137
+ "Lane Marking - Crosswalk",
138
+ 23,
139
+ 23,
140
+ "marking",
141
+ 4,
142
+ False,
143
+ True,
144
+ (200, 128, 128),
145
+ ),
146
+ _Label("Lane Marking - General", 24, 24, "marking", 4, True, False, (255, 255, 255)),
147
+ _Label("Mountain", 25, 25, "nature", 5, False, False, (64, 170, 64)),
148
+ _Label("Sand", 26, 26, "nature", 5, False, False, (230, 160, 50)),
149
+ _Label("Sky", 27, 27, "sky", 6, False, False, (70, 130, 180)),
150
+ _Label("Snow", 28, 28, "nature", 5, False, False, (190, 255, 255)),
151
+ _Label("Terrain", 29, 29, "nature", 5, False, False, (152, 251, 152)),
152
+ _Label("Vegetation", 30, 30, "nature", 5, False, False, (107, 142, 35)),
153
+ _Label("Water", 31, 31, "water", 7, False, False, (0, 170, 30)),
154
+ _Label("Banner", 32, 32, "object", 8, False, False, (255, 220, 0)),
155
+ _Label("Bench", 33, 33, "object", 8, False, False, (255, 0, 0)),
156
+ _Label("Bike Rack", 34, 34, "object", 8, False, False, (255, 0, 0)),
157
+ _Label("Billboard", 35, 35, "object", 8, False, False, (255, 0, 0)),
158
+ _Label("Catch Basin", 36, 36, "object", 8, False, False, (255, 0, 0)),
159
+ _Label("CCTV Camera", 37, 37, "object", 8, False, False, (255, 0, 0)),
160
+ _Label("Fire Hydrant", 38, 38, "object", 8, False, False, (255, 0, 0)),
161
+ _Label("Junction Box", 39, 39, "object", 8, False, False, (255, 0, 0)),
162
+ _Label("Mailbox", 40, 40, "object", 8, False, False, (255, 0, 0)),
163
+ _Label("Manhole", 41, 41, "object", 8, False, False, (255, 0, 0)),
164
+ _Label("Phone Booth", 42, 42, "object", 8, False, False, (255, 0, 0)),
165
+ _Label("Pothole", 43, 43, "object", 8, False, False, (255, 0, 0)),
166
+ _Label("Street Light", 44, 44, "object", 8, False, False, (255, 0, 0)),
167
+ _Label("Pole", 45, 45, "object", 8, False, False, (255, 0, 0)),
168
+ _Label("Traffic Sign Frame", 46, 46, "object", 8, False, False, (255, 0, 0)),
169
+ _Label("Utility Pole", 47, 47, "object", 8, False, False, (255, 0, 0)),
170
+ _Label("Traffic Light", 48, 48, "object", 8, False, False, (255, 0, 0)),
171
+ _Label("Traffic Sign (Back)", 49, 49, "object", 8, False, False, (255, 0, 0)),
172
+ _Label("Traffic Sign (Front)", 50, 50, "object", 8, False, False, (255, 0, 0)),
173
+ _Label("Trash Can", 51, 51, "object", 8, False, False, (255, 0, 0)),
174
+ _Label("Bicycle", 52, 52, "vehicle", 9, True, False, (119, 11, 32)),
175
+ _Label("Boat", 53, 53, "vehicle", 9, False, False, (0, 0, 142)),
176
+ _Label("Bus", 54, 54, "vehicle", 9, True, False, (0, 60, 100)),
177
+ _Label("Car", 55, 55, "vehicle", 9, True, False, (0, 0, 142)),
178
+ _Label("Caravan", 56, 56, "vehicle", 9, True, False, (0, 0, 90)),
179
+ _Label("Motorcycle", 57, 57, "vehicle", 9, True, False, (0, 0, 230)),
180
+ _Label("On Rails", 58, 58, "vehicle", 9, False, False, (0, 80, 100)),
181
+ _Label("Other Vehicle", 59, 59, "vehicle", 9, True, False, (128, 64, 64)),
182
+ _Label("Trailer", 60, 60, "vehicle", 9, True, False, (0, 0, 110)),
183
+ _Label("Truck", 61, 61, "vehicle", 9, True, False, (0, 0, 70)),
184
+ _Label("Wheeled Slow", 62, 62, "vehicle", 9, False, False, (0, 0, 192)),
185
+ _Label("Car Mount", 63, 63, "vehicle", 9, True, False, (32, 32, 32)),
186
+ _Label("Ego Vehicle", 64, 64, "vehicle", 9, True, False, (120, 10, 10)),
187
+ ]
188
+ return labels
189
+
190
+
191
+ def _get_resized_dimensions(width: int, height: int, max_size: int = 2048) -> Tuple[int, int]:
192
+ """Calculate the new dimensions of an image to maintain aspect ratio.
193
+
194
+ If both dimensions are less than or equal to max_size, the original dimensions are returned.
195
+
196
+ Args:
197
+ width (int): The original width of the image.
198
+ height (int): The original height of the image.
199
+ max_size (int, optional): The maximum size for either dimension. Defaults to 2048.
200
+
201
+ Returns:
202
+ Tuple[int, int]: The new dimensions (width, height) of the image.
203
+ """
204
+ if max(width, height) > max_size:
205
+ scaling_factor = max_size / max(width, height)
206
+ new_width = int(width * scaling_factor)
207
+ new_height = int(height * scaling_factor)
208
+ return new_width, new_height
209
+ else:
210
+ # Return original dimensions if resizing is not necessary
211
+ return width, height
212
+
213
+
214
+ class ImageDataset(Dataset):
215
+ """A dataset class for loading and processing images.
216
+
217
+ This class handles the loading of images from specified file paths,
218
+ resizing them to a maximum size while maintaining the aspect ratio,
219
+ and converting them to RGB format if required.
220
+
221
+ Args:
222
+ image_files (List[Path]): A list of paths to the image files.
223
+ max_size (int, optional): The maximum size for resizing the images. Defaults to 2048.
224
+ rgb (bool, optional): If True, images will be converted to RGB format. Defaults to True.
225
+ """
226
+
227
+ def __init__(self, image_files: List[Path], max_size: int = 2048, rgb: bool = True) -> None:
228
+ """Initializes the ImageDataset with the paths to images, maximum size for resizing,
229
+ and color mode.
230
+
231
+ Args:
232
+ image_files (List[Path]): A list of paths to the image files.
233
+ max_size (int, optional): The maximum size for resizing the images. Defaults to 2048.
234
+ rgb (bool, optional): If True, images will be converted to RGB format. Defaults to True.
235
+ """
236
+ self.image_files = [
237
+ image_file
238
+ for image_file in image_files
239
+ if image_file.suffix.lower() in [".jpg", ".jpeg", ".png"] and not image_file.name.startswith(".")
240
+ ]
241
+ self.max_size = max_size
242
+ self.rgb = rgb
243
+
244
+ def __len__(self) -> int:
245
+ """Returns the number of images in the dataset.
246
+
247
+ Returns:
248
+ int: The number of images in the dataset.
249
+ """
250
+ return len(self.image_files)
251
+
252
+ def __getitem__(self, idx: int) -> Tuple[str, cv2.Mat, Tuple[int, int]]:
253
+ """Retrieves an image and its metadata from the dataset.
254
+
255
+ Args:
256
+ idx (int): The index of the image to retrieve.
257
+
258
+ Returns:
259
+ Tuple[str, cv2.Mat, Tuple[int, int]]: A tuple containing the image file path,
260
+ the image data, and the dimensions of the image (height, width).
261
+
262
+ Raises:
263
+ ValueError: If the image cannot be read.
264
+ """
265
+ image_file = self.image_files[idx]
266
+ img = cv2.imread(str(image_file))
267
+
268
+ if img is None:
269
+ raise ValueError(f"Unable to read image at {image_file}")
270
+
271
+ original_height, original_width = img.shape[:2]
272
+ new_width, new_height = _get_resized_dimensions(original_width, original_height, self.max_size)
273
+
274
+ # Resize image if necessary
275
+ if (original_width, original_height) != (new_width, new_height):
276
+ img = cv2.resize(img, (new_width, new_height))
277
+
278
+ if self.rgb:
279
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
280
+
281
+ return str(image_file), img, (new_height, new_width)
282
+
283
+ def collate_fn(
284
+ self, data: List[Tuple[str, cv2.Mat, Tuple[int, int]]]
285
+ ) -> Tuple[List[str], List[cv2.Mat], List[Tuple[int, int]]]:
286
+ """Custom collate function for the dataset.
287
+
288
+ Args:
289
+ data (List[Tuple[str, cv2.Mat, Tuple[int, int]]]): A list of tuples containing
290
+ image file path, image data, and original image dimensions.
291
+
292
+ Returns:
293
+ Tuple[List[str], List[cv2.Mat], List[Tuple[int, int]]]: A tuple containing lists
294
+ of image file paths, image data, and original image dimensions.
295
+ """
296
+ image_files, images, original_img_shape = zip(*data)
297
+ return list(image_files), list(images), list(original_img_shape)
298
+
299
+
300
+ class Segmenter:
301
+ """A class for performing semantic and panoptic segmentation on images.
302
+
303
+ The models used are from the Mask2Former (https://huggingface.co/docs/transformers/model_doc/mask2former).
304
+
305
+ Attributes:
306
+ device (str): The device to run the model on (e.g., "cuda" or "cpu").
307
+ dataset (str): The name of the dataset (e.g., "cityscapes" or "mapillary").
308
+ task (str): The type of segmentation task (e.g., "semantic" or "panoptic").
309
+ model_name (str): The name of the pre-trained model corresponding to the dataset and task.
310
+ model: The segmentation model.
311
+ processor: The image processor for the model.
312
+ color_map: A mapping of class IDs to colors.
313
+ label_map: A mapping of class IDs to labels.
314
+ id_to_name_map: A mapping of label IDs to label names.
315
+ verbosity (int): Level of verbosity for progress bars (0=no progress, 1=outer loops only, 2=all loops)
316
+
317
+ Args:
318
+ dataset (str): The name of the dataset (default is "cityscapes").
319
+ task (str): The type of task (default is "semantic").
320
+ device (str, optional): The device to run the model on (e.g., "cuda" or "cpu"). If None, the default device will be used.
321
+ verbosity (int, optional): Level of verbosity for progress bars (0=no progress, 1=outer loops only, 2=all loops). Default is 1.
322
+
323
+ Returns:
324
+ None
325
+ """
326
+
327
+ def __init__(
328
+ self, dataset: str = "cityscapes", task: str = "semantic", device: Optional[str] = None, verbosity: int = 1
329
+ ) -> None:
330
+ """Initializes the Segmenter with a model and dataset.
331
+
332
+ Args:
333
+ dataset (str): The name of the dataset (default is "cityscapes").
334
+ task (str): The type of task (default is "semantic").
335
+ device (str, optional): The device to run the model on (e.g., "cuda" or "cpu"). If None, the default device will be used.
336
+ verbosity (int, optional): Level of verbosity for progress bars (0=no progress, 1=outer loops only, 2=all loops). Default is 1.
337
+
338
+ Returns:
339
+ None
340
+ """
341
+ self.device = self._get_device(device)
342
+ self.dataset = dataset
343
+ self.task = task
344
+ self.model_name = self._get_model_name(self.dataset, self.task)
345
+ self.model, self.processor = self._get_model_processor(self.model_name)
346
+ self.color_map = self._create_color_map(dataset)
347
+ self.label_map = self._create_label_map(dataset)
348
+ self.id_to_name_map = self._create_id_to_name_map(dataset)
349
+ self.verbosity = verbosity
350
+
351
+ def _get_model_name(self, dataset: str, task: str) -> str:
352
+ """Get the model name based on the dataset and task.
353
+
354
+ Args:
355
+ dataset (str): The name of the dataset (e.g., "cityscapes", "mapillary").
356
+ task (str): The type of task (e.g., "semantic", "panoptic").
357
+
358
+ Returns:
359
+ str: The name of the pre-trained model corresponding to the dataset and task.
360
+
361
+ Raises:
362
+ ValueError: If the dataset is unknown.
363
+
364
+ """
365
+ if dataset == "cityscapes":
366
+ if task == "semantic":
367
+ return "facebook/mask2former-swin-tiny-cityscapes-semantic"
368
+ elif task == "panoptic":
369
+ return "facebook/mask2former-swin-tiny-cityscapes-panoptic"
370
+ elif dataset == "mapillary":
371
+ if task == "semantic":
372
+ return "facebook/mask2former-swin-large-mapillary-vistas-semantic"
373
+ elif task == "panoptic":
374
+ return "facebook/mask2former-swin-large-mapillary-vistas-panoptic"
375
+ else:
376
+ raise ValueError(f"Unknown dataset: {dataset}")
377
+
378
+ def _get_model_processor(self, model_name: str) -> Tuple:
379
+ """Get the model and processor for the given model name.
380
+
381
+ Args:
382
+ model_name(str): The name of the pre-trained model.
383
+
384
+ Returns:
385
+ Tuple: The model and processor.
386
+
387
+ """
388
+ # Add other models in the future
389
+ if "mask2former" in model_name:
390
+ processor = AutoImageProcessor.from_pretrained(model_name)
391
+ model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name).to(self.device)
392
+ return model, processor
393
+
394
+ def _create_color_map(self, dataset: str) -> np.ndarray:
395
+ """Create a color map based on the given dataset."""
396
+
397
+ if dataset == "cityscapes":
398
+ labels = _create_cityscapes_label_colormap()
399
+ elif dataset == "mapillary":
400
+ labels = _create_mapillary_vistas_label_colormap()
401
+ else:
402
+ raise ValueError(f"Unknown dataset: {dataset}")
403
+
404
+ # Important:
405
+ # For Cityscapes, trainId=255 is the ignore label.
406
+ # It should not be treated as a normal semantic class.
407
+ valid_labels = [
408
+ label for label in labels
409
+ if label.trainId is not None and label.trainId >= 0 and label.trainId < 255
410
+ ]
411
+
412
+ train_ids = np.array([label.trainId for label in valid_labels], dtype=np.int64)
413
+ colors = np.array([label.color for label in valid_labels], dtype=np.uint8)
414
+
415
+ if len(train_ids) == 0:
416
+ raise ValueError(
417
+ f"No valid trainIds found for dataset={dataset}. "
418
+ "Please check the label definitions."
419
+ )
420
+
421
+ max_train_id = int(np.max(train_ids)) + 1
422
+ color_map = np.zeros((max_train_id, 3), dtype=np.uint8)
423
+ color_map[train_ids] = colors
424
+
425
+ self.train_id_to_name = {
426
+ int(label.trainId): label.name for label in valid_labels
427
+ }
428
+
429
+ return color_map
430
+
431
+ def _create_label_map(self, dataset: str) -> Dict[Tuple, _Label]:
432
+ """Create a label map based on the given dataset.
433
+
434
+ Args:
435
+ dataset(str): The name of the dataset.
436
+
437
+ Returns:
438
+ Dict[Tuple, _Label]: A dictionary mapping colors to labels.
439
+
440
+ """
441
+ if dataset == "cityscapes":
442
+ labels = _create_cityscapes_label_colormap()
443
+ elif dataset == "mapillary":
444
+ labels = _create_mapillary_vistas_label_colormap()
445
+ else:
446
+ raise ValueError(f"Unknown dataset: {dataset}")
447
+
448
+ color_to_label = {}
449
+ for label in labels:
450
+ color = label.color
451
+ color_to_label[color] = label
452
+
453
+ return color_to_label
454
+
455
+ def _create_id_to_name_map(self, dataset: str) -> Dict[int, str]:
456
+ """Create a mapping from train IDs to label names based on the dataset."""
457
+
458
+ if dataset == "cityscapes":
459
+ labels = _create_cityscapes_label_colormap()
460
+ elif dataset == "mapillary":
461
+ labels = _create_mapillary_vistas_label_colormap()
462
+ else:
463
+ raise ValueError(f"Unknown dataset: {dataset}")
464
+
465
+ valid_labels = [
466
+ label for label in labels
467
+ if label.trainId is not None and label.trainId >= 0 and label.trainId < 255
468
+ ]
469
+
470
+ return {
471
+ int(label.trainId): label.name
472
+ for label in valid_labels
473
+ }
474
+
475
+ def _get_device(self, device: Optional[str]) -> torch.device:
476
+ """Get the appropriate device for running the model.
477
+
478
+ Args:
479
+ device (str or None): The device to use (e.g., "cpu", "cuda", "mps"). If None, the function will select the best available device.
480
+
481
+ Returns:
482
+ torch.device: The device to use for running the model.
483
+
484
+ Raises:
485
+ ValueError: If the provided device is not recognized.
486
+ """
487
+ if device is not None:
488
+ print(f"Using {device.upper()}")
489
+ return torch.device(device)
490
+ if torch.cuda.is_available():
491
+ print("Using GPU")
492
+ return torch.device("cuda")
493
+ else:
494
+ print("Using CPU")
495
+ return torch.device("cpu")
496
+
497
+ def _calculate_pixel_ratios(self, segmented_img: np.ndarray) -> Dict[str, float]:
498
+ """Calculate pixel ratios for each class in the segmented image."""
499
+
500
+ unique, counts = np.unique(segmented_img, return_counts=True)
501
+ total_pixels = np.sum(counts)
502
+
503
+ pixel_ratios = {}
504
+
505
+ for train_id, count in zip(unique, counts):
506
+ train_id = int(train_id)
507
+
508
+ # Skip ignored or unknown labels
509
+ if train_id not in self.train_id_to_name:
510
+ continue
511
+
512
+ pixel_ratios[self.train_id_to_name[train_id]] = count / total_pixels
513
+
514
+ return pixel_ratios
515
+
516
+ def _save_as_csv(self, input_dict: dict, dir_output: Path, value_name: str, csv_format: str) -> None:
517
+ """Save pixel ratios as a CSV file.
518
+
519
+ This function takes a dictionary of pixel ratios and saves it to a CSV file in either long or wide format.
520
+
521
+ Args:
522
+ input_dict (dict): A dictionary containing pixel ratios for each image and label.
523
+ dir_output (Path): The directory where the CSV file will be saved.
524
+ value_name (str): The name of the value to be saved in the CSV.
525
+ csv_format (str): The format of the CSV file, either 'long' or 'wide'.
526
+
527
+ Returns:
528
+ None: This function does not return any value but saves the CSV file to the specified directory.
529
+ """
530
+ if csv_format == "long":
531
+ df_list = [
532
+ pd.DataFrame(
533
+ {
534
+ "filename_key": [filename_key],
535
+ "label_name": [key],
536
+ value_name: [value] if value is not None else [0],
537
+ }
538
+ )
539
+ for filename_key, inner_dict in input_dict.items()
540
+ for key, value in inner_dict.items()
541
+ ]
542
+
543
+ pixel_ratios_df = pd.concat(df_list, ignore_index=True)
544
+
545
+ elif csv_format == "wide":
546
+ pixel_ratios_df = pd.DataFrame(input_dict).transpose().fillna(0)
547
+ pixel_ratios_df.index.names = ["filename_key"]
548
+
549
+ pixel_ratios_df.to_csv(dir_output / Path(value_name + ".csv"))
550
+
551
+ def _panoptic_segmentation(self, images: List[np.ndarray], original_img_shape: List[Tuple[int, int]]) -> list:
552
+ """Perform panoptic segmentation on the given images.
553
+
554
+ Args:
555
+ images(list): List of input images.
556
+ original_img_shape(tuple): Original image shape.
557
+
558
+ Returns:
559
+ list: List of panoptic segmentation outputs.
560
+
561
+ """
562
+ inputs = self.processor(images=images, return_tensors="pt").to(self.model.device)
563
+ outputs = self.model(**inputs)
564
+ return self.processor.post_process_panoptic_segmentation(
565
+ outputs, target_sizes=original_img_shape, label_ids_to_fuse=set([])
566
+ )
567
+
568
+ def _semantic_segmentation(self, images: List[np.ndarray], original_img_shape: List[Tuple[int, int]]) -> list:
569
+ """Perform semantic segmentation on the given images.
570
+
571
+ Args:
572
+ images(list): List of input images.
573
+ original_img_shape(tuple): Original image shape.
574
+
575
+ Returns:
576
+ tuple: Tuple containing list of semantic segmentation outputs and list of pixel ratios.
577
+
578
+ """
579
+ inputs = self.processor(images=images, return_tensors="pt").to(self.model.device)
580
+ with torch.no_grad():
581
+ outputs = self.model(**inputs)
582
+ segmentations = self.processor.post_process_semantic_segmentation(outputs, target_sizes=original_img_shape)
583
+ return segmentations
584
+
585
+ def _trainid_to_color(self, segmented_img: np.ndarray) -> np.ndarray:
586
+ """Convert segmented image with train IDs to a colored image.
587
+
588
+ Args:
589
+ segmented_img(numpy.ndarray): Segmented image with train IDs.
590
+
591
+ Returns:
592
+ numpy.ndarray: Colored segmented image.
593
+
594
+ """
595
+ colored_img = self.color_map[segmented_img]
596
+ return colored_img
597
+
598
+ def _save_panoptic_segmentation_image(
599
+ self, image_file: str, img: np.ndarray, dir_output: Path, output: dict
600
+ ) -> None:
601
+ """Save the panoptic segmentation image as a blended image with the original input image.
602
+
603
+ Args:
604
+ image_file (str): The input image file path.
605
+ img (np.ndarray): The input image in the format of a NumPy array.
606
+ dir_output (Path): The output directory path to save the blended image.
607
+ output (dict): The output dictionary containing the segmentation data.
608
+
609
+ Returns:
610
+ None: This function does not return any value but saves the blended image and segmented image to the specified directory.
611
+ """
612
+ colored_segmented_img = self._trainid_to_color(output["label_segmentation"].cpu().numpy())
613
+ alpha = 0.5
614
+ blend_img = cv2.addWeighted(img, alpha, colored_segmented_img, 1 - alpha, 0)
615
+
616
+ # Calculate the scale factor for text size
617
+ height, width, _ = img.shape
618
+ scale_factor = np.sqrt(height * width) / 1000 # Example scale, adjust as needed
619
+
620
+ # Add annotations for each segment
621
+ for segment_info in output["segments_info"]:
622
+ segment_id = segment_info["id"]
623
+ label_id = segment_info["label_id"]
624
+ score = segment_info["score"]
625
+
626
+ # Use the label name instead of the label_id
627
+ label_name = self.id_to_name_map.get(label_id)
628
+
629
+ # Find the center of the segment for the label placement
630
+ y, x = np.where(output["segmentation"].cpu().numpy() == segment_id)
631
+ center_x, center_y = np.mean(x), np.mean(y)
632
+
633
+ # Add the annotation with dynamic font size
634
+ font_scale = 1 * scale_factor # Adjust base font size (1 here) as needed
635
+ thickness = 1 * scale_factor # Adjust base thickness (1 here) as needed
636
+ cv2.putText(
637
+ blend_img,
638
+ f"{label_name}-{score:.2f}",
639
+ (int(center_x), int(center_y)),
640
+ cv2.FONT_HERSHEY_SIMPLEX,
641
+ font_scale,
642
+ (255, 255, 255),
643
+ ceil(thickness),
644
+ cv2.LINE_AA,
645
+ )
646
+
647
+ output_file = dir_output / Path(image_file).name
648
+
649
+ # Save images based on specified options
650
+ if "segmented_image" in self.save_image_options:
651
+ cv2.imwrite(
652
+ str(output_file.with_name(output_file.stem + "_colored_segmented.png")),
653
+ cv2.cvtColor(colored_segmented_img, cv2.COLOR_RGB2BGR),
654
+ )
655
+ if "blend_image" in self.save_image_options:
656
+ cv2.imwrite(
657
+ str(output_file.with_name(output_file.stem + "_blend.png")),
658
+ cv2.cvtColor(blend_img, cv2.COLOR_RGB2BGR),
659
+ )
660
+
661
+ def _save_semantic_segmentation_image(
662
+ self, image_file: str, img: np.ndarray, dir_output: Path, output: torch.Tensor
663
+ ) -> None:
664
+ """Saves the semantic segmentation image as a colored segmented image and/or a
665
+ blended image with the original input image.
666
+
667
+ Args:
668
+ image_file (str): The input image file path.
669
+ img (np.array): The input image in the format of a NumPy array.
670
+ dir_output (Path): The output directory path to save the colored segmented and/or blended image.
671
+ output (Tensor): The output tensor containing the semantic segmentation data.
672
+
673
+ Returns:
674
+ None
675
+ """
676
+ colored_segmented_img = self._trainid_to_color(output.cpu().numpy())
677
+ alpha = 0.5
678
+ blend_img = cv2.addWeighted(img, alpha, colored_segmented_img, 1 - alpha, 0)
679
+
680
+ output_file = dir_output / Path(image_file).name
681
+
682
+ # Save images based on specified options
683
+ if "segmented_image" in self.save_image_options:
684
+ cv2.imwrite(
685
+ str(output_file.with_name(output_file.stem + "_colored_segmented.png")),
686
+ cv2.cvtColor(colored_segmented_img, cv2.COLOR_RGB2BGR),
687
+ )
688
+ if "blend_image" in self.save_image_options:
689
+ cv2.imwrite(
690
+ str(output_file.with_name(output_file.stem + "_blend.png")),
691
+ cv2.cvtColor(blend_img, cv2.COLOR_RGB2BGR),
692
+ )
693
+
694
+ def _panoptic_count_labels(self, output: dict) -> Dict[str, int]:
695
+ """Count the occurrences of each label in the panoptic segmentation output.
696
+
697
+ Args:
698
+ output (dict): The output dictionary containing segmentation information.
699
+ It should have a key "segments_info" which is a list of dictionaries,
700
+ each containing a "label_id".
701
+
702
+ Returns:
703
+ dict: A dictionary where keys are label names and values are the counts
704
+ of each label in the segmentation output.
705
+ """
706
+ label_counts = {}
707
+
708
+ # Loop through each segment in the image
709
+ segments_info_list = output["segments_info"]
710
+ for segments_info in segments_info_list:
711
+ # Convert label_id to label_name
712
+ label_name = self.id_to_name_map.get(segments_info["label_id"])
713
+ # Increment the count for the name in label_counts
714
+ if label_name in label_counts:
715
+ label_counts[label_name] += 1
716
+ else:
717
+ label_counts[label_name] = 1
718
+ return label_counts
719
+
720
+ def _panoptic_segment_to_label(self, output: dict) -> torch.Tensor:
721
+ """This function converts the output of post_process_panoptic_segmentation
722
+ function from segment_id to label_id.
723
+
724
+ Args:
725
+ output: The output dictionary from the
726
+ post_process_panoptic_segmentation function
727
+
728
+ Returns:
729
+ : segmentation with label_ids instead of segment_ids
730
+
731
+ """
732
+ # Extract the segmentation and segments_info from the output
733
+ segmentation = output["segmentation"]
734
+ segments_info = output["segments_info"]
735
+
736
+ # Create a mapping from segment_id to label_id
737
+ id_map = {segment["id"]: segment["label_id"] for segment in segments_info}
738
+
739
+ # Use the map to convert the segmentation tensor from segment_ids to label_ids
740
+ new_segmentation = segmentation.clone()
741
+
742
+ for seg_id, label_id in id_map.items():
743
+ new_segmentation[segmentation == seg_id] = label_id
744
+
745
+ return new_segmentation
746
+
747
+ def _process_images(
748
+ self,
749
+ task: str,
750
+ image_files: List[str],
751
+ images: List[np.ndarray],
752
+ dir_output: Optional[Path],
753
+ pixel_ratio_dict: Dict[str, Dict[str, float]],
754
+ original_img_shape: List[Tuple[int, int]],
755
+ panoptic_dict: Optional[Dict[str, Dict[str, int]]] = None,
756
+ ) -> None:
757
+ """Process the input images for segmentation and save the output images.
758
+
759
+ Args:
760
+ task(str): The segmentation task to perform, either "panoptic" or "semantic".
761
+ image_files(List[str]): The list of file paths of the input images.
762
+ images(List[ndarray]): The list of input images in the form of numpy arrays.
763
+ dir_output(Path): The output directory where the segmented images will be saved.
764
+ pixel_ratio_dict(defaultdict): A dictionary to store the pixel ratios of the segmented images.
765
+ original_img_shape(List[Tuple[int): The original shapes of the input images.
766
+ panoptic_dict: (Default value = None)
767
+
768
+ Returns:
769
+ : None
770
+
771
+ """
772
+ outputs = None
773
+ if task == "panoptic":
774
+ outputs = self._panoptic_segmentation(images, original_img_shape)
775
+ if outputs is not None:
776
+ for image_file, img, output in zip(image_files, images, outputs):
777
+ # create a new segmentation with label_ids instead of segment_ids
778
+ output["label_segmentation"] = self._panoptic_segment_to_label(output)
779
+ if (len(self.save_image_options) > 0) & (dir_output is not None):
780
+ self._save_panoptic_segmentation_image(image_file, img, dir_output, output)
781
+ pixel_ratio = self._calculate_pixel_ratios(output["label_segmentation"].cpu().numpy())
782
+ label_counts = self._panoptic_count_labels(output)
783
+ image_file_key = Path(image_file).stem
784
+ pixel_ratio_dict[image_file_key] = pixel_ratio
785
+ panoptic_dict[image_file_key] = label_counts
786
+
787
+ elif task == "semantic":
788
+ segmentations = self._semantic_segmentation(images, original_img_shape)
789
+ if segmentations is not None:
790
+ for image_file, img, segmentation in zip(image_files, images, segmentations):
791
+ if (len(self.save_image_options) > 0) & (dir_output is not None):
792
+ self._save_semantic_segmentation_image(image_file, img, dir_output, segmentation)
793
+ pixel_ratio = self._calculate_pixel_ratios(segmentation.cpu().numpy())
794
+ image_file_key = Path(image_file).stem
795
+ pixel_ratio_dict[image_file_key] = pixel_ratio
796
+
797
+ # Modify the segment method inside the Segmenter class
798
+ def segment(
799
+ self,
800
+ dir_input: Union[str, Path],
801
+ dir_image_output: Union[str, Path, None] = None,
802
+ dir_summary_output: Union[str, Path, None] = None,
803
+ batch_size: int = 1,
804
+ save_image_options: str = "segmented_image blend_image",
805
+ save_format: str = "json csv",
806
+ csv_format: str = "long", # "long" or "wide"
807
+ max_workers: Optional[int] = None,
808
+ ) -> None:
809
+ """Processes a batch of images for segmentation, saves the segmented images and
810
+ summary statistics.
811
+
812
+ This method handles the processing of images for segmentation, managing input/output directories,
813
+ saving options, and parallel processing settings. The method requires specifying an input directory
814
+ or a path to a single image and supports optional saving of output images and segmentation summaries.
815
+
816
+ Args:
817
+ dir_input: Input directory or path to a single image file
818
+ dir_image_output: Output directory where segmented images are saved
819
+ dir_summary_output: Output directory where segmentation summary files are saved
820
+ batch_size: Batch size for processing images (Default: 1)
821
+ save_image_options: Options for saving images ("segmented_image blend_image")
822
+ save_format: Format for saving summary files ("json csv")
823
+ csv_format: Format for CSV summary files ("long" or "wide")
824
+ max_workers: Maximum number of workers for parallel processing
825
+
826
+ Returns:
827
+ None: The method does not return any value but saves the processed results to specified directories.
828
+
829
+ Raises:
830
+ ValueError: If neither dir_image_output nor dir_summary_output is provided
831
+ ValueError: If the input path is neither a file nor a directory
832
+ """
833
+ # make sure that at least one of dir_image_output and dir_summary_output is not None
834
+ if (dir_image_output is None) & (dir_summary_output is None):
835
+ raise ValueError("At least one of dir_image_output and dir_summary_output must not be None.")
836
+
837
+ # skip if there's pixel_ratio.json and/or pixel_ratios.csv in dir_summary_output, depending on save_format
838
+ if dir_summary_output is not None:
839
+ if "json" in save_format and "csv" in save_format:
840
+ if (Path(dir_summary_output) / "pixel_ratios.json").exists() and (
841
+ Path(dir_summary_output) / "pixel_ratios.csv"
842
+ ).exists():
843
+ print("Segmentation summary already exists. Skipping segmentation.")
844
+ return
845
+ elif "json" in save_format:
846
+ if (Path(dir_summary_output) / "pixel_ratios.json").exists():
847
+ print("Segmentation summary already exists. Skipping segmentation.")
848
+ return
849
+ elif "csv" in save_format:
850
+ if (Path(dir_summary_output) / "pixel_ratios.csv").exists():
851
+ print("Segmentation summary already exists. Skipping segmentation.")
852
+ return
853
+ # save_image_options as a property of the class
854
+ self.save_image_options = save_image_options
855
+
856
+ # make directory
857
+ dir_input = Path(dir_input)
858
+
859
+ # initialize completed_image_files
860
+ completed_image_files = set()
861
+ if dir_image_output is not None:
862
+ dir_image_output = Path(dir_image_output)
863
+ dir_image_output.mkdir(parents=True, exist_ok=True)
864
+ # get a list of .png files and _blend.png files in the output directory and get the file names as a set
865
+ completed_image_files.update(
866
+ [
867
+ str(Path(f).stem).replace("_blend", "").replace("_colored_segmented", "")
868
+ for f in dir_image_output.glob("*.png")
869
+ ]
870
+ )
871
+
872
+ if dir_summary_output is not None:
873
+ dir_summary_output = Path(dir_summary_output)
874
+ dir_summary_output.mkdir(parents=True, exist_ok=True)
875
+ # Create a new directory called "pixel_ratio_checkpoints"
876
+ dir_cache_segmentation_summary = dir_summary_output / "pixel_ratio_checkpoints"
877
+ dir_cache_segmentation_summary.mkdir(parents=True, exist_ok=True)
878
+
879
+ # Load all the checkpoint json files
880
+ checkpoints = glob.glob(str(dir_cache_segmentation_summary / "*.json"))
881
+ checkpoint_start_index = len(checkpoints)
882
+
883
+ if checkpoint_start_index > 0:
884
+ for checkpoint in checkpoints:
885
+ with open(checkpoint, "r") as f:
886
+ checkpoint_dict = json.load(f)
887
+ completed_image_files.update(checkpoint_dict.keys())
888
+
889
+ # also check pixel_ratios.json in dir_cache_segmentation_summary
890
+ if (dir_cache_segmentation_summary / "pixel_ratios.json").exists():
891
+ with open(dir_cache_segmentation_summary / "pixel_ratios.json", "r") as f:
892
+ pixel_ratio_dict = json.load(f)
893
+ completed_image_files.update(pixel_ratio_dict.keys())
894
+
895
+ # Get the list of all image files and filter the ones that are not completed yet
896
+ # Handle both single file and directory inputs
897
+ if dir_input.is_file():
898
+ # Process as a single file
899
+ image_file_list = [dir_input]
900
+ elif dir_input.is_dir():
901
+ # Process all suitable files in the directory
902
+ image_extensions = [
903
+ ".jpg",
904
+ ".jpeg",
905
+ ".png",
906
+ ".tif",
907
+ ".tiff",
908
+ ".bmp",
909
+ ".dib",
910
+ ".pbm",
911
+ ".pgm",
912
+ ".ppm",
913
+ ".sr",
914
+ ".ras",
915
+ ".exr",
916
+ ".jp2",
917
+ ]
918
+ # Get the list of all image files in the directory that are not completed yet
919
+ image_file_list = [
920
+ f
921
+ for f in Path(dir_input).iterdir()
922
+ if f.suffix in image_extensions and f.stem not in completed_image_files
923
+ ]
924
+ else:
925
+ raise ValueError("dir_input must be either a file or a directory.")
926
+
927
+ # skip if there are no image files to process
928
+ if len(image_file_list) == 0:
929
+ print("No image files to process. Skipping segmentation.")
930
+ return
931
+
932
+ outer_batch_size = 1000 # Number of inner batches in one outer batch
933
+ num_outer_batches = (len(image_file_list) + outer_batch_size * batch_size - 1) // (
934
+ outer_batch_size * batch_size
935
+ )
936
+
937
+ for i in verbosity_tqdm(
938
+ range(num_outer_batches),
939
+ desc=f"Processing outer batches of size {min(outer_batch_size * batch_size, len(image_file_list))}",
940
+ verbosity=self.verbosity,
941
+ level=1,
942
+ ):
943
+ # Get the image files for the current outer batch
944
+ outer_batch_image_file_list = image_file_list[
945
+ i * outer_batch_size * batch_size : (i + 1) * outer_batch_size * batch_size
946
+ ]
947
+
948
+ dataset = ImageDataset(outer_batch_image_file_list)
949
+ dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)
950
+
951
+ # set up pixel_ratio_dict for the current outer batch
952
+ pixel_ratio_dict = defaultdict(dict) # reset pixel_ratio_dict for each outer batch
953
+ panoptic_dict = defaultdict(dict) # reset panoptic_dict for each outer batch
954
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
955
+ futures = []
956
+
957
+ for batch in dataloader:
958
+ image_files, images, original_img_shape = batch
959
+ if self.task == "panoptic":
960
+ future = executor.submit(
961
+ self._process_images,
962
+ self.task,
963
+ image_files,
964
+ images,
965
+ dir_image_output,
966
+ pixel_ratio_dict,
967
+ original_img_shape,
968
+ panoptic_dict,
969
+ )
970
+ elif self.task == "semantic":
971
+ future = executor.submit(
972
+ self._process_images,
973
+ self.task,
974
+ image_files,
975
+ images,
976
+ dir_image_output,
977
+ pixel_ratio_dict,
978
+ original_img_shape,
979
+ )
980
+ futures.append(future)
981
+
982
+ for completed_future in verbosity_tqdm(
983
+ as_completed(futures),
984
+ total=len(futures),
985
+ desc=f"Processing outer batch #{i+1}",
986
+ verbosity=self.verbosity,
987
+ level=2,
988
+ ):
989
+ completed_future.result()
990
+
991
+ if dir_summary_output is not None:
992
+ # Save checkpoint for each outer batch
993
+ with open(
994
+ f"{dir_cache_segmentation_summary}/checkpoint_batch_{checkpoint_start_index+i+1}_pixel_ratio.json",
995
+ "w",
996
+ ) as f:
997
+ json.dump(pixel_ratio_dict, f)
998
+
999
+ if self.task == "panoptic":
1000
+ with open(
1001
+ f"{dir_cache_segmentation_summary}/checkpoint_batch_{checkpoint_start_index+i+1}_panoptic.json",
1002
+ "w",
1003
+ ) as f:
1004
+ json.dump(panoptic_dict, f)
1005
+ if dir_summary_output is not None:
1006
+ # Merge all checkpoints into a single pixel_ratio_dict
1007
+ pixel_ratio_dict = defaultdict(dict)
1008
+ for checkpoint in glob.glob(str(dir_cache_segmentation_summary / "*_pixel_ratio.json")):
1009
+ with open(checkpoint, "r") as f:
1010
+ checkpoint_dict = json.load(f)
1011
+ for key, value in checkpoint_dict.items():
1012
+ pixel_ratio_dict[key] = value
1013
+
1014
+ # Merge all checkpoints into a single panoptic_dict
1015
+ if self.task == "panoptic":
1016
+ panoptic_dict = defaultdict(dict)
1017
+ for checkpoint in glob.glob(str(dir_cache_segmentation_summary / "*_panoptic.json")):
1018
+ with open(checkpoint, "r") as f:
1019
+ checkpoint_dict = json.load(f)
1020
+ for key, value in checkpoint_dict.items():
1021
+ panoptic_dict[key] = value
1022
+
1023
+ # Merge existing pixel_ratios.json with the new pixel_ratio_dict
1024
+ if (dir_summary_output / "pixel_ratios.json").exists():
1025
+ with open(dir_summary_output / "pixel_ratios.json", "r") as f:
1026
+ existing_pixel_ratio_dict = json.load(f)
1027
+ for key, value in existing_pixel_ratio_dict.items():
1028
+ pixel_ratio_dict[key] = value
1029
+
1030
+ # Merge existing label_counts.json with the new panoptic_dict
1031
+ if self.task == "panoptic":
1032
+ if (dir_summary_output / "label_counts.json").exists():
1033
+ with open(dir_summary_output / "label_counts.json", "r") as f:
1034
+ existing_panoptic_dict = json.load(f)
1035
+ for key, value in existing_panoptic_dict.items():
1036
+ panoptic_dict[key] = value
1037
+
1038
+ # Save pixel_ratio_dict as a JSON or CSV file
1039
+ if "json" in save_format:
1040
+ with open(dir_summary_output / "pixel_ratios.json", "w") as f:
1041
+ json.dump(pixel_ratio_dict, f)
1042
+ if self.task == "panoptic":
1043
+ with open(dir_summary_output / "label_counts.json", "w") as f:
1044
+ json.dump(panoptic_dict, f)
1045
+ if "csv" in save_format:
1046
+ self._save_as_csv(pixel_ratio_dict, dir_summary_output, "pixel_ratios", csv_format)
1047
+ if self.task == "panoptic":
1048
+ self._save_as_csv(panoptic_dict, dir_summary_output, "label_counts", csv_format)
1049
+
1050
+ # Delete the "pixel_ratio_checkpoints" directory
1051
+ shutil.rmtree(dir_cache_segmentation_summary, ignore_errors=True)
1052
+
1053
+ def calculate_pixel_ratio_post_process(
1054
+ self, dir_input: Union[str, Path], dir_output: Union[str, Path], save_format: str = "json csv"
1055
+ ) -> None:
1056
+ """Calculates the pixel ratio of different classes present in the segmented
1057
+ images and saves the results in either JSON or CSV format.
1058
+
1059
+ Args:
1060
+ dir_input: A string or Path object representing the input directory containing the segmented images.
1061
+ dir_output: A string or Path object representing the output directory where the pixel ratio results will be saved.
1062
+ save_format: A list containing the file formats in which the results will be saved. The allowed file formats are "json" and "csv". The default value is "json csv".
1063
+
1064
+ Returns:
1065
+ : None
1066
+
1067
+ """
1068
+
1069
+ def calculate_label_ratios(image, label_map):
1070
+ """Calculates the pixel ratio of different classes present in a single
1071
+ image.
1072
+
1073
+ Args:
1074
+ image: A numpy array representing an image.
1075
+ label_map: A dictionary containing the label names and their respective RGB colors.
1076
+
1077
+ Returns:
1078
+ : A dictionary containing the pixel ratio of different classes in the given image.
1079
+
1080
+ """
1081
+ label_ratios = {}
1082
+ valid_pixels = 0
1083
+
1084
+ # First pass: count valid pixels that match colors in the label map
1085
+ for color, label in label_map.items():
1086
+ color_pixels = np.count_nonzero(np.all(image == color, axis=-1))
1087
+ valid_pixels += color_pixels
1088
+ label_ratios[label.name] = color_pixels
1089
+
1090
+ # Second pass: normalize by total valid pixels
1091
+ if valid_pixels > 0: # Avoid division by zero
1092
+ for label_name in label_ratios:
1093
+ label_ratios[label_name] = label_ratios[label_name] / valid_pixels
1094
+
1095
+ return label_ratios
1096
+
1097
+ def process_image_file(image_file, label_map):
1098
+ """Calculates the pixel ratio of different classes in a single segmented
1099
+ image file.
1100
+
1101
+ Args:
1102
+ image_file: A Path object representing the segmented image file.
1103
+ label_map: A dictionary containing the label names and their respective RGB colors.
1104
+
1105
+ Returns:
1106
+ : A tuple containing the image file key and the pixel ratio of different classes in the given image.
1107
+
1108
+ """
1109
+ image_file_key = str(Path(image_file).stem).replace("_colored_segmented", "")
1110
+ image = cv2.imread(str(image_file))
1111
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1112
+ label_ratios = calculate_label_ratios(image, label_map)
1113
+ return image_file_key, label_ratios
1114
+
1115
+ def results_to_dataframe(results):
1116
+ """Converts the results obtained from processing each image file into a
1117
+ Pandas DataFrame.
1118
+
1119
+ Args:
1120
+ results: A list of tuples, where each tuple contains the image file key and the pixel ratio of different classes in the corresponding image.
1121
+
1122
+ Returns:
1123
+ : A Pandas DataFrame containing the pixel ratios of different classes in each image file.
1124
+
1125
+ """
1126
+ pixel_ratio_dict = {}
1127
+
1128
+ for image_file_key, label_ratios in results:
1129
+ pixel_ratio_dict[str(image_file_key)] = label_ratios
1130
+
1131
+ pixel_ratios_df = pd.DataFrame(pixel_ratio_dict).transpose()
1132
+ pixel_ratios_df.fillna(0, inplace=True)
1133
+ pixel_ratios_df.index.names = ["filename_key"]
1134
+
1135
+ return pixel_ratios_df
1136
+
1137
+ def results_to_nested_dict(results):
1138
+ """Converts the results obtained from processing each image file into a
1139
+ nested dictionary.
1140
+
1141
+ Args:
1142
+ results: A list of tuples, where each tuple contains the image file key and the pixel ratio of different classes in the corresponding image.
1143
+
1144
+ Returns:
1145
+ : A nested dictionary containing the pixel ratios of different classes in each image file.
1146
+
1147
+ """
1148
+ data = {}
1149
+
1150
+ for image_file_key, label_ratios in results:
1151
+ image_file_key = str(image_file_key)
1152
+ data[image_file_key] = label_ratios
1153
+
1154
+ return data
1155
+
1156
+ # create dir_output
1157
+ dir_output = Path(dir_output)
1158
+ dir_output.mkdir(parents=True, exist_ok=True)
1159
+
1160
+ # get files
1161
+ if isinstance(dir_input, str):
1162
+ dir_input = Path(dir_input)
1163
+
1164
+ # Set image file extensions
1165
+ image_extensions = [".jpg", ".png"]
1166
+
1167
+ if dir_input.is_file():
1168
+ image_files = [dir_input]
1169
+ elif dir_input.is_dir():
1170
+ image_files = [
1171
+ file
1172
+ for file in dir_input.rglob("*")
1173
+ if file.suffix.lower() in image_extensions and "_colored_segmented" in file.stem
1174
+ ]
1175
+ else:
1176
+ raise ValueError("dir_input must be either a file or a directory.")
1177
+
1178
+ results = thread_map(process_image_file, image_files, [self.label_map] * len(image_files))
1179
+
1180
+ if "json" in save_format:
1181
+ json_output_file = Path(dir_output) / "pixel_ratios.json"
1182
+ nested_dict = results_to_nested_dict(results)
1183
+ with open(json_output_file, "w") as f:
1184
+ json.dump(nested_dict, f, indent=2)
1185
+
1186
+ if "csv" in save_format:
1187
+ csv_output_file = Path(dir_output) / "pixel_ratios.csv"
1188
+ df = results_to_dataframe(results)
1189
+ df.to_csv(csv_output_file)