Y Phung Nguyen commited on
Commit
ef322a1
·
1 Parent(s): b4f06b4

Upd langdetect acc

Browse files
Files changed (2) hide show
  1. model.py +0 -128
  2. utils.py +62 -4
model.py DELETED
@@ -1,128 +0,0 @@
1
- """
2
- Model inference functions that require GPU.
3
- These functions are tagged with @spaces.GPU(max_duration=120) to ensure
4
- they only run on GPU and don't waste GPU time on CPU operations.
5
- """
6
-
7
- import os
8
- import torch
9
- import logging
10
- from transformers import (
11
- AutoModelForCausalLM,
12
- AutoTokenizer,
13
- TextIteratorStreamer,
14
- StoppingCriteria,
15
- StoppingCriteriaList,
16
- )
17
- from llama_index.llms.huggingface import HuggingFaceLLM
18
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
19
- import spaces
20
- import threading
21
-
22
- logger = logging.getLogger(__name__)
23
-
24
- # Model configurations
25
- MEDSWIN_MODELS = {
26
- "MedSwin SFT": "MedSwin/MedSwin-7B-SFT",
27
- "MedSwin KD": "MedSwin/MedSwin-7B-KD",
28
- "MedSwin TA": "MedSwin/MedSwin-Merged-TA-SFT-0.7"
29
- }
30
- DEFAULT_MEDICAL_MODEL = "MedSwin TA"
31
- EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1"
32
- HF_TOKEN = os.environ.get("HF_TOKEN")
33
-
34
- # Global model storage (shared with app.py)
35
- # These will be initialized in app.py and accessed here
36
- global_medical_models = {}
37
- global_medical_tokenizers = {}
38
-
39
-
40
- def initialize_medical_model(model_name: str):
41
- """Initialize medical model (MedSwin) - download on demand"""
42
- global global_medical_models, global_medical_tokenizers
43
- if model_name not in global_medical_models or global_medical_models[model_name] is None:
44
- logger.info(f"Initializing medical model: {model_name}...")
45
- model_path = MEDSWIN_MODELS[model_name]
46
- tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
47
- model = AutoModelForCausalLM.from_pretrained(
48
- model_path,
49
- device_map="auto",
50
- trust_remote_code=True,
51
- token=HF_TOKEN,
52
- torch_dtype=torch.float16
53
- )
54
- global_medical_models[model_name] = model
55
- global_medical_tokenizers[model_name] = tokenizer
56
- logger.info(f"Medical model {model_name} initialized successfully")
57
- return global_medical_models[model_name], global_medical_tokenizers[model_name]
58
-
59
-
60
- @spaces.GPU(max_duration=120)
61
- def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
62
- """Get LLM for RAG indexing (uses medical model) - GPU only"""
63
- # Use medical model for RAG indexing instead of translation model
64
- medical_model_obj, medical_tokenizer = initialize_medical_model(DEFAULT_MEDICAL_MODEL)
65
-
66
- return HuggingFaceLLM(
67
- context_window=4096,
68
- max_new_tokens=max_new_tokens,
69
- tokenizer=medical_tokenizer,
70
- model=medical_model_obj,
71
- generate_kwargs={
72
- "do_sample": True,
73
- "temperature": temperature,
74
- "top_k": top_k,
75
- "top_p": top_p
76
- }
77
- )
78
-
79
-
80
- @spaces.GPU(max_duration=120)
81
- def get_embedding_model():
82
- """Get embedding model for RAG - GPU only"""
83
- return HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
84
-
85
-
86
- @spaces.GPU(max_duration=120)
87
- def generate_with_medswin(
88
- medical_model_obj,
89
- medical_tokenizer,
90
- prompt: str,
91
- max_new_tokens: int,
92
- temperature: float,
93
- top_p: float,
94
- top_k: int,
95
- penalty: float,
96
- eos_token_id: int,
97
- pad_token_id: int,
98
- stop_event: threading.Event,
99
- streamer: TextIteratorStreamer,
100
- stopping_criteria: StoppingCriteriaList
101
- ):
102
- """
103
- Generate text with MedSwin model - GPU only
104
-
105
- This function only performs the actual model inference on GPU.
106
- All other operations (prompt preparation, post-processing) should be done outside.
107
- """
108
- # Tokenize prompt (this is a CPU operation but happens here for simplicity)
109
- # The actual GPU work is in model.generate()
110
- inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
111
-
112
- # Prepare generation kwargs
113
- generation_kwargs = dict(
114
- **inputs,
115
- streamer=streamer,
116
- max_new_tokens=max_new_tokens,
117
- temperature=temperature,
118
- top_p=top_p,
119
- top_k=top_k,
120
- repetition_penalty=penalty,
121
- do_sample=True,
122
- stopping_criteria=stopping_criteria,
123
- eos_token_id=eos_token_id,
124
- pad_token_id=pad_token_id
125
- )
126
-
127
- # Run generation on GPU - this is the only GPU operation
128
- medical_model_obj.generate(**generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -1,6 +1,7 @@
1
  """Utility functions for translation, language detection, and formatting"""
2
  import asyncio
3
- from langdetect import detect, LangDetectException
 
4
  from logger import logger
5
  from client import MCP_AVAILABLE, call_agent
6
  from config import GEMINI_MODEL_LITE
@@ -33,13 +34,70 @@ def format_prompt_manually(messages: list, tokenizer) -> str:
33
  return prompt
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def detect_language(text: str) -> str:
37
- """Detect language of input text"""
 
 
 
 
 
 
 
 
 
 
 
38
  try:
39
- lang = detect(text)
40
- return lang
41
  except LangDetectException:
42
  return "en"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  def format_url_as_domain(url: str) -> str:
 
1
  """Utility functions for translation, language detection, and formatting"""
2
  import asyncio
3
+ import re
4
+ from langdetect import detect_langs, LangDetectException
5
  from logger import logger
6
  from client import MCP_AVAILABLE, call_agent
7
  from config import GEMINI_MODEL_LITE
 
34
  return prompt
35
 
36
 
37
+ MIN_TEXT_LENGTH_FOR_DETECTION = 12
38
+ LANG_CONFIDENCE_THRESHOLD = 0.8
39
+ ASCII_DOMINANCE_THRESHOLD = 0.97
40
+ ENGLISH_HINT_RATIO = 0.2
41
+ ENGLISH_HINT_WORDS = {
42
+ "the", "and", "with", "for", "you", "your", "have", "has", "that", "this",
43
+ "pain", "blood", "pressure", "please", "what", "how", "can", "should", "need"
44
+ }
45
+
46
+
47
+ def _ascii_ratio(text: str) -> float:
48
+ if not text:
49
+ return 1.0
50
+ ascii_chars = sum(1 for ch in text if ord(ch) < 128)
51
+ return ascii_chars / max(len(text), 1)
52
+
53
+
54
+ def _looks_english(text: str) -> bool:
55
+ words = re.findall(r"[A-Za-z']+", text.lower())
56
+ if not words:
57
+ return False
58
+ english_hits = sum(1 for word in words if word in ENGLISH_HINT_WORDS)
59
+ return english_hits / len(words) >= ENGLISH_HINT_RATIO
60
+
61
+
62
  def detect_language(text: str) -> str:
63
+ """Detect language of input text with basic confidence heuristics"""
64
+ if not text:
65
+ return "en"
66
+ sample = text.strip()
67
+ if not sample:
68
+ return "en"
69
+
70
+ ascii_ratio = _ascii_ratio(sample)
71
+ has_non_ascii = ascii_ratio < 1.0
72
+ if len(sample) < MIN_TEXT_LENGTH_FOR_DETECTION and not has_non_ascii:
73
+ return "en"
74
+
75
  try:
76
+ detections = detect_langs(sample)
 
77
  except LangDetectException:
78
  return "en"
79
+ except Exception as exc:
80
+ logger.debug(f"[LANG-DETECT] Unexpected error, defaulting to English: {exc}")
81
+ return "en"
82
+
83
+ if not detections:
84
+ return "en"
85
+
86
+ top = detections[0]
87
+ lang_code = top.lang
88
+ confidence = getattr(top, "prob", 0.0)
89
+
90
+ if confidence < LANG_CONFIDENCE_THRESHOLD:
91
+ return "en"
92
+
93
+ if lang_code == "en":
94
+ return "en"
95
+
96
+ if not has_non_ascii and ascii_ratio >= ASCII_DOMINANCE_THRESHOLD and _looks_english(sample):
97
+ logger.info(f"[LANG-DETECT] Overrode {lang_code} due to English heuristics (ascii_ratio={ascii_ratio:.2f})")
98
+ return "en"
99
+
100
+ return lang_code
101
 
102
 
103
  def format_url_as_domain(url: str) -> str: