santi9462 commited on
Commit
392668e
·
verified ·
1 Parent(s): b016b55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -16
app.py CHANGED
@@ -1,12 +1,19 @@
 
1
  import os, csv, random
2
  from datetime import datetime
3
  import numpy as np
4
  from PIL import Image
5
  import gradio as gr
6
 
 
 
 
 
7
  # ---------- Paths ----------
8
  IMG_DIR = "durian_images"
9
  HIS_CSV = "history/history.csv"
 
 
10
  os.makedirs(IMG_DIR, exist_ok=True)
11
  os.makedirs(os.path.dirname(HIS_CSV), exist_ok=True)
12
 
@@ -118,14 +125,10 @@ RIPENESS_CAPTION_VARIANTS = {
118
  ],
119
  }
120
 
121
- # ---------- ความมั่นใจ ----------
122
  def adjust_confidence(raw_conf: float) -> float:
123
  p = raw_conf if raw_conf > 1 else raw_conf * 100.0
124
- if p < 30.0:
125
- return 96.0
126
- elif p < 60.0:
127
- return 85.0
128
- return round(p, 1)
129
 
130
  # ---------- เลือกแคปชั่น ----------
131
  def pick_variant(level: str):
@@ -138,7 +141,7 @@ def generate_caption(level: str, raw_conf_pct: float) -> str:
138
  body_list = pick_variant(level)
139
  return " ".join([head] + body_list)
140
 
141
- # ---------- Fallback inference ----------
142
  def _classify_4class_by_color(img: Image.Image):
143
  arr = np.array(img.convert("RGB")).reshape(-1, 3).mean(axis=0)
144
  R, G, B = arr
@@ -148,10 +151,97 @@ def _classify_4class_by_color(img: Image.Image):
148
  idx = int(np.argmax(probs))
149
  return idx, probs
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  # ---------- Core inference ----------
 
 
 
152
  def infer_ripeness_and_caption(image: Image.Image):
153
- idx, probs = _classify_4class_by_color(image)
154
- label = RIPENESS_LABELS[idx]
 
 
 
 
155
  raw_conf_pct = float(probs[idx]) * 100.0
156
  cap = generate_caption(label, raw_conf_pct)
157
  return idx, probs, cap
@@ -181,21 +271,34 @@ def analyze(image):
181
  return "กรุณาอัปโหลดภาพ", "", None, "❌ ไม่มีภาพ"
182
  try:
183
  idx, probs, caption = infer_ripeness_and_caption(image)
184
- class_name = RIPENESS_LABELS[idx]
185
- conf = f"{adjust_confidence(float(probs[idx])):.1f}%"
186
- except Exception:
187
- class_name, conf, caption = "พร้อมรับประทาน(สำรอง)", "100.0%", "เดโม แบบสำรอง"
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
190
  out_path = os.path.join(IMG_DIR, f"durian_{ts}.jpg")
191
  try:
192
  image.save(out_path, quality=90)
193
  save_history_row(datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
194
- class_name, conf, caption, out_path)
195
  except Exception:
196
  pass
197
 
198
- return f"ระดับ: {class_name} (ความมั่นใจ {conf})", caption, image, " เสร็จสิ้น"
 
199
 
200
  def show_history():
201
  rows = load_history(limit=200)
@@ -233,4 +336,5 @@ with gr.Blocks(title="Durian Happiness Level") as demo:
233
 
234
  if __name__ == "__main__":
235
  random.seed()
236
- demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)
 
 
1
+ # app.py
2
  import os, csv, random
3
  from datetime import datetime
4
  import numpy as np
5
  from PIL import Image
6
  import gradio as gr
7
 
8
+ # ---- PyTorch / Torchvision ----
9
+ import torch, torch.nn as nn
10
+ from torchvision import transforms, models
11
+
12
  # ---------- Paths ----------
13
  IMG_DIR = "durian_images"
14
  HIS_CSV = "history/history.csv"
15
+ CKPT_PATH = "durian_mnv2_ckpt.pth" # วางไฟล์น้ำหนักไว้โฟลเดอร์เดียวกับ app.py
16
+
17
  os.makedirs(IMG_DIR, exist_ok=True)
18
  os.makedirs(os.path.dirname(HIS_CSV), exist_ok=True)
19
 
 
125
  ],
126
  }
127
 
128
+ # ---------- ความมั่นใจ (ไม่ปั๊มตัวเลข) ----------
129
  def adjust_confidence(raw_conf: float) -> float:
130
  p = raw_conf if raw_conf > 1 else raw_conf * 100.0
131
+ return round(max(0.0, min(100.0, p)), 1)
 
 
 
 
132
 
133
  # ---------- เลือกแคปชั่น ----------
134
  def pick_variant(level: str):
 
141
  body_list = pick_variant(level)
142
  return " ".join([head] + body_list)
143
 
144
+ # ---------- Fallback 4-class by color ----------
145
  def _classify_4class_by_color(img: Image.Image):
146
  arr = np.array(img.convert("RGB")).reshape(-1, 3).mean(axis=0)
147
  R, G, B = arr
 
151
  idx = int(np.argmax(probs))
152
  return idx, probs
153
 
154
+ # ---------- Model (PyTorch) ----------
155
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156
+ idx_to_class = {i: name for i, name in enumerate(RIPENESS_LABELS)}
157
+ model = None
158
+ temperature = 1.0
159
+
160
+ def _build_model(num_classes=4):
161
+ m = models.mobilenet_v2(weights=None)
162
+ in_f = m.classifier[1].in_features
163
+ m.classifier[1] = nn.Linear(in_f, num_classes)
164
+ return m
165
+
166
+ def _load_model():
167
+ """โหลดโมเดล + mapping จาก ckpt ถ้ามี"""
168
+ global model, idx_to_class, temperature
169
+ if not os.path.exists(CKPT_PATH):
170
+ print(f"[WARN] ckpt not found at {CKPT_PATH}. Use color fallback.")
171
+ return False
172
+
173
+ ckpt = torch.load(CKPT_PATH, map_location=device)
174
+
175
+ state = None
176
+ if isinstance(ckpt, dict):
177
+ for k in ["state_dict", "model_state_dict", "model"]:
178
+ if k in ckpt and isinstance(ckpt[k], dict):
179
+ state = ckpt[k]; break
180
+ if state is None:
181
+ # บางครั้งเซฟเป็น state_dict ตรง ๆ
182
+ # ตรวจคร่าว ๆ ว่า value เป็น tensor
183
+ if all(hasattr(v, "shape") for v in ckpt.values()):
184
+ state = ckpt
185
+ else:
186
+ # ผิดรูปแบบ
187
+ print("[WARN] Unknown ckpt format, using fallback.")
188
+ return False
189
+
190
+ model_ = _build_model(num_classes=len(RIPENESS_LABELS))
191
+ if state is not None:
192
+ state = {k.replace("module.", ""): v for k, v in state.items()}
193
+ missing, unexpected = model_.load_state_dict(state, strict=False)
194
+ if missing:
195
+ print(f"[INFO] missing keys: {missing[:5]}{'...' if len(missing)>5 else ''}")
196
+ if unexpected:
197
+ print(f"[INFO] unexpected keys: {unexpected[:5]}{'...' if len(unexpected)>5 else ''}")
198
+ else:
199
+ return False
200
+
201
+ if "class_to_idx" in ckpt and isinstance(ckpt["class_to_idx"], dict):
202
+ c2i = ckpt["class_to_idx"]
203
+ idx_to_class = {i: lbl for lbl, i in c2i.items()}
204
+
205
+ temperature = float(ckpt.get("temperature", 1.0))
206
+ model_.to(device).eval()
207
+ print(f"[OK] Model loaded. Temperature={temperature}. Classes={list(idx_to_class.values())}")
208
+ # set global
209
+ globals()["model"] = model_
210
+ globals()["idx_to_class"] = idx_to_class
211
+ globals()["temperature"] = temperature
212
+ return True
213
+
214
+ MODEL_READY = _load_model()
215
+
216
+ # Preprocess (ปรับให้ตรงกับตอนเทรน)
217
+ IM_SIZE = 224
218
+ _base_tf = transforms.Compose([
219
+ transforms.Resize((IM_SIZE, IM_SIZE)),
220
+ transforms.ToTensor(),
221
+ transforms.Normalize(mean=[0.485,0.456,0.406],
222
+ std=[0.229,0.224,0.225]),
223
+ ])
224
+
225
+ def _predict_proba_with_model(img: Image.Image):
226
+ """TTA เบา ๆ : original + flip แล้วเฉลี่ย"""
227
+ imgs = [img, img.transpose(Image.FLIP_LEFT_RIGHT)]
228
+ xs = torch.stack([_base_tf(im) for im in imgs], dim=0).to(device)
229
+ with torch.no_grad():
230
+ logits = model(xs) / temperature
231
+ probs = torch.softmax(logits, dim=1).mean(dim=0).cpu().numpy()
232
+ return probs
233
+
234
  # ---------- Core inference ----------
235
+ def label_by_idx(i: int) -> str:
236
+ return idx_to_class.get(i, RIPENESS_LABELS[i])
237
+
238
  def infer_ripeness_and_caption(image: Image.Image):
239
+ if MODEL_READY:
240
+ probs = _predict_proba_with_model(image)
241
+ idx = int(np.argmax(probs))
242
+ else:
243
+ idx, probs = _classify_4class_by_color(image)
244
+ label = label_by_idx(idx)
245
  raw_conf_pct = float(probs[idx]) * 100.0
246
  cap = generate_caption(label, raw_conf_pct)
247
  return idx, probs, cap
 
271
  return "กรุณาอัปโหลดภาพ", "", None, "❌ ไม่มีภาพ"
272
  try:
273
  idx, probs, caption = infer_ripeness_and_caption(image)
274
+
275
+ order = np.argsort(probs)[::-1]
276
+ top1, top2 = int(order[0]), int(order[1])
277
+ p1, p2 = float(probs[top1]), float(probs[top2])
278
+
279
+ class_name = label_by_idx(top1)
280
+ conf_str = f"{adjust_confidence(p1):.1f}%"
281
+
282
+ borderline = ""
283
+ if (p1 - p2) < 0.15:
284
+ borderline = f"\n⚠️ ก้ำกึ่งระหว่าง {label_by_idx(top1)} ({p1*100:.1f}%) และ {label_by_idx(top2)} ({p2*100:.1f}%)"
285
+
286
+ result_text = f"ระดับ: {class_name} (ความมั่นใจ {conf_str}){borderline}"
287
+ except Exception as e:
288
+ result_text = "พร้อมรับประทาน(สำรอง) (ความมั่นใจ 100.0%)"
289
+ caption = "เดโม แบบสำรอง"
290
 
291
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
292
  out_path = os.path.join(IMG_DIR, f"durian_{ts}.jpg")
293
  try:
294
  image.save(out_path, quality=90)
295
  save_history_row(datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
296
+ class_name, conf_str, caption, out_path)
297
  except Exception:
298
  pass
299
 
300
+ status_text = " เสร็จสิ้น" if MODEL_READY else "ℹ️ ใช้โหมดสำรอง (สี)"
301
+ return result_text, caption, image, status_text
302
 
303
  def show_history():
304
  rows = load_history(limit=200)
 
336
 
337
  if __name__ == "__main__":
338
  random.seed()
339
+ port = int(os.environ.get("PORT", "7860"))
340
+ demo.launch(server_name="0.0.0.0", server_port=port, show_api=False)