File size: 43,894 Bytes
52b4ed7
 
 
 
 
83a4de1
52b4ed7
c8562d7
52b4ed7
 
 
 
 
 
f7415cc
52b4ed7
 
 
 
 
 
 
83a4de1
0cd2df1
 
52b4ed7
 
83a4de1
2fffb9d
83a4de1
 
 
1fc52ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd13e35
 
 
 
 
 
 
 
 
9099c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fc52ea
 
 
 
 
 
 
 
 
 
 
 
 
83a4de1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f06b4
83a4de1
b4f06b4
 
 
 
 
83a4de1
 
 
 
 
 
 
 
 
47e5fb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83a4de1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8562d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590a3e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83a4de1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8562d7
 
83a4de1
 
 
 
 
 
 
 
c8562d7
 
83a4de1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fc52ea
 
83a4de1
 
47e5fb1
83a4de1
 
 
 
 
 
47e5fb1
 
83a4de1
 
 
 
 
 
 
 
 
 
 
 
c8562d7
 
83a4de1
 
52b4ed7
2fffb9d
52b4ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
83a4de1
52b4ed7
 
 
 
 
 
 
 
 
2574b82
4a5418d
2574b82
 
 
 
 
4a5418d
 
 
2574b82
f7415cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52b4ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
c8562d7
52b4ed7
 
 
 
 
 
 
 
 
 
 
 
83a4de1
 
 
 
 
 
47e5fb1
 
 
 
 
83a4de1
52b4ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
46971ea
 
 
 
 
 
47e5fb1
 
4bc9414
 
 
 
83a4de1
 
c8562d7
83a4de1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47e5fb1
 
 
 
 
 
 
 
 
 
 
 
 
83a4de1
 
1fc52ea
 
83a4de1
47e5fb1
83a4de1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52b4ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc23c24
46971ea
 
fc23c24
52b4ed7
 
 
 
 
 
 
 
 
 
 
590a3e5
 
 
1fc52ea
 
590a3e5
 
 
 
 
 
 
1fc52ea
52b4ed7
 
fc23c24
46971ea
 
fc23c24
52b4ed7
 
 
 
 
1fc52ea
 
52b4ed7
 
 
0cd2df1
52b4ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fc52ea
 
52b4ed7
 
 
 
 
 
b61cc05
52b4ed7
47e5fb1
 
 
52b4ed7
47e5fb1
52b4ed7
47e5fb1
 
52b4ed7
 
 
 
46971ea
 
52b4ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03d8100
 
 
 
 
 
52b4ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46971ea
52b4ed7
 
 
 
 
 
 
4bc9414
 
 
 
 
 
 
 
 
 
 
 
 
1fc52ea
4bc9414
 
1fc52ea
4bc9414
52b4ed7
4bc9414
 
 
52b4ed7
4bc9414
 
 
 
 
 
 
 
 
 
 
 
52b4ed7
4bc9414
 
 
52b4ed7
4bc9414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52b4ed7
 
 
 
1fc52ea
 
52b4ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fc52ea
 
 
52b4ed7
 
 
 
 
 
 
fc23c24
 
46971ea
 
52b4ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46971ea
fc23c24
52b4ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
"""Main chat pipeline - stream_chat function"""
import os
import json
import time
import logging
import threading
import concurrent.futures
import hashlib
import gradio as gr
import spaces
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
from llama_index.core import Settings
from llama_index.core.retrievers import AutoMergingRetriever
from logger import logger, ThoughtCaptureHandler
from models import initialize_medical_model, get_or_create_embed_model, is_model_loaded, get_model_loading_state, set_model_loading_state, move_model_to_gpu
from utils import detect_language, translate_text, format_url_as_domain
from search import search_web, summarize_web_content
from reasoning import autonomous_reasoning, create_execution_plan, autonomous_execution_strategy
from supervisor import (
    gemini_supervisor_breakdown, gemini_supervisor_search_strategies,
    gemini_supervisor_rag_brainstorm, execute_medswin_task,
    gemini_supervisor_synthesize, gemini_supervisor_challenge,
    gemini_supervisor_enhance_answer, gemini_supervisor_check_clarity,
    gemini_clinical_intake_triage, gemini_summarize_clinical_insights,
    MAX_SEARCH_STRATEGIES
)

MAX_CLINICAL_QA_ROUNDS = 5
MAX_DURATION = 120
_clinical_intake_sessions = {}
_clinical_intake_lock = threading.Lock()

# Thread pool executor for running Gemini supervisor calls without blocking GPU task
_gemini_executor = concurrent.futures.ThreadPoolExecutor(max_workers=2, thread_name_prefix="gemini-supervisor")


def run_gemini_in_thread(fn, *args, **kwargs):
    """
    Run Gemini supervisor function in a separate thread to avoid blocking GPU task.
    This ensures Gemini API calls don't consume GPU task time and cause timeouts.
    """
    try:
        future = _gemini_executor.submit(fn, *args, **kwargs)
        # Set a reasonable timeout (30s) to prevent hanging
        result = future.result(timeout=30.0)
        return result
    except concurrent.futures.TimeoutError:
        logger.error(f"[GEMINI SUPERVISOR] Function {fn.__name__} timed out after 30s")
        # Return fallback based on function
        return _supervisor_logics(fn.__name__, args)
    except Exception as e:
        logger.error(f"[GEMINI SUPERVISOR] Function {fn.__name__} failed with error: {type(e).__name__}: {str(e)}")
        # Return fallback based on function
        return _supervisor_logics(fn.__name__, args)


def _supervisor_logics(fn_name: str, args: tuple):
    """Get appropriate fallback value based on function name"""
    try:
        if "breakdown" in fn_name:
            return {
                "sub_topics": [
                    {"id": 1, "topic": "Answer", "instruction": args[0] if args else "Address the question", "expected_tokens": 400, "priority": "high", "approach": "direct answer"}
                ],
                "strategy": "Direct answer (fallback)",
                "exploration_note": "Gemini supervisor error"
            }
        elif "search_strategies" in fn_name:
            return {
                "search_strategies": [
                    {"id": 1, "strategy": args[0] if args else "", "target_sources": 2, "focus": "main query"}
                ],
                "max_strategies": 1
            }
        elif "rag_brainstorm" in fn_name:
            return {
                "contexts": [
                    {"id": 1, "context": args[1][:500] if len(args) > 1 else "", "focus": "retrieved information", "relevance": "high"}
                ],
                "max_contexts": 1
            }
        elif "synthesize" in fn_name:
            # Return concatenated MedSwin answers as fallback
            return "\n\n".join(args[1] if len(args) > 1 and args[1] else [])
        elif "challenge" in fn_name:
            return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""}
        elif "enhance_answer" in fn_name:
            return args[1] if len(args) > 1 else ""
        elif "check_clarity" in fn_name:
            return {"is_unclear": False, "needs_search": False, "search_queries": []}
        elif "clinical_intake_triage" in fn_name:
            return {
                "needs_additional_info": False,
                "decision_reason": "Error fallback",
                "max_rounds": args[2] if len(args) > 2 else 5,
                "questions": [],
                "initial_hypotheses": []
            }
        elif "summarize_clinical_insights" in fn_name:
            return {
                "patient_profile": "",
                "refined_problem_statement": args[0] if args else "",
                "key_findings": [],
                "handoff_note": "Proceed with regular workflow."
            }
        else:
            logger.warning(f"[GEMINI SUPERVISOR] Unknown function {fn_name}, returning None")
            return None
    except Exception as e:
        logger.error(f"[GEMINI SUPERVISOR] Error running {fn.__name__} in thread: {e}")
        # Return appropriate fallback
        if "breakdown" in fn.__name__:
            return {
                "sub_topics": [
                    {"id": 1, "topic": "Answer", "instruction": args[0] if args else "Address the question", "expected_tokens": 400, "priority": "high", "approach": "direct answer"}
                ],
                "strategy": "Direct answer (error fallback)",
                "exploration_note": "Gemini supervisor error"
            }
        return None


def _get_clinical_intake_state(session_id: str):
    with _clinical_intake_lock:
        return _clinical_intake_sessions.get(session_id)


def _set_clinical_intake_state(session_id: str, state: dict):
    with _clinical_intake_lock:
        _clinical_intake_sessions[session_id] = state


def _clear_clinical_intake_state(session_id: str):
    with _clinical_intake_lock:
        _clinical_intake_sessions.pop(session_id, None)


def _history_to_text(history: list, limit: int = 6) -> str:
    if not history:
        return "No prior conversation."
    recent = history[-limit:]
    lines = []
    for turn in recent:
        role = turn.get("role", "user")
        content = turn.get("content", "")
        lines.append(f"{role}: {content}")
    return "\n".join(lines)

def _format_intake_question(question: dict, round_idx: int, max_rounds: int, target_lang: str) -> str:
    header = f"🩺 Question for clarity {round_idx}/{max_rounds}"
    body = question.get("question") or "Could you share a bit more detail so I can give an accurate answer?"
    prompt_parts = [
        header,
        body,
        "Please answer in 1-2 sentences so I can continue."
    ]
    prompt_text = "\n\n".join(prompt_parts)
    if target_lang and target_lang != "en":
        try:
            prompt_text = translate_text(prompt_text, target_lang=target_lang, source_lang="en")
        except Exception as exc:
            logger.warning(f"[INTAKE] Question translation failed: {exc}")
    return prompt_text


def _format_qa_transcript(qa_pairs: list) -> str:
    if not qa_pairs:
        return ""
    lines = []
    for idx, qa in enumerate(qa_pairs, 1):
        question = qa.get("question", "").strip()
        answer = qa.get("answer", "").strip()
        if question:
            lines.append(f"Q{idx}: {question}")
        if answer:
            lines.append(f"A{idx}: {answer}")
        lines.append("")
    return "\n".join(lines).strip()


def _format_insights_block(insights: dict) -> str:
    if not insights:
        return ""
    lines = []
    profile = insights.get("patient_profile")
    if profile:
        lines.append(f"- Patient profile: {profile}")
    for finding in insights.get("key_findings", []):
        title = finding.get("title", "Insight")
        detail = finding.get("detail", "")
        implication = finding.get("clinical_implication", "")
        line = f"- {title}: {detail}"
        if implication:
            line += f" (Clinical note: {implication})"
        lines.append(line)
    return "\n".join(lines)


def _build_refined_query(base_query: str, insights: dict, insights_block: str) -> str:
    sections = [base_query.strip()] if base_query else []
    if insights_block:
        sections.append(f"Clinical intake summary:\n{insights_block}")
    refined = insights.get("refined_problem_statement")
    if refined:
        sections.append(f"Refined problem statement:\n{refined}")
    handoff = insights.get("handoff_note")
    if handoff:
        sections.append(f"Handoff note:\n{handoff}")
    return "\n\n".join([section for section in sections if section])


def _hash_prompt_text(text: str) -> str:
    if not text:
        return ""
    digest = hashlib.sha1()
    digest.update(text.strip().encode("utf-8"))
    return digest.hexdigest()


def _extract_pending_intake_prompt(history: list) -> str:
    if not history:
        return ""
    for turn in reversed(history):
        if turn.get("role") != "assistant":
            continue
        content = turn.get("content", "")
        if content.startswith("🩺 Question for clarity"):
            return content
    return ""


def _rehydrate_intake_state(session_id: str, history: list):
    state = _get_clinical_intake_state(session_id)
    if state or not history:
        return state
    pending_prompt = _extract_pending_intake_prompt(history)
    if not pending_prompt:
        return None
    prompt_hash = _hash_prompt_text(pending_prompt)
    if not prompt_hash:
        return None
    with _clinical_intake_lock:
        for existing_id, existing_state in list(_clinical_intake_sessions.items()):
            if existing_state.get("awaiting_answer") and existing_state.get("last_prompt_hash") == prompt_hash:
                if existing_id != session_id:
                    _clinical_intake_sessions.pop(existing_id, None)
                    _clinical_intake_sessions[session_id] = existing_state
                return existing_state
    return None


def _get_last_assistant_answer(history: list) -> str:
    """
    Extract the last non-empty assistant answer from history.
    Skips clinical intake clarification prompts so that follow-up
    questions like "clarify your answer" refer to the real medical
    answer, not an intake question.
    """
    if not history:
        return ""
    for turn in reversed(history):
        if turn.get("role") != "assistant":
            continue
        content = (turn.get("content") or "").strip()
        if not content:
            continue
        # Skip intake prompts that start with the standard header
        if content.startswith("🩺 Question for clarity"):
            continue
        return content
    return ""


def _start_clinical_intake_session(session_id: str, plan: dict, base_query: str, original_language: str):
    questions = plan.get("questions", []) or []
    if not questions:
        return None
    max_rounds = plan.get("max_rounds") or len(questions)
    max_rounds = max(1, min(MAX_CLINICAL_QA_ROUNDS, max_rounds, len(questions)))
    state = {
        "base_query": base_query,
        "original_language": original_language or "en",
        "questions": questions,
        "max_rounds": max_rounds,
        "current_round": 1,
        "pending_question_index": 0,
        "awaiting_answer": True,
        "answers": [],
        "decision_reason": plan.get("decision_reason", ""),
        "initial_hypotheses": plan.get("initial_hypotheses", []),
        "started_at": time.time(),
        "last_prompt_hash": ""
    }
    _set_clinical_intake_state(session_id, state)
    first_prompt = _format_intake_question(
        questions[0],
        round_idx=1,
        max_rounds=max_rounds,
        target_lang=state["original_language"]
    )
    state["last_prompt_hash"] = _hash_prompt_text(first_prompt)
    _set_clinical_intake_state(session_id, state)
    return first_prompt


def _handle_clinical_answer(session_id: str, answer_text: str):
    state = _get_clinical_intake_state(session_id)
    if not state:
        return {"type": "error"}
    questions = state.get("questions", [])
    idx = state.get("pending_question_index", 0)
    if idx >= len(questions):
        logger.warning("[INTAKE] Pending question index out of range, ending intake session")
        _clear_clinical_intake_state(session_id)
        return {"type": "error"}
    question_meta = questions[idx] or {}
    qa_entry = {
        "question": question_meta.get("question", ""),
        "focus": question_meta.get("clinical_focus"),
        "why_it_matters": question_meta.get("why_it_matters"),
        "round": state.get("current_round", len(state.get("answers", [])) + 1),
        "answer": answer_text.strip()
    }
    state["answers"].append(qa_entry)
    next_index = idx + 1
    reached_round_limit = len(state["answers"]) >= state["max_rounds"]
    if reached_round_limit or next_index >= len(questions):
        # Run in thread pool to avoid blocking GPU task
        insights = run_gemini_in_thread(gemini_summarize_clinical_insights, state["base_query"], state["answers"])
        insights_block = _format_insights_block(insights)
        refined_query = _build_refined_query(state["base_query"], insights, insights_block)
        transcript = _format_qa_transcript(state["answers"])
        _clear_clinical_intake_state(session_id)
        return {
            "type": "insights",
            "insights": insights,
            "insights_block": insights_block,
            "refined_query": refined_query,
            "qa_pairs": state["answers"],
            "qa_transcript": transcript
        }
    state["pending_question_index"] = next_index
    state["current_round"] = len(state["answers"]) + 1
    state["awaiting_answer"] = True
    _set_clinical_intake_state(session_id, state)
    next_question = questions[next_index]
    prompt = _format_intake_question(
        next_question,
        round_idx=state["current_round"],
        max_rounds=state["max_rounds"],
        target_lang=state["original_language"]
    )
    state["last_prompt_hash"] = _hash_prompt_text(prompt)
    _set_clinical_intake_state(session_id, state)
    return {"type": "question", "prompt": prompt}


@spaces.GPU(max_duration=MAX_DURATION)
def stream_chat(
    message: str,
    history: list,
    system_prompt: str,
    temperature: float,
    max_new_tokens: int,
    top_p: float,
    top_k: int,
    penalty: float,
    retriever_k: int,
    merge_threshold: float,
    use_rag: bool,
    medical_model: str,
    use_web_search: bool,
    enable_clinical_intake: bool,
    disable_agentic_reasoning: bool,
    show_thoughts: bool,
    request: gr.Request
):
    """Main chat pipeline implementing MAC architecture"""
    if not request:
        yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}], ""
        return
    
    # Check if model is loaded before proceeding
    # NOTE: We don't load the model here to save time - it should be pre-loaded before stream_chat is called
    if not is_model_loaded(medical_model):
        loading_state = get_model_loading_state(medical_model)
        if loading_state == "loading":
            error_msg = f"⏳ {medical_model} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
        else:
            error_msg = f"⚠️ {medical_model} is not loaded. Please wait for the model to finish loading or select a model from the dropdown."
        yield history + [{"role": "assistant", "content": error_msg}], ""
        return
    
    # ZeroGPU best practice: If model is on CPU, move it to GPU now (we're in a GPU-decorated function)
    # This ensures the model is ready for inference without consuming GPU quota during startup
    try:
        import config
        if medical_model in config.global_medical_models:
            model = config.global_medical_models[medical_model]
            if model is not None:
                # Check if model is on CPU (device_map="cpu" or device is CPU)
                model_on_cpu = False
                if hasattr(model, 'device'):
                    if str(model.device) == 'cpu':
                        model_on_cpu = True
                elif hasattr(model, 'hf_device_map'):
                    # Model loaded with device_map - check if it's on CPU
                    if isinstance(model.hf_device_map, dict):
                        # If all devices are CPU, move to GPU
                        if all('cpu' in str(dev).lower() for dev in model.hf_device_map.values()):
                            model_on_cpu = True
                
                if model_on_cpu:
                    logger.info(f"[STREAM_CHAT] Model {medical_model} is on CPU, moving to GPU for inference...")
                    move_model_to_gpu(medical_model)
                    logger.info(f"[STREAM_CHAT] ✅ Model {medical_model} moved to GPU successfully")
    except Exception as e:
        logger.warning(f"[STREAM_CHAT] Could not move model to GPU (may already be on GPU): {e}")
        # Continue anyway - model might already be on GPU
    
    thought_handler = None
    if show_thoughts:
        thought_handler = ThoughtCaptureHandler()
        thought_handler.setLevel(logging.INFO)
        thought_handler.clear()
        logger.addHandler(thought_handler)
    
    session_start = time.time()
    soft_timeout = 100
    hard_timeout = 118
    
    def elapsed():
        return time.time() - session_start
    
    user_id = request.session_hash or "anonymous"
    index_dir = f"./{user_id}_index"
    has_rag_index = os.path.exists(index_dir)
    
    original_lang = detect_language(message)
    original_message = message
    needs_translation = original_lang != "en"
    
    pipeline_diagnostics = {
        "reasoning": None,
        "plan": None,
        "strategy_decisions": [],
        "stage_metrics": {},
        "search": {"strategies": [], "total_results": 0},
        "clinical_intake": {
            "enabled": enable_clinical_intake,
            "activated": False,
            "rounds": 0,
            "reason": "",
            "insights": [],
            "plan": [],
            "qa_pairs": [],
            "transcript": "",
            "insights_block": ""
        }
    }
    def record_stage(stage_name: str, start_time: float):
        pipeline_diagnostics["stage_metrics"][stage_name] = round(time.time() - start_time, 3)
    
    translation_stage_start = time.time()
    if needs_translation:
        logger.info(f"[GEMINI SUPERVISOR] Detected non-English language: {original_lang}, translating...")
        message = translate_text(message, target_lang="en", source_lang=original_lang)
        logger.info(f"[GEMINI SUPERVISOR] Translated query: {message[:100]}...")
    record_stage("translation", translation_stage_start)
    
    final_use_rag = use_rag and has_rag_index and not disable_agentic_reasoning
    final_use_web_search = use_web_search and not disable_agentic_reasoning
    
    # Initialize updated_history early to avoid UnboundLocalError
    updated_history = history + [
        {"role": "user", "content": original_message},
        {"role": "assistant", "content": ""}
    ]
    
    clinical_intake_context_block = ""
    
    # Clinical intake currently uses Gemini-based supervisors.
    # When agentic reasoning is disabled, we also skip all Gemini-driven
    # intake planning and summarization so the flow is purely MedSwin.
    if disable_agentic_reasoning or not enable_clinical_intake:
        _clear_clinical_intake_state(user_id)
    else:
        intake_state = _rehydrate_intake_state(user_id, history)
        if intake_state and intake_state.get("awaiting_answer"):
            logger.info("[INTAKE] Awaiting patient response - processing answer")
            intake_result = _handle_clinical_answer(user_id, message)
            if intake_result.get("type") == "question":
                logger.info("[INTAKE] Requesting additional follow-up")
                updated_history[-1]["content"] = intake_result["prompt"]
                thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
                yield updated_history, thoughts_text
                if thought_handler:
                    logger.removeHandler(thought_handler)
                return
            if intake_result.get("type") == "insights":
                pipeline_diagnostics["clinical_intake"]["activated"] = True
                pipeline_diagnostics["clinical_intake"]["rounds"] = len(intake_result.get("qa_pairs", []))
                pipeline_diagnostics["clinical_intake"]["insights"] = intake_result.get("insights", {}).get("key_findings", [])
                pipeline_diagnostics["clinical_intake"]["qa_pairs"] = intake_result.get("qa_pairs", [])
                pipeline_diagnostics["clinical_intake"]["transcript"] = intake_result.get("qa_transcript", "")
                pipeline_diagnostics["clinical_intake"]["insights_block"] = intake_result.get("insights_block", "")
                base_refined = intake_result.get("refined_query", message)
                summary_section = ""
                transcript_section = ""
                if intake_result.get("insights_block"):
                    summary_section = f"Clinical intake summary:\n{intake_result['insights_block']}"
                if intake_result.get("qa_transcript"):
                    transcript_section = f"Clinical intake Q&A transcript:\n{intake_result['qa_transcript']}"
                sections = [base_refined, summary_section, transcript_section]
                message = "\n\n---\n\n".join([section for section in sections if section])
                clinical_intake_context_block = "\n\n".join([seg for seg in [summary_section, transcript_section] if seg])
        else:
            history_context = _history_to_text(history)
            # Run in thread pool to avoid blocking GPU task
            triage_plan = run_gemini_in_thread(gemini_clinical_intake_triage, message, history_context, MAX_CLINICAL_QA_ROUNDS)
            pipeline_diagnostics["clinical_intake"]["reason"] = triage_plan.get("decision_reason", "")
            pipeline_diagnostics["clinical_intake"]["plan"] = triage_plan.get("questions", [])
            needs_intake = triage_plan.get("needs_additional_info") and triage_plan.get("questions")
            if needs_intake:
                first_prompt = _start_clinical_intake_session(
                    user_id,
                    triage_plan,
                    message,
                    original_lang
                )
                if first_prompt:
                    pipeline_diagnostics["clinical_intake"]["activated"] = True
                    updated_history[-1]["content"] = first_prompt
                    thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
                    yield updated_history, thoughts_text
                    if thought_handler:
                        logger.removeHandler(thought_handler)
                    return
    
    plan = None
    if not disable_agentic_reasoning:
        reasoning_stage_start = time.time()
        reasoning = autonomous_reasoning(message, history)
        record_stage("autonomous_reasoning", reasoning_stage_start)
        pipeline_diagnostics["reasoning"] = reasoning
        plan = create_execution_plan(reasoning, message, has_rag_index)
        pipeline_diagnostics["plan"] = plan
        execution_strategy = autonomous_execution_strategy(
            reasoning, plan, final_use_rag, final_use_web_search, has_rag_index
        )
        
        if final_use_rag and not reasoning.get("requires_rag", True):
            final_use_rag = False
            pipeline_diagnostics["strategy_decisions"].append("Skipped RAG per autonomous reasoning")
        elif not final_use_rag and reasoning.get("requires_rag", True) and not has_rag_index:
            pipeline_diagnostics["strategy_decisions"].append("Reasoning wanted RAG but no index available")
        
        if final_use_web_search and not reasoning.get("requires_web_search", False):
            final_use_web_search = False
            pipeline_diagnostics["strategy_decisions"].append("Skipped web search per autonomous reasoning")
        elif not final_use_web_search and reasoning.get("requires_web_search", False):
            if not use_web_search:
                pipeline_diagnostics["strategy_decisions"].append("User disabled web search despite reasoning request")
            else:
                pipeline_diagnostics["strategy_decisions"].append("Web search requested by reasoning but disabled by mode")
    else:
        pipeline_diagnostics["strategy_decisions"].append("Agentic reasoning disabled by user")
    
    # Update thoughts after reasoning stage
    thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
    yield updated_history, thoughts_text
    
    if disable_agentic_reasoning:
        logger.info("[MAC] Agentic reasoning disabled - using MedSwin alone")
        breakdown = {
            "sub_topics": [
                {"id": 1, "topic": "Answer", "instruction": message, "expected_tokens": 400, "priority": "high", "approach": "direct answer"}
            ],
            "strategy": "Direct answer",
            "exploration_note": "Direct mode - no breakdown"
        }
    else:
        logger.info("[GEMINI SUPERVISOR] Breaking query into sub-topics...")
        # Provide previous assistant answer as context so Gemini can
        # interpret follow-up queries like "clarify your answer".
        previous_answer = _get_last_assistant_answer(history)
        # Run in thread pool to avoid blocking GPU task
        breakdown = run_gemini_in_thread(
            gemini_supervisor_breakdown,
            message,
            final_use_rag,
            final_use_web_search,
            elapsed(),
            120,
            previous_answer,
        )
        logger.info(f"[GEMINI SUPERVISOR] Created {len(breakdown.get('sub_topics', []))} sub-topics")
    
    # Update thoughts after breakdown
    thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
    yield updated_history, thoughts_text
    
    search_contexts = []
    web_urls = []
    if final_use_web_search:
        search_stage_start = time.time()
        logger.info("[GEMINI SUPERVISOR] Search mode: Creating search strategies...")
        # Run in thread pool to avoid blocking GPU task
        search_strategies = run_gemini_in_thread(gemini_supervisor_search_strategies, message, elapsed())
        
        all_search_results = []
        strategy_jobs = []
        for strategy in search_strategies.get("search_strategies", [])[:MAX_SEARCH_STRATEGIES]:
            search_query = strategy.get("strategy", message)
            target_sources = strategy.get("target_sources", 2)
            strategy_jobs.append({
                "query": search_query,
                "target_sources": target_sources,
                "meta": strategy
            })
        
        def execute_search(job):
            job_start = time.time()
            try:
                results = search_web(job["query"], max_results=job["target_sources"])
                duration = time.time() - job_start
                return results, duration, None
            except Exception as exc:
                return [], time.time() - job_start, exc
        
        def record_search_diag(job, duration, results_count, error=None):
            entry = {
                "query": job["query"],
                "target_sources": job["target_sources"],
                "duration": round(duration, 3),
                "results": results_count
            }
            if error:
                entry["error"] = str(error)
            pipeline_diagnostics["search"]["strategies"].append(entry)
        
        if strategy_jobs:
            max_workers = min(len(strategy_jobs), 4)
            if len(strategy_jobs) > 1:
                with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
                    future_map = {executor.submit(execute_search, job): job for job in strategy_jobs}
                    for future in concurrent.futures.as_completed(future_map):
                        job = future_map[future]
                        try:
                            results, duration, error = future.result()
                        except Exception as exc:
                            results, duration, error = [], 0.0, exc
                        record_search_diag(job, duration, len(results), error)
                        if not error and results:
                            all_search_results.extend(results)
                            web_urls.extend([r.get('url', '') for r in results if r.get('url')])
            else:
                job = strategy_jobs[0]
                results, duration, error = execute_search(job)
                record_search_diag(job, duration, len(results), error)
                if not error and results:
                    all_search_results.extend(results)
                    web_urls.extend([r.get('url', '') for r in results if r.get('url')])
        else:
            pipeline_diagnostics["strategy_decisions"].append("No viable web search strategies returned")
        
        pipeline_diagnostics["search"]["total_results"] = len(all_search_results)
        
        if all_search_results:
            logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(all_search_results)} search results...")
            search_summary = summarize_web_content(all_search_results, message)
            if search_summary:
                search_contexts.append(search_summary)
                logger.info(f"[GEMINI SUPERVISOR] Search summary created: {len(search_summary)} chars")
        record_stage("web_search", search_stage_start)
    
    rag_contexts = []
    if final_use_rag and has_rag_index:
        rag_stage_start = time.time()
        if elapsed() >= soft_timeout - 10:
            logger.warning("[GEMINI SUPERVISOR] Skipping RAG due to time pressure")
            final_use_rag = False
        else:
            logger.info("[GEMINI SUPERVISOR] RAG mode: Retrieving documents...")
            embed_model = get_or_create_embed_model()
            Settings.embed_model = embed_model
            storage_context = StorageContext.from_defaults(persist_dir=index_dir)
            index = load_index_from_storage(storage_context, settings=Settings)
            base_retriever = index.as_retriever(similarity_top_k=retriever_k)
            auto_merging_retriever = AutoMergingRetriever(
                base_retriever,
                storage_context=storage_context,
                simple_ratio_thresh=merge_threshold,
                verbose=False
            )
            merged_nodes = auto_merging_retriever.retrieve(message)
            retrieved_docs = "\n\n".join([n.node.text for n in merged_nodes])
            logger.info(f"[GEMINI SUPERVISOR] Retrieved {len(merged_nodes)} document nodes")
            
            logger.info("[GEMINI SUPERVISOR] Brainstorming RAG contexts...")
            # Run in thread pool to avoid blocking GPU task
            rag_brainstorm = run_gemini_in_thread(gemini_supervisor_rag_brainstorm, message, retrieved_docs, elapsed())
            rag_contexts = [ctx.get("context", "") for ctx in rag_brainstorm.get("contexts", [])]
            logger.info(f"[GEMINI SUPERVISOR] Created {len(rag_contexts)} RAG contexts")
        record_stage("rag_retrieval", rag_stage_start)
    
    medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model)
    
    base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables. Provide answers directly without conversational prefixes like 'Here is...', 'This is...'. Start with the actual content immediately."
    
    context_sections = []
    if clinical_intake_context_block:
        context_sections.append("Clinical Intake Context:\n" + clinical_intake_context_block)
    if rag_contexts:
        context_sections.append("Document Context:\n" + "\n\n".join(rag_contexts[:4]))
    if search_contexts:
        context_sections.append("Web Search Context:\n" + "\n\n".join(search_contexts))
    combined_context = "\n\n".join(context_sections)
    
    logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...")
    medswin_answers = []
    
    # Update thoughts before starting MedSwin tasks
    thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
    yield updated_history, thoughts_text
    
    medswin_stage_start = time.time()
    for idx, sub_topic in enumerate(breakdown.get("sub_topics", []), 1):
        if elapsed() >= hard_timeout - 5:
            logger.warning(f"[MEDSWIN] Time limit approaching, stopping at task {idx}")
            break
        
        task_instruction = sub_topic.get("instruction", "")
        topic_name = sub_topic.get("topic", f"Topic {idx}")
        priority = sub_topic.get("priority", "medium")
        
        logger.info(f"[MEDSWIN] Executing task {idx}/{len(breakdown.get('sub_topics', []))}: {topic_name} (priority: {priority})")
        
        task_context = combined_context
        if len(rag_contexts) > 1 and idx <= len(rag_contexts):
            task_context = rag_contexts[idx - 1] if idx <= len(rag_contexts) else combined_context
        
        # Add small delay between GPU requests to prevent ZeroGPU scheduler conflicts
        if idx > 1:
            delay = 0.5  # 500ms delay between sequential GPU requests
            logger.debug(f"[MEDSWIN] Waiting {delay}s before next GPU request to avoid scheduler conflicts...")
            time.sleep(delay)
        
        try:
            task_answer = execute_medswin_task(
                medical_model_obj=medical_model_obj,
                medical_tokenizer=medical_tokenizer,
                task_instruction=task_instruction,
                context=task_context if task_context else "",
                system_prompt_base=base_system_prompt,
                temperature=temperature,
                max_new_tokens=min(max_new_tokens, 800),
                top_p=top_p,
                top_k=top_k,
                penalty=penalty
            )
            
            formatted_answer = f"## {topic_name}\n\n{task_answer}"
            medswin_answers.append(formatted_answer)
            logger.info(f"[MEDSWIN] Task {idx} completed: {len(task_answer)} chars")
            
            partial_final = "\n\n".join(medswin_answers)
            updated_history[-1]["content"] = partial_final
            thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
            yield updated_history, thoughts_text
    
        except Exception as e:
            logger.error(f"[MEDSWIN] Task {idx} failed: {e}")
            continue
    record_stage("medswin_tasks", medswin_stage_start)
    
    # If agentic reasoning is disabled, we skip all Gemini-based synthesis,
    # challenge, and enhancement loops. The final answer is just the
    # concatenation of MedSwin task outputs.
    if disable_agentic_reasoning:
        logger.info("[MAC] Agentic reasoning disabled - skipping Gemini synthesis and challenge")
        if medswin_answers:
            final_answer = "\n\n".join(medswin_answers)
        else:
            final_answer = "I apologize, but I was unable to generate a response."
    else:
        logger.info("[GEMINI SUPERVISOR] Synthesizing final answer from all MedSwin responses...")
        raw_medswin_answers = [ans.split('\n\n', 1)[1] if '\n\n' in ans else ans for ans in medswin_answers]
        synthesis_stage_start = time.time()
        # Run in thread pool to avoid blocking GPU task
        final_answer = run_gemini_in_thread(
            gemini_supervisor_synthesize, message, raw_medswin_answers, rag_contexts, search_contexts, breakdown
        )
        record_stage("synthesis", synthesis_stage_start)
        
        if not final_answer or len(final_answer.strip()) < 50:
            logger.warning("[GEMINI SUPERVISOR] Synthesis failed or too short, using concatenation")
            final_answer = "\n\n".join(medswin_answers) if medswin_answers else "I apologize, but I was unable to generate a response."
        
        if "|" in final_answer and "---" in final_answer:
            logger.warning("[MEDSWIN] Final answer contains tables, converting to bullets")
            lines = final_answer.split('\n')
            cleaned_lines = []
            for line in lines:
                if '|' in line and '---' not in line:
                    cells = [cell.strip() for cell in line.split('|') if cell.strip()]
                    if cells:
                        cleaned_lines.append(f"- {' / '.join(cells)}")
                elif '---' not in line:
                    cleaned_lines.append(line)
            final_answer = '\n'.join(cleaned_lines)
        
        max_challenge_iterations = 2
        challenge_iteration = 0
        challenge_stage_start = time.time()
        
        while challenge_iteration < max_challenge_iterations and elapsed() < soft_timeout - 15:
            challenge_iteration += 1
            logger.info(f"[GEMINI SUPERVISOR] Challenge iteration {challenge_iteration}/{max_challenge_iterations}...")
            
            # Run in thread pool to avoid blocking GPU task
            evaluation = run_gemini_in_thread(
                gemini_supervisor_challenge, message, final_answer, raw_medswin_answers, rag_contexts, search_contexts
            )
            
            if evaluation.get("is_optimal", False):
                logger.info(f"[GEMINI SUPERVISOR] Answer confirmed optimal after {challenge_iteration} iteration(s)")
                break
            
            enhancement_instructions = evaluation.get("enhancement_instructions", "")
            if not enhancement_instructions:
                logger.info("[GEMINI SUPERVISOR] No enhancement instructions, considering answer optimal")
                break
            
            logger.info(f"[GEMINI SUPERVISOR] Enhancing answer based on feedback...")
            # Run in thread pool to avoid blocking GPU task
            enhanced_answer = run_gemini_in_thread(
                gemini_supervisor_enhance_answer, message, final_answer, enhancement_instructions, raw_medswin_answers, rag_contexts, search_contexts
            )
            
            if enhanced_answer and len(enhanced_answer.strip()) > len(final_answer.strip()) * 0.8:
                final_answer = enhanced_answer
                logger.info(f"[GEMINI SUPERVISOR] Answer enhanced (new length: {len(final_answer)} chars)")
            else:
                logger.info("[GEMINI SUPERVISOR] Enhancement did not improve answer significantly, stopping")
                break
        record_stage("challenge_loop", challenge_stage_start)
    
    if final_use_web_search and elapsed() < soft_timeout - 10:
        logger.info("[GEMINI SUPERVISOR] Checking if additional search is needed...")
        clarity_stage_start = time.time()
        # Run in thread pool to avoid blocking GPU task
        clarity_check = run_gemini_in_thread(gemini_supervisor_check_clarity, message, final_answer, final_use_web_search)
        record_stage("clarity_check", clarity_stage_start)
        
        if clarity_check.get("needs_search", False) and clarity_check.get("search_queries"):
            logger.info(f"[GEMINI SUPERVISOR] Triggering additional search: {clarity_check.get('search_queries', [])}")
            additional_search_results = []
            followup_stage_start = time.time()
            for search_query in clarity_check.get("search_queries", [])[:3]:
                if elapsed() >= soft_timeout - 5:
                    break
                extra_start = time.time()
                results = search_web(search_query, max_results=2)
                extra_duration = time.time() - extra_start
                pipeline_diagnostics["search"]["strategies"].append({
                    "query": search_query,
                    "target_sources": 2,
                    "duration": round(extra_duration, 3),
                    "results": len(results),
                    "type": "followup"
                })
                additional_search_results.extend(results)
                web_urls.extend([r.get('url', '') for r in results if r.get('url')])
            
            if additional_search_results:
                pipeline_diagnostics["search"]["total_results"] += len(additional_search_results)
                logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(additional_search_results)} additional search results...")
                additional_summary = summarize_web_content(additional_search_results, message)
                if additional_summary:
                    search_contexts.append(additional_summary)
                    logger.info("[GEMINI SUPERVISOR] Enhancing answer with additional search context...")
                    # Run in thread pool to avoid blocking GPU task
                    enhanced_with_search = run_gemini_in_thread(
                        gemini_supervisor_enhance_answer, message, final_answer,
                        f"Incorporate the following additional information from web search: {additional_summary}",
                        raw_medswin_answers, rag_contexts, search_contexts
                    )
                    if enhanced_with_search and len(enhanced_with_search.strip()) > 50:
                        final_answer = enhanced_with_search
                        logger.info("[GEMINI SUPERVISOR] Answer enhanced with additional search context")
            record_stage("followup_search", followup_stage_start)
            
            # Update thoughts after followup search
            thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
            yield updated_history, thoughts_text
    
    citations_text = ""
    
    if needs_translation and final_answer:
        logger.info(f"[GEMINI SUPERVISOR] Translating response back to {original_lang}...")
        final_answer = translate_text(final_answer, target_lang=original_lang, source_lang="en")
    
    if web_urls:
        unique_urls = list(dict.fromkeys(web_urls))
        citation_links = []
        for url in unique_urls[:5]:
            domain = format_url_as_domain(url)
            if domain:
                citation_links.append(f"[{domain}]({url})")
        
        if citation_links:
            citations_text = "\n\n**Sources:** " + ", ".join(citation_links)
    
    speaker_icon = ' 🔊'
    final_answer_with_metadata = final_answer + citations_text + speaker_icon
    
    updated_history[-1]["content"] = final_answer_with_metadata
    thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
    # Always yield thoughts_text, even if empty, to ensure UI updates
    yield updated_history, thoughts_text
    
    if thought_handler:
        logger.removeHandler(thought_handler)
    
    diag_summary = {
        "stage_metrics": pipeline_diagnostics["stage_metrics"],
        "decisions": pipeline_diagnostics["strategy_decisions"],
        "search": pipeline_diagnostics["search"],
    }
    try:
        logger.info(f"[MAC] Diagnostics summary: {json.dumps(diag_summary)[:1200]}")
    except Exception:
        logger.info(f"[MAC] Diagnostics summary (non-serializable)")
    logger.info(f"[MAC] Final answer generated: {len(final_answer)} chars, {len(breakdown.get('sub_topics', []))} tasks completed")