rntc's picture
Upload app.py with huggingface_hub
d2d1011 verified
raw
history blame
14.4 kB
"""
Gradio app to explore pancreas cancer clinical report annotations.
Loads data from rntc/biomed-fr-pancreas-annotations on HuggingFace.
"""
import gradio as gr
from datasets import load_dataset
from difflib import SequenceMatcher
# Load the dataset
print("Loading dataset from HuggingFace...")
dataset = load_dataset("rntc/biomed-fr-pancreas-annotations", split="train")
print(f"Loaded {len(dataset)} samples")
def fuzzy_find_span(text: str, span: str, threshold: float = 0.85) -> tuple:
"""
Find a span in text with fuzzy matching.
Returns (start, end) or None if not found.
"""
# First try exact match
idx = text.find(span)
if idx != -1:
return (idx, idx + len(span))
# Try fuzzy match with sliding window
span_len = len(span)
if span_len < 10 or span_len > len(text):
return None
best_ratio = 0
best_pos = None
# Use a window slightly larger than span
window_size = min(span_len + 20, len(text))
for i in range(0, len(text) - span_len + 1, max(1, span_len // 4)):
window = text[i:i + window_size]
ratio = SequenceMatcher(None, span, window[:span_len]).ratio()
if ratio > best_ratio and ratio >= threshold:
best_ratio = ratio
best_pos = i
if best_pos is not None:
return (best_pos, best_pos + span_len)
return None
def escape_html(text: str) -> str:
"""Escape HTML special characters."""
if not text:
return ""
return (str(text)
.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;"))
# Soft pastel colors for better readability
COLORS = [
"#FFE082", # amber
"#A5D6A7", # green
"#90CAF9", # blue
"#FFAB91", # deep orange
"#CE93D8", # purple
"#80DEEA", # cyan
"#C5E1A5", # light green
"#FFCC80", # orange
"#B39DDB", # deep purple
"#81D4FA", # light blue
"#EF9A9A", # red
"#FFF59D", # yellow
"#F48FB1", # pink
"#80CBC4", # teal
"#BCAAA4", # brown
]
def highlight_spans_in_text(cr_text: str, annotation: dict) -> str:
"""
Highlight spans in the CR text based on annotations.
Returns HTML with highlighted spans.
"""
if not cr_text or not annotation:
return f"<div class='cr-text'>{escape_html(cr_text)}</div>"
# Collect all spans with their variable names
spans_to_highlight = []
for var_name, var_data in annotation.items():
if var_data and isinstance(var_data, dict):
span = var_data.get("span")
value = var_data.get("value")
if span and value and len(span) >= 5: # Skip very short spans
spans_to_highlight.append({
"span": span,
"var_name": var_name,
"value": str(value)
})
if not spans_to_highlight:
return f"<div class='cr-text'>{escape_html(cr_text)}</div>"
# Sort spans by length (longest first) to prioritize longer matches
spans_to_highlight.sort(key=lambda x: len(x["span"]), reverse=True)
# Find spans in text (with fuzzy matching)
found_spans = []
for item in spans_to_highlight:
result = fuzzy_find_span(cr_text, item["span"])
if result:
start, end = result
found_spans.append({
"start": start,
"end": end,
"var_name": item["var_name"],
"value": item["value"],
"span": cr_text[start:end] # Use actual text from CR
})
if not found_spans:
return f"<div class='cr-text'>{escape_html(cr_text)}</div>"
# Sort by start position
found_spans.sort(key=lambda x: x["start"])
# Remove overlapping spans (keep the first/longest one)
non_overlapping = []
for span in found_spans:
if not non_overlapping:
non_overlapping.append(span)
elif span["start"] >= non_overlapping[-1]["end"]:
non_overlapping.append(span)
# Assign colors to variable names
var_colors = {}
color_idx = 0
for span in non_overlapping:
if span["var_name"] not in var_colors:
var_colors[span["var_name"]] = COLORS[color_idx % len(COLORS)]
color_idx += 1
# Build HTML with highlights
html_parts = []
last_end = 0
for span in non_overlapping:
# Add text before this span
if span["start"] > last_end:
html_parts.append(escape_html(cr_text[last_end:span["start"]]))
# Add highlighted span
color = var_colors[span["var_name"]]
var_label = span["var_name"].replace("_", " ").replace(" ", " ").title()
tooltip = f"{var_label}\\n→ {span['value']}"
html_parts.append(
f'<mark class="entity" style="background-color: {color};" '
f'title="{escape_html(tooltip)}" '
f'data-var="{escape_html(var_label)}">'
f'{escape_html(span["span"])}'
f'<span class="entity-label">{escape_html(var_label[:20])}</span>'
f'</mark>'
)
last_end = span["end"]
# Add remaining text
if last_end < len(cr_text):
html_parts.append(escape_html(cr_text[last_end:]))
html = "".join(html_parts)
return f"<div class='cr-text'>{html}</div>"
def format_annotations_table(annotation: dict) -> str:
"""Format annotations as an HTML table with categories."""
if not annotation:
return "<p>No annotations</p>"
# Group variables by category (simple heuristic based on name)
categories = {
"Patient Info": ["date_of_birth", "age_at_cancer_diagnosis", "biological_gender", "vital_status", "date_of_death"],
"Diagnosis": ["date_of_cancer_diagnostic", "primary_tumor_localisation", "ctnm_stage", "stage_as_per_ehr", "histological_type", "epithelial_tumor_subtype"],
"Tumor Characteristics": ["resectability_status", "two_largest_diameters", "metastasis_localisation", "number_of_metastatic_sites"],
"Lab Results": ["crp_at_diagnosis", "albumin_at_diagnosis", "alanine_transaminase", "aspartate_aminotransferase", "conjugated_bilirubin", "ca19_9"],
"Treatment": ["surgery", "loco_regional_radiotherapy", "immunotherapy", "targeted_therapy", "full_course_of_initial_treatment"],
"Molecular": ["germline_mutation", "tumor_molecular_profiling"],
"Progression": ["date_of_first_progression", "type_of_first_progression", "treatment_at_first_progression", "best_response", "reason_for_treatment_end"],
}
def get_category(var_name):
for cat, keywords in categories.items():
for kw in keywords:
if kw in var_name.lower():
return cat
return "Other"
# Group rows by category
categorized = {}
for var_name, var_data in annotation.items():
if var_data and isinstance(var_data, dict):
value = var_data.get("value")
if value:
cat = get_category(var_name)
if cat not in categorized:
categorized[cat] = []
categorized[cat].append((var_name, var_data))
if not categorized:
return "<p class='no-data'>No extracted values</p>"
html_parts = []
for category in ["Patient Info", "Diagnosis", "Tumor Characteristics", "Lab Results", "Treatment", "Molecular", "Progression", "Other"]:
if category not in categorized:
continue
html_parts.append(f"<div class='category'><h4>{category}</h4>")
html_parts.append("<table class='annotations-table'>")
for var_name, var_data in categorized[category]:
value = var_data.get("value", "")
span = var_data.get("span", "")
var_label = var_name.replace("_", " ").title()
span_preview = span[:80] + "..." if span and len(span) > 80 else span
html_parts.append(f"""
<tr>
<td class='var-name'>{escape_html(var_label)}</td>
<td class='var-value'>{escape_html(str(value))}</td>
<td class='var-span'>{escape_html(span_preview) if span_preview else '-'}</td>
</tr>
""")
html_parts.append("</table></div>")
return "".join(html_parts)
def get_stats(annotation: dict) -> str:
"""Get statistics about extracted values."""
if not annotation:
return "No data"
total = len(annotation)
extracted = sum(1 for v in annotation.values() if v and isinstance(v, dict) and v.get("value"))
return f"πŸ“Š Extracted: {extracted}/{total} variables ({100*extracted//total}%)"
def display_sample(sample_idx: int):
"""Display a sample from the dataset."""
if sample_idx < 0 or sample_idx >= len(dataset):
return "Invalid sample index", "<p>Invalid sample index</p>", "Invalid"
sample = dataset[int(sample_idx)]
cr_text = sample.get("CR", "")
annotation = sample.get("annotation", {})
highlighted_html = highlight_spans_in_text(cr_text, annotation)
annotations_html = format_annotations_table(annotation)
stats = get_stats(annotation)
return highlighted_html, annotations_html, stats
def search_samples(query: str):
"""Search samples by text content."""
if not query or len(query) < 3:
# Return first 20 samples
return [[i, dataset[i]["CR"][:80] + "..."] for i in range(min(20, len(dataset)))]
results = []
query_lower = query.lower()
for i, sample in enumerate(dataset):
cr = sample.get("CR", "")
if query_lower in cr.lower():
results.append([i, cr[:80] + "..."])
if len(results) >= 50:
break
if not results:
return [["No results", f"No samples found containing '{query}'"]]
return results
# Custom CSS for better styling
custom_css = """
.cr-text {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
font-size: 14px;
line-height: 1.8;
padding: 20px;
background: #fafafa;
border-radius: 8px;
white-space: pre-wrap;
max-height: 500px;
overflow-y: auto;
}
.entity {
padding: 2px 6px;
border-radius: 4px;
cursor: help;
position: relative;
display: inline;
transition: all 0.2s;
}
.entity:hover {
filter: brightness(0.9);
box-shadow: 0 2px 8px rgba(0,0,0,0.15);
}
.entity-label {
display: none;
position: absolute;
bottom: 100%;
left: 0;
background: #333;
color: white;
padding: 4px 8px;
border-radius: 4px;
font-size: 11px;
white-space: nowrap;
z-index: 100;
}
.entity:hover .entity-label {
display: block;
}
.category {
margin-bottom: 20px;
}
.category h4 {
color: #1976d2;
border-bottom: 2px solid #1976d2;
padding-bottom: 8px;
margin-bottom: 12px;
}
.annotations-table {
width: 100%;
border-collapse: collapse;
font-size: 13px;
}
.annotations-table tr:nth-child(even) {
background: #f5f5f5;
}
.annotations-table td {
padding: 10px 12px;
border-bottom: 1px solid #e0e0e0;
vertical-align: top;
}
.var-name {
font-weight: 600;
color: #333;
width: 30%;
}
.var-value {
color: #1976d2;
font-weight: 500;
width: 25%;
}
.var-span {
color: #666;
font-style: italic;
font-size: 12px;
width: 45%;
}
.no-data {
color: #999;
font-style: italic;
padding: 20px;
text-align: center;
}
.stats-badge {
background: #e3f2fd;
color: #1976d2;
padding: 8px 16px;
border-radius: 20px;
font-weight: 500;
display: inline-block;
}
"""
# Build the Gradio interface
with gr.Blocks(
title="Pancreas Cancer Annotations Explorer",
theme=gr.themes.Soft(primary_hue="blue"),
css=custom_css
) as demo:
gr.Markdown("""
# πŸ”¬ Pancreas Cancer Clinical Report Annotations Explorer
Explore structured annotations extracted from synthetic French clinical reports about pancreas cancer.
**How to use:**
- Use the slider or search to navigate samples
- Hover over highlighted text to see extracted variables
- View the complete annotation table below
""")
with gr.Row():
with gr.Column(scale=2):
sample_slider = gr.Slider(
minimum=0,
maximum=len(dataset) - 1,
step=1,
value=0,
label=f"πŸ“Œ Sample Index (0 - {len(dataset) - 1})",
info="Drag to browse samples"
)
with gr.Column(scale=1):
stats_display = gr.Markdown("", elem_classes=["stats-badge"])
with gr.Row():
with gr.Column(scale=1):
search_box = gr.Textbox(
label="πŸ” Search",
placeholder="Type to search in clinical reports...",
info="Min 3 characters"
)
search_results = gr.Dataframe(
headers=["#", "Preview"],
label="Results",
interactive=False,
height=200
)
gr.Markdown("---")
gr.Markdown("### πŸ“„ Clinical Report with Entity Highlighting")
gr.Markdown("*Hover over colored text to see the extracted variable and value*")
cr_display = gr.HTML()
gr.Markdown("---")
gr.Markdown("### πŸ“Š Extracted Annotations")
annotations_display = gr.HTML()
# Event handlers
sample_slider.change(
fn=display_sample,
inputs=[sample_slider],
outputs=[cr_display, annotations_display, stats_display]
)
search_box.change(
fn=search_samples,
inputs=[search_box],
outputs=[search_results]
)
def on_select(evt: gr.SelectData, data):
if data is not None and len(data) > 0:
try:
selected_idx = int(data[evt.index[0]][0])
return selected_idx
except (ValueError, IndexError, TypeError):
pass
return 0
search_results.select(
fn=on_select,
inputs=[search_results],
outputs=[sample_slider]
)
# Load first sample on start
demo.load(
fn=display_sample,
inputs=[sample_slider],
outputs=[cr_display, annotations_display, stats_display]
)
if __name__ == "__main__":
demo.launch()