Ahmed-El-Sharkawy commited on
Commit
21bd14f
·
verified ·
1 Parent(s): 13e4b9b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (2).py +77 -0
  2. requirements.txt +6 -0
app (2).py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms as transforms
6
+ import torchvision.models as models
7
+ import os
8
+ import torch
9
+
10
+ # Set device
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ # Load the main classifier (Detector_best_model.pth)
14
+ main_model = models.mobilenet_v3_large(weights=None) # Updated: weights=None
15
+
16
+ #num_ftrs = main_model.fc.in_features
17
+ # main_model.fc = nn.Linear(num_ftrs, 2) # 2 classes: AI-generated_Image, Real_Image
18
+
19
+ num_ftrs = main_model.classifier[3].in_features
20
+ main_model.classifier[3] = nn.Linear(num_ftrs, 2)
21
+
22
+ # main_model.fc = nn.Sequential(
23
+ # nn.Dropout(p=0.5), # Match the training architecture
24
+ # nn.Linear(num_ftrs, 2) # 2 classes: AI-generated Image, Real Image
25
+ # )
26
+
27
+ main_model.load_state_dict(torch.load('best_model3_mobilenetv3_large.pth', map_location=device, weights_only=True)) # Updated: weights_only=True
28
+ main_model = main_model.to(device)
29
+ main_model.eval()
30
+
31
+ # Define class names for the classifier based on the Folder structure
32
+ classes_name = ['AI-generated Image', 'Real Image']
33
+
34
+ def convert_to_rgb(image):
35
+ """
36
+ Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'.
37
+ This is to avoid transparency issues during model training.
38
+ """
39
+ if image.mode in ('P', 'RGBA'):
40
+ return image.convert('RGB')
41
+ return image
42
+
43
+ # Define preprocessing transformations (same used during training)
44
+ preprocess = transforms.Compose([
45
+ transforms.Lambda(convert_to_rgb),
46
+ transforms.Resize((224, 224)), # Resize here, no need for shape argument in gr.Image
47
+ transforms.ToTensor(),
48
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization
49
+ ])
50
+
51
+ def classify_image(image):
52
+ # Open the image using PIL
53
+ image = Image.fromarray(image)
54
+
55
+ # Preprocess the image
56
+ input_image = preprocess(image).unsqueeze(0).to(device)
57
+
58
+ # Perform inference with the main classifier
59
+ with torch.no_grad():
60
+ output = main_model(input_image)
61
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
62
+ confidence, predicted_class = torch.max(probabilities, 0)
63
+
64
+ # Main classifier result
65
+ main_prediction = classes_name[predicted_class]
66
+ main_confidence = confidence.item()
67
+
68
+ return f"Image is : {main_prediction} (Confidence: {main_confidence:.4f})"
69
+
70
+ # Gradio interface (updated)
71
+ image_input = gr.Image(image_mode="RGB") # Removed shape argument
72
+ output_text = gr.Textbox()
73
+
74
+ gr.Interface(fn=classify_image, inputs=image_input, outputs=[output_text],
75
+ title="Detect AI-generated Image ",
76
+ description="Upload an image to Detected AI-generated Image .",
77
+ theme="default").launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Pillow==10.3.0
2
+ torch==2.4.1
3
+ torchvision==0.19.1
4
+ numpy==1.26.4
5
+ gradio
6
+ transformers