import gradio as gr import torch from transformers import AutoTokenizer, AutoModel from PIL import Image from torchvision import transforms import json from torch import nn from typing import Literal import os import logging import traceback import warnings import time import signal import sys warnings.filterwarnings("ignore") logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Use lighter models to reduce storage LIGHTWEIGHT_TEXT_MODEL = "distilbert-base-uncased" # Much smaller than BERT LIGHTWEIGHT_IMAGE_MODEL = "microsoft/resnet-18" # Smaller than ResNet-34 # Set environment variables to prevent protocol errors os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0" os.environ["GRADIO_SERVER_PORT"] = "7860" class MultimodalClassifier(nn.Module): def __init__( self, text_encoder_id_or_path: str, image_encoder_id_or_path: str, projection_dim: int, fusion_method: Literal["concat", "align", "cosine_similarity"] = "concat", proj_dropout: float = 0.1, fusion_dropout: float = 0.1, num_classes: int = 1, ) -> None: super().__init__() self.fusion_method = fusion_method self.projection_dim = projection_dim self.num_classes = num_classes self.text_encoder = AutoModel.from_pretrained(text_encoder_id_or_path) self.text_projection = nn.Sequential( nn.Linear(self.text_encoder.config.hidden_size, self.projection_dim), nn.Dropout(proj_dropout), ) self.image_encoder = AutoModel.from_pretrained(image_encoder_id_or_path, trust_remote_code=True) self.image_encoder.classifier = nn.Identity() # Adjust for ResNet-18 (512 features) vs ResNet-34 (512 features) self.image_projection = nn.Sequential( nn.Linear(512, self.projection_dim), nn.Dropout(proj_dropout), ) fusion_input_dim = self.projection_dim * 2 if fusion_method == "concat" else self.projection_dim self.fusion_layer = nn.Sequential( nn.Dropout(fusion_dropout), nn.Linear(fusion_input_dim, self.projection_dim), nn.GELU(), nn.Dropout(fusion_dropout), ) self.classifier = nn.Linear(self.projection_dim, self.num_classes) def forward(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: full_text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True).last_hidden_state full_text_features = full_text_features[:, 0, :] full_text_features = self.text_projection(full_text_features) resnet_image_features = self.image_encoder(pixel_values=pixel_values).last_hidden_state resnet_image_features = resnet_image_features.mean(dim=[-2, -1]) resnet_image_features = self.image_projection(resnet_image_features) if self.fusion_method == "concat": fused_features = torch.cat([full_text_features, resnet_image_features], dim=-1) else: fused_features = full_text_features * resnet_image_features fused_features = self.fusion_layer(fused_features) classification_output = self.classifier(fused_features) return classification_output def load_model(): try: if not os.path.exists("config.json"): raise FileNotFoundError("config.json file not found. Please ensure it exists in the current directory.") logger.info("Loading configuration from config.json...") with open("config.json", "r") as f: config = json.load(f) required_keys = ["text_encoder_id_or_path", "projection_dim", "fusion_method", "proj_dropout", "fusion_dropout", "num_classes"] for key in required_keys: if key not in config: raise KeyError(f"Missing required key '{key}' in config.json") logger.info("Initializing MultimodalClassifier with lightweight models...") model = MultimodalClassifier( text_encoder_id_or_path=LIGHTWEIGHT_TEXT_MODEL, # Use DistilBERT instead of BERT image_encoder_id_or_path=LIGHTWEIGHT_IMAGE_MODEL, # Use ResNet-18 instead of ResNet-34 projection_dim=config["projection_dim"], fusion_method=config["fusion_method"], proj_dropout=config["proj_dropout"], fusion_dropout=config["fusion_dropout"], num_classes=config["num_classes"] ) if os.path.exists("model_weights.pth"): logger.info("Loading model weights...") checkpoint = torch.load("model_weights.pth", map_location=torch.device('cpu')) model.load_state_dict(checkpoint, strict=False) else: logger.warning("model_weights.pth not found. Using untrained model for demonstration.") logger.warning("For best results, please provide the trained model weights.") logger.info("Model loaded successfully!") return model except FileNotFoundError as e: logger.error(f"File error: {e}") raise except json.JSONDecodeError as e: logger.error(f"JSON parsing error in config.json: {e}") raise ValueError(f"Invalid JSON format in config.json: {e}") except KeyError as e: logger.error(f"Configuration error: {e}") raise except Exception as e: logger.error(f"Unexpected error loading model: {e}") logger.error(traceback.format_exc()) raise def initialize_components(): global model, text_tokenizer try: logger.info("Initializing model and tokenizer...") model = load_model() model.eval() logger.info("Loading DistilBERT tokenizer...") text_tokenizer = AutoTokenizer.from_pretrained(LIGHTWEIGHT_TEXT_MODEL) logger.info("All components initialized successfully!") return True except Exception as e: logger.error(f"Failed to initialize components: {e}") logger.error(traceback.format_exc()) return False model = None text_tokenizer = None initialization_success = initialize_components() image_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def validate_inputs(image, text): if not initialization_success: raise RuntimeError("Model initialization failed. Please check the logs for details.") if model is None or text_tokenizer is None: raise RuntimeError("Model or tokenizer not loaded. Please restart the application.") if image is None: raise ValueError("Please upload an image.") if not text or not text.strip(): raise ValueError("Please enter some text for analysis.") if len(text.strip()) < 10: raise ValueError("Text is too short. Please provide at least 10 characters.") if len(text) > 2000: raise ValueError("Text is too long. Please provide text with less than 2000 characters.") try: if image.mode not in ['RGB', 'RGBA', 'L']: image = image.convert('RGB') except Exception as e: raise ValueError(f"Invalid image format: {e}") def simple_fake_news_detector(text: str) -> str: """Simple rule-based fake news detector as fallback""" fake_indicators = [ "breaking news", "shocking", "you won't believe", "doctors hate", "one weird trick", "click here", "urgent", "exclusive", "leaked", "nazi salute", "hail trump", "take our country back", "step on toes" ] real_indicators = [ "according to", "reported by", "official statement", "confirmed", "study shows", "research indicates", "data reveals", "analysis" ] text_lower = text.lower() fake_score = sum(1 for indicator in fake_indicators if indicator in text_lower) real_score = sum(1 for indicator in real_indicators if indicator in text_lower) if fake_score > real_score: return "Fake News (Rule-based)" elif real_score > fake_score: return "Real News (Rule-based)" else: return "Uncertain (Rule-based)" def predict(image: Image.Image, text: str) -> str: try: logger.info("Starting prediction...") validate_inputs(image, text) # If model weights are not available, use simple rule-based detection if not os.path.exists("model_weights.pth"): logger.info("Using rule-based fallback detection...") return simple_fake_news_detector(text) logger.info("Processing text input...") text_inputs = text_tokenizer( text, return_tensors="pt", padding="max_length", truncation=True, max_length=512 ) logger.info("Processing image input...") try: if image.mode != 'RGB': image = image.convert('RGB') image_input = image_transform(image).unsqueeze(0) except Exception as e: raise ValueError(f"Failed to process image: {e}") logger.info("Running model inference...") with torch.no_grad(): try: classification_output = model( pixel_values=image_input, input_ids=text_inputs["input_ids"], attention_mask=text_inputs["attention_mask"] ) predicted_class = torch.sigmoid(classification_output).round().item() except Exception as e: logger.warning(f"Model inference failed, using fallback: {e}") return simple_fake_news_detector(text) result = "Fake News" if predicted_class == 1 else "Real News" logger.info(f"Prediction completed: {result}") return result except ValueError as e: logger.warning(f"Input validation error: {e}") return f"Error: {e}" except RuntimeError as e: logger.error(f"Runtime error: {e}") return f"Error: {e}" except Exception as e: logger.error(f"Unexpected error during prediction: {e}") logger.error(traceback.format_exc()) return f"Error: An unexpected error occurred. Please try again." def create_interface(): try: if not initialization_success: error_msg = "Failed to initialize the model. Please check that config.json and model_weights.pth files exist and are valid." logger.error(error_msg) def error_function(image, text): return error_msg return gr.Interface( fn=error_function, inputs=[ gr.Image(type="pil", label="Upload Related Image"), gr.Textbox(lines=2, placeholder="Enter news text for classification...", label="Input Text") ], outputs=gr.Label(label="Error"), title="Fake News Detector - Initialization Error", description=error_msg ) return gr.Interface( fn=predict, inputs=[ gr.Image(type="pil"), gr.Textbox(lines=2) ], outputs=gr.Textbox(lines=1), title="Fake News Detector", allow_flagging="never" ) except Exception as e: logger.error(f"Failed to create interface: {e}") logger.error(traceback.format_exc()) raise def create_simple_interface(): """Create a minimal interface to avoid protocol errors""" try: return gr.Interface( fn=predict, inputs=[ gr.Image(type="pil"), gr.Textbox(lines=2) ], outputs=gr.Textbox(lines=1), title="Fake News Detector", allow_flagging="never" ) except Exception as e: logger.error(f"Failed to create simple interface: {e}") raise def create_ultra_minimal_interface(): """Create the most minimal interface possible""" try: return gr.Interface( fn=predict, inputs=[gr.Image(), gr.Textbox()], outputs=gr.Textbox(), allow_flagging="never" ) except Exception as e: logger.error(f"Failed to create ultra minimal interface: {e}") raise def main(): try: logger.info("Starting Fake News Detector application...") interface = None try: interface = create_interface() except Exception as e: logger.warning(f"Failed to create full interface, trying simple version: {e}") try: interface = create_simple_interface() except Exception as e2: logger.warning(f"Failed to create simple interface, using ultra minimal: {e2}") interface = create_ultra_minimal_interface() logger.info("Launching Gradio interface...") try: interface.launch( server_name="0.0.0.0", server_port=7860, share=True, quiet=False, inbrowser=False ) except Exception as launch_error: logger.warning(f"Launch failed on port 7860, trying port 7861: {launch_error}") interface.launch( server_name="0.0.0.0", server_port=7861, share=True, quiet=False, inbrowser=False ) except Exception as e: logger.error(f"Failed to start application: {e}") logger.error(traceback.format_exc()) print(f"Error starting application: {e}") print("Please check the logs for more details.") if __name__ == "__main__": main()