Add active selections box

#2
Files changed (2) hide show
  1. app.py +118 -3
  2. style.css +4 -0
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import os
 
3
 
4
  import gradio as gr
5
  import requests
@@ -11,9 +12,9 @@ from guidance import json as gen_json
11
  from guidance.models import Transformers
12
  from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed
13
 
14
- from schema import GDCCohortSchema
15
 
16
- DEBUG = False
17
  EXAMPLE_INPUTS = [
18
  "bam files for TCGA-BRCA",
19
  "kidney or adrenal gland cancers with alcohol history",
@@ -23,7 +24,7 @@ EXAMPLE_INPUTS = [
23
  GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases"
24
  MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M"
25
  TOKENIZER_NAME = MODEL_NAME
26
- AUTH_TOKEN = os.environ.get("HF_TOKEN", False)
27
 
28
  with open("config.yaml", "r") as f:
29
  CONFIG = yaml.safe_load(f)
@@ -380,6 +381,79 @@ def update_cards_with_counts(cohort_filter: str, *selected_filters_per_card):
380
  return card_updates + [gr.update(value=f"{case_count} Cases")]
381
 
382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  def prepare_value_count(value, count):
384
  return f"{value} [{count}]"
385
 
@@ -448,6 +522,12 @@ with gr.Blocks(css_paths="style.css") as demo:
448
  with gr.Column(scale=7):
449
  text_input = gr.Textbox(
450
  label="Describe the cohort you're looking for:",
 
 
 
 
 
 
451
  submit_btn="Generate Cohort",
452
  elem_id="description-input",
453
  placeholder="Enter a cohort description to begin...",
@@ -483,9 +563,21 @@ with gr.Blocks(css_paths="style.css") as demo:
483
  elem_id="json-output",
484
  )
485
 
 
 
 
 
 
 
 
 
 
 
 
486
  with gr.Row():
487
  gr.Markdown(
488
  "The generated cohort filter will autopopulate into the filter cards below. "
 
489
  "Refine your search using the interactive checkboxes. "
490
  "Note that many other options can be found by selecting the different tabs on the left."
491
  )
@@ -576,6 +668,10 @@ with gr.Blocks(css_paths="style.css") as demo:
576
  fn=process_query,
577
  inputs=text_input,
578
  outputs=filter_cards + [json_output],
 
 
 
 
579
  )
580
 
581
  # Update JSON based on cards
@@ -587,14 +683,33 @@ with gr.Blocks(css_paths="style.css") as demo:
587
  fn=update_json_from_cards,
588
  inputs=filter_cards,
589
  outputs=json_output,
 
 
 
 
590
  )
591
  else:
592
  filter_card.input(
593
  fn=update_json_from_cards,
594
  inputs=filter_cards,
595
  outputs=json_output,
 
 
 
 
596
  )
597
 
 
 
 
 
 
 
 
 
 
 
 
598
  # Update checkboxes after executing filter query
599
  json_output.change(
600
  fn=update_cards_with_counts,
 
1
  import json
2
  import os
3
+ from collections import defaultdict
4
 
5
  import gradio as gr
6
  import requests
 
12
  from guidance.models import Transformers
13
  from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed
14
 
15
+ from schema import GDCCohortSchema # isort: skip
16
 
17
+ DEBUG = "DEBUG" in os.environ
18
  EXAMPLE_INPUTS = [
19
  "bam files for TCGA-BRCA",
20
  "kidney or adrenal gland cancers with alcohol history",
 
24
  GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases"
25
  MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M"
26
  TOKENIZER_NAME = MODEL_NAME
27
+ AUTH_TOKEN = os.environ.get("HF_TOKEN", False) # HF_TOKEN must be set to use auth
28
 
29
  with open("config.yaml", "r") as f:
30
  CONFIG = yaml.safe_load(f)
 
381
  return card_updates + [gr.update(value=f"{case_count} Cases")]
382
 
383
 
384
+ def update_active_selections(*selected_filters_per_card):
385
+ choices = []
386
+ for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
387
+ # use the default values to determine card type (checkbox, range, etc)
388
+ default_values = CARD_2_VALUES[card_name]
389
+ if isinstance(default_values, list):
390
+ # checkbox
391
+ for selected_value in selected_filters:
392
+ base_value = get_base_value(selected_value)
393
+ choices.append(f"{card_name.upper()}: {base_value}")
394
+ elif isinstance(default_values, dict):
395
+ # range-slider, maybe other options in the future?
396
+ assert (
397
+ default_values["type"] == "range"
398
+ ), f"Expected range slider for card {card_name}"
399
+ lo, hi = selected_filters
400
+ if lo != default_values["min"] or hi != default_values["max"]:
401
+ # only add range filter if not default
402
+ lo, hi = int(lo), int(hi)
403
+ choices.append(f"{card_name.upper()}: {lo}-{hi}")
404
+ else:
405
+ raise ValueError(f"Unknown values for card {card_name}")
406
+
407
+ return gr.update(choices=choices, value=choices)
408
+
409
+
410
+ def update_cards_from_active(current_selections, *selected_filters_per_card):
411
+ # active selector uses a flattened list so re-agg values under card groups
412
+ grouped_selections = defaultdict(set)
413
+ for k_v in current_selections:
414
+ idx = k_v.find(": ")
415
+ k, v = k_v[:idx], k_v[idx + 2 :]
416
+ grouped_selections[k].add(v)
417
+
418
+ card_updates = []
419
+ for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
420
+ # use the default values to determine card type (checkbox, range, etc)
421
+ default_values = CARD_2_VALUES[card_name]
422
+ if isinstance(default_values, list):
423
+ # checkbox
424
+ updated_values = []
425
+ for selected_value in selected_filters:
426
+ base_value = get_base_value(selected_value)
427
+ if base_value in grouped_selections[card_name.upper()]:
428
+ updated_values.append(selected_value)
429
+ update_obj = gr.update(value=updated_values)
430
+ elif isinstance(default_values, dict):
431
+ # range-slider, maybe other options in the future?
432
+ assert (
433
+ default_values["type"] == "range"
434
+ ), f"Expected range slider for card {card_name}"
435
+ # the active selector cannot change range values
436
+ # so if present as an active selection, no action is needed
437
+ # otherwise, reset entire range selector
438
+ if card_name.upper() in grouped_selections:
439
+ update_obj = gr.update()
440
+ else:
441
+ update_obj = gr.update(
442
+ value=(
443
+ default_values["min"],
444
+ default_values["max"],
445
+ )
446
+ )
447
+ else:
448
+ raise ValueError(f"Unknown values for card {card_name}")
449
+
450
+ card_updates.append(update_obj)
451
+
452
+ # also remove unselected value as possible choice
453
+ active_selection_update = gr.update(choices=current_selections)
454
+ return [active_selection_update] + card_updates
455
+
456
+
457
  def prepare_value_count(value, count):
458
  return f"{value} [{count}]"
459
 
 
522
  with gr.Column(scale=7):
523
  text_input = gr.Textbox(
524
  label="Describe the cohort you're looking for:",
525
+ info=(
526
+ "Only provide the cohort characteristics. "
527
+ "Do not include extraneous text. "
528
+ "For example, write 'patients with X' "
529
+ "instead of 'I would like patients with X':"
530
+ ),
531
  submit_btn="Generate Cohort",
532
  elem_id="description-input",
533
  placeholder="Enter a cohort description to begin...",
 
563
  elem_id="json-output",
564
  )
565
 
566
+ with gr.Row(equal_height=True):
567
+ with gr.Column(scale=1, min_width=250):
568
+ gr.Markdown("## Currently Selected Filters")
569
+ with gr.Column(scale=4):
570
+ active_selections = gr.CheckboxGroup(
571
+ choices=[],
572
+ show_label=False,
573
+ interactive=True,
574
+ elem_id="active-selections",
575
+ )
576
+
577
  with gr.Row():
578
  gr.Markdown(
579
  "The generated cohort filter will autopopulate into the filter cards below. "
580
+ "**GDC Cohort Copilot can make mistakes!** "
581
  "Refine your search using the interactive checkboxes. "
582
  "Note that many other options can be found by selecting the different tabs on the left."
583
  )
 
668
  fn=process_query,
669
  inputs=text_input,
670
  outputs=filter_cards + [json_output],
671
+ ).success(
672
+ fn=update_active_selections,
673
+ inputs=filter_cards,
674
+ outputs=[active_selections],
675
  )
676
 
677
  # Update JSON based on cards
 
683
  fn=update_json_from_cards,
684
  inputs=filter_cards,
685
  outputs=json_output,
686
+ ).success(
687
+ fn=update_active_selections,
688
+ inputs=filter_cards,
689
+ outputs=[active_selections],
690
  )
691
  else:
692
  filter_card.input(
693
  fn=update_json_from_cards,
694
  inputs=filter_cards,
695
  outputs=json_output,
696
+ ).success(
697
+ fn=update_active_selections,
698
+ inputs=filter_cards,
699
+ outputs=[active_selections],
700
  )
701
 
702
+ # Enable functionality of the active filter selectors
703
+ active_selections.input(
704
+ fn=update_cards_from_active,
705
+ inputs=[active_selections] + filter_cards,
706
+ outputs=[active_selections] + filter_cards,
707
+ ).success(
708
+ fn=update_json_from_cards,
709
+ inputs=filter_cards,
710
+ outputs=json_output,
711
+ )
712
+
713
  # Update checkboxes after executing filter query
714
  json_output.change(
715
  fn=update_cards_with_counts,
style.css CHANGED
@@ -27,6 +27,10 @@
27
  height: 80% !important;
28
  }
29
 
 
 
 
 
30
  .card-group, .card-group > div {
31
  background-color: transparent;
32
  border: 0px;
 
27
  height: 80% !important;
28
  }
29
 
30
+ #active-selections {
31
+ height: 50px !important;
32
+ }
33
+
34
  .card-group, .card-group > div {
35
  background-color: transparent;
36
  border: 0px;