Spaces:
Running
Running
Improve UX: dynamic tokens, better colors, live feedback
Browse files- Use dynamic sequence length (no fixed padding)
- Show live token counter below textarea with color coding
- Switch to green gradient palette with proper text contrast
- Add token count to JSON output
- Remove 512 token limit, use full 1408 from model config
PLAN.md
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# WimBERT Synth v0 โ Hugging Face Space Plan
|
| 2 |
+
|
| 3 |
+
This plan describes a lightweight, reliable Space to demo the dualโhead multiโlabel classifier (onderwerp + beleving) defined by `wimbert-synth-v0/model.py`, with labels from `wimbert-synth-v0/label_names.json` and licensing in `wimbert-synth-v0/LICENSE` (Apacheโ2.0).
|
| 4 |
+
|
| 5 |
+
## Goals
|
| 6 |
+
- Input: a single Dutch โsignaalberichtโ (freeโtext).
|
| 7 |
+
- Output: per head (onderwerp, beleving), show probabilities for all labels:
|
| 8 |
+
- Visual: colorโcoded list/table where color intensity reflects probability.
|
| 9 |
+
- Numeric: exact probability values (0โ1) and topโK summary.
|
| 10 |
+
- โPredictedโ set using an adjustable threshold (default 0.5).
|
| 11 |
+
- UX: oneโclick Predict button; optional โliveโ inference (after brief inactivity).
|
| 12 |
+
- Portable, reproducible, and fast enough on CPU; optionally GPUโready.
|
| 13 |
+
|
| 14 |
+
## Toolkit Choice
|
| 15 |
+
- Gradio is the best fit for this demo on Spaces:
|
| 16 |
+
- Firstโclass support on Hugging Face Spaces, minimal boilerplate (`app.py`).
|
| 17 |
+
- Simple event model (button click, input change) and components for text, tabs, HTML, charts.
|
| 18 |
+
- Easy to serve both a compact topโK view and a full โall labelsโ view with custom styling.
|
| 19 |
+
- No Streamlit server/page lifecycle complexities for this small, singleโpage inference app.
|
| 20 |
+
|
| 21 |
+
## Model + License
|
| 22 |
+
- Model artifacts live in `wimbert-synth-v0/` with Apacheโ2.0 license (redistribution permitted with attribution). Use the exact `LICENSE` in the Space repo.
|
| 23 |
+
- The model is large (~1.2 GB for `model.safetensors`). To keep the Space repo small and boot times predictable, prefer hosting the model as a separate Model repo on the Hub, then download/cache in the Space at runtime.
|
| 24 |
+
- Recommended: publish a model repo, e.g. `UWV/wimbert-synth-v0`, containing:
|
| 25 |
+
- `model.safetensors`, `config.json`, tokenizer files, `dual_head_state.pt`, `label_names.json`, `model.py`, `README.md`, `LICENSE`.
|
| 26 |
+
- The Space loads via `DualHeadModel.from_pretrained(<model_repo_or_local_dir>)`.
|
| 27 |
+
|
| 28 |
+
## UX & Visualization
|
| 29 |
+
- Input: `gr.Textbox(label="Signaalbericht", lines=6, placeholder=...)`.
|
| 30 |
+
- Controls:
|
| 31 |
+
- `Predict` button (primary path).
|
| 32 |
+
- `Auto-run` toggle to enable live inference: trigger after user stops typing for ~600โ800 ms (using Gradioโs input event with debounce or a simple timer wrapper). If performance on CPU is borderline, keep off by default.
|
| 33 |
+
- `Threshold` slider (0.0โ1.0, default 0.5) to highlight predicted labels.
|
| 34 |
+
- `TopโK` slider (1โ15, default 5) to size the summary.
|
| 35 |
+
- Output: tabs per head and views:
|
| 36 |
+
- Tab 1: โSamenvattingโ โ two columns for Onderwerp and Beleving, each listing TopโK labels with probabilities.
|
| 37 |
+
- Tab 2: โAlle labelsโ โ scrollable, colorโcoded tables (or HTML lists) for every label with exact probabilities.
|
| 38 |
+
- Tab 3: โJSON/CSVโ โ exportable raw probabilities (dict of label โ prob) + list of predicted labels at current threshold.
|
| 39 |
+
- Color mapping:
|
| 40 |
+
- Use a lightโtoโdark monochrome (e.g., blue/green) where intensity โ probability; add a subtle border for > threshold.
|
| 41 |
+
- Ensure text contrast (AA) and include numbers to avoid relying on color alone (accessibility).
|
| 42 |
+
|
| 43 |
+
## Space Layout
|
| 44 |
+
- Repo root (Space):
|
| 45 |
+
- `app.py` โ Gradio app with UI + inference.
|
| 46 |
+
- `requirements.txt` โ runtime deps.
|
| 47 |
+
- `README.md` โ usage, model card link, privacy note.
|
| 48 |
+
- `LICENSE` โ Apacheโ2.0 (from `wimbert-synth-v0/LICENSE`).
|
| 49 |
+
- Optional: `assets/` (logo), `examples/` (preset texts), `.gitattributes`.
|
| 50 |
+
- The model is not vendored into the Space to avoid 1.2 GB LFS; itโs pulled at startup via `huggingface_hub.snapshot_download` or `from_pretrained` on the Hub repo.
|
| 51 |
+
|
| 52 |
+
## Dependencies
|
| 53 |
+
- `gradio>=4.0`
|
| 54 |
+
- `transformers>=4.40`
|
| 55 |
+
- `torch` (CPU is fine; GPU preferred if available)
|
| 56 |
+
- `safetensors`, `huggingface_hub`
|
| 57 |
+
- Optional perf: `accelerate` (device placement), `onnxruntime`/`optimum` (future optimization)
|
| 58 |
+
|
| 59 |
+
## Inference Design
|
| 60 |
+
- Load once at Space start (global singleton). Warm up with a short dummy input.
|
| 61 |
+
- Device: choose `cuda` if available, else CPU. Cast to `float16` on GPU; keep `float32` on CPU.
|
| 62 |
+
- Tokenization: use `max_length` from `dual_head_state.pt` config; allow truncation; optionally expose a compact/fast mode (e.g., cap at 512) if CPU latency needs improvement.
|
| 63 |
+
- Output structures:
|
| 64 |
+
- Dicts for each head: `[ {label, prob, predicted} ... ]` with `predicted = prob >= threshold`.
|
| 65 |
+
- TopโK lists derived from the sorted full list.
|
| 66 |
+
- Visualization adapters render the above into: HTML tables (for colorโcoding), and JSON/CSV text.
|
| 67 |
+
|
| 68 |
+
## Event Flow
|
| 69 |
+
1. User edits text.
|
| 70 |
+
2. If Autoโrun enabled, debounce and run; else wait for Predict button.
|
| 71 |
+
3. Tokenize โ model.predict โ probs (two tensors).
|
| 72 |
+
4. Sort, slice to TopโK summary and prepare full tables.
|
| 73 |
+
5. Render to tabs and compact โPredicted labelsโ chips (one line per head).
|
| 74 |
+
|
| 75 |
+
## Pseudocode Sketch (app.py)
|
| 76 |
+
```python
|
| 77 |
+
import gradio as gr
|
| 78 |
+
import torch, json, importlib.util
|
| 79 |
+
from huggingface_hub import snapshot_download
|
| 80 |
+
|
| 81 |
+
MODEL_REPO = "UWV/wimbert-synth-v0"
|
| 82 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 83 |
+
|
| 84 |
+
# Download/copy model folder and import DualHeadModel
|
| 85 |
+
model_dir = snapshot_download(MODEL_REPO)
|
| 86 |
+
spec = importlib.util.spec_from_file_location("model", f"{model_dir}/model.py")
|
| 87 |
+
model_mod = importlib.util.module_from_spec(spec); spec.loader.exec_module(model_mod)
|
| 88 |
+
DualHeadModel = model_mod.DualHeadModel
|
| 89 |
+
model, tokenizer, cfg = DualHeadModel.from_pretrained(model_dir, device=DEVICE)
|
| 90 |
+
|
| 91 |
+
# Warm-up
|
| 92 |
+
_ = model.predict(*tokenizer("Hoi", return_tensors="pt", padding="max_length", max_length=cfg["max_length"]).values())
|
| 93 |
+
|
| 94 |
+
def predict(text, threshold, topk):
|
| 95 |
+
enc = tokenizer(text or "", truncation=True, padding="max_length", max_length=cfg["max_length"], return_tensors="pt")
|
| 96 |
+
on_p, be_p = model.predict(enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE))
|
| 97 |
+
# Convert to python lists and build views ...
|
| 98 |
+
return topk_view, all_labels_html, json_text
|
| 99 |
+
|
| 100 |
+
with gr.Blocks(title="WimBERT Synth v0") as demo:
|
| 101 |
+
# Inputs, controls, tabs, outputs ...
|
| 102 |
+
...
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
demo.launch()
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Performance Notes
|
| 109 |
+
- CPU on free Spaces will work but can be slow for long texts (base mmBERT at `max_lengthโ1408`). Mitigations:
|
| 110 |
+
- Warmโup once; cap max length to 512 in a โfast modeโ toggle; show spinner while running.
|
| 111 |
+
- Prefer a small GPU (T4 small) if available; cast to fp16 on GPU.
|
| 112 |
+
- Caching: `snapshot_download` uses the shared cache; subsequent restarts are faster.
|
| 113 |
+
|
| 114 |
+
## Privacy & Safety
|
| 115 |
+
- The Space processes user text in memory only; no logging beyond Gradio defaults. Mention this in the Space README.
|
| 116 |
+
- Include a โUse responsiblyโ note (analytics/routing aid; no automated decisions) mirroring the model card.
|
| 117 |
+
|
| 118 |
+
## Deliverables
|
| 119 |
+
- `app.py` with:
|
| 120 |
+
- Robust model loading (Hub), device selection, warmโup.
|
| 121 |
+
- Predict function returning: topโK per head, full colored table, JSON dump.
|
| 122 |
+
- UI: textbox, Predict button, Autoโrun toggle (debounced), threshold & TopโK sliders, tabs per view.
|
| 123 |
+
- Example(s) from the model card (`widget` example) via `gr.Examples`.
|
| 124 |
+
- `requirements.txt` (gradio, transformers, torch, huggingface_hub, safetensors).
|
| 125 |
+
- `README.md` with screenshots, hardware recommendation, and links to the model card.
|
| 126 |
+
- `LICENSE` copied from `wimbert-synth-v0/LICENSE`.
|
| 127 |
+
|
| 128 |
+
## StepโByโStep
|
| 129 |
+
1) Publish/verify model on Hub (`UWV/wimbert-synth-v0`), including `model.py` and license.
|
| 130 |
+
2) Create Space repo with SDK=Gradio and pick hardware (CPU โ OK; GPU โ faster).
|
| 131 |
+
3) Add Space files (`app.py`, `requirements.txt`, `README.md`, `LICENSE`).
|
| 132 |
+
4) Implement and test inference locally (CPU) with a few sample texts; tune debounce/threshold defaults.
|
| 133 |
+
5) Push Space; verify coldโstart time and inference latency; adjust max_length and hardware if needed.
|
| 134 |
+
6) Polish visuals (colors, fonts, accessibility), add screenshots, and publish.
|
| 135 |
+
|
| 136 |
+
## NiceโToโHaves (Later)
|
| 137 |
+
- Perโclass thresholds (if you decide to introduce learned or tuned thresholds).
|
| 138 |
+
- ONNX/Optimum path for CPU acceleration.
|
| 139 |
+
- Sessionโlevel analytics (aggregate latency, not storing user text).
|
| 140 |
+
- Download CSV/JSON of the current result.
|
| 141 |
+
- Translations for UI labels (NL/EN toggle).
|
| 142 |
+
|
| 143 |
+
```
|
| 144 |
+
Summary: Use Gradio for a singleโpage Space that downloads the Apacheโlicensed model from the Hub, offers both buttonโbased and debounced live inference, and presents perโhead probabilities as colorโcoded tables with numeric values, plus topโK and JSON outputs.
|
| 145 |
+
```
|
app.py
CHANGED
|
@@ -14,7 +14,6 @@ from huggingface_hub import snapshot_download
|
|
| 14 |
MODEL_REPO = "UWV/wimbert-synth-v0"
|
| 15 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
DTYPE = torch.float16 if DEVICE.type == "cuda" else torch.float32
|
| 17 |
-
MAX_LENGTH = 512 # Default to 512 for better CPU performance
|
| 18 |
|
| 19 |
print(f"๐ง Loading model from {MODEL_REPO}...")
|
| 20 |
print(f"๐ฅ๏ธ Device: {DEVICE} ({DTYPE})")
|
|
@@ -37,14 +36,14 @@ if DTYPE == torch.float16:
|
|
| 37 |
|
| 38 |
# Warm-up inference
|
| 39 |
with torch.no_grad():
|
| 40 |
-
dummy_input = tokenizer("Warm-up", return_tensors="pt",
|
| 41 |
-
max_length=
|
| 42 |
_ = model.predict(
|
| 43 |
dummy_input["input_ids"].to(DEVICE),
|
| 44 |
dummy_input["attention_mask"].to(DEVICE)
|
| 45 |
)
|
| 46 |
|
| 47 |
-
print(f"โ
Model loaded and warmed up")
|
| 48 |
|
| 49 |
# Extract label names
|
| 50 |
LABELS_ONDERWERP = config["labels"]["onderwerp"]
|
|
@@ -52,10 +51,33 @@ LABELS_BELEVING = config["labels"]["beleving"]
|
|
| 52 |
|
| 53 |
|
| 54 |
def prob_to_color(prob: float, threshold: float) -> str:
|
| 55 |
-
"""Generate CSS style for probability visualization"""
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
def format_topk(labels: list, probs: list, threshold: float, topk: int) -> str:
|
|
@@ -97,15 +119,17 @@ def predict(text: str, threshold: float, topk: int):
|
|
| 97 |
empty_msg = "<p style='color: #666; font-style: italic;'>Voer een bericht in om te classificeren...</p>"
|
| 98 |
return empty_msg, empty_msg, {}
|
| 99 |
|
| 100 |
-
# Tokenize
|
| 101 |
inputs = tokenizer(
|
| 102 |
text,
|
| 103 |
return_tensors="pt",
|
| 104 |
-
|
| 105 |
-
max_length=
|
| 106 |
-
truncation=True
|
| 107 |
)
|
| 108 |
|
|
|
|
|
|
|
|
|
|
| 109 |
# Move to device
|
| 110 |
input_ids = inputs["input_ids"].to(DEVICE)
|
| 111 |
attention_mask = inputs["attention_mask"].to(DEVICE)
|
|
@@ -132,6 +156,8 @@ def predict(text: str, threshold: float, topk: int):
|
|
| 132 |
# Generate JSON output
|
| 133 |
json_output = {
|
| 134 |
"text": text,
|
|
|
|
|
|
|
| 135 |
"threshold": threshold,
|
| 136 |
"onderwerp": {
|
| 137 |
"probabilities": {label: float(prob) for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs)},
|
|
@@ -146,6 +172,29 @@ def predict(text: str, threshold: float, topk: int):
|
|
| 146 |
return summary_html, all_labels_html, json_output
|
| 147 |
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
def load_examples():
|
| 150 |
"""Load example texts"""
|
| 151 |
try:
|
|
@@ -168,9 +217,9 @@ with gr.Blocks(title="WimBERT Synth v0", theme=gr.themes.Soft()) as demo:
|
|
| 168 |
input_text = gr.Textbox(
|
| 169 |
label="Signaalbericht (Nederlands)",
|
| 170 |
lines=8,
|
| 171 |
-
placeholder="Bijv: Ik kan niet parkeren bij mijn huis en de website voor vergunningen werkt niet..."
|
| 172 |
-
info="Voer een bericht in en klik op 'Voorspel'"
|
| 173 |
)
|
|
|
|
| 174 |
with gr.Row():
|
| 175 |
predict_btn = gr.Button("๐ฎ Voorspel", variant="primary", scale=2)
|
| 176 |
clear_btn = gr.ClearButton([input_text], value="๐๏ธ Wissen", scale=1)
|
|
@@ -195,7 +244,7 @@ with gr.Blocks(title="WimBERT Synth v0", theme=gr.themes.Soft()) as demo:
|
|
| 195 |
gr.Markdown(f"""
|
| 196 |
**Hardware:** {DEVICE.type.upper()}
|
| 197 |
**Dtype:** {DTYPE}
|
| 198 |
-
**Max length:** {
|
| 199 |
""")
|
| 200 |
|
| 201 |
with gr.Tabs():
|
|
@@ -225,6 +274,15 @@ with gr.Blocks(title="WimBERT Synth v0", theme=gr.themes.Soft()) as demo:
|
|
| 225 |
""")
|
| 226 |
|
| 227 |
# Event handlers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
predict_btn.click(
|
| 229 |
fn=predict,
|
| 230 |
inputs=[input_text, threshold_slider, topk_slider],
|
|
|
|
| 14 |
MODEL_REPO = "UWV/wimbert-synth-v0"
|
| 15 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
DTYPE = torch.float16 if DEVICE.type == "cuda" else torch.float32
|
|
|
|
| 17 |
|
| 18 |
print(f"๐ง Loading model from {MODEL_REPO}...")
|
| 19 |
print(f"๐ฅ๏ธ Device: {DEVICE} ({DTYPE})")
|
|
|
|
| 36 |
|
| 37 |
# Warm-up inference
|
| 38 |
with torch.no_grad():
|
| 39 |
+
dummy_input = tokenizer("Warm-up", return_tensors="pt", truncation=True,
|
| 40 |
+
max_length=config["max_length"])
|
| 41 |
_ = model.predict(
|
| 42 |
dummy_input["input_ids"].to(DEVICE),
|
| 43 |
dummy_input["attention_mask"].to(DEVICE)
|
| 44 |
)
|
| 45 |
|
| 46 |
+
print(f"โ
Model loaded and warmed up (max_length: {config['max_length']})")
|
| 47 |
|
| 48 |
# Extract label names
|
| 49 |
LABELS_ONDERWERP = config["labels"]["onderwerp"]
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
def prob_to_color(prob: float, threshold: float) -> str:
|
| 54 |
+
"""Generate CSS style for probability visualization (10X UX approved)"""
|
| 55 |
+
# Green gradient: low prob = very light green, high prob = saturated green
|
| 56 |
+
# Use HSL: Hue=145 (green), Saturation increases with prob, Lightness decreases
|
| 57 |
+
saturation = 30 + int(prob * 50) # 30% to 80%
|
| 58 |
+
lightness = 92 - int(prob * 55) # 92% to 37%
|
| 59 |
+
|
| 60 |
+
# Text color: white for dark backgrounds (prob > 0.6), dark for light
|
| 61 |
+
text_color = "#ffffff" if prob > 0.6 else "#1f2937"
|
| 62 |
+
|
| 63 |
+
# Border: thick + accent for predicted, subtle for others
|
| 64 |
+
if prob >= threshold:
|
| 65 |
+
border = "2px solid #059669"
|
| 66 |
+
box_shadow = "0 1px 3px rgba(5, 150, 105, 0.3)"
|
| 67 |
+
else:
|
| 68 |
+
border = "1px solid #d1d5db"
|
| 69 |
+
box_shadow = "none"
|
| 70 |
+
|
| 71 |
+
return (
|
| 72 |
+
f"background: hsl(145, {saturation}%, {lightness}%); "
|
| 73 |
+
f"color: {text_color}; "
|
| 74 |
+
f"border: {border}; "
|
| 75 |
+
f"box-shadow: {box_shadow}; "
|
| 76 |
+
f"padding: 6px 12px; "
|
| 77 |
+
f"border-radius: 4px; "
|
| 78 |
+
f"margin: 2px 0; "
|
| 79 |
+
f"font-weight: 500;"
|
| 80 |
+
)
|
| 81 |
|
| 82 |
|
| 83 |
def format_topk(labels: list, probs: list, threshold: float, topk: int) -> str:
|
|
|
|
| 119 |
empty_msg = "<p style='color: #666; font-style: italic;'>Voer een bericht in om te classificeren...</p>"
|
| 120 |
return empty_msg, empty_msg, {}
|
| 121 |
|
| 122 |
+
# Tokenize with dynamic length (only truncate if needed)
|
| 123 |
inputs = tokenizer(
|
| 124 |
text,
|
| 125 |
return_tensors="pt",
|
| 126 |
+
truncation=True,
|
| 127 |
+
max_length=config["max_length"] # 1408 from model config
|
|
|
|
| 128 |
)
|
| 129 |
|
| 130 |
+
# Get actual sequence length (non-padding tokens)
|
| 131 |
+
actual_length = inputs["attention_mask"].sum().item()
|
| 132 |
+
|
| 133 |
# Move to device
|
| 134 |
input_ids = inputs["input_ids"].to(DEVICE)
|
| 135 |
attention_mask = inputs["attention_mask"].to(DEVICE)
|
|
|
|
| 156 |
# Generate JSON output
|
| 157 |
json_output = {
|
| 158 |
"text": text,
|
| 159 |
+
"token_count": actual_length,
|
| 160 |
+
"max_length": config["max_length"],
|
| 161 |
"threshold": threshold,
|
| 162 |
"onderwerp": {
|
| 163 |
"probabilities": {label: float(prob) for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs)},
|
|
|
|
| 172 |
return summary_html, all_labels_html, json_output
|
| 173 |
|
| 174 |
|
| 175 |
+
def count_tokens(text: str) -> str:
|
| 176 |
+
"""Count tokens for live feedback"""
|
| 177 |
+
if not text or not text.strip():
|
| 178 |
+
return "๐ Tokens: 0 / 1408"
|
| 179 |
+
|
| 180 |
+
# Quick tokenization (no GPU needed, just counting)
|
| 181 |
+
tokens = tokenizer(text, truncation=True, max_length=config["max_length"])
|
| 182 |
+
actual_length = sum(tokens["attention_mask"])
|
| 183 |
+
|
| 184 |
+
# Color code based on usage
|
| 185 |
+
if actual_length > config["max_length"]:
|
| 186 |
+
color = "#dc2626" # Red: truncated
|
| 187 |
+
warning = " โ ๏ธ (truncated)"
|
| 188 |
+
elif actual_length > config["max_length"] * 0.8:
|
| 189 |
+
color = "#f59e0b" # Orange: getting long
|
| 190 |
+
warning = ""
|
| 191 |
+
else:
|
| 192 |
+
color = "#059669" # Green: all good
|
| 193 |
+
warning = ""
|
| 194 |
+
|
| 195 |
+
return f"<span style='color: {color}; font-size: 0.875rem; font-weight: 500;'>๐ Tokens: {actual_length} / {config['max_length']}{warning}</span>"
|
| 196 |
+
|
| 197 |
+
|
| 198 |
def load_examples():
|
| 199 |
"""Load example texts"""
|
| 200 |
try:
|
|
|
|
| 217 |
input_text = gr.Textbox(
|
| 218 |
label="Signaalbericht (Nederlands)",
|
| 219 |
lines=8,
|
| 220 |
+
placeholder="Bijv: Ik kan niet parkeren bij mijn huis en de website voor vergunningen werkt niet..."
|
|
|
|
| 221 |
)
|
| 222 |
+
token_counter = gr.HTML(value="<span style='color: #6b7280; font-size: 0.875rem;'>๐ Tokens: 0 / 1408</span>")
|
| 223 |
with gr.Row():
|
| 224 |
predict_btn = gr.Button("๐ฎ Voorspel", variant="primary", scale=2)
|
| 225 |
clear_btn = gr.ClearButton([input_text], value="๐๏ธ Wissen", scale=1)
|
|
|
|
| 244 |
gr.Markdown(f"""
|
| 245 |
**Hardware:** {DEVICE.type.upper()}
|
| 246 |
**Dtype:** {DTYPE}
|
| 247 |
+
**Max length:** {config['max_length']}
|
| 248 |
""")
|
| 249 |
|
| 250 |
with gr.Tabs():
|
|
|
|
| 274 |
""")
|
| 275 |
|
| 276 |
# Event handlers
|
| 277 |
+
|
| 278 |
+
# Live token counting as user types
|
| 279 |
+
input_text.change(
|
| 280 |
+
fn=count_tokens,
|
| 281 |
+
inputs=input_text,
|
| 282 |
+
outputs=token_counter
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Prediction on button click
|
| 286 |
predict_btn.click(
|
| 287 |
fn=predict,
|
| 288 |
inputs=[input_text, threshold_slider, topk_slider],
|