#!/usr/bin/env python3 """ WimBERT Synth v0 Gradio Space Dual-head multi-label classifier for Dutch signal messages """ import json import importlib.util import torch import gradio as gr from huggingface_hub import snapshot_download # Constants MODEL_REPO = "UWV/wimbert-synth-v0" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DTYPE = torch.float16 if DEVICE.type == "cuda" else torch.float32 print(f"๐Ÿ”ง Loading model from {MODEL_REPO}...") print(f"๐Ÿ–ฅ๏ธ Device: {DEVICE} ({DTYPE})") # Download model files (uses HF cache) model_dir = snapshot_download(MODEL_REPO, cache_dir=None) # Dynamic import of model.py from downloaded dir spec = importlib.util.spec_from_file_location("model", f"{model_dir}/model.py") model_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(model_module) DualHeadModel = model_module.DualHeadModel # Load model + tokenizer + config model, tokenizer, config = DualHeadModel.from_pretrained(model_dir, device=DEVICE) # Cast to target dtype if DTYPE == torch.float16: model = model.half() # Warm-up inference with torch.no_grad(): dummy_input = tokenizer("Warm-up", return_tensors="pt", truncation=True, max_length=config["max_length"]) _ = model.predict( dummy_input["input_ids"].to(DEVICE), dummy_input["attention_mask"].to(DEVICE) ) print(f"โœ… Model loaded and warmed up (max_length: {config['max_length']})") # Extract label names LABELS_ONDERWERP = config["labels"]["onderwerp"] LABELS_BELEVING = config["labels"]["beleving"] def prob_to_color(prob: float, threshold: float) -> str: """Generate CSS style for probability visualization (10X UX approved)""" # Green gradient: low prob = very light green, high prob = saturated green # Use HSL: Hue=145 (green), Saturation increases with prob, Lightness decreases saturation = 30 + int(prob * 50) # 30% to 80% lightness = 92 - int(prob * 55) # 92% to 37% # Text color: white for dark backgrounds (prob > 0.6), dark for light text_color = "#ffffff" if prob > 0.6 else "#1f2937" # Border: thick + accent for predicted, subtle for others if prob >= threshold: border = "2px solid #059669" box_shadow = "0 1px 3px rgba(5, 150, 105, 0.3)" else: border = "1px solid #d1d5db" box_shadow = "none" return ( f"background: hsl(145, {saturation}%, {lightness}%); " f"color: {text_color}; " f"border: {border}; " f"box-shadow: {box_shadow}; " f"padding: 6px 12px; " f"border-radius: 4px; " f"margin: 2px 0; " f"font-weight: 500;" ) def format_topk(labels: list, probs: list, threshold: float, topk: int) -> str: """Generate HTML for top-K labels""" sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True) html = "
" for idx in sorted_indices[:topk]: label = labels[idx] prob = probs[idx] style = prob_to_color(prob, threshold) predicted = " โœ“" if prob >= threshold else "" html += f"
{label}: {prob:.3f}{predicted}
" html += "
" return html def format_all_labels(head_name: str, labels: list, probs: list, threshold: float) -> str: """Generate scrollable table for all labels""" sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True) html = f"

{head_name}

" html += "" html += "" html += "" html += "" for idx in sorted_indices: label = labels[idx] prob = probs[idx] style = prob_to_color(prob, threshold) predicted = "โœ“" if prob >= threshold else "" html += f"" html += "
LabelProbabilityPredicted
{label}{prob:.4f}{predicted}
" return html @torch.inference_mode() def predict(text: str, threshold: float, topk: int): """Run inference and return visualizations""" if not text or not text.strip(): empty_msg = "

Voer een bericht in om te classificeren...

" return empty_msg, empty_msg, {} # Tokenize with dynamic length (only truncate if needed) inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=config["max_length"] # 1408 from model config ) # Get actual sequence length (non-padding tokens) actual_length = inputs["attention_mask"].sum().item() # Move to device input_ids = inputs["input_ids"].to(DEVICE) attention_mask = inputs["attention_mask"].to(DEVICE) # Predict onderwerp_probs, beleving_probs = model.predict(input_ids, attention_mask) # Convert to lists onderwerp_probs = onderwerp_probs[0].cpu().numpy().tolist() beleving_probs = beleving_probs[0].cpu().numpy().tolist() # Generate summary view (top-K for each head side by side) summary_html = "
" summary_html += f"

Onderwerp (Top-{topk})

{format_topk(LABELS_ONDERWERP, onderwerp_probs, threshold, topk)}
" summary_html += f"

Beleving (Top-{topk})

{format_topk(LABELS_BELEVING, beleving_probs, threshold, topk)}
" summary_html += "
" # Generate all labels view all_labels_html = "
" all_labels_html += f"
{format_all_labels('Onderwerp', LABELS_ONDERWERP, onderwerp_probs, threshold)}
" all_labels_html += f"
{format_all_labels('Beleving', LABELS_BELEVING, beleving_probs, threshold)}
" all_labels_html += "
" # Generate JSON output json_output = { "text": text, "token_count": actual_length, "max_length": config["max_length"], "threshold": threshold, "onderwerp": { "probabilities": {label: float(prob) for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs)}, "predicted": [label for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs) if prob >= threshold] }, "beleving": { "probabilities": {label: float(prob) for label, prob in zip(LABELS_BELEVING, beleving_probs)}, "predicted": [label for label, prob in zip(LABELS_BELEVING, beleving_probs) if prob >= threshold] } } return summary_html, all_labels_html, json_output def count_tokens(text: str) -> str: """Count tokens for live feedback""" if not text or not text.strip(): return "๐Ÿ“ Tokens: 0 / 1408" # Quick tokenization (no GPU needed, just counting) tokens = tokenizer(text, truncation=True, max_length=config["max_length"]) actual_length = sum(tokens["attention_mask"]) # Color code based on usage if actual_length > config["max_length"]: color = "#dc2626" # Red: truncated warning = " โš ๏ธ (truncated)" elif actual_length > config["max_length"] * 0.8: color = "#f59e0b" # Orange: getting long warning = "" else: color = "#059669" # Green: all good warning = "" return f"๐Ÿ“ Tokens: {actual_length} / {config['max_length']}{warning}" def load_examples(): """Load example texts""" try: with open("examples.json") as f: return json.load(f) except: return [] # Build Gradio interface with gr.Blocks(title="WimBERT Synth v0", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # ๐Ÿ›๏ธ WimBERT Synth v0: Multi-label Signaal Classifier Classificeert Nederlandse signaalberichten op **Onderwerp** (64 categorieรซn) en **Beleving** (33 categorieรซn). """) with gr.Row(): with gr.Column(scale=2): input_text = gr.Textbox( label="Signaalbericht (Nederlands)", lines=8, placeholder="Bijv: Ik kan niet parkeren bij mijn huis en de website voor vergunningen werkt niet..." ) token_counter = gr.HTML(value="๐Ÿ“ Tokens: 0 / 1408") with gr.Row(): predict_btn = gr.Button("๐Ÿ”ฎ Voorspel", variant="primary", scale=2) clear_btn = gr.ClearButton([input_text], value="๐Ÿ—‘๏ธ Wissen", scale=1) with gr.Column(scale=1): threshold_slider = gr.Slider( minimum=0, maximum=1, value=0.5, step=0.05, label="๐ŸŽฏ Drempel", info="Labels boven deze waarde worden als 'voorspeld' gemarkeerd" ) topk_slider = gr.Slider( minimum=1, maximum=15, value=5, step=1, label="๐Ÿ“Š Top-K", info="Aantal top labels om te tonen in samenvatting" ) gr.Markdown(f""" **Hardware:** {DEVICE.type.upper()} **Dtype:** {DTYPE} **Max length:** {config['max_length']} """) with gr.Tabs(): with gr.Tab("๐Ÿ“‹ Samenvatting"): summary_output = gr.HTML(label="Top voorspellingen per categorie") with gr.Tab("๐Ÿ“Š Alle labels"): all_labels_output = gr.HTML(label="Volledige classificatie") with gr.Tab("๐Ÿ’พ JSON"): json_output = gr.JSON(label="Ruwe output") gr.Examples( examples=load_examples(), inputs=input_text, label="๐Ÿ“ Voorbeelden" ) gr.Markdown(""" --- ### โ„น๏ธ Over dit model - **Model:** `UWV/wimbert-synth-v0` (dual-head BERT) - **Licentie:** Apache-2.0 - **Privacy:** Input wordt alleen in-memory verwerkt, niet opgeslagen [Model Card](https://huggingface.co/UWV/wimbert-synth-v0) โ€ข Gebouwd met Gradio """) # Event handlers # Live token counting as user types input_text.change( fn=count_tokens, inputs=input_text, outputs=token_counter ) # Prediction on button click predict_btn.click( fn=predict, inputs=[input_text, threshold_slider, topk_slider], outputs=[summary_output, all_labels_output, json_output] ) # Update predictions when threshold/topk changes (if there's existing output) threshold_slider.change( fn=predict, inputs=[input_text, threshold_slider, topk_slider], outputs=[summary_output, all_labels_output, json_output] ) topk_slider.change( fn=predict, inputs=[input_text, threshold_slider, topk_slider], outputs=[summary_output, all_labels_output, json_output] ) if __name__ == "__main__": demo.launch()