Abs6187 commited on
Commit
d664f3f
·
verified ·
1 Parent(s): ddda78b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -3,18 +3,29 @@ from ultralytics import YOLO
3
  from PIL import Image
4
  import os
5
 
 
 
 
 
6
 
7
-
8
- # Load the trained YOLOv8 model
9
- model = YOLO("best.pt")
 
 
 
 
 
 
 
10
 
11
  # Define the prediction function
12
  def predict(image):
13
- results = model(image) # Run YOLOv8 model on the uploaded image
14
  results_img = results[0].plot() # Get image with bounding boxes
15
  return Image.fromarray(results_img)
16
 
17
- # Get example images from the images folder
18
  def get_example_images():
19
  examples = []
20
  image_folder = "images"
@@ -25,13 +36,13 @@ def get_example_images():
25
 
26
  # Create Gradio interface
27
  interface = gr.Interface(
28
- fn=predict,
29
- inputs=gr.Image(type="pil"),
30
  outputs=gr.Image(type="pil"),
31
- title="Helmet Detection with YOLO",
32
- description="Upload an image to detect helmets.",
33
  examples=get_example_images()
34
  )
35
 
36
  # Launch the interface
37
- interface.launch(share=True)
 
3
  from PIL import Image
4
  import os
5
 
6
+ # Load models with priority to YOLOv8
7
+ # Try to load YOLOv8 model first, fall back to YOLOv11 if not available
8
+ model = None
9
+ model_name = ""
10
 
11
+ if os.path.exists("best.pt"):
12
+ model = YOLO("best.pt")
13
+ model_name = "YOLOv8 (best.pt)"
14
+ print("✓ Loaded YOLOv8 model (best.pt)")
15
+ elif os.path.exists("yolov11nbest.pt"):
16
+ model = YOLO("yolov11nbest.pt")
17
+ model_name = "YOLOv11 (yolov11nbest.pt)"
18
+ print("✓ Loaded YOLOv11 model (yolov11nbest.pt)")
19
+ else:
20
+ raise FileNotFoundError("No model file found. Please ensure 'best.pt' or 'yolov11nbest.pt' exists.")
21
 
22
  # Define the prediction function
23
  def predict(image):
24
+ results = model(image) # Run YOLO model on the uploaded image
25
  results_img = results[0].plot() # Get image with bounding boxes
26
  return Image.fromarray(results_img)
27
 
28
+ # Get example images from the root folder
29
  def get_example_images():
30
  examples = []
31
  image_folder = "images"
 
36
 
37
  # Create Gradio interface
38
  interface = gr.Interface(
39
+ fn=predict,
40
+ inputs=gr.Image(type="pil"),
41
  outputs=gr.Image(type="pil"),
42
+ title=f"Helmet Detection with YOLO",
43
+ description=f"Upload an image to detect helmets. **Currently using: {model_name}**",
44
  examples=get_example_images()
45
  )
46
 
47
  # Launch the interface
48
+ interface.launch()