import torch import torch.nn.functional as F from transformers import AutoTokenizer from huggingface_hub import hf_hub_download import gradio as gr import requests import re import time import sys import logging import urllib3 # Import urllib3 to handle warnings # --- Suppress specific noisy asyncio errors on shutdown --- if sys.version_info >= (3, 10): logging.getLogger("asyncio").setLevel(logging.WARNING) # --- Suppress InsecureRequestWarning --- # This is expected behavior for a Phishing Detector as we often scan sites with invalid SSL urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # --- import your architecture --- # Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py) # and update the import path accordingly. from model import DeBERTaLSTMClassifier # <-- your class # --- Import RAG modules --- from rag_engine import RAGEngine from llm_client import LLMClient # --------- Config ---------- REPO_ID = "dungeon29/deberta-lstm-detect-phishing" CKPT_NAME = "pytorch_model.bin" MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone LABELS = ["benign", "phishing"] # adjust to your classes # If your checkpoint contains hyperparams, you can fetch them like: # checkpoint.get("config") or checkpoint.get("model_args") # and pass into DeBERTaLSTMClassifier(**model_args) # --------- Load model/tokenizer once (global) ---------- device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME) checkpoint = torch.load(ckpt_path, map_location=device) # If you saved hyperparams in the checkpoint, use them: model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...} model = DeBERTaLSTMClassifier(**model_args) # Load weights try: state_dict = torch.load(ckpt_path, map_location=device) # Xử lý nếu file lưu dạng checkpoint đầy đủ (có key "model_state_dict") if "model_state_dict" in state_dict: state_dict = state_dict["model_state_dict"] model.load_state_dict(state_dict, strict=False) # Kiểm tra layer attention if hasattr(model, 'attention') and 'attention.weight' not in state_dict: print("⚠️ Loaded model without attention layer, using newly initialized attention weights") else: print("✅ Load weights successfully!") except Exception as e: print(f"❌ Error when loading weights: {e}") raise e model.to(device).eval() # --------- Initialize RAG & LLM ---------- print("Initializing RAG Engine (LangChain)...") rag_engine = RAGEngine() print("RAG Engine ready.") print("Initializing Qwen3-0.6B(GGUF) LLM (LangChain)...") # Pass vector_store to LLMClient for RetrievalQA llm_client = LLMClient(vector_store=rag_engine.vector_store) print("LLM ready.") # --------- Helper functions ---------- def is_url(text): """Check if text is a URL""" url_pattern = re.compile( r'^https?://' # http:// or https:// r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain... r'localhost|' # localhost... r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip r'(?::\d+)?' # optional port r'(?:/?|[/?]\S+)$', re.IGNORECASE) return url_pattern.match(text) is not None def fetch_html_content(url, timeout=10): """Fetch HTML content from URL (Raw HTML for Model)""" try: headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36', 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8', 'Accept-Language': 'en-US,en;q=0.9,vi;q=0.8', 'Referer': 'https://www.google.com/' } # verify=False is intentional for phishing detection, warning suppressed globally response = requests.get(url, headers=headers, timeout=timeout, verify=False) response.raise_for_status() # Return FULL RAW HTML content instead of stripped text # Model needs HTML tags/structure to detect hidden threats return response.text, response.status_code except requests.exceptions.RequestException as e: return None, f"Request error: {str(e)}" except Exception as e: return None, f"General error: {str(e)}" def predict_single_text(text, text_type="text"): """Predict for a single text input""" # Tokenize # Increased max_length to 512 to capture more HTML content inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=512 ) # DeBERTa typically doesn't use token_type_ids inputs.pop("token_type_ids", None) # Move to device inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): try: # Try to get predictions with attention weights result = model(**inputs, return_attention=True) if isinstance(result, tuple) and len(result) == 3: logits, attention_weights, deberta_attentions = result has_attention = True else: logits = result has_attention = False except TypeError: # Fallback for older model without return_attention parameter logits = model(**inputs) has_attention = False probs = F.softmax(logits, dim=-1).squeeze(0).tolist() # Get tokens for visualization tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist()) return probs, tokens, has_attention, attention_weights if has_attention else None def combine_predictions(url_probs, html_probs, url_weight=0.3, html_weight=0.7): """Combine URL and HTML content predictions""" combined_probs = [ url_weight * url_probs[0] + html_weight * html_probs[0], # benign url_weight * url_probs[1] + html_weight * html_probs[1] # phishing ] return combined_probs # --------- Inference function ---------- def predict_fn(text: str): if not text or not text.strip(): return {"error": "Please enter a URL or text."}, "" # Check if input is URL if is_url(text.strip()): # Process URL url = text.strip() # Get prediction for URL itself url_probs, url_tokens, url_has_attention, url_attention = predict_single_text(url, "URL") # Try to fetch HTML content html_content, status = fetch_html_content(url) if html_content: # Get prediction for HTML content (Raw HTML now) html_probs, html_tokens, html_has_attention, html_attention = predict_single_text(html_content, "HTML") # Combine predictions combined_probs = combine_predictions(url_probs, html_probs) # Use combined probabilities but show analysis for both probs = combined_probs tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display has_attention = url_has_attention or html_has_attention attention_weights = url_attention if url_has_attention else html_attention analysis_type = "Combined URL + HTML Analysis" fetch_status = f"✅ Successfully fetched HTML content (Status: {status})" else: # Fallback for URL-only analysis probs = url_probs tokens = url_tokens has_attention = url_has_attention attention_weights = url_attention analysis_type = "URL-only Analysis" fetch_status = f"⚠️ Could not fetch HTML content: {status}" else: # Process as regular text probs, tokens, has_attention, attention_weights = predict_single_text(text, "text") analysis_type = "Text Analysis" fetch_status = "" # Create detailed analysis predicted_class = "phishing" if probs[1] > probs[0] else "benign" confidence = max(probs) detailed_analysis = f"""

🔍 {analysis_type}

{predicted_class.upper()}
Confidence: {confidence:.1%}
{'This appears to be a phishing attempt!' if predicted_class == 'phishing' else '✅ This appears to be legitimate content.'}
""" if fetch_status: detailed_analysis += f"""
Fetch Status: {fetch_status}
""" if has_attention and attention_weights is not None: attention_scores = attention_weights.squeeze(0).tolist() token_analysis = [] for i, (token, score) in enumerate(zip(tokens, attention_scores)): # More lenient filtering - include more tokens for text analysis if token not in ['[CLS]', '[SEP]', '[PAD]', '', ''] and len(token.strip()) > 0 and score > 0.005: clean_token = token.replace(' ', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes if clean_token: # Only add if token has content after cleaning token_analysis.append({ 'token': clean_token, 'importance': score, 'position': i }) # Sort by importance token_analysis.sort(key=lambda x: x['importance'], reverse=True) detailed_analysis += f""" ## Top important tokens:
Analysis Info: Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens
""" for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens bar_width = int(token_info['importance'] * 100) color = "#ff4444" if predicted_class == "phishing" else "#44ff44" detailed_analysis += f"""
{i+1}.
{token_info['token']}
{token_info['importance']:.1%}
""" detailed_analysis += "
\n" detailed_analysis += f""" ## Detailed analysis:

Statistical Overview

{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}
Total tokens
{len([t for t in token_analysis if t['importance'] > 0.05])}
High impact tokens (>5%)

Prediction Confidence

Phishing Benign
{probs[1]:.1%}
Benign: {probs[0]:.1%}
""" else: # Fallback analysis without attention weights detailed_analysis += f"""

Basic Analysis

{probs[1]:.1%}
Phishing
{probs[0]:.1%}
Benign
{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}
Tokens

🔤 Tokens in text:

""" + ''.join([f'{token.replace(" ", "")}' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""
Debug info: Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens

Note: Detailed attention weights analysis is not available for the current model.

""" # Build label->prob mapping for Gradio Label output if len(LABELS) == len(probs): prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} else: prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)} return prediction_result, detailed_analysis # --------- RAG Inference function ---------- def rag_predict_fn(text: str): if not text or not text.strip(): return "Please enter text to analyze." start_time = time.time() # Check if input is a URL input_text = text.strip() is_link = is_url(input_text) analysis_context = input_text status_msg = "" if is_link: print(f"🌐 Detected URL: {input_text}. Fetching content...") fetched_content, status = fetch_html_content(input_text) if fetched_content: # Limit content length to avoid token overflow truncated_content = fetched_content[:4000] analysis_context = f"URL: {input_text}\n\nWebsite Content:\n{truncated_content}\n..." status_msg = f"✅ Successfully fetched {len(fetched_content)} chars from URL (Status: {status})." print(status_msg) else: analysis_context = f"URL: {input_text}\n\n(Could not fetch website content. Error: {status})" status_msg = f"⚠️ Failed to fetch URL content: {status}" print(status_msg) else: status_msg = "📝 Analyzing raw text input." # Call LLM (which now handles retrieval internally via LangChain) response = llm_client.analyze(analysis_context) end_time = time.time() elapsed_time = end_time - start_time # Parse LLM Response (New Format) classification = "UNKNOWN" confidence = "N/A" explanation = response # Simple parsing logic lines = response.split('\n') for line in lines: line = line.strip() if line.upper().startswith("CLASSIFICATION:"): classification = line.split(":", 1)[1].strip().upper() elif line.upper().startswith("CONFIDENCE SCORE:"): confidence = line.split(":", 1)[1].strip() elif line.upper().startswith("EXPLANATION:"): explanation = line.split(":", 1)[1].strip() # If explanation is still the full response, try to clean it up if other fields were found if classification != "UNKNOWN" and explanation == response: # Fallback: if explanation wasn't explicitly found but others were, # assume everything after the last known field is explanation or just keep raw pass # Determine Color/Icon if "PHISHING" in classification: label = "PHISHING" color_grad = "linear-gradient(135deg, #ff4b1f 0%, #ff9068 100%)" icon = "⛔" border_col = "#ff4b1f" elif "BENIGN" in classification: label = "BENIGN" color_grad = "linear-gradient(135deg, #11998e 0%, #38ef7d 100%)" icon = "✅" border_col = "#11998e" else: label = "UNCERTAIN" color_grad = "linear-gradient(135deg, #f8b500 0%, #fceabb 100%)" icon = "⚠️" border_col = "#f8b500" # HTML Output html_output = f"""

{icon} {label}

Confidence: {confidence}
Explanation:
{explanation}
⏱️ Processing Time: {elapsed_time:.2f}s
🛡️ CyberGuard AI Analysis
Input Status: {status_msg}
AI can make mistakes. Always verify critical URLs manually.
""" return html_output # --------- Refresh Knowledge Base function ---------- def refresh_kb(): return rag_engine.refresh_knowledge_base() # --------- Gradio UI ---------- css_style=""" .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #1e1e1e !important; color: #ffffff !important; } /* Customize Buttons */ .gradio-container button.primary, .gradio-container button.secondary { background-color: #4a4a4a !important; color: #ffffff !important; border: 1px solid #666 !important; } .gradio-container button.primary:hover, .gradio-container button.secondary:hover { background-color: #5a5a5a !important; } /* Customize Textboxes (Inputs) */ .gradio-container textarea, .gradio-container input { background-color: #3d3d3d !important; color: #ffffff !important; border: 1px solid #666 !important; } /* Customize Blocks/Panels */ .gradio-container .block { background-color: #2d2d2d !important; border: 1px solid #444 !important; } """ with gr.Blocks() as demo: gr.HTML(f"") gr.Markdown("# 🛡️ Phishing Detector (DeBERTa + LSTM + RAG)") with gr.Tabs(): # --- Tab 1: Standard Detection --- with gr.TabItem("🔍 Standard Detection"): gr.Markdown(""" Enter a URL or text for analysis using the DeBERTa + LSTM model. **Features:** - **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis - **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis - **Visual Analysis**: Predict phishing/benign probability with visual charts - **Token Importance**: Display the most important tokens in classification - **Detailed Insights**: Comprehensive analysis of the impact of each token **How it works for URLs:** 1. Analyze the URL structure itself 2. Fetch the webpage HTML content 3. Analyze the webpage content 4. Combine both results for final prediction (30% URL + 70% content) """) with gr.Row(): with gr.Column(scale=2): input_box = gr.Textbox( label="URL or text", placeholder="Example: http://suspicious-site.example or paste any text", lines=3 ) btn_submit = gr.Button("🔍 Analyze", variant="primary") gr.Examples( examples=[ ["http://rendmoiunserviceeee.com"], ["https://www.google.com"], ["Dear customer, your account has been suspended. Click here to verify your identity immediately."], ["https://mail-secure-login-verify.example/path?token=suspicious"], ["http://paypaI-security-update.net/login"], ["Your package has been delivered successfully. Thank you for using our service."], ["https://github.com/user/repo"], ["Dear customer, your account has been suspended. Click here to verify."], ], inputs=input_box ) with gr.Column(scale=3): output_html = gr.HTML(label="Analysis Result") btn_submit.click(fn=predict_fn, inputs=input_box, outputs=output_html) # --- Tab 2: LLM + RAG Analysis --- with gr.TabItem("🤖 AI Assistant (RAG)"): gr.Markdown(""" **AI Assistant** uses **Qwen2.5-3B** + **LangChain** to explain *why* a message is suspicious. **Features:** - 🌐 Multilingual support (English + Vietnamese) - 📚 Knowledge Base retrieval (Auto-sync) - 🚀 No rate limits (self-hosted) """) with gr.Row(): with gr.Column(scale=1): rag_input = gr.Textbox( label="Suspicious Text/URL", placeholder="Paste the email content or URL here...", lines=5 ) with gr.Row(): btn_rag = gr.Button("🤖 Ask AI Assistant", variant="primary") btn_refresh = gr.Button("♻️ Refresh Knowledge Base") gr.Examples( examples=[ ["Your PayPal account has been suspended. Click http://paypal-verify.com to unlock."], ["Tài khoản ngân hàng của bạn bị khóa. Nhấn vào đây để mở khóa ngay."], ["Your package is ready for delivery. Track here: https://fedex-track.com"], ], inputs=rag_input ) with gr.Column(scale=1): # Changed from gr.Markdown to gr.HTML for custom styling rag_output = gr.HTML(label="AI Analysis") refresh_output = gr.Markdown(label="Status") btn_rag.click(fn=rag_predict_fn, inputs=[rag_input], outputs=rag_output) btn_refresh.click(fn=refresh_kb, inputs=[], outputs=refresh_output) if __name__ == "__main__": demo.launch(ssr_mode=False)