import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import gradio as gr
import requests
import re
import time
import sys
import logging
import urllib3 # Import urllib3 to handle warnings
# --- Suppress specific noisy asyncio errors on shutdown ---
if sys.version_info >= (3, 10):
logging.getLogger("asyncio").setLevel(logging.WARNING)
# --- Suppress InsecureRequestWarning ---
# This is expected behavior for a Phishing Detector as we often scan sites with invalid SSL
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# --- import your architecture ---
# Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py)
# and update the import path accordingly.
from model import DeBERTaLSTMClassifier # <-- your class
# --- Import RAG modules ---
from rag_engine import RAGEngine
from llm_client import LLMClient
# --------- Config ----------
REPO_ID = "dungeon29/deberta-lstm-detect-phishing"
CKPT_NAME = "pytorch_model.bin"
MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone
LABELS = ["benign", "phishing"] # adjust to your classes
# If your checkpoint contains hyperparams, you can fetch them like:
# checkpoint.get("config") or checkpoint.get("model_args")
# and pass into DeBERTaLSTMClassifier(**model_args)
# --------- Load model/tokenizer once (global) ----------
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME)
checkpoint = torch.load(ckpt_path, map_location=device)
# If you saved hyperparams in the checkpoint, use them:
model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...}
model = DeBERTaLSTMClassifier(**model_args)
# Load weights
try:
state_dict = torch.load(ckpt_path, map_location=device)
# Xử lý nếu file lưu dạng checkpoint đầy đủ (có key "model_state_dict")
if "model_state_dict" in state_dict:
state_dict = state_dict["model_state_dict"]
model.load_state_dict(state_dict, strict=False)
# Kiểm tra layer attention
if hasattr(model, 'attention') and 'attention.weight' not in state_dict:
print("⚠️ Loaded model without attention layer, using newly initialized attention weights")
else:
print("✅ Load weights successfully!")
except Exception as e:
print(f"❌ Error when loading weights: {e}")
raise e
model.to(device).eval()
# --------- Initialize RAG & LLM ----------
print("Initializing RAG Engine (LangChain)...")
rag_engine = RAGEngine()
print("RAG Engine ready.")
print("Initializing Qwen3-0.6B(GGUF) LLM (LangChain)...")
# Pass vector_store to LLMClient for RetrievalQA
llm_client = LLMClient(vector_store=rag_engine.vector_store)
print("LLM ready.")
# --------- Helper functions ----------
def is_url(text):
"""Check if text is a URL"""
url_pattern = re.compile(
r'^https?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain...
r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
return url_pattern.match(text) is not None
def fetch_html_content(url, timeout=10):
"""Fetch HTML content from URL (Raw HTML for Model)"""
try:
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8',
'Accept-Language': 'en-US,en;q=0.9,vi;q=0.8',
'Referer': 'https://www.google.com/'
}
# verify=False is intentional for phishing detection, warning suppressed globally
response = requests.get(url, headers=headers, timeout=timeout, verify=False)
response.raise_for_status()
# Return FULL RAW HTML content instead of stripped text
# Model needs HTML tags/structure to detect hidden threats
return response.text, response.status_code
except requests.exceptions.RequestException as e:
return None, f"Request error: {str(e)}"
except Exception as e:
return None, f"General error: {str(e)}"
def predict_single_text(text, text_type="text"):
"""Predict for a single text input"""
# Tokenize
# Increased max_length to 512 to capture more HTML content
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
)
# DeBERTa typically doesn't use token_type_ids
inputs.pop("token_type_ids", None)
# Move to device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
try:
# Try to get predictions with attention weights
result = model(**inputs, return_attention=True)
if isinstance(result, tuple) and len(result) == 3:
logits, attention_weights, deberta_attentions = result
has_attention = True
else:
logits = result
has_attention = False
except TypeError:
# Fallback for older model without return_attention parameter
logits = model(**inputs)
has_attention = False
probs = F.softmax(logits, dim=-1).squeeze(0).tolist()
# Get tokens for visualization
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0).tolist())
return probs, tokens, has_attention, attention_weights if has_attention else None
def combine_predictions(url_probs, html_probs, url_weight=0.3, html_weight=0.7):
"""Combine URL and HTML content predictions"""
combined_probs = [
url_weight * url_probs[0] + html_weight * html_probs[0], # benign
url_weight * url_probs[1] + html_weight * html_probs[1] # phishing
]
return combined_probs
# --------- Inference function ----------
def predict_fn(text: str):
if not text or not text.strip():
return {"error": "Please enter a URL or text."}, ""
# Check if input is URL
if is_url(text.strip()):
# Process URL
url = text.strip()
# Get prediction for URL itself
url_probs, url_tokens, url_has_attention, url_attention = predict_single_text(url, "URL")
# Try to fetch HTML content
html_content, status = fetch_html_content(url)
if html_content:
# Get prediction for HTML content (Raw HTML now)
html_probs, html_tokens, html_has_attention, html_attention = predict_single_text(html_content, "HTML")
# Combine predictions
combined_probs = combine_predictions(url_probs, html_probs)
# Use combined probabilities but show analysis for both
probs = combined_probs
tokens = url_tokens + ["[SEP]"] + html_tokens[:50] # Limit HTML tokens for display
has_attention = url_has_attention or html_has_attention
attention_weights = url_attention if url_has_attention else html_attention
analysis_type = "Combined URL + HTML Analysis"
fetch_status = f"✅ Successfully fetched HTML content (Status: {status})"
else:
# Fallback for URL-only analysis
probs = url_probs
tokens = url_tokens
has_attention = url_has_attention
attention_weights = url_attention
analysis_type = "URL-only Analysis"
fetch_status = f"⚠️ Could not fetch HTML content: {status}"
else:
# Process as regular text
probs, tokens, has_attention, attention_weights = predict_single_text(text, "text")
analysis_type = "Text Analysis"
fetch_status = ""
# Create detailed analysis
predicted_class = "phishing" if probs[1] > probs[0] else "benign"
confidence = max(probs)
detailed_analysis = f"""
🔍 {analysis_type}
{predicted_class.upper()}
Confidence: {confidence:.1%}
{'This appears to be a phishing attempt!' if predicted_class == 'phishing' else '✅ This appears to be legitimate content.'}
"""
if fetch_status:
detailed_analysis += f"""
Fetch Status: {fetch_status}
"""
if has_attention and attention_weights is not None:
attention_scores = attention_weights.squeeze(0).tolist()
token_analysis = []
for i, (token, score) in enumerate(zip(tokens, attention_scores)):
# More lenient filtering - include more tokens for text analysis
if token not in ['[CLS]', '[SEP]', '[PAD]', '
', ''] and len(token.strip()) > 0 and score > 0.005:
clean_token = token.replace(' ', '').replace('Ġ', '').strip() # Handle different tokenizer prefixes
if clean_token: # Only add if token has content after cleaning
token_analysis.append({
'token': clean_token,
'importance': score,
'position': i
})
# Sort by importance
token_analysis.sort(key=lambda x: x['importance'], reverse=True)
detailed_analysis += f"""
## Top important tokens:
Analysis Info: Found {len(token_analysis)} important tokens out of {len(tokens)} total tokens
"""
for i, token_info in enumerate(token_analysis[:10]): # Top 10 tokens
bar_width = int(token_info['importance'] * 100)
color = "#ff4444" if predicted_class == "phishing" else "#44ff44"
detailed_analysis += f"""
{i+1}.
{token_info['token']}
{token_info['importance']:.1%}
"""
detailed_analysis += "
\n"
detailed_analysis += f"""
## Detailed analysis:
Statistical Overview
{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}
Total tokens
{len([t for t in token_analysis if t['importance'] > 0.05])}
High impact tokens (>5%)
Prediction Confidence
Phishing
Benign
Benign: {probs[0]:.1%}
"""
else:
# Fallback analysis without attention weights
detailed_analysis += f"""
Basic Analysis
{len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])}
Tokens
🔤 Tokens in text:
""" + ''.join([f'{token.replace(" ", "")}' for token in tokens if token not in ['[CLS]', '[SEP]', '[PAD]']]) + f"""
Debug info: Found {len(tokens)} total tokens, {len([t for t in tokens if t not in ['[CLS]', '[SEP]', '[PAD]']])} content tokens
Note: Detailed attention weights analysis is not available for the current model.
"""
# Build label->prob mapping for Gradio Label output
if len(LABELS) == len(probs):
prediction_result = {LABELS[i]: float(probs[i]) for i in range(len(LABELS))}
else:
prediction_result = {f"class_{i}": float(p) for i, p in enumerate(probs)}
return prediction_result, detailed_analysis
# --------- RAG Inference function ----------
def rag_predict_fn(text: str):
if not text or not text.strip():
return "Please enter text to analyze."
start_time = time.time()
# Check if input is a URL
input_text = text.strip()
is_link = is_url(input_text)
analysis_context = input_text
status_msg = ""
if is_link:
print(f"🌐 Detected URL: {input_text}. Fetching content...")
fetched_content, status = fetch_html_content(input_text)
if fetched_content:
# Limit content length to avoid token overflow
truncated_content = fetched_content[:4000]
analysis_context = f"URL: {input_text}\n\nWebsite Content:\n{truncated_content}\n..."
status_msg = f"✅ Successfully fetched {len(fetched_content)} chars from URL (Status: {status})."
print(status_msg)
else:
analysis_context = f"URL: {input_text}\n\n(Could not fetch website content. Error: {status})"
status_msg = f"⚠️ Failed to fetch URL content: {status}"
print(status_msg)
else:
status_msg = "📝 Analyzing raw text input."
# Call LLM (which now handles retrieval internally via LangChain)
response = llm_client.analyze(analysis_context)
end_time = time.time()
elapsed_time = end_time - start_time
# Parse LLM Response (New Format)
classification = "UNKNOWN"
confidence = "N/A"
explanation = response
# Simple parsing logic
lines = response.split('\n')
for line in lines:
line = line.strip()
if line.upper().startswith("CLASSIFICATION:"):
classification = line.split(":", 1)[1].strip().upper()
elif line.upper().startswith("CONFIDENCE SCORE:"):
confidence = line.split(":", 1)[1].strip()
elif line.upper().startswith("EXPLANATION:"):
explanation = line.split(":", 1)[1].strip()
# If explanation is still the full response, try to clean it up if other fields were found
if classification != "UNKNOWN" and explanation == response:
# Fallback: if explanation wasn't explicitly found but others were,
# assume everything after the last known field is explanation or just keep raw
pass
# Determine Color/Icon
if "PHISHING" in classification:
label = "PHISHING"
color_grad = "linear-gradient(135deg, #ff4b1f 0%, #ff9068 100%)"
icon = "⛔"
border_col = "#ff4b1f"
elif "BENIGN" in classification:
label = "BENIGN"
color_grad = "linear-gradient(135deg, #11998e 0%, #38ef7d 100%)"
icon = "✅"
border_col = "#11998e"
else:
label = "UNCERTAIN"
color_grad = "linear-gradient(135deg, #f8b500 0%, #fceabb 100%)"
icon = "⚠️"
border_col = "#f8b500"
# HTML Output
html_output = f"""
{icon} {label}
Confidence: {confidence}
Explanation:
{explanation}
⏱️ Processing Time: {elapsed_time:.2f}s
🛡️ CyberGuard AI Analysis
Input Status: {status_msg}
AI can make mistakes. Always verify critical URLs manually.
"""
return html_output
# --------- Refresh Knowledge Base function ----------
def refresh_kb():
return rag_engine.refresh_knowledge_base()
# --------- Gradio UI ----------
css_style="""
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background-color: #1e1e1e !important;
color: #ffffff !important;
}
/* Customize Buttons */
.gradio-container button.primary, .gradio-container button.secondary {
background-color: #4a4a4a !important;
color: #ffffff !important;
border: 1px solid #666 !important;
}
.gradio-container button.primary:hover, .gradio-container button.secondary:hover {
background-color: #5a5a5a !important;
}
/* Customize Textboxes (Inputs) */
.gradio-container textarea, .gradio-container input {
background-color: #3d3d3d !important;
color: #ffffff !important;
border: 1px solid #666 !important;
}
/* Customize Blocks/Panels */
.gradio-container .block {
background-color: #2d2d2d !important;
border: 1px solid #444 !important;
}
"""
with gr.Blocks() as demo:
gr.HTML(f"")
gr.Markdown("# 🛡️ Phishing Detector (DeBERTa + LSTM + RAG)")
with gr.Tabs():
# --- Tab 1: Standard Detection ---
with gr.TabItem("🔍 Standard Detection"):
gr.Markdown("""
Enter a URL or text for analysis using the DeBERTa + LSTM model.
**Features:**
- **URL Analysis**: For URLs, the system will fetch HTML content and combine both URL and content analysis
- **Combined Prediction**: Uses weighted combination of URL structure and webpage content analysis
- **Visual Analysis**: Predict phishing/benign probability with visual charts
- **Token Importance**: Display the most important tokens in classification
- **Detailed Insights**: Comprehensive analysis of the impact of each token
**How it works for URLs:**
1. Analyze the URL structure itself
2. Fetch the webpage HTML content
3. Analyze the webpage content
4. Combine both results for final prediction (30% URL + 70% content)
""")
with gr.Row():
with gr.Column(scale=2):
input_box = gr.Textbox(
label="URL or text",
placeholder="Example: http://suspicious-site.example or paste any text",
lines=3
)
btn_submit = gr.Button("🔍 Analyze", variant="primary")
gr.Examples(
examples=[
["http://rendmoiunserviceeee.com"],
["https://www.google.com"],
["Dear customer, your account has been suspended. Click here to verify your identity immediately."],
["https://mail-secure-login-verify.example/path?token=suspicious"],
["http://paypaI-security-update.net/login"],
["Your package has been delivered successfully. Thank you for using our service."],
["https://github.com/user/repo"],
["Dear customer, your account has been suspended. Click here to verify."],
],
inputs=input_box
)
with gr.Column(scale=3):
output_html = gr.HTML(label="Analysis Result")
btn_submit.click(fn=predict_fn, inputs=input_box, outputs=output_html)
# --- Tab 2: LLM + RAG Analysis ---
with gr.TabItem("🤖 AI Assistant (RAG)"):
gr.Markdown("""
**AI Assistant** uses **Qwen2.5-3B** + **LangChain** to explain *why* a message is suspicious.
**Features:**
- 🌐 Multilingual support (English + Vietnamese)
- 📚 Knowledge Base retrieval (Auto-sync)
- 🚀 No rate limits (self-hosted)
""")
with gr.Row():
with gr.Column(scale=1):
rag_input = gr.Textbox(
label="Suspicious Text/URL",
placeholder="Paste the email content or URL here...",
lines=5
)
with gr.Row():
btn_rag = gr.Button("🤖 Ask AI Assistant", variant="primary")
btn_refresh = gr.Button("♻️ Refresh Knowledge Base")
gr.Examples(
examples=[
["Your PayPal account has been suspended. Click http://paypal-verify.com to unlock."],
["Tài khoản ngân hàng của bạn bị khóa. Nhấn vào đây để mở khóa ngay."],
["Your package is ready for delivery. Track here: https://fedex-track.com"],
],
inputs=rag_input
)
with gr.Column(scale=1):
# Changed from gr.Markdown to gr.HTML for custom styling
rag_output = gr.HTML(label="AI Analysis")
refresh_output = gr.Markdown(label="Status")
btn_rag.click(fn=rag_predict_fn, inputs=[rag_input], outputs=rag_output)
btn_refresh.click(fn=refresh_kb, inputs=[], outputs=refresh_output)
if __name__ == "__main__":
demo.launch(ssr_mode=False)