yhavinga commited on
Commit
1734421
ยท
1 Parent(s): 85efe28

Improve UX: dynamic tokens, better colors, live feedback

Browse files

- Use dynamic sequence length (no fixed padding)
- Show live token counter below textarea with color coding
- Switch to green gradient palette with proper text contrast
- Add token count to JSON output
- Remove 512 token limit, use full 1408 from model config

Files changed (2) hide show
  1. PLAN.md +145 -0
  2. app.py +73 -15
PLAN.md ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WimBERT Synth v0 โ€” Hugging Face Space Plan
2
+
3
+ This plan describes a lightweight, reliable Space to demo the dualโ€‘head multiโ€‘label classifier (onderwerp + beleving) defined by `wimbert-synth-v0/model.py`, with labels from `wimbert-synth-v0/label_names.json` and licensing in `wimbert-synth-v0/LICENSE` (Apacheโ€‘2.0).
4
+
5
+ ## Goals
6
+ - Input: a single Dutch โ€œsignaalberichtโ€ (freeโ€‘text).
7
+ - Output: per head (onderwerp, beleving), show probabilities for all labels:
8
+ - Visual: colorโ€‘coded list/table where color intensity reflects probability.
9
+ - Numeric: exact probability values (0โ€“1) and topโ€‘K summary.
10
+ - โ€œPredictedโ€ set using an adjustable threshold (default 0.5).
11
+ - UX: oneโ€‘click Predict button; optional โ€œliveโ€ inference (after brief inactivity).
12
+ - Portable, reproducible, and fast enough on CPU; optionally GPUโ€‘ready.
13
+
14
+ ## Toolkit Choice
15
+ - Gradio is the best fit for this demo on Spaces:
16
+ - Firstโ€‘class support on Hugging Face Spaces, minimal boilerplate (`app.py`).
17
+ - Simple event model (button click, input change) and components for text, tabs, HTML, charts.
18
+ - Easy to serve both a compact topโ€‘K view and a full โ€œall labelsโ€ view with custom styling.
19
+ - No Streamlit server/page lifecycle complexities for this small, singleโ€‘page inference app.
20
+
21
+ ## Model + License
22
+ - Model artifacts live in `wimbert-synth-v0/` with Apacheโ€‘2.0 license (redistribution permitted with attribution). Use the exact `LICENSE` in the Space repo.
23
+ - The model is large (~1.2 GB for `model.safetensors`). To keep the Space repo small and boot times predictable, prefer hosting the model as a separate Model repo on the Hub, then download/cache in the Space at runtime.
24
+ - Recommended: publish a model repo, e.g. `UWV/wimbert-synth-v0`, containing:
25
+ - `model.safetensors`, `config.json`, tokenizer files, `dual_head_state.pt`, `label_names.json`, `model.py`, `README.md`, `LICENSE`.
26
+ - The Space loads via `DualHeadModel.from_pretrained(<model_repo_or_local_dir>)`.
27
+
28
+ ## UX & Visualization
29
+ - Input: `gr.Textbox(label="Signaalbericht", lines=6, placeholder=...)`.
30
+ - Controls:
31
+ - `Predict` button (primary path).
32
+ - `Auto-run` toggle to enable live inference: trigger after user stops typing for ~600โ€“800 ms (using Gradioโ€™s input event with debounce or a simple timer wrapper). If performance on CPU is borderline, keep off by default.
33
+ - `Threshold` slider (0.0โ€“1.0, default 0.5) to highlight predicted labels.
34
+ - `Topโ€‘K` slider (1โ€“15, default 5) to size the summary.
35
+ - Output: tabs per head and views:
36
+ - Tab 1: โ€œSamenvattingโ€ โ†’ two columns for Onderwerp and Beleving, each listing Topโ€‘K labels with probabilities.
37
+ - Tab 2: โ€œAlle labelsโ€ โ†’ scrollable, colorโ€‘coded tables (or HTML lists) for every label with exact probabilities.
38
+ - Tab 3: โ€œJSON/CSVโ€ โ†’ exportable raw probabilities (dict of label โ†’ prob) + list of predicted labels at current threshold.
39
+ - Color mapping:
40
+ - Use a lightโ€‘toโ€‘dark monochrome (e.g., blue/green) where intensity โˆ probability; add a subtle border for > threshold.
41
+ - Ensure text contrast (AA) and include numbers to avoid relying on color alone (accessibility).
42
+
43
+ ## Space Layout
44
+ - Repo root (Space):
45
+ - `app.py` โ€” Gradio app with UI + inference.
46
+ - `requirements.txt` โ€” runtime deps.
47
+ - `README.md` โ€” usage, model card link, privacy note.
48
+ - `LICENSE` โ€” Apacheโ€‘2.0 (from `wimbert-synth-v0/LICENSE`).
49
+ - Optional: `assets/` (logo), `examples/` (preset texts), `.gitattributes`.
50
+ - The model is not vendored into the Space to avoid 1.2 GB LFS; itโ€™s pulled at startup via `huggingface_hub.snapshot_download` or `from_pretrained` on the Hub repo.
51
+
52
+ ## Dependencies
53
+ - `gradio>=4.0`
54
+ - `transformers>=4.40`
55
+ - `torch` (CPU is fine; GPU preferred if available)
56
+ - `safetensors`, `huggingface_hub`
57
+ - Optional perf: `accelerate` (device placement), `onnxruntime`/`optimum` (future optimization)
58
+
59
+ ## Inference Design
60
+ - Load once at Space start (global singleton). Warm up with a short dummy input.
61
+ - Device: choose `cuda` if available, else CPU. Cast to `float16` on GPU; keep `float32` on CPU.
62
+ - Tokenization: use `max_length` from `dual_head_state.pt` config; allow truncation; optionally expose a compact/fast mode (e.g., cap at 512) if CPU latency needs improvement.
63
+ - Output structures:
64
+ - Dicts for each head: `[ {label, prob, predicted} ... ]` with `predicted = prob >= threshold`.
65
+ - Topโ€‘K lists derived from the sorted full list.
66
+ - Visualization adapters render the above into: HTML tables (for colorโ€‘coding), and JSON/CSV text.
67
+
68
+ ## Event Flow
69
+ 1. User edits text.
70
+ 2. If Autoโ€‘run enabled, debounce and run; else wait for Predict button.
71
+ 3. Tokenize โ†’ model.predict โ†’ probs (two tensors).
72
+ 4. Sort, slice to Topโ€‘K summary and prepare full tables.
73
+ 5. Render to tabs and compact โ€œPredicted labelsโ€ chips (one line per head).
74
+
75
+ ## Pseudocode Sketch (app.py)
76
+ ```python
77
+ import gradio as gr
78
+ import torch, json, importlib.util
79
+ from huggingface_hub import snapshot_download
80
+
81
+ MODEL_REPO = "UWV/wimbert-synth-v0"
82
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+
84
+ # Download/copy model folder and import DualHeadModel
85
+ model_dir = snapshot_download(MODEL_REPO)
86
+ spec = importlib.util.spec_from_file_location("model", f"{model_dir}/model.py")
87
+ model_mod = importlib.util.module_from_spec(spec); spec.loader.exec_module(model_mod)
88
+ DualHeadModel = model_mod.DualHeadModel
89
+ model, tokenizer, cfg = DualHeadModel.from_pretrained(model_dir, device=DEVICE)
90
+
91
+ # Warm-up
92
+ _ = model.predict(*tokenizer("Hoi", return_tensors="pt", padding="max_length", max_length=cfg["max_length"]).values())
93
+
94
+ def predict(text, threshold, topk):
95
+ enc = tokenizer(text or "", truncation=True, padding="max_length", max_length=cfg["max_length"], return_tensors="pt")
96
+ on_p, be_p = model.predict(enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE))
97
+ # Convert to python lists and build views ...
98
+ return topk_view, all_labels_html, json_text
99
+
100
+ with gr.Blocks(title="WimBERT Synth v0") as demo:
101
+ # Inputs, controls, tabs, outputs ...
102
+ ...
103
+
104
+ if __name__ == "__main__":
105
+ demo.launch()
106
+ ```
107
+
108
+ ## Performance Notes
109
+ - CPU on free Spaces will work but can be slow for long texts (base mmBERT at `max_lengthโ‰ˆ1408`). Mitigations:
110
+ - Warmโ€‘up once; cap max length to 512 in a โ€œfast modeโ€ toggle; show spinner while running.
111
+ - Prefer a small GPU (T4 small) if available; cast to fp16 on GPU.
112
+ - Caching: `snapshot_download` uses the shared cache; subsequent restarts are faster.
113
+
114
+ ## Privacy & Safety
115
+ - The Space processes user text in memory only; no logging beyond Gradio defaults. Mention this in the Space README.
116
+ - Include a โ€œUse responsiblyโ€ note (analytics/routing aid; no automated decisions) mirroring the model card.
117
+
118
+ ## Deliverables
119
+ - `app.py` with:
120
+ - Robust model loading (Hub), device selection, warmโ€‘up.
121
+ - Predict function returning: topโ€‘K per head, full colored table, JSON dump.
122
+ - UI: textbox, Predict button, Autoโ€‘run toggle (debounced), threshold & Topโ€‘K sliders, tabs per view.
123
+ - Example(s) from the model card (`widget` example) via `gr.Examples`.
124
+ - `requirements.txt` (gradio, transformers, torch, huggingface_hub, safetensors).
125
+ - `README.md` with screenshots, hardware recommendation, and links to the model card.
126
+ - `LICENSE` copied from `wimbert-synth-v0/LICENSE`.
127
+
128
+ ## Stepโ€‘Byโ€‘Step
129
+ 1) Publish/verify model on Hub (`UWV/wimbert-synth-v0`), including `model.py` and license.
130
+ 2) Create Space repo with SDK=Gradio and pick hardware (CPU โ†’ OK; GPU โ†’ faster).
131
+ 3) Add Space files (`app.py`, `requirements.txt`, `README.md`, `LICENSE`).
132
+ 4) Implement and test inference locally (CPU) with a few sample texts; tune debounce/threshold defaults.
133
+ 5) Push Space; verify coldโ€‘start time and inference latency; adjust max_length and hardware if needed.
134
+ 6) Polish visuals (colors, fonts, accessibility), add screenshots, and publish.
135
+
136
+ ## Niceโ€‘Toโ€‘Haves (Later)
137
+ - Perโ€‘class thresholds (if you decide to introduce learned or tuned thresholds).
138
+ - ONNX/Optimum path for CPU acceleration.
139
+ - Sessionโ€‘level analytics (aggregate latency, not storing user text).
140
+ - Download CSV/JSON of the current result.
141
+ - Translations for UI labels (NL/EN toggle).
142
+
143
+ ```
144
+ Summary: Use Gradio for a singleโ€‘page Space that downloads the Apacheโ€‘licensed model from the Hub, offers both buttonโ€‘based and debounced live inference, and presents perโ€‘head probabilities as colorโ€‘coded tables with numeric values, plus topโ€‘K and JSON outputs.
145
+ ```
app.py CHANGED
@@ -14,7 +14,6 @@ from huggingface_hub import snapshot_download
14
  MODEL_REPO = "UWV/wimbert-synth-v0"
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  DTYPE = torch.float16 if DEVICE.type == "cuda" else torch.float32
17
- MAX_LENGTH = 512 # Default to 512 for better CPU performance
18
 
19
  print(f"๐Ÿ”ง Loading model from {MODEL_REPO}...")
20
  print(f"๐Ÿ–ฅ๏ธ Device: {DEVICE} ({DTYPE})")
@@ -37,14 +36,14 @@ if DTYPE == torch.float16:
37
 
38
  # Warm-up inference
39
  with torch.no_grad():
40
- dummy_input = tokenizer("Warm-up", return_tensors="pt", padding="max_length",
41
- max_length=MAX_LENGTH, truncation=True)
42
  _ = model.predict(
43
  dummy_input["input_ids"].to(DEVICE),
44
  dummy_input["attention_mask"].to(DEVICE)
45
  )
46
 
47
- print(f"โœ… Model loaded and warmed up")
48
 
49
  # Extract label names
50
  LABELS_ONDERWERP = config["labels"]["onderwerp"]
@@ -52,10 +51,33 @@ LABELS_BELEVING = config["labels"]["beleving"]
52
 
53
 
54
  def prob_to_color(prob: float, threshold: float) -> str:
55
- """Generate CSS style for probability visualization"""
56
- lightness = 95 - int(prob * 65)
57
- border = "2px solid #1e3a8a" if prob >= threshold else "1px solid #e5e7eb"
58
- return f"background: hsl(210, 80%, {lightness}%); border: {border}; padding: 6px 12px; border-radius: 4px; margin: 2px 0;"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
  def format_topk(labels: list, probs: list, threshold: float, topk: int) -> str:
@@ -97,15 +119,17 @@ def predict(text: str, threshold: float, topk: int):
97
  empty_msg = "<p style='color: #666; font-style: italic;'>Voer een bericht in om te classificeren...</p>"
98
  return empty_msg, empty_msg, {}
99
 
100
- # Tokenize
101
  inputs = tokenizer(
102
  text,
103
  return_tensors="pt",
104
- padding="max_length",
105
- max_length=MAX_LENGTH,
106
- truncation=True
107
  )
108
 
 
 
 
109
  # Move to device
110
  input_ids = inputs["input_ids"].to(DEVICE)
111
  attention_mask = inputs["attention_mask"].to(DEVICE)
@@ -132,6 +156,8 @@ def predict(text: str, threshold: float, topk: int):
132
  # Generate JSON output
133
  json_output = {
134
  "text": text,
 
 
135
  "threshold": threshold,
136
  "onderwerp": {
137
  "probabilities": {label: float(prob) for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs)},
@@ -146,6 +172,29 @@ def predict(text: str, threshold: float, topk: int):
146
  return summary_html, all_labels_html, json_output
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def load_examples():
150
  """Load example texts"""
151
  try:
@@ -168,9 +217,9 @@ with gr.Blocks(title="WimBERT Synth v0", theme=gr.themes.Soft()) as demo:
168
  input_text = gr.Textbox(
169
  label="Signaalbericht (Nederlands)",
170
  lines=8,
171
- placeholder="Bijv: Ik kan niet parkeren bij mijn huis en de website voor vergunningen werkt niet...",
172
- info="Voer een bericht in en klik op 'Voorspel'"
173
  )
 
174
  with gr.Row():
175
  predict_btn = gr.Button("๐Ÿ”ฎ Voorspel", variant="primary", scale=2)
176
  clear_btn = gr.ClearButton([input_text], value="๐Ÿ—‘๏ธ Wissen", scale=1)
@@ -195,7 +244,7 @@ with gr.Blocks(title="WimBERT Synth v0", theme=gr.themes.Soft()) as demo:
195
  gr.Markdown(f"""
196
  **Hardware:** {DEVICE.type.upper()}
197
  **Dtype:** {DTYPE}
198
- **Max length:** {MAX_LENGTH}
199
  """)
200
 
201
  with gr.Tabs():
@@ -225,6 +274,15 @@ with gr.Blocks(title="WimBERT Synth v0", theme=gr.themes.Soft()) as demo:
225
  """)
226
 
227
  # Event handlers
 
 
 
 
 
 
 
 
 
228
  predict_btn.click(
229
  fn=predict,
230
  inputs=[input_text, threshold_slider, topk_slider],
 
14
  MODEL_REPO = "UWV/wimbert-synth-v0"
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  DTYPE = torch.float16 if DEVICE.type == "cuda" else torch.float32
 
17
 
18
  print(f"๐Ÿ”ง Loading model from {MODEL_REPO}...")
19
  print(f"๐Ÿ–ฅ๏ธ Device: {DEVICE} ({DTYPE})")
 
36
 
37
  # Warm-up inference
38
  with torch.no_grad():
39
+ dummy_input = tokenizer("Warm-up", return_tensors="pt", truncation=True,
40
+ max_length=config["max_length"])
41
  _ = model.predict(
42
  dummy_input["input_ids"].to(DEVICE),
43
  dummy_input["attention_mask"].to(DEVICE)
44
  )
45
 
46
+ print(f"โœ… Model loaded and warmed up (max_length: {config['max_length']})")
47
 
48
  # Extract label names
49
  LABELS_ONDERWERP = config["labels"]["onderwerp"]
 
51
 
52
 
53
  def prob_to_color(prob: float, threshold: float) -> str:
54
+ """Generate CSS style for probability visualization (10X UX approved)"""
55
+ # Green gradient: low prob = very light green, high prob = saturated green
56
+ # Use HSL: Hue=145 (green), Saturation increases with prob, Lightness decreases
57
+ saturation = 30 + int(prob * 50) # 30% to 80%
58
+ lightness = 92 - int(prob * 55) # 92% to 37%
59
+
60
+ # Text color: white for dark backgrounds (prob > 0.6), dark for light
61
+ text_color = "#ffffff" if prob > 0.6 else "#1f2937"
62
+
63
+ # Border: thick + accent for predicted, subtle for others
64
+ if prob >= threshold:
65
+ border = "2px solid #059669"
66
+ box_shadow = "0 1px 3px rgba(5, 150, 105, 0.3)"
67
+ else:
68
+ border = "1px solid #d1d5db"
69
+ box_shadow = "none"
70
+
71
+ return (
72
+ f"background: hsl(145, {saturation}%, {lightness}%); "
73
+ f"color: {text_color}; "
74
+ f"border: {border}; "
75
+ f"box-shadow: {box_shadow}; "
76
+ f"padding: 6px 12px; "
77
+ f"border-radius: 4px; "
78
+ f"margin: 2px 0; "
79
+ f"font-weight: 500;"
80
+ )
81
 
82
 
83
  def format_topk(labels: list, probs: list, threshold: float, topk: int) -> str:
 
119
  empty_msg = "<p style='color: #666; font-style: italic;'>Voer een bericht in om te classificeren...</p>"
120
  return empty_msg, empty_msg, {}
121
 
122
+ # Tokenize with dynamic length (only truncate if needed)
123
  inputs = tokenizer(
124
  text,
125
  return_tensors="pt",
126
+ truncation=True,
127
+ max_length=config["max_length"] # 1408 from model config
 
128
  )
129
 
130
+ # Get actual sequence length (non-padding tokens)
131
+ actual_length = inputs["attention_mask"].sum().item()
132
+
133
  # Move to device
134
  input_ids = inputs["input_ids"].to(DEVICE)
135
  attention_mask = inputs["attention_mask"].to(DEVICE)
 
156
  # Generate JSON output
157
  json_output = {
158
  "text": text,
159
+ "token_count": actual_length,
160
+ "max_length": config["max_length"],
161
  "threshold": threshold,
162
  "onderwerp": {
163
  "probabilities": {label: float(prob) for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs)},
 
172
  return summary_html, all_labels_html, json_output
173
 
174
 
175
+ def count_tokens(text: str) -> str:
176
+ """Count tokens for live feedback"""
177
+ if not text or not text.strip():
178
+ return "๐Ÿ“ Tokens: 0 / 1408"
179
+
180
+ # Quick tokenization (no GPU needed, just counting)
181
+ tokens = tokenizer(text, truncation=True, max_length=config["max_length"])
182
+ actual_length = sum(tokens["attention_mask"])
183
+
184
+ # Color code based on usage
185
+ if actual_length > config["max_length"]:
186
+ color = "#dc2626" # Red: truncated
187
+ warning = " โš ๏ธ (truncated)"
188
+ elif actual_length > config["max_length"] * 0.8:
189
+ color = "#f59e0b" # Orange: getting long
190
+ warning = ""
191
+ else:
192
+ color = "#059669" # Green: all good
193
+ warning = ""
194
+
195
+ return f"<span style='color: {color}; font-size: 0.875rem; font-weight: 500;'>๐Ÿ“ Tokens: {actual_length} / {config['max_length']}{warning}</span>"
196
+
197
+
198
  def load_examples():
199
  """Load example texts"""
200
  try:
 
217
  input_text = gr.Textbox(
218
  label="Signaalbericht (Nederlands)",
219
  lines=8,
220
+ placeholder="Bijv: Ik kan niet parkeren bij mijn huis en de website voor vergunningen werkt niet..."
 
221
  )
222
+ token_counter = gr.HTML(value="<span style='color: #6b7280; font-size: 0.875rem;'>๐Ÿ“ Tokens: 0 / 1408</span>")
223
  with gr.Row():
224
  predict_btn = gr.Button("๐Ÿ”ฎ Voorspel", variant="primary", scale=2)
225
  clear_btn = gr.ClearButton([input_text], value="๐Ÿ—‘๏ธ Wissen", scale=1)
 
244
  gr.Markdown(f"""
245
  **Hardware:** {DEVICE.type.upper()}
246
  **Dtype:** {DTYPE}
247
+ **Max length:** {config['max_length']}
248
  """)
249
 
250
  with gr.Tabs():
 
274
  """)
275
 
276
  # Event handlers
277
+
278
+ # Live token counting as user types
279
+ input_text.change(
280
+ fn=count_tokens,
281
+ inputs=input_text,
282
+ outputs=token_counter
283
+ )
284
+
285
+ # Prediction on button click
286
  predict_btn.click(
287
  fn=predict,
288
  inputs=[input_text, threshold_slider, topk_slider],