Spaces:
Running
on
Zero
Running
on
Zero
Add active selections box
#2
by
songs1
- opened
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
|
|
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
import requests
|
|
@@ -11,9 +12,9 @@ from guidance import json as gen_json
|
|
| 11 |
from guidance.models import Transformers
|
| 12 |
from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed
|
| 13 |
|
| 14 |
-
from schema import GDCCohortSchema
|
| 15 |
|
| 16 |
-
DEBUG =
|
| 17 |
EXAMPLE_INPUTS = [
|
| 18 |
"bam files for TCGA-BRCA",
|
| 19 |
"kidney or adrenal gland cancers with alcohol history",
|
|
@@ -23,7 +24,7 @@ EXAMPLE_INPUTS = [
|
|
| 23 |
GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases"
|
| 24 |
MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M"
|
| 25 |
TOKENIZER_NAME = MODEL_NAME
|
| 26 |
-
AUTH_TOKEN = os.environ.get("HF_TOKEN", False)
|
| 27 |
|
| 28 |
with open("config.yaml", "r") as f:
|
| 29 |
CONFIG = yaml.safe_load(f)
|
|
@@ -380,6 +381,79 @@ def update_cards_with_counts(cohort_filter: str, *selected_filters_per_card):
|
|
| 380 |
return card_updates + [gr.update(value=f"{case_count} Cases")]
|
| 381 |
|
| 382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
def prepare_value_count(value, count):
|
| 384 |
return f"{value} [{count}]"
|
| 385 |
|
|
@@ -448,6 +522,12 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
| 448 |
with gr.Column(scale=7):
|
| 449 |
text_input = gr.Textbox(
|
| 450 |
label="Describe the cohort you're looking for:",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
submit_btn="Generate Cohort",
|
| 452 |
elem_id="description-input",
|
| 453 |
placeholder="Enter a cohort description to begin...",
|
|
@@ -483,9 +563,21 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
| 483 |
elem_id="json-output",
|
| 484 |
)
|
| 485 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
with gr.Row():
|
| 487 |
gr.Markdown(
|
| 488 |
"The generated cohort filter will autopopulate into the filter cards below. "
|
|
|
|
| 489 |
"Refine your search using the interactive checkboxes. "
|
| 490 |
"Note that many other options can be found by selecting the different tabs on the left."
|
| 491 |
)
|
|
@@ -576,6 +668,10 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
| 576 |
fn=process_query,
|
| 577 |
inputs=text_input,
|
| 578 |
outputs=filter_cards + [json_output],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
)
|
| 580 |
|
| 581 |
# Update JSON based on cards
|
|
@@ -587,14 +683,33 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
| 587 |
fn=update_json_from_cards,
|
| 588 |
inputs=filter_cards,
|
| 589 |
outputs=json_output,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
)
|
| 591 |
else:
|
| 592 |
filter_card.input(
|
| 593 |
fn=update_json_from_cards,
|
| 594 |
inputs=filter_cards,
|
| 595 |
outputs=json_output,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
)
|
| 597 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
# Update checkboxes after executing filter query
|
| 599 |
json_output.change(
|
| 600 |
fn=update_cards_with_counts,
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
+
from collections import defaultdict
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import requests
|
|
|
|
| 12 |
from guidance.models import Transformers
|
| 13 |
from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed
|
| 14 |
|
| 15 |
+
from schema import GDCCohortSchema # isort: skip
|
| 16 |
|
| 17 |
+
DEBUG = "DEBUG" in os.environ
|
| 18 |
EXAMPLE_INPUTS = [
|
| 19 |
"bam files for TCGA-BRCA",
|
| 20 |
"kidney or adrenal gland cancers with alcohol history",
|
|
|
|
| 24 |
GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases"
|
| 25 |
MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M"
|
| 26 |
TOKENIZER_NAME = MODEL_NAME
|
| 27 |
+
AUTH_TOKEN = os.environ.get("HF_TOKEN", False) # HF_TOKEN must be set to use auth
|
| 28 |
|
| 29 |
with open("config.yaml", "r") as f:
|
| 30 |
CONFIG = yaml.safe_load(f)
|
|
|
|
| 381 |
return card_updates + [gr.update(value=f"{case_count} Cases")]
|
| 382 |
|
| 383 |
|
| 384 |
+
def update_active_selections(*selected_filters_per_card):
|
| 385 |
+
choices = []
|
| 386 |
+
for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
|
| 387 |
+
# use the default values to determine card type (checkbox, range, etc)
|
| 388 |
+
default_values = CARD_2_VALUES[card_name]
|
| 389 |
+
if isinstance(default_values, list):
|
| 390 |
+
# checkbox
|
| 391 |
+
for selected_value in selected_filters:
|
| 392 |
+
base_value = get_base_value(selected_value)
|
| 393 |
+
choices.append(f"{card_name.upper()}: {base_value}")
|
| 394 |
+
elif isinstance(default_values, dict):
|
| 395 |
+
# range-slider, maybe other options in the future?
|
| 396 |
+
assert (
|
| 397 |
+
default_values["type"] == "range"
|
| 398 |
+
), f"Expected range slider for card {card_name}"
|
| 399 |
+
lo, hi = selected_filters
|
| 400 |
+
if lo != default_values["min"] or hi != default_values["max"]:
|
| 401 |
+
# only add range filter if not default
|
| 402 |
+
lo, hi = int(lo), int(hi)
|
| 403 |
+
choices.append(f"{card_name.upper()}: {lo}-{hi}")
|
| 404 |
+
else:
|
| 405 |
+
raise ValueError(f"Unknown values for card {card_name}")
|
| 406 |
+
|
| 407 |
+
return gr.update(choices=choices, value=choices)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def update_cards_from_active(current_selections, *selected_filters_per_card):
|
| 411 |
+
# active selector uses a flattened list so re-agg values under card groups
|
| 412 |
+
grouped_selections = defaultdict(set)
|
| 413 |
+
for k_v in current_selections:
|
| 414 |
+
idx = k_v.find(": ")
|
| 415 |
+
k, v = k_v[:idx], k_v[idx + 2 :]
|
| 416 |
+
grouped_selections[k].add(v)
|
| 417 |
+
|
| 418 |
+
card_updates = []
|
| 419 |
+
for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
|
| 420 |
+
# use the default values to determine card type (checkbox, range, etc)
|
| 421 |
+
default_values = CARD_2_VALUES[card_name]
|
| 422 |
+
if isinstance(default_values, list):
|
| 423 |
+
# checkbox
|
| 424 |
+
updated_values = []
|
| 425 |
+
for selected_value in selected_filters:
|
| 426 |
+
base_value = get_base_value(selected_value)
|
| 427 |
+
if base_value in grouped_selections[card_name.upper()]:
|
| 428 |
+
updated_values.append(selected_value)
|
| 429 |
+
update_obj = gr.update(value=updated_values)
|
| 430 |
+
elif isinstance(default_values, dict):
|
| 431 |
+
# range-slider, maybe other options in the future?
|
| 432 |
+
assert (
|
| 433 |
+
default_values["type"] == "range"
|
| 434 |
+
), f"Expected range slider for card {card_name}"
|
| 435 |
+
# the active selector cannot change range values
|
| 436 |
+
# so if present as an active selection, no action is needed
|
| 437 |
+
# otherwise, reset entire range selector
|
| 438 |
+
if card_name.upper() in grouped_selections:
|
| 439 |
+
update_obj = gr.update()
|
| 440 |
+
else:
|
| 441 |
+
update_obj = gr.update(
|
| 442 |
+
value=(
|
| 443 |
+
default_values["min"],
|
| 444 |
+
default_values["max"],
|
| 445 |
+
)
|
| 446 |
+
)
|
| 447 |
+
else:
|
| 448 |
+
raise ValueError(f"Unknown values for card {card_name}")
|
| 449 |
+
|
| 450 |
+
card_updates.append(update_obj)
|
| 451 |
+
|
| 452 |
+
# also remove unselected value as possible choice
|
| 453 |
+
active_selection_update = gr.update(choices=current_selections)
|
| 454 |
+
return [active_selection_update] + card_updates
|
| 455 |
+
|
| 456 |
+
|
| 457 |
def prepare_value_count(value, count):
|
| 458 |
return f"{value} [{count}]"
|
| 459 |
|
|
|
|
| 522 |
with gr.Column(scale=7):
|
| 523 |
text_input = gr.Textbox(
|
| 524 |
label="Describe the cohort you're looking for:",
|
| 525 |
+
info=(
|
| 526 |
+
"Only provide the cohort characteristics. "
|
| 527 |
+
"Do not include extraneous text. "
|
| 528 |
+
"For example, write 'patients with X' "
|
| 529 |
+
"instead of 'I would like patients with X':"
|
| 530 |
+
),
|
| 531 |
submit_btn="Generate Cohort",
|
| 532 |
elem_id="description-input",
|
| 533 |
placeholder="Enter a cohort description to begin...",
|
|
|
|
| 563 |
elem_id="json-output",
|
| 564 |
)
|
| 565 |
|
| 566 |
+
with gr.Row(equal_height=True):
|
| 567 |
+
with gr.Column(scale=1, min_width=250):
|
| 568 |
+
gr.Markdown("## Currently Selected Filters")
|
| 569 |
+
with gr.Column(scale=4):
|
| 570 |
+
active_selections = gr.CheckboxGroup(
|
| 571 |
+
choices=[],
|
| 572 |
+
show_label=False,
|
| 573 |
+
interactive=True,
|
| 574 |
+
elem_id="active-selections",
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
with gr.Row():
|
| 578 |
gr.Markdown(
|
| 579 |
"The generated cohort filter will autopopulate into the filter cards below. "
|
| 580 |
+
"**GDC Cohort Copilot can make mistakes!** "
|
| 581 |
"Refine your search using the interactive checkboxes. "
|
| 582 |
"Note that many other options can be found by selecting the different tabs on the left."
|
| 583 |
)
|
|
|
|
| 668 |
fn=process_query,
|
| 669 |
inputs=text_input,
|
| 670 |
outputs=filter_cards + [json_output],
|
| 671 |
+
).success(
|
| 672 |
+
fn=update_active_selections,
|
| 673 |
+
inputs=filter_cards,
|
| 674 |
+
outputs=[active_selections],
|
| 675 |
)
|
| 676 |
|
| 677 |
# Update JSON based on cards
|
|
|
|
| 683 |
fn=update_json_from_cards,
|
| 684 |
inputs=filter_cards,
|
| 685 |
outputs=json_output,
|
| 686 |
+
).success(
|
| 687 |
+
fn=update_active_selections,
|
| 688 |
+
inputs=filter_cards,
|
| 689 |
+
outputs=[active_selections],
|
| 690 |
)
|
| 691 |
else:
|
| 692 |
filter_card.input(
|
| 693 |
fn=update_json_from_cards,
|
| 694 |
inputs=filter_cards,
|
| 695 |
outputs=json_output,
|
| 696 |
+
).success(
|
| 697 |
+
fn=update_active_selections,
|
| 698 |
+
inputs=filter_cards,
|
| 699 |
+
outputs=[active_selections],
|
| 700 |
)
|
| 701 |
|
| 702 |
+
# Enable functionality of the active filter selectors
|
| 703 |
+
active_selections.input(
|
| 704 |
+
fn=update_cards_from_active,
|
| 705 |
+
inputs=[active_selections] + filter_cards,
|
| 706 |
+
outputs=[active_selections] + filter_cards,
|
| 707 |
+
).success(
|
| 708 |
+
fn=update_json_from_cards,
|
| 709 |
+
inputs=filter_cards,
|
| 710 |
+
outputs=json_output,
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
# Update checkboxes after executing filter query
|
| 714 |
json_output.change(
|
| 715 |
fn=update_cards_with_counts,
|
style.css
CHANGED
|
@@ -27,6 +27,10 @@
|
|
| 27 |
height: 80% !important;
|
| 28 |
}
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
.card-group, .card-group > div {
|
| 31 |
background-color: transparent;
|
| 32 |
border: 0px;
|
|
|
|
| 27 |
height: 80% !important;
|
| 28 |
}
|
| 29 |
|
| 30 |
+
#active-selections {
|
| 31 |
+
height: 50px !important;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
.card-group, .card-group > div {
|
| 35 |
background-color: transparent;
|
| 36 |
border: 0px;
|