Spaces:
Paused
Paused
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| import requests | |
| import re | |
| import time | |
| import sys | |
| import logging | |
| import urllib3 # Import urllib3 to handle warnings | |
| # --- Suppress specific noisy asyncio errors on shutdown --- | |
| if sys.version_info >= (3, 10): | |
| logging.getLogger("asyncio").setLevel(logging.WARNING) | |
| # --- Suppress InsecureRequestWarning --- | |
| # This is expected behavior for a Phishing Detector as we often scan sites with invalid SSL | |
| urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) | |
| # --- import your architecture --- | |
| # Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py) | |
| # and update the import path accordingly. | |
| from model import DeBERTaLSTMClassifier # <-- your class | |
| # --- Import RAG modules --- | |
| from rag_engine import RAGEngine | |
| from llm_client import LLMClient | |
| # --------- Config ---------- | |
| REPO_ID = "dungeon29/deberta-lstm-detect-phishing" | |
| CKPT_NAME = "pytorch_model.bin" | |
| MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone | |
| LABELS = ["benign", "phishing"] # adjust to your classes | |
| # If your checkpoint contains hyperparams, you can fetch them like: | |
| # checkpoint.get("config") or checkpoint.get("model_args") | |
| # and pass into DeBERTaLSTMClassifier(**model_args) | |
| # --------- Load model/tokenizer once (global) ---------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME) | |
| checkpoint = torch.load(ckpt_path, map_location=device) | |
| # If you saved hyperparams in the checkpoint, use them: | |
| model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...} | |
| model = DeBERTaLSTMClassifier(**model_args) | |
| # Load weights | |
| try: | |
| state_dict = torch.load(ckpt_path, map_location=device) | |
| # Xử lý nếu file lưu dạng checkpoint đầy đủ (có key "model_state_dict") | |
| if "model_state_dict" in state_dict: | |
| state_dict = state_dict["model_state_dict"] | |
| model.load_state_dict(state_dict, strict=False) | |
| # Kiểm tra layer attention | |
| if hasattr(model, 'attention') and 'attention.weight' not in state_dict: | |
| print("⚠️ Loaded model without attention layer, using newly initialized attention weights") | |
| else: | |
| print("✅ Load weights successfully!") | |
| except Exception as e: | |
| print(f"❌ Error when loading weights: {e}") | |
| raise e | |
| model.to(device).eval() | |
| # --------- Initialize RAG & LLM ---------- | |
| print("Initializing RAG Engine (LangChain)...") | |
| rag_engine = RAGEngine() | |
| print("RAG Engine ready.") | |
| print("Initializing Qwen3-0.6B(GGUF) LLM (LangChain)...") | |
| # Pass vector_store to LLMClient for RetrievalQA | |
| llm_client = LLMClient(vector_store=rag_engine.vector_store) | |
| print("LLM ready.") | |
| # --------- Helper functions ---------- | |
| def is_url(text): | |
| """Check if text is a URL""" | |
| url_pattern = re.compile( | |
| r'^https?://' # http:// or https:// | |
| r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain... | |
| r'localhost|' # localhost... | |
| r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip | |
| r'(?::\d+)?' # optional port | |
| r'(?:/?|[/?]\S+)$', re.IGNORECASE) | |
| return url_pattern.match(text) is not None | |
| def fetch_html_content(url, timeout=10): | |
| """Fetch HTML content from URL (Raw HTML for Model)""" | |
| try: | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36', | |
| 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8', | |
| 'Accept-Language': 'en-US,en;q=0.9,vi;q=0.8', | |
| 'Referer': 'https://www.google.com/' | |
| } | |
| # verify=False is intentional for phishing detection, warning suppressed globally | |
| response = requests.get(url, headers=headers, timeout=timeout, verify=False) | |
| response.raise_for_status() | |
| # Return FULL RAW HTML content instead of stripped text | |
| # Model needs HTML tags/structure to detect hidden threats | |
| return response.text, response.status_code | |
| except requests.exceptions.RequestException as e: | |
| return None, f"Request error: {str(e)}" | |
| except Exception as e: | |
| return None, f"General error: {str(e)}" | |
| def predict_single_text(text, text_type="text"): | |
| """Predict for a single text input""" | |
| # Tokenize | |
| # Increased max_length to 512 to capture more HTML content | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=512 | |
| ) | |
| # DeBERTa typically doesn't use token_type_ids | |
| inputs.pop("token_type_ids", None) | |
| # Move to device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| try: | |
| # Try to get predictions with attention weights | |
| result = model(**inputs, return_attention=True) | |
| if isinstance(result, tuple) and len(result) == 3: | |
| logits, attention_weights, deberta_attentions = result | |
| has_attention = True | |
| else: | |
| logits = result | |
| has_attention = False | |
| except TypeError: | |
| # Fallback for older model without return_attention parameter | |
| logits = model(**inputs) | |
| has_attention = False | |
| probs = F.softmax(logits, dim=-1).squeeze(0).tolist() | |
| # Get tokens for visualization | |
| tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist()) | |
| return probs, tokens, has_attention, attention_weights if has_attention else None | |
| def combine_predictions(url_probs, html_probs, url_weight=0.3, html_weight=0.7): | |
| """Combine URL and HTML content predictions""" | |
| combined_probs = [ | |
| url_weight * url_probs[0] + html_weight * html_probs[0], # benign | |
| url_weight * url_probs[1] + html_weight * html_probs[1] # phishing | |
| ] | |
| return combined_probs | |
| # --------- Inference function ---------- | |
| def predict_fn(text: str): | |
| if not text or not text.strip(): | |
| return {"error": "Please enter a URL or text."}, "" | |
| # Check if input is URL | |
| if is_url(text.strip()): | |
| # Process URL | |
| url = text.strip() | |
| # Get prediction for URL itself | |
| url_probs, url_tokens, url_has_attention, url_attention = predict_single_text(url, "URL") | |
| # Try to fetch HTML content | |
| html_content, status = fetch_html_content(url) | |
| if html_content: | |
| # Get prediction for HTML content (Raw HTML now) | |
| html_probs, html_tokens, html_has_attention, html_attention = predict_single_text(html_content, "HTML") | |
| # Combine predictions | |
| combined_probs = combine_predictions(url_probs, html_probs) | |
| # Use combined probabilities but show analysis for both | |
| probs = combined_probs | |
| tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display | |
| has_attention = url_has_attention or html_has_attention | |
| attention_weights = url_attention if url_has_attention else html_attention | |
| analysis_type = "Combined URL + HTML Analysis" | |
| fetch_status = f"✅ Successfully fetched HTML content (Status: {status})" | |
| else: | |
| # Fallback for URL-only analysis | |
| probs = url_probs | |
| tokens = url_tokens | |
| has_attention = url_has_attention | |
| attention_weights = url_attention | |
| analysis_type = "URL-only Analysis" | |
| fetch_status = f"⚠️ Could not fetch HTML content: {status}" | |
| else: | |
| # Process as regular text | |
| probs, tokens, has_attention, attention_weights = predict_single_text(text, "text") | |
| analysis_type = "Text Analysis" | |
| fetch_status = "" | |
| # Create detailed analysis | |
| predicted_class = "phishing" if probs[1] > probs[0] else "benign" | |
| confidence = max(probs) | |
| detailed_analysis = f""" | |
| <div style="font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; background: #1e1e1e; padding: 20px; border-radius: 15px;"> | |
| <div style="background: linear-gradient(135deg, {'#8b0000' if predicted_class == 'phishing' else '#006400'} 0%, {'#dc143c' if predicted_class == 'phishing' else '#228b22'} 100%); padding: 25px; border-radius: 20px; color: white; text-align: center; margin-bottom: 20px; box-shadow: 0 8px 32px rgba(0,0,0,0.5); border: 2px solid {'#ff4444' if predicted_class == 'phishing' else '#44ff44'};"> | |
| <h2 style="margin: 0 0 10px 0; font-size: 28px; color: white;">🔍 {analysis_type}</h2> | |
| <div style="font-size: 36px; font-weight: bold; margin: 10px 0; color: white;"> | |
| {predicted_class.upper()} | |
| </div> | |
| <div style="font-size: 18px; color: #f0f0f0;"> | |
| Confidence: {confidence:.1%} | |
| </div> | |
| <div style="margin-top: 15px; font-size: 14px; color: #e0e0e0;"> | |
| {'This appears to be a phishing attempt!' if predicted_class == 'phishing' else '✅ This appears to be legitimate content.'} | |
| </div> | |
| </div> | |
| """ | |
| if fetch_status: | |
| detailed_analysis += f""" | |
| <div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;"> | |
| <strong>Fetch Status:</strong> {fetch_status} | |
| </div> | |
| """ | |
| if has_attention and attention_weights is not None: | |
| attention_scores = attention_weights.squeeze(0).tolist() | |
| token_analysis = [] | |
| for i, (token, score) in enumerate(zip(tokens, attention_scores)): | |
| # More lenient filtering - include more tokens for text analysis | |
| if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and len(token.strip()) > 0 and score > 0.005: | |
| clean_token = token.replace(' ', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes | |
| if clean_token: # Only add if token has content after cleaning | |
| token_analysis.append({ | |
| 'token': clean_token, | |
| 'importance': score, | |
| 'position': i | |
| }) | |
| # Sort by importance | |
| token_analysis.sort(key=lambda x: x['importance'], reverse=True) | |
| detailed_analysis += f""" | |
| ## Top important tokens: | |
| <div style="background: #2d2d2d; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #4caf50; color: #e0e0e0;"> | |
| <strong>Analysis Info:</strong> Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens | |
| </div> | |
| <div style="font-family: Arial, sans-serif;"> | |
| """ | |
| for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens | |
| bar_width = int(token_info['importance'] * 100) | |
| color = "#ff4444" if predicted_class == "phishing" else "#44ff44" | |
| detailed_analysis += f""" | |
| <div style="margin: 8px 0; display: flex; align-items: center; background: #2d2d2d; padding: 8px; border-radius: 8px; border-left: 4px solid {color};"> | |
| <div style="width: 30px; text-align: right; margin-right: 10px; font-weight: bold; color: #ffffff;"> | |
| {i+1}. | |
| </div> | |
| <div style="width: 120px; margin-right: 10px; font-weight: bold; color: #e0e0e0; text-align: right;"> | |
| {token_info['token']} | |
| </div> | |
| <div style="width: 300px; background-color: #404040; border-radius: 10px; overflow: hidden; margin-right: 10px; border: 1px solid #555;"> | |
| <div style="width: {bar_width}%; background-color: {color}; height: 20px; border-radius: 10px; transition: width 0.3s ease;"></div> | |
| </div> | |
| <div style="color: #cccccc; font-size: 12px; font-weight: bold;"> | |
| {token_info['importance']:.1%} | |
| </div> | |
| </div> | |
| """ | |
| detailed_analysis += "</div>\n" | |
| detailed_analysis += f""" | |
| ## Detailed analysis: | |
| <div style="font-family: Arial, sans-serif; background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;"> | |
| <h3 style="margin: 0 0 15px 0; color: white;">Statistical Overview</h3> | |
| <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;"> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">Total tokens</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold, color: white;">{len([t for t in token_analysis if t['importance'] > 0.05])}</div> | |
| <div style="font-size: 14px, color: #e0e0e0;">High impact tokens (>5%)</div> | |
| </div> | |
| </div> | |
| </div> | |
| <div style="font-family: Arial, sans-serif; margin: 15px 0; background: #2d2d2d; padding: 20px; border-radius: 15px; border: 1px solid #555;"> | |
| <h3 style="color: #ffffff; margin-bottom: 15px;"> Prediction Confidence</h3> | |
| <div style="display: flex; justify-content: space-between; margin-bottom: 10px;"> | |
| <span style="font-weight: bold; color: #ff4444;">Phishing</span> | |
| <span style="font-weight: bold; color: #44ff44;">Benign</span> | |
| </div> | |
| <div style="width: 100%; background-color: #404040; border-radius: 25px; overflow: hidden; height: 30px; border: 1px solid #666;"> | |
| <div style="width: {probs[1]*100:.1f}%; background: linear-gradient(90deg, #ff4444 0%, #ff6666 100%); height: 100%; display: flex; align-items: center; justify-content: center; color: white; font-weight: bold; font-size: 14px;"> | |
| {probs[1]:.1%} | |
| </div> | |
| </div> | |
| <div style="margin-top: 10px; text-align: center; color: #cccccc; font-size: 14px;"> | |
| Benign: {probs[0]:.1%} | |
| </div> | |
| </div> | |
| """ | |
| else: | |
| # Fallback analysis without attention weights | |
| detailed_analysis += f""" | |
| <div style="background: linear-gradient(135deg, #1a237e 0%, #3949ab 100%); padding: 20px; border-radius: 15px; color: white; margin: 15px 0; border: 2px solid #3f51b5;"> | |
| <h3 style="margin: 0 0 15px 0; color: white;">Basic Analysis</h3> | |
| <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 15px;"> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{probs[1]:.1%}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">Phishing</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{probs[0]:.1%}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">Benign</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.1); padding: 15px; border-radius: 10px; text-align: center; border: 1px solid rgba(255,255,255,0.2);"> | |
| <div style="font-size: 24px; font-weight: bold; color: white;">{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}</div> | |
| <div style="font-size: 14px; color: #e0e0e0;">Tokens</div> | |
| </div> | |
| </div> | |
| </div> | |
| <div style="font-family: Arial, sans-serif; margin: 15px 0; background: #2d2d2d; padding: 20px; border-radius: 15px; border: 1px solid #555;"> | |
| <h3 style="color: #ffffff; margin: 0 0 15px 0;">🔤 Tokens in text:</h3> | |
| <div style="display: flex; flex-wrap: wrap; gap: 8px;">""" + ''.join([f'<span style="background: #404040; color: #64b5f6; padding: 4px 8px; border-radius: 15px; font-size: 12px; border: 1px solid #666;">{token.replace(" ", "")}</span>' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""</div> | |
| <div style="margin-top: 15px; padding: 10px; background: #3d2914; border-radius: 8px; border-left: 4px solid #ff9800;"> | |
| <strong style="color: #ffcc02;">Debug info:</strong> <span style="color: #e0e0e0;">Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens</span> | |
| </div> | |
| </div> | |
| <div style="background: #3d2914; padding: 15px; border-radius: 10px; border-left: 4px solid #ff9800; margin: 15px 0;"> | |
| <p style="margin: 0; color: #ffcc02; font-size: 14px;"> | |
| <strong>Note:</strong> Detailed attention weights analysis is not available for the current model. | |
| </p> | |
| </div> | |
| """ | |
| # Build label->prob mapping for Gradio Label output | |
| if len(LABELS) == len(probs): | |
| prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} | |
| else: | |
| prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)} | |
| return prediction_result, detailed_analysis | |
| # --------- RAG Inference function ---------- | |
| def rag_predict_fn(text: str): | |
| if not text or not text.strip(): | |
| return "Please enter text to analyze." | |
| start_time = time.time() | |
| # Check if input is a URL | |
| input_text = text.strip() | |
| is_link = is_url(input_text) | |
| analysis_context = input_text | |
| status_msg = "" | |
| if is_link: | |
| print(f"🌐 Detected URL: {input_text}. Fetching content...") | |
| fetched_content, status = fetch_html_content(input_text) | |
| if fetched_content: | |
| # Limit content length to avoid token overflow | |
| truncated_content = fetched_content[:4000] | |
| analysis_context = f"URL: {input_text}\n\nWebsite Content:\n{truncated_content}\n..." | |
| status_msg = f"✅ Successfully fetched {len(fetched_content)} chars from URL (Status: {status})." | |
| print(status_msg) | |
| else: | |
| analysis_context = f"URL: {input_text}\n\n(Could not fetch website content. Error: {status})" | |
| status_msg = f"⚠️ Failed to fetch URL content: {status}" | |
| print(status_msg) | |
| else: | |
| status_msg = "📝 Analyzing raw text input." | |
| # Call LLM (which now handles retrieval internally via LangChain) | |
| response = llm_client.analyze(analysis_context) | |
| end_time = time.time() | |
| elapsed_time = end_time - start_time | |
| # Parse LLM Response (New Format) | |
| classification = "UNKNOWN" | |
| confidence = "N/A" | |
| explanation = response | |
| # Simple parsing logic | |
| lines = response.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if line.upper().startswith("CLASSIFICATION:"): | |
| classification = line.split(":", 1)[1].strip().upper() | |
| elif line.upper().startswith("CONFIDENCE SCORE:"): | |
| confidence = line.split(":", 1)[1].strip() | |
| elif line.upper().startswith("EXPLANATION:"): | |
| explanation = line.split(":", 1)[1].strip() | |
| # If explanation is still the full response, try to clean it up if other fields were found | |
| if classification != "UNKNOWN" and explanation == response: | |
| # Fallback: if explanation wasn't explicitly found but others were, | |
| # assume everything after the last known field is explanation or just keep raw | |
| pass | |
| # Determine Color/Icon | |
| if "PHISHING" in classification: | |
| label = "PHISHING" | |
| color_grad = "linear-gradient(135deg, #ff4b1f 0%, #ff9068 100%)" | |
| icon = "⛔" | |
| border_col = "#ff4b1f" | |
| elif "BENIGN" in classification: | |
| label = "BENIGN" | |
| color_grad = "linear-gradient(135deg, #11998e 0%, #38ef7d 100%)" | |
| icon = "✅" | |
| border_col = "#11998e" | |
| else: | |
| label = "UNCERTAIN" | |
| color_grad = "linear-gradient(135deg, #f8b500 0%, #fceabb 100%)" | |
| icon = "⚠️" | |
| border_col = "#f8b500" | |
| # HTML Output | |
| html_output = f""" | |
| <div style="font-family: 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; max-width: 800px; margin: 0 auto; background: #1e1e1e; padding: 25px; border-radius: 16px; box-shadow: 0 10px 30px rgba(0,0,0,0.5); border: 1px solid #333;"> | |
| <div style="background: {color_grad}; padding: 30px; border-radius: 12px; color: white; text-align: center; margin-bottom: 25px; box-shadow: 0 4px 15px rgba(0,0,0,0.3); position: relative; overflow: hidden;"> | |
| <div style="position: relative; z-index: 2;"> | |
| <h2 style="margin: 0 0 5px 0; font-size: 42px; font-weight: 800; letter-spacing: 1px; text-shadow: 0 2px 4px rgba(0,0,0,0.2);">{icon} {label}</h2> | |
| <div style="font-size: 24px; font-weight: 600; opacity: 0.95; margin-bottom: 15px;">Confidence: {confidence}</div> | |
| <div style="background: rgba(0,0,0,0.2); padding: 15px; border-radius: 8px; text-align: left; font-size: 16px; line-height: 1.5; backdrop-filter: blur(5px);"> | |
| <strong>Explanation:</strong><br> | |
| {explanation} | |
| </div> | |
| </div> | |
| </div> | |
| <div style="display: flex; justify-content: space-between; align-items: center; color: #888; font-size: 13px; padding: 0 10px;"> | |
| <div> | |
| ⏱️ Processing Time: <b>{elapsed_time:.2f}s</b> | |
| </div> | |
| <div> | |
| 🛡️ CyberGuard AI Analysis | |
| </div> | |
| </div> | |
| <div style="background: #2d2d2d; padding: 15px; border-radius: 8px; margin-top: 20px; border-left: 4px solid {border_col}; color: #ccc; font-size: 14px;"> | |
| <strong>Input Status:</strong> {status_msg}<br> | |
| <span style="font-size: 12px; opacity: 0.7;">AI can make mistakes. Always verify critical URLs manually.</span> | |
| </div> | |
| </div> | |
| """ | |
| return html_output | |
| # --------- Refresh Knowledge Base function ---------- | |
| def refresh_kb(): | |
| return rag_engine.refresh_knowledge_base() | |
| # --------- Gradio UI ---------- | |
| css_style=""" | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| background-color: #1e1e1e !important; | |
| color: #ffffff !important; | |
| } | |
| /* Customize Buttons */ | |
| .gradio-container button.primary, .gradio-container button.secondary { | |
| background-color: #4a4a4a !important; | |
| color: #ffffff !important; | |
| border: 1px solid #666 !important; | |
| } | |
| .gradio-container button.primary:hover, .gradio-container button.secondary:hover { | |
| background-color: #5a5a5a !important; | |
| } | |
| /* Customize Textboxes (Inputs) */ | |
| .gradio-container textarea, .gradio-container input { | |
| background-color: #3d3d3d !important; | |
| color: #ffffff !important; | |
| border: 1px solid #666 !important; | |
| } | |
| /* Customize Blocks/Panels */ | |
| .gradio-container .block { | |
| background-color: #2d2d2d !important; | |
| border: 1px solid #444 !important; | |
| } | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.HTML(f"<style>{css_style}</style>") | |
| gr.Markdown("# 🛡️ Phishing Detector (DeBERTa + LSTM + RAG)") | |
| with gr.Tabs(): | |
| # --- Tab 1: Standard Detection --- | |
| with gr.TabItem("🔍 Standard Detection"): | |
| gr.Markdown(""" | |
| Enter a URL or text for analysis using the DeBERTa + LSTM model. | |
| **Features:** | |
| - **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis | |
| - **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis | |
| - **Visual Analysis**: Predict phishing/benign probability with visual charts | |
| - **Token Importance**: Display the most important tokens in classification | |
| - **Detailed Insights**: Comprehensive analysis of the impact of each token | |
| **How it works for URLs:** | |
| 1. Analyze the URL structure itself | |
| 2. Fetch the webpage HTML content | |
| 3. Analyze the webpage content | |
| 4. Combine both results for final prediction (30% URL + 70% content) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_box = gr.Textbox( | |
| label="URL or text", | |
| placeholder="Example: http://suspicious-site.example or paste any text", | |
| lines=3 | |
| ) | |
| btn_submit = gr.Button("🔍 Analyze", variant="primary") | |
| gr.Examples( | |
| examples=[ | |
| ["http://rendmoiunserviceeee.com"], | |
| ["https://www.google.com"], | |
| ["Dear customer, your account has been suspended. Click here to verify your identity immediately."], | |
| ["https://mail-secure-login-verify.example/path?token=suspicious"], | |
| ["http://paypaI-security-update.net/login"], | |
| ["Your package has been delivered successfully. Thank you for using our service."], | |
| ["https://github.com/user/repo"], | |
| ["Dear customer, your account has been suspended. Click here to verify."], | |
| ], | |
| inputs=input_box | |
| ) | |
| with gr.Column(scale=3): | |
| output_html = gr.HTML(label="Analysis Result") | |
| btn_submit.click(fn=predict_fn, inputs=input_box, outputs=output_html) | |
| # --- Tab 2: LLM + RAG Analysis --- | |
| with gr.TabItem("🤖 AI Assistant (RAG)"): | |
| gr.Markdown(""" | |
| **AI Assistant** uses **Qwen2.5-3B** + **LangChain** to explain *why* a message is suspicious. | |
| **Features:** | |
| - 🌐 Multilingual support (English + Vietnamese) | |
| - 📚 Knowledge Base retrieval (Auto-sync) | |
| - 🚀 No rate limits (self-hosted) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| rag_input = gr.Textbox( | |
| label="Suspicious Text/URL", | |
| placeholder="Paste the email content or URL here...", | |
| lines=5 | |
| ) | |
| with gr.Row(): | |
| btn_rag = gr.Button("🤖 Ask AI Assistant", variant="primary") | |
| btn_refresh = gr.Button("♻️ Refresh Knowledge Base") | |
| gr.Examples( | |
| examples=[ | |
| ["Your PayPal account has been suspended. Click http://paypal-verify.com to unlock."], | |
| ["Tài khoản ngân hàng của bạn bị khóa. Nhấn vào đây để mở khóa ngay."], | |
| ["Your package is ready for delivery. Track here: https://fedex-track.com"], | |
| ], | |
| inputs=rag_input | |
| ) | |
| with gr.Column(scale=1): | |
| # Changed from gr.Markdown to gr.HTML for custom styling | |
| rag_output = gr.HTML(label="AI Analysis") | |
| refresh_output = gr.Markdown(label="Status") | |
| btn_rag.click(fn=rag_predict_fn, inputs=[rag_input], outputs=rag_output) | |
| btn_refresh.click(fn=refresh_kb, inputs=[], outputs=refresh_output) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |