#!/usr/bin/env python3
"""
WimBERT Synth v0 Gradio Space
Dual-head multi-label classifier for Dutch signal messages
"""
import json
import importlib.util
import torch
import gradio as gr
from huggingface_hub import snapshot_download
# Constants
MODEL_REPO = "UWV/wimbert-synth-v0"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float16 if DEVICE.type == "cuda" else torch.float32
print(f"๐ง Loading model from {MODEL_REPO}...")
print(f"๐ฅ๏ธ Device: {DEVICE} ({DTYPE})")
# Download model files (uses HF cache)
model_dir = snapshot_download(MODEL_REPO, cache_dir=None)
# Dynamic import of model.py from downloaded dir
spec = importlib.util.spec_from_file_location("model", f"{model_dir}/model.py")
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)
DualHeadModel = model_module.DualHeadModel
# Load model + tokenizer + config
model, tokenizer, config = DualHeadModel.from_pretrained(model_dir, device=DEVICE)
# Cast to target dtype
if DTYPE == torch.float16:
model = model.half()
# Warm-up inference
with torch.no_grad():
dummy_input = tokenizer("Warm-up", return_tensors="pt", truncation=True,
max_length=config["max_length"])
_ = model.predict(
dummy_input["input_ids"].to(DEVICE),
dummy_input["attention_mask"].to(DEVICE)
)
print(f"โ
Model loaded and warmed up (max_length: {config['max_length']})")
# Extract label names
LABELS_ONDERWERP = config["labels"]["onderwerp"]
LABELS_BELEVING = config["labels"]["beleving"]
def prob_to_color(prob: float, threshold: float) -> str:
"""Generate CSS style for probability visualization (10X UX approved)"""
# Green gradient: low prob = very light green, high prob = saturated green
# Use HSL: Hue=145 (green), Saturation increases with prob, Lightness decreases
saturation = 30 + int(prob * 50) # 30% to 80%
lightness = 92 - int(prob * 55) # 92% to 37%
# Text color: white for dark backgrounds (prob > 0.6), dark for light
text_color = "#ffffff" if prob > 0.6 else "#1f2937"
# Border: thick + accent for predicted, subtle for others
if prob >= threshold:
border = "2px solid #059669"
box_shadow = "0 1px 3px rgba(5, 150, 105, 0.3)"
else:
border = "1px solid #d1d5db"
box_shadow = "none"
return (
f"background: hsl(145, {saturation}%, {lightness}%); "
f"color: {text_color}; "
f"border: {border}; "
f"box-shadow: {box_shadow}; "
f"padding: 6px 12px; "
f"border-radius: 4px; "
f"margin: 2px 0; "
f"font-weight: 500;"
)
def format_topk(labels: list, probs: list, threshold: float, topk: int) -> str:
"""Generate HTML for top-K labels"""
sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)
html = "
"
for idx in sorted_indices[:topk]:
label = labels[idx]
prob = probs[idx]
style = prob_to_color(prob, threshold)
predicted = " โ" if prob >= threshold else ""
html += f"
{label}: {prob:.3f}{predicted}
"
html += "
"
return html
def format_all_labels(head_name: str, labels: list, probs: list, threshold: float) -> str:
"""Generate scrollable table for all labels"""
sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)
html = f"{head_name}
"
html += "
"
html += ""
html += "| Label | Probability | Predicted |
"
html += ""
for idx in sorted_indices:
label = labels[idx]
prob = probs[idx]
style = prob_to_color(prob, threshold)
predicted = "โ" if prob >= threshold else ""
html += f"| {label} | {prob:.4f} | {predicted} |
"
html += "
"
return html
@torch.inference_mode()
def predict(text: str, threshold: float, topk: int):
"""Run inference and return visualizations"""
if not text or not text.strip():
empty_msg = "Voer een bericht in om te classificeren...
"
return empty_msg, empty_msg, {}
# Tokenize with dynamic length (only truncate if needed)
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=config["max_length"] # 1408 from model config
)
# Get actual sequence length (non-padding tokens)
actual_length = inputs["attention_mask"].sum().item()
# Move to device
input_ids = inputs["input_ids"].to(DEVICE)
attention_mask = inputs["attention_mask"].to(DEVICE)
# Predict
onderwerp_probs, beleving_probs = model.predict(input_ids, attention_mask)
# Convert to lists
onderwerp_probs = onderwerp_probs[0].cpu().numpy().tolist()
beleving_probs = beleving_probs[0].cpu().numpy().tolist()
# Generate summary view (top-K for each head side by side)
summary_html = ""
summary_html += f"
Onderwerp (Top-{topk})
{format_topk(LABELS_ONDERWERP, onderwerp_probs, threshold, topk)}"
summary_html += f"
Beleving (Top-{topk})
{format_topk(LABELS_BELEVING, beleving_probs, threshold, topk)}"
summary_html += "
"
# Generate all labels view
all_labels_html = ""
all_labels_html += f"
{format_all_labels('Onderwerp', LABELS_ONDERWERP, onderwerp_probs, threshold)}
"
all_labels_html += f"
{format_all_labels('Beleving', LABELS_BELEVING, beleving_probs, threshold)}
"
all_labels_html += "
"
# Generate JSON output
json_output = {
"text": text,
"token_count": actual_length,
"max_length": config["max_length"],
"threshold": threshold,
"onderwerp": {
"probabilities": {label: float(prob) for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs)},
"predicted": [label for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs) if prob >= threshold]
},
"beleving": {
"probabilities": {label: float(prob) for label, prob in zip(LABELS_BELEVING, beleving_probs)},
"predicted": [label for label, prob in zip(LABELS_BELEVING, beleving_probs) if prob >= threshold]
}
}
return summary_html, all_labels_html, json_output
def count_tokens(text: str) -> str:
"""Count tokens for live feedback"""
if not text or not text.strip():
return "๐ Tokens: 0 / 1408"
# Quick tokenization (no GPU needed, just counting)
tokens = tokenizer(text, truncation=True, max_length=config["max_length"])
actual_length = sum(tokens["attention_mask"])
# Color code based on usage
if actual_length > config["max_length"]:
color = "#dc2626" # Red: truncated
warning = " โ ๏ธ (truncated)"
elif actual_length > config["max_length"] * 0.8:
color = "#f59e0b" # Orange: getting long
warning = ""
else:
color = "#059669" # Green: all good
warning = ""
return f"๐ Tokens: {actual_length} / {config['max_length']}{warning}"
def load_examples():
"""Load example texts"""
try:
with open("examples.json") as f:
return json.load(f)
except:
return []
# Build Gradio interface
with gr.Blocks(title="WimBERT Synth v0", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ๐๏ธ WimBERT Synth v0: Multi-label Signaal Classifier
Classificeert Nederlandse signaalberichten op **Onderwerp** (64 categorieรซn) en **Beleving** (33 categorieรซn).
""")
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(
label="Signaalbericht (Nederlands)",
lines=8,
placeholder="Bijv: Ik kan niet parkeren bij mijn huis en de website voor vergunningen werkt niet..."
)
token_counter = gr.HTML(value="๐ Tokens: 0 / 1408")
with gr.Row():
predict_btn = gr.Button("๐ฎ Voorspel", variant="primary", scale=2)
clear_btn = gr.ClearButton([input_text], value="๐๏ธ Wissen", scale=1)
with gr.Column(scale=1):
threshold_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.05,
label="๐ฏ Drempel",
info="Labels boven deze waarde worden als 'voorspeld' gemarkeerd"
)
topk_slider = gr.Slider(
minimum=1,
maximum=15,
value=5,
step=1,
label="๐ Top-K",
info="Aantal top labels om te tonen in samenvatting"
)
gr.Markdown(f"""
**Hardware:** {DEVICE.type.upper()}
**Dtype:** {DTYPE}
**Max length:** {config['max_length']}
""")
with gr.Tabs():
with gr.Tab("๐ Samenvatting"):
summary_output = gr.HTML(label="Top voorspellingen per categorie")
with gr.Tab("๐ Alle labels"):
all_labels_output = gr.HTML(label="Volledige classificatie")
with gr.Tab("๐พ JSON"):
json_output = gr.JSON(label="Ruwe output")
gr.Examples(
examples=load_examples(),
inputs=input_text,
label="๐ Voorbeelden"
)
gr.Markdown("""
---
### โน๏ธ Over dit model
- **Model:** `UWV/wimbert-synth-v0` (dual-head BERT)
- **Licentie:** Apache-2.0
- **Privacy:** Input wordt alleen in-memory verwerkt, niet opgeslagen
[Model Card](https://huggingface.co/UWV/wimbert-synth-v0) โข Gebouwd met Gradio
""")
# Event handlers
# Live token counting as user types
input_text.change(
fn=count_tokens,
inputs=input_text,
outputs=token_counter
)
# Prediction on button click
predict_btn.click(
fn=predict,
inputs=[input_text, threshold_slider, topk_slider],
outputs=[summary_output, all_labels_output, json_output]
)
# Update predictions when threshold/topk changes (if there's existing output)
threshold_slider.change(
fn=predict,
inputs=[input_text, threshold_slider, topk_slider],
outputs=[summary_output, all_labels_output, json_output]
)
topk_slider.change(
fn=predict,
inputs=[input_text, threshold_slider, topk_slider],
outputs=[summary_output, all_labels_output, json_output]
)
if __name__ == "__main__":
demo.launch()