sivakorn-su commited on
Commit
c167971
·
1 Parent(s): 2f67175

feat: Add Predict text

Browse files
Files changed (5) hide show
  1. README.md +139 -49
  2. app.py +85 -15
  3. models.py +9 -10
  4. requirements.txt +5 -0
  5. utils.py +512 -120
README.md CHANGED
@@ -1,87 +1,177 @@
1
- ---
2
- title: WhisperPyanoteLLM
3
- emoji: 📉
4
- colorFrom: indigo
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- license: apache-2.0
9
- ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
 
13
- # WhisperPyanoteLLM
 
 
 
 
 
 
14
 
15
- A FastAPI-based app for speaker diarization and transcription using Whisper and PyAnnote, with LLM-powered summarization.
 
 
 
 
 
16
 
17
- ## Features
18
- - Speaker diarization with pyannote.audio
19
- - Transcription with OpenAI Whisper
20
- - Summarization with Together LLM
21
- - REST API for video/audio upload and processing
 
22
 
23
- ## Quick Start (Development)
24
 
25
- 1. **Clone the repository:**
26
- ```sh
 
 
27
  git clone <your-repo-url>
28
- cd WhisperPyanoteLLM
29
  ```
30
 
31
- 2. **Create a `.env` file:**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ```env
33
  HF_TOKEN=your_huggingface_token
34
  TOGETHER_API_KEY=your_together_api_key
35
  NGROK_AUTH_TOKEN=your_ngrok_token
36
  ```
37
 
38
- 3. **Install dependencies:**
39
- ```sh
40
  pip install -r requirements.txt
41
  ```
42
 
43
- 4. **Run the app:**
44
- ```sh
45
  uvicorn app:app --reload --port 8300
46
  ```
47
 
48
- 5. **Access the API:**
49
- - Health check: [http://localhost:8300/health](http://localhost:8300/health)
50
  - Upload endpoint: `/upload_video/`
 
51
 
52
- ---
53
-
54
- ## Production (Docker)
55
 
56
- 1. **Create a `.env.prod` file:**
57
  ```env
58
  HF_TOKEN=your_huggingface_token
59
  TOGETHER_API_KEY=your_together_api_key
60
  NGROK_AUTH_TOKEN=your_ngrok_token
61
  ```
62
 
63
- 2. **Build the Docker image:**
64
- ```sh
65
- docker build -t whisperpyanote .
66
  ```
67
 
68
- 3. **Run the Docker container:**
69
- ```sh
70
- docker run --env-file .env.prod -p 8300:8300 whisperpyanote
71
  ```
72
 
73
- 4. **Access the API:**
74
- - Health check: [http://localhost:8300/health](http://localhost:8300/health)
75
- - Upload endpoint: `/upload_video/`
76
-
77
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- ## Notes
80
- - Make sure your `.env` and `.env.prod` files are **not** committed to version control.
81
- - For best performance, run on a machine with a CUDA-enabled GPU.
82
- - For more details, see the code and comments in `app.py`.
83
 
84
  ---
85
 
86
- ## License
87
- Apache-2.0
 
 
1
+ # 🎤 Advanced Voice Diarization System
2
+
3
+ ระบบแยกเสียงพูดและถอดเสียงขั้นสูงที่รองรับการพูดทับซ้อนกัน พร้อมการปรับปรุงข้อความด้วย AI สำหรับภาษาไทย
 
 
 
 
 
 
4
 
5
+ ## คุณสมบัติหลัก
6
 
7
+ ### 🔄 **ขั้นตอนการประมวลผล 6 ขั้นตอน**
8
+ 1. **Preprocess** - ปรับเสียงเป็น 16 kHz mono และ normalize
9
+ 2. **Diarization** - แยก speaker และตรวจจับการพูดทับซ้อน
10
+ 3. **Branching Logic** - แยกเส้นทางการประมวลผล Clean vs Overlap
11
+ 4. **ASR Processing** - ถอดเสียงแบบ deterministic หรือแยกเสียงด้วย Asteroid
12
+ 5. **Timeline Stitching** - รวมผลลัพธ์ตามลำดับเวลา
13
+ 6. **Post-processing** - ปรับปรุงข้อความไทยด้วย LLM
14
 
15
+ ### 🎯 **เทคโนโลยีที่ใช้**
16
+ - **PyAnnote** - Speaker diarization และ overlap detection
17
+ - **Whisper** - Speech-to-text transcription
18
+ - **Asteroid ConvTasNet** - Source separation สำหรับเสียงทับซ้อน
19
+ - **SpeechBrain** - Speaker embedding และ matching
20
+ - **LLM** - Text correction และ normalization
21
 
22
+ ### 📊 **ผลลัพธ์ที่ได้**
23
+ - แยก speaker พร้อมช่วงเวลาที่แม่นยำ
24
+ - ข้อความที่ถอดจากเสียงพร้อมค่าความเชื่อมั่น
25
+ - การตรวจจับและประมวลผลเสียงทับซ้อน
26
+ - สถิติการประมวลผล (overlap ratio, confidence scores)
27
+ - Export หลายรูปแบบ: JSON, SRT, VTT, TXT
28
 
29
+ ## 🚀 การติดตั้งและใช้งาน
30
 
31
+ ### การพัฒนา (Development)
32
+
33
+ 1. **Clone repository:**
34
+ ```bash
35
  git clone <your-repo-url>
36
+ cd project-voice-diarzation
37
  ```
38
 
39
+ 2. **ตั้งค่า Python Environment (เลือก 1 วิธี):**
40
+
41
+ **Option A: ใช้ Conda (แนะนำ)**
42
+ ```bash
43
+ # สร้าง environment ใหม่
44
+ conda create -n voice-diarization python=3.9
45
+ conda activate voice-diarization
46
+
47
+ # ติดตั้ง PyTorch สำหรับ CUDA (ถ้ามี GPU)
48
+ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
49
+
50
+ # หรือสำหรับ CPU เท่านั้น
51
+ # conda install pytorch torchvision torchaudio cpuonly -c pytorch
52
+ ```
53
+
54
+ **Option B: ใช้ pyenv + venv**
55
+ ```bash
56
+ # ติดตั้ง Python version ที่ต้องการ
57
+ pyenv install 3.9.18
58
+ pyenv local 3.9.18
59
+
60
+ # สร้าง virtual environment
61
+ python -m venv venv
62
+ source venv/bin/activate # macOS/Linux
63
+ # หรือ venv\Scripts\activate # Windows
64
+ ```
65
+
66
+ **Option C: ใช้ pip + venv (พื้นฐาน)**
67
+ ```bash
68
+ python -m venv venv
69
+ source venv/bin/activate # macOS/Linux
70
+ # หรือ venv\Scripts\activate # Windows
71
+ ```
72
+
73
+ 3. **สร้างไฟล์ `.env`:**
74
  ```env
75
  HF_TOKEN=your_huggingface_token
76
  TOGETHER_API_KEY=your_together_api_key
77
  NGROK_AUTH_TOKEN=your_ngrok_token
78
  ```
79
 
80
+ 4. **ติดตั้ง dependencies:**
81
+ ```bash
82
  pip install -r requirements.txt
83
  ```
84
 
85
+ 5. **รันแอปพลิเคชัน:**
86
+ ```bash
87
  uvicorn app:app --reload --port 8300
88
  ```
89
 
90
+ 6. **เข้าใช้งาน API:**
91
+ - Health check: http://localhost:8300/health
92
  - Upload endpoint: `/upload_video/`
93
+ - API docs: http://localhost:8300/docs
94
 
95
+ ### การใช้งานจริง (Production)
 
 
96
 
97
+ 1. **สร้างไฟล์ `.env.prod`:**
98
  ```env
99
  HF_TOKEN=your_huggingface_token
100
  TOGETHER_API_KEY=your_together_api_key
101
  NGROK_AUTH_TOKEN=your_ngrok_token
102
  ```
103
 
104
+ 2. **Build Docker image:**
105
+ ```bash
106
+ docker build -t voice-diarization .
107
  ```
108
 
109
+ 3. **Run Docker container:**
110
+ ```bash
111
+ docker run --env-file .env.prod -p 8300:8300 voice-diarization
112
  ```
113
 
114
+ ## 📋 ตัวอย่างผลลัพธ์
115
+
116
+ ```json
117
+ {
118
+ "data": [
119
+ {
120
+ "speaker": "SPEAKER_00",
121
+ "start": 0.5,
122
+ "end": 3.2,
123
+ "text": "สวัสดีครับทุกคน วันนี้เราจะมาประชุมเรื่องโปรเจคใหม่",
124
+ "confidence": 0.92,
125
+ "has_overlap": false,
126
+ "processing_type": "clean"
127
+ }
128
+ ],
129
+ "processing_stats": {
130
+ "clean_segments": 3,
131
+ "overlap_segments": 2,
132
+ "overlap_ratio": 0.192,
133
+ "avg_confidence": 0.856
134
+ }
135
+ }
136
+ ```
137
+
138
+ ## 🔧 การกำหนดค่า
139
+
140
+ ### ข้อกำหนดระบบ
141
+ - **GPU**: CUDA-enabled GPU แนะนำสำหรับประสิทธิภาพสูงสุด
142
+ - **RAM**: อย่างน้อย 8GB
143
+ - **Python**: 3.8+
144
+
145
+ ### ไฟล์ที่รองรับ
146
+ - **Audio**: WAV, MP3, M4A, FLAC
147
+ - **Video**: MP4, AVI, MOV, MKV
148
+
149
+ ## 📚 API Documentation
150
+
151
+ ### POST `/upload_video/`
152
+ อัปโหลดไฟล์เสียงหรือวิดีโอเพื่อประมวลผล
153
+
154
+ **Parameters:**
155
+ - `file`: ไฟล์เสียงหรือวิดีโอ
156
+ - `num_speakers` (optional): จำนวน speaker ที่คาดหวัง
157
+
158
+ **Response:**
159
+ - ผลลัพธ์การแยกเสียงและถอดเสียงแบบละเอียด
160
+ - สถิติการประมวลผล
161
+ - ข้อมูลการพูดทับซ้อน
162
+
163
+ ## ⚠️ ข้อควรระวัง
164
+
165
+ - ไฟล์ `.env` และ `.env.prod` **ห้าม** commit ลง version control
166
+ - สำหรับประสิทธิภาพสูงสุด ควรใช้เครื่องที่มี CUDA GPU
167
+ - การประมวลผลไฟล์ขนาดใหญ่อาจใช้เวลานาน
168
+
169
+ ## 📄 License
170
 
171
+ Apache-2.0
 
 
 
172
 
173
  ---
174
 
175
+ **พัฒนาโดย:** Advanced Voice Processing Team
176
+ **เวอร์ชัน:** 2.0
177
+ **อัปเดตล่าสุด:** สิงหาคม 2024
app.py CHANGED
@@ -27,6 +27,7 @@ from utils import (
27
  summarize_texts,
28
  add_llm_spell_corrected_text_column,
29
  download_to_temp,
 
30
  )
31
  # from supabase import create_client, Client
32
 
@@ -59,7 +60,7 @@ async def startup_event():
59
 
60
  logger.info("🔁 Loading models at startup...")
61
  try:
62
- pipeline, model = await load_model_bundle()
63
  except Exception as e:
64
  logger.exception(f"❌ Model loading failed: {e}")
65
  import sys; sys.exit(1)
@@ -135,25 +136,86 @@ def upload_video(video_path: str):
135
  from config import together_api_key
136
  # video_path = save_uploaded_file(file)
137
  audio_path = extract_and_normalize_audio(video_path)
 
 
138
  df_diarization = diarize_audio(audio_path)
139
- segment_folder = split_segments(audio_path, df_diarization)
140
- df_transcriptions = transcribe_segments(segment_folder)
141
- min_len = min(len(df_diarization), len(df_transcriptions))
142
- df_merged = pd.concat([
143
- df_diarization.iloc[:min_len].reset_index(drop=True),
144
- df_transcriptions.iloc[:min_len].reset_index(drop=True)
145
- ], axis=1)
146
- # df_merged = add_corrected_text_column(df_merged)
147
- df_merged = add_llm_spell_corrected_text_column(df_merged)
148
- # summaries = summarize_texts(df_merged["text"].tolist(), together_api_key, delay=0)
149
- result = df_merged.to_dict(orient="records")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  speaker_array = df_diarization["speaker"].unique().tolist()
151
  counter = Counter(df_diarization["speaker"])
152
  result_array = [{"speaker": spk, "count": cnt} for spk, cnt in counter.most_common()]
 
153
  from pydub import AudioSegment
154
  duration_minutes = len(AudioSegment.from_wav(audio_path)) / 1000 / 60
155
- # save result to supabase
156
- # supabase.table("summaries").insert(result).execute()
 
 
 
 
 
 
 
157
  return {
158
  "video_path": video_path,
159
  "audio_path": audio_path,
@@ -162,8 +224,16 @@ def upload_video(video_path: str):
162
  "speaker_array": speaker_array,
163
  "count_speaker": result_array,
164
  "num_speakers": len(speaker_array),
165
- "total_sentence": len(df_merged['text']),
166
  "summaries": 'This feature not available',
 
 
 
 
 
 
 
 
167
  }
168
 
169
 
 
27
  summarize_texts,
28
  add_llm_spell_corrected_text_column,
29
  download_to_temp,
30
+ process_segments_with_branching
31
  )
32
  # from supabase import create_client, Client
33
 
 
60
 
61
  logger.info("🔁 Loading models at startup...")
62
  try:
63
+ pipeline, model, overlap_pipeline = await load_model_bundle()
64
  except Exception as e:
65
  logger.exception(f"❌ Model loading failed: {e}")
66
  import sys; sys.exit(1)
 
136
  from config import together_api_key
137
  # video_path = save_uploaded_file(file)
138
  audio_path = extract_and_normalize_audio(video_path)
139
+
140
+ # (1) Diarization + Overlap Detection
141
  df_diarization = diarize_audio(audio_path)
142
+
143
+ # (2-4) Branching Logic + Source Separation + Transcription + Timeline Stitching
144
+ branching_results = process_segments_with_branching(audio_path, df_diarization)
145
+
146
+ # รวมผลลัพธ์จาก clean และ overlap segments
147
+ all_transcriptions = []
148
+
149
+ # เพิ่มผลจาก clean segments
150
+ for i, clean_trans in enumerate(branching_results["clean_transcriptions"]):
151
+ # ใช้ index ในการจับคู่แทน filename
152
+ if i < len(branching_results["clean_segments"]):
153
+ segment = branching_results["clean_segments"][i]
154
+ all_transcriptions.append({
155
+ "speaker": segment["speaker"],
156
+ "start": segment["start"],
157
+ "end": segment["end"],
158
+ "duration": segment["duration"],
159
+ "confidence": segment.get("confidence", 0.5),
160
+ "text": clean_trans["text"],
161
+ "text_array": clean_trans.get("text_array", [clean_trans["text"]]),
162
+ "avg_probability": clean_trans["avg_probability"],
163
+ "has_overlap": segment.get("has_overlap", False),
164
+ "overlap_ratio": segment.get("overlap_ratio", 0.0),
165
+ "is_remove": segment.get("is_remove", False),
166
+ "remove_reason": segment.get("remove_reason", ""),
167
+ "processing_type": "clean",
168
+ "overlap_detail": []
169
+ })
170
+
171
+ # เพิ่มผลจาก overlap segments
172
+ for overlap_trans in branching_results["overlap_transcriptions"]:
173
+ original_segment = overlap_trans["original_segment"]
174
+ all_transcriptions.append({
175
+ "speaker": overlap_trans["speaker"],
176
+ "start": original_segment["start"],
177
+ "end": original_segment["end"],
178
+ "duration": original_segment["duration"],
179
+ "confidence": original_segment.get("confidence", 0.5),
180
+ "text": overlap_trans["transcription"]["text"],
181
+ "text_array": overlap_trans["transcription"].get("text_array", [overlap_trans["transcription"]["text"]]),
182
+ "avg_probability": overlap_trans["transcription"]["avg_probability"],
183
+ "has_overlap": True,
184
+ "overlap_ratio": original_segment.get("overlap_ratio", 1.0),
185
+ "is_remove": original_segment.get("is_remove", False),
186
+ "remove_reason": original_segment.get("remove_reason", ""),
187
+ "processing_type": "overlap_separated",
188
+ "stream_id": overlap_trans["stream_id"],
189
+ "overlap_detail": overlap_trans.get("matched_streams", [])
190
+ })
191
+
192
+ # เรียงตามเวลา
193
+ all_transcriptions.sort(key=lambda x: x["start"])
194
+
195
+ # (5) Post-process - LLM correction
196
+ df_merged = pd.DataFrame(all_transcriptions)
197
+ if not df_merged.empty:
198
+ df_merged = add_llm_spell_corrected_text_column(df_merged)
199
+ result = df_merged.to_dict(orient="records")
200
+ else:
201
+ result = []
202
+
203
+ # สถิติ
204
  speaker_array = df_diarization["speaker"].unique().tolist()
205
  counter = Counter(df_diarization["speaker"])
206
  result_array = [{"speaker": spk, "count": cnt} for spk, cnt in counter.most_common()]
207
+
208
  from pydub import AudioSegment
209
  duration_minutes = len(AudioSegment.from_wav(audio_path)) / 1000 / 60
210
+
211
+ # คำนวณ metrics
212
+ overlap_segments_count = len(branching_results["overlap_segments"])
213
+ clean_segments_count = len(branching_results["clean_segments"])
214
+ total_segments = overlap_segments_count + clean_segments_count
215
+ overlap_ratio = overlap_segments_count / max(total_segments, 1)
216
+
217
+ avg_confidence = np.mean([r.get("confidence", 0.5) for r in result]) if result else 0.0
218
+
219
  return {
220
  "video_path": video_path,
221
  "audio_path": audio_path,
 
224
  "speaker_array": speaker_array,
225
  "count_speaker": result_array,
226
  "num_speakers": len(speaker_array),
227
+ "total_sentence": len(result),
228
  "summaries": 'This feature not available',
229
+ # เพิ่ม metrics ใหม่
230
+ "processing_stats": {
231
+ "clean_segments": clean_segments_count,
232
+ "overlap_segments": overlap_segments_count,
233
+ "overlap_ratio": round(overlap_ratio, 3),
234
+ "avg_confidence": round(avg_confidence, 3),
235
+ "branching_enabled": True
236
+ }
237
  }
238
 
239
 
models.py CHANGED
@@ -34,11 +34,10 @@ def setup_together_and_ngrok():
34
  together = setup_together_and_ngrok()
35
 
36
  async def load_model_bundle():
37
- global pipelines, models
38
- # , overlap_pipeline
39
- if pipelines and models:
40
  logger.info("✅ Models already loaded. Skipping reinitialization.")
41
- return pipelines[0], models[0]
42
  def _load_models():
43
  n = torch.cuda.device_count()
44
  logger.info(f"🖥️ Found {n} CUDA device(s)")
@@ -58,11 +57,11 @@ async def load_model_bundle():
58
  cache_dir=HF_CACHE_DIR
59
  ).to(device_torch)
60
 
61
- # overlap_pipeline = Pipeline.from_pretrained(
62
- # "pyannote/overlapped-speech-detection",
63
- # use_auth_token=token,
64
- # cache_dir=HF_CACHE_DIR # ใช้ cache เดียวกับโมเดลอื่น
65
- # )
66
  model_fallback_chain = [PREFERRED_MODEL] + [m for m in FALLBACK_MODELS if m != PREFERRED_MODEL]
67
 
68
  model = None
@@ -80,7 +79,7 @@ async def load_model_bundle():
80
 
81
  pipelines.append(pipeline)
82
  models.append(model)
83
- return pipeline, model,
84
 
85
  loop = asyncio.get_event_loop()
86
  return await loop.run_in_executor(None, _load_models)
 
34
  together = setup_together_and_ngrok()
35
 
36
  async def load_model_bundle():
37
+ global pipelines, models, overlap_pipeline
38
+ if pipelines and models and overlap_pipeline:
 
39
  logger.info("✅ Models already loaded. Skipping reinitialization.")
40
+ return pipelines[0], models[0], overlap_pipeline
41
  def _load_models():
42
  n = torch.cuda.device_count()
43
  logger.info(f"🖥️ Found {n} CUDA device(s)")
 
57
  cache_dir=HF_CACHE_DIR
58
  ).to(device_torch)
59
 
60
+ overlap_pipeline = Pipeline.from_pretrained(
61
+ "pyannote/overlapped-speech-detection",
62
+ use_auth_token=token,
63
+ cache_dir=HF_CACHE_DIR
64
+ ).to(device_torch)
65
  model_fallback_chain = [PREFERRED_MODEL] + [m for m in FALLBACK_MODELS if m != PREFERRED_MODEL]
66
 
67
  model = None
 
79
 
80
  pipelines.append(pipeline)
81
  models.append(model)
82
+ return pipeline, model, overlap_pipeline
83
 
84
  loop = asyncio.get_event_loop()
85
  return await loop.run_in_executor(None, _load_models)
requirements.txt CHANGED
@@ -16,6 +16,11 @@ faster-whisper==1.1.1
16
  librosa==0.10.1
17
  soundfile==0.12.1
18
 
 
 
 
 
 
19
  # API and networking
20
  python-multipart==0.0.6
21
  pyngrok==7.0.0
 
16
  librosa==0.10.1
17
  soundfile==0.12.1
18
 
19
+ # Source separation and speaker recognition
20
+ asteroid-filterbanks==0.4.0
21
+ speechbrain==0.5.16
22
+ torchaudio>=0.13.0
23
+
24
  # API and networking
25
  python-multipart==0.0.6
26
  pyngrok==7.0.0
utils.py CHANGED
@@ -11,7 +11,7 @@ import numpy as np
11
  from collections import Counter
12
  import time
13
  from config import UPLOAD_FOLDER
14
- from models import pipelines, models, together
15
  import subprocess
16
  import librosa
17
  from pydantic import BaseModel, AnyHttpUrl
@@ -133,7 +133,7 @@ def split_segments(audio_path: str, df: pd.DataFrame, stretch_factor: float = 1.
133
 
134
  return segment_folder
135
 
136
- def transcribe_segments(segment_folder: str) -> pd.DataFrame:
137
  files = sorted([f for f in os.listdir(segment_folder) if f.endswith(".wav")])
138
  model = models[0]
139
 
@@ -143,41 +143,60 @@ def transcribe_segments(segment_folder: str) -> pd.DataFrame:
143
  segment_path = os.path.join(segment_folder, filename)
144
 
145
  try:
146
- segments, _ = model.transcribe(
147
- segment_path,
148
- language="th",
149
- beam_size=5,
150
- vad_filter=True,
151
- word_timestamps=True
152
- )
153
-
154
- # ดึงคำทั้งหมดจากทุก segment
155
- words = [word for seg in segments if hasattr(seg, "words") for word in seg.words]
156
-
157
- if words:
158
- full_text = ''.join([w.word for w in words])
159
- probs = [w.probability for w in words if w.probability is not None]
160
- avg_prob = float(np.mean(probs)) if probs else 0.0
161
- avg_prob = round(avg_prob, 4)
162
-
163
- results.append({
164
- "filename": filename,
165
- "text": full_text,
166
- "avg_probability": avg_prob,
167
- })
 
 
 
 
 
 
 
 
 
 
 
 
168
  else:
169
- results.append({
170
- "filename": filename,
171
- "text": "",
172
- "avg_probability": 0.0,
173
- })
 
 
 
 
 
174
 
175
  except Exception as e:
176
  print(f"❌ Error with {filename}: {e}")
177
  results.append({
178
  "filename": filename,
179
  "text": "",
 
180
  "avg_probability": 0.0,
 
181
  "error": str(e)
182
  })
183
 
@@ -301,18 +320,43 @@ def add_llm_spell_corrected_text_column(df, model="google/gemma-3-27b-it", delay
301
  ]
302
  return any(k in msg for k in keys)
303
 
304
- texts = df["text"].fillna("").astype(str).tolist()
 
 
 
 
 
 
305
  corrected = []
306
 
307
- for idx, text in enumerate(texts):
 
 
 
 
 
 
 
 
 
 
 
 
308
  prompt = f"""
309
- กรุณาแก้ไขข้อความต่อไปนี้ให้ถูกต้องตามหลักภาษาไทย:
 
 
 
 
 
 
 
310
 
311
- - แก้ไขคำสะกดผิด คำพิมพ์ผิด หรือคำที่ไม่ถูกต้องและการผันวรรณยุกต์ผิด
 
312
  - ห้ามเปลี่ยนความหมาย
313
- - ห้ามตอบเกิน
314
- - **ตอบกลับเฉพาะข้อความที่แก้แล้ว**
315
- {text}
316
  """.strip()
317
 
318
  try:
@@ -321,33 +365,20 @@ def add_llm_spell_corrected_text_column(df, model="google/gemma-3-27b-it", delay
321
  messages=[
322
  {
323
  "role": "system",
324
- "content": """คุณคือนักภาษาศาสตร์ผู้เชี่ยวชาญด้านการตรวจสอบคำสะกดผิด คำพิมพ์ผิด และการผันวรรณยุกต์ผิดของภาษาไทย
325
- หน้าที่ของคุณคือแก้ไขคำผิดในข้อความที่ได้รับให้ถูกต้องตามมาตรฐานภาษาไทย โดยไม่เปลี่ยนความหมายเดิม
326
-
327
- หน้าที่ของคุณ:
328
- - แก้ไขข้อความภาษาไทยให้ถูกต้องตามหลักภาษาไทยมาตรฐาน
329
- - ตรวจสอบคำสะกดผิด คำพิมพ์ผิด และการผันวรรณยุกต์ผิด
330
- - แก้คำเพี้ยน คำที่มาจากเสียงพูด เช่น ภาษาวัยรุ่นหรือคำพูดที่ออกเสียงคล้ายกัน ให้เป็นคำที่ถูกต้อง
331
- - รักษาความหมายเดิมของข้อความให้มากที่สุด
332
- - ห้ามแปลความใหม่ ห้ามตีความเกิน ห้ามปรับสำนวน
333
- - ห้ามอธิบาย หรือใส่คำพูดใด ๆ เพิ่มเติมก่อนหรือหลังข้อความ
334
- - **ให้ตอบกลับเฉพาะข้อความที่แก้ไขแล้วเท่านั้น**
335
-
336
- ตัวอย่าง:
337
-
338
- ผู้ใช้: ผมไช้คอมพิวเตอรทุกวัน
339
- คุณ: ผมใช้คอมพิวเตอร์ทุกวั��
340
-
341
- ผู้ใช้: ปวดหัวจะตายุ่ละ
342
- คุณ: ปวดหัวจะตายอยู่ละ
343
-
344
- ผู้ใช้: ไอ้เส้นหลั่งกุ้ง
345
- คุณ: ไอ้เส้นหลังกุ้ง
346
-
347
- ผู้ใช้: เซโยโมมันน่ากลัว
348
- คุณ: เชื้อโรคมันน่ากลัว
349
-
350
- จงตอบกลับเฉพาะข้อความที่แก้ไขแล้วตามตัวอย่างข้างต้นเท่านั้น
351
  """
352
  },
353
  {"role": "user", "content": prompt}
@@ -365,45 +396,47 @@ def add_llm_spell_corrected_text_column(df, model="google/gemma-3-27b-it", delay
365
  if _is_quota_error(err):
366
  corrected.append(" - ")
367
  else:
368
- corrected.append("")
 
 
369
 
370
- if idx < len(texts) - 1:
371
  time.sleep(delay)
372
 
373
  df["llm_corrected_text"] = corrected
374
  return df
375
 
376
- # def _merge_intervals(intervals, gap=0.0):
377
- # if not intervals:
378
- # return []
379
- # intervals = sorted(intervals, key=lambda x: x[0])
380
- # merged = [list(intervals[0])]
381
- # for s, e in intervals[1:]:
382
- # if s <= merged[-1][1] + gap:
383
- # merged[-1][1] = max(merged[-1][1], e)
384
- # else:
385
- # merged.append([s, e])
386
- # return [(float(a), float(b)) for a, b in merged]
387
-
388
- # def _interval_intersection(a, b):
389
- # s = max(a[0], b[0]); e = min(a[1], b[1])
390
- # return (s, e) if e > s else None
391
-
392
- # def detect_overlap_timeline(audio_path: str):
393
- # """
394
- # คืนรายการช่วงเวลาที่มีการพูดซ้อน [(start, end), ...]
395
- # ถ้าโหลดโมเดลไม่ได้ → คืน []
396
- # """
397
- # if overlap_pipeline is None:
398
- # return []
399
-
400
- # try:
401
- # ov = overlap_pipeline(audio_path) # pyannote Annotation
402
- # intervals = [(float(seg.start), float(seg.end)) for seg in ov.get_timeline()]
403
- # return _merge_intervals(intervals)
404
- # except Exception as e:
405
- # print(f"⚠️ Overlap detection failed: {e}")
406
- # return []
407
 
408
  def _confidence_metrics(audio_seg, sr):
409
  try:
@@ -464,24 +497,24 @@ def tag_segments_use_or_remove(segments: list, min_segment_duration=3.0, min_spe
464
 
465
  return kept, removed, sorted(list(valid_speakers))
466
 
467
- # def enrich_with_overlap(segments: list, overlap_timeline: list):
468
- # """
469
- # เติม: has_overlap, overlap_intervals, overlap_ratio
470
- # """
471
- # for seg in segments:
472
- # s, e = float(seg["start"]), float(seg["end"])
473
- # overlaps = []
474
- # total = 0.0
475
- # for (os, oe) in overlap_timeline:
476
- # inter = _interval_intersection((s, e), (os, oe))
477
- # if inter:
478
- # overlaps.append([round(inter[0], 3), round(inter[1], 3)])
479
- # total += (inter[1] - inter[0])
480
- # dur = max(1e-9, e - s)
481
- # seg["has_overlap"] = bool(overlaps)
482
- # seg["overlap_intervals"] = overlaps
483
- # seg["overlap_ratio"] = float(total / dur)
484
- # return segments
485
 
486
  def diarize_audio(audio_path: str) -> pd.DataFrame:
487
  sr = 16000
@@ -514,10 +547,10 @@ def diarize_audio(audio_path: str) -> pd.DataFrame:
514
  min_speaker_total=min_speaker_total
515
  )
516
 
517
- # # 4) Overlap
518
- # ov_tl = detect_overlap_timeline(audio_path)
519
- # kept = enrich_with_overlap(kept, ov_tl)
520
- # removed = enrich_with_overlap(removed, ov_tl)
521
 
522
  # 5) Combine
523
  all_rows = kept + removed
@@ -525,6 +558,365 @@ def diarize_audio(audio_path: str) -> pd.DataFrame:
525
 
526
  df = pd.DataFrame(all_rows, columns=[
527
  "speaker","start","end","duration","confidence",
528
- "tag","remove_reason"
529
  ])
530
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from collections import Counter
12
  import time
13
  from config import UPLOAD_FOLDER
14
+ from models import pipelines, models, together, overlap_pipeline
15
  import subprocess
16
  import librosa
17
  from pydantic import BaseModel, AnyHttpUrl
 
133
 
134
  return segment_folder
135
 
136
+ def transcribe_segments(segment_folder: str, num_rounds: int = 3) -> pd.DataFrame:
137
  files = sorted([f for f in os.listdir(segment_folder) if f.endswith(".wav")])
138
  model = models[0]
139
 
 
143
  segment_path = os.path.join(segment_folder, filename)
144
 
145
  try:
146
+ text_array = []
147
+ prob_array = []
148
+
149
+ # ทำ transcription หลายรอบ
150
+ for round_num in range(num_rounds):
151
+ segments, _ = model.transcribe(
152
+ segment_path,
153
+ language="th",
154
+ beam_size=5,
155
+ vad_filter=True,
156
+ word_timestamps=True,
157
+ temperature=0.0 if round_num == 0 else 0.2 # รอบแรกใช้ deterministic
158
+ )
159
+
160
+ # ดึงคำทั้งหมดจากทุก segment
161
+ words = [word for seg in segments if hasattr(seg, "words") for word in seg.words]
162
+
163
+ if words:
164
+ full_text = ''.join([w.word for w in words])
165
+ probs = [w.probability for w in words if w.probability is not None]
166
+ avg_prob = round(np.mean(probs), 4) if probs else 0.0
167
+ avg_prob = round(avg_prob, 4)
168
+
169
+ text_array.append(full_text)
170
+ prob_array.append(avg_prob)
171
+ else:
172
+ text_array.append("")
173
+ prob_array.append(0.0)
174
+
175
+ # เลือกผลลัพธ์ที่ดีที่สุด (probability สูงสุด)
176
+ if prob_array and max(prob_array) > 0:
177
+ best_idx = prob_array.index(max(prob_array))
178
+ best_text = text_array[best_idx]
179
+ best_prob = prob_array[best_idx]
180
  else:
181
+ best_text = text_array[0] if text_array else ""
182
+ best_prob = prob_array[0] if prob_array else 0.0
183
+
184
+ results.append({
185
+ "filename": filename,
186
+ "text": best_text,
187
+ "text_array": text_array,
188
+ "avg_probability": best_prob,
189
+ "prob_array": prob_array,
190
+ })
191
 
192
  except Exception as e:
193
  print(f"❌ Error with {filename}: {e}")
194
  results.append({
195
  "filename": filename,
196
  "text": "",
197
+ "text_array": ["", "", ""],
198
  "avg_probability": 0.0,
199
+ "prob_array": [0.0, 0.0, 0.0],
200
  "error": str(e)
201
  })
202
 
 
320
  ]
321
  return any(k in msg for k in keys)
322
 
323
+ # ใช้ text_array ถ้ามี ไม่งั้นใช้ text เดี่ยว
324
+ if "text_array" in df.columns:
325
+ text_arrays = df["text_array"].fillna("").tolist()
326
+ else:
327
+ texts = df["text"].fillna("").astype(str).tolist()
328
+ text_arrays = [[text] for text in texts] # แปลงเป็น array
329
+
330
  corrected = []
331
 
332
+ for idx, text_array in enumerate(text_arrays):
333
+ # ถ้าเป็น string เดี่ยว แปลงเป็น list
334
+ if isinstance(text_array, str):
335
+ text_array = [text_array]
336
+
337
+ # ถ้าไม่มีข้อความ skip
338
+ if not text_array or all(not t.strip() for t in text_array):
339
+ corrected.append("")
340
+ continue
341
+
342
+ # สร้าง prompt ให้ LLM เลือกและแก้ไข
343
+ text_options = "\n".join([f"ตัวเลือก {i+1}: {text}" for i, text in enumerate(text_array) if text.strip()])
344
+
345
  prompt = f"""
346
+ จากตัวเลือกข้อความต่อไปนี้ กรุณาเลือกตัวเลือกที่ดีที่สุด แล้วแก้ไขให้ถูกต้องตามหลักภาษาไทย:
347
+
348
+ {text_options}
349
+
350
+ หลักเกณฑ์การเลือก:
351
+ - เลือกข้อความที่มีความหมายชัดเจนที่สุด
352
+ - เลือกข้อความที่สมบูรณ์ที่สุด (ไม่ขาดคำ)
353
+ - หลีกเลี่ยงข้อความที่ซ้ำซ้อนหรือผิดพลาดชัดเจน
354
 
355
+ การแก้ไข:
356
+ - แก้ไขคำสะกดผิด คำพิมพ์ผิด หรือการผันวรรณยุกต์ผิด
357
  - ห้ามเปลี่ยนความหมาย
358
+ - ห้ามอธิบายหรือใส่คำพูดเพิ่มเติม
359
+ - **ตอบกลับเฉพาะข้อความที่เลือกและแก้ไขแล้วเท่านั้น**
 
360
  """.strip()
361
 
362
  try:
 
365
  messages=[
366
  {
367
  "role": "system",
368
+ "content": """คุณคือนักภาษาศาสตร์ผู้เชี่ยวชาญด้านการตรวจสอบและแก้ไขข้อความภาษาไทย
369
+ หน้าที่ของคุณคือเลือกข้อความที่ดีที่สุดจากตัวเลือกที่ให้มา แล้วแก้ไขให้ถูกต้องตามมาตรฐานภาษาไทย
370
+
371
+ หลักเกณฑ์การเลือก:
372
+ 1. ความสมบูรณ์ของข้อความ (ไม่ขาดคำสำคัญ)
373
+ 2. ความชัดเจนของความหมาย
374
+ 3. ความถูกต้องทางไวยากรณ์
375
+ 4. หลีกเลี่ยงการซ้ำซ้อนหรือข้อผิดพลาดชัดเจน
376
+
377
+ การแก้ไข:
378
+ - แก้ไขคำสะกดผิด คำพิมพ์ผิด และการผันวรรณยุกต์ผิด
379
+ - รักษาความหมายเดิมของข้อความ
380
+ - ห้ามแปลความใหม่ ห้ามตีความเกิน
381
+ - **ตอบกลับเฉพาะข้อความที่เลือกและแก้ไขแล้วเท่านั้น**
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  """
383
  },
384
  {"role": "user", "content": prompt}
 
396
  if _is_quota_error(err):
397
  corrected.append(" - ")
398
  else:
399
+ # Fallback: ใช้ตัวเลือกแรกที่ไม่ว่าง
400
+ fallback_text = next((t for t in text_array if t.strip()), "")
401
+ corrected.append(fallback_text)
402
 
403
+ if idx < len(text_arrays) - 1:
404
  time.sleep(delay)
405
 
406
  df["llm_corrected_text"] = corrected
407
  return df
408
 
409
+ def _merge_intervals(intervals, gap=0.0):
410
+ if not intervals:
411
+ return []
412
+ intervals = sorted(intervals, key=lambda x: x[0])
413
+ merged = [list(intervals[0])]
414
+ for s, e in intervals[1:]:
415
+ if s <= merged[-1][1] + gap:
416
+ merged[-1][1] = max(merged[-1][1], e)
417
+ else:
418
+ merged.append([s, e])
419
+ return [(float(a), float(b)) for a, b in merged]
420
+
421
+ def _interval_intersection(a, b):
422
+ s = max(a[0], b[0]); e = min(a[1], b[1])
423
+ return (s, e) if e > s else None
424
+
425
+ def detect_overlap_timeline(audio_path: str):
426
+ """
427
+ คืนรายการช่วงเวลาที่มีการพูดซ้อน [(start, end), ...]
428
+ ถ้าโหลดโมเดลไม่ได้ → คืน []
429
+ """
430
+ if overlap_pipeline is None:
431
+ return []
432
+
433
+ try:
434
+ ov = overlap_pipeline(audio_path) # pyannote Annotation
435
+ intervals = [(float(seg.start), float(seg.end)) for seg in ov.get_timeline()]
436
+ return _merge_intervals(intervals)
437
+ except Exception as e:
438
+ print(f"⚠️ Overlap detection failed: {e}")
439
+ return []
440
 
441
  def _confidence_metrics(audio_seg, sr):
442
  try:
 
497
 
498
  return kept, removed, sorted(list(valid_speakers))
499
 
500
+ def enrich_with_overlap(segments: list, overlap_timeline: list):
501
+ """
502
+ เติม: has_overlap, overlap_intervals, overlap_ratio
503
+ """
504
+ for seg in segments:
505
+ s, e = float(seg["start"]), float(seg["end"])
506
+ overlaps = []
507
+ total = 0.0
508
+ for (os, oe) in overlap_timeline:
509
+ inter = _interval_intersection((s, e), (os, oe))
510
+ if inter:
511
+ overlaps.append([round(inter[0], 3), round(inter[1], 3)])
512
+ total += (inter[1] - inter[0])
513
+ dur = max(1e-9, e - s)
514
+ seg["has_overlap"] = bool(overlaps)
515
+ seg["overlap_intervals"] = overlaps
516
+ seg["overlap_ratio"] = float(total / dur)
517
+ return segments
518
 
519
  def diarize_audio(audio_path: str) -> pd.DataFrame:
520
  sr = 16000
 
547
  min_speaker_total=min_speaker_total
548
  )
549
 
550
+ # 4) Overlap
551
+ ov_tl = detect_overlap_timeline(audio_path)
552
+ kept = enrich_with_overlap(kept, ov_tl)
553
+ removed = enrich_with_overlap(removed, ov_tl)
554
 
555
  # 5) Combine
556
  all_rows = kept + removed
 
558
 
559
  df = pd.DataFrame(all_rows, columns=[
560
  "speaker","start","end","duration","confidence",
561
+ "is_remove","remove_reason","has_overlap","overlap_intervals","overlap_ratio"
562
  ])
563
+ return df
564
+
565
+ def detect_speech_boundaries(audio_data: np.ndarray, sample_rate: int, offset_time: float,
566
+ energy_threshold: float = 0.01, min_speech_duration: float = 0.1):
567
+ """
568
+ หาขอบเขตของการพูดจริงใน audio stream ด้วย energy-based detection
569
+ """
570
+ import numpy as np
571
+
572
+ # คำนวณ energy ของ audio
573
+ frame_size = int(0.025 * sample_rate) # 25ms frames
574
+ hop_size = int(0.010 * sample_rate) # 10ms hop
575
+
576
+ energy = []
577
+ for i in range(0, len(audio_data) - frame_size, hop_size):
578
+ frame = audio_data[i:i + frame_size]
579
+ frame_energy = np.sum(frame ** 2) / len(frame)
580
+ energy.append(frame_energy)
581
+
582
+ energy = np.array(energy)
583
+
584
+ # หา threshold แบบ adaptive
585
+ if len(energy) > 0:
586
+ max_energy = np.max(energy)
587
+ adaptive_threshold = max_energy * energy_threshold
588
+
589
+ # หาจุดเริ่มต้นและสิ้นสุดของการพูด
590
+ speech_frames = energy > adaptive_threshold
591
+
592
+ if np.any(speech_frames):
593
+ # หาจุดเริ่มต้น
594
+ start_frame = np.where(speech_frames)[0][0]
595
+ end_frame = np.where(speech_frames)[0][-1]
596
+
597
+ # แปลงเป็นเวลา
598
+ start_time = offset_time + (start_frame * hop_size / sample_rate)
599
+ end_time = offset_time + ((end_frame + 1) * hop_size / sample_rate)
600
+
601
+ # ตรวจสอบ minimum duration
602
+ if end_time - start_time >= min_speech_duration:
603
+ return start_time, end_time
604
+
605
+ # Fallback: ใช้เวลาเต็ม
606
+ duration = len(audio_data) / sample_rate
607
+ return offset_time, offset_time + duration
608
+
609
+ def separate_overlapping_segments(audio_path: str, overlap_segments: list):
610
+ """
611
+ แยกเสียงสำหรับ segments ที่มี overlap ด้วย Asteroid
612
+ """
613
+ try:
614
+ import torch
615
+ import torchaudio
616
+ from asteroid.models import ConvTasNet
617
+
618
+ # โหลด pre-trained model
619
+ model = ConvTasNet.from_pretrained("JorisCos/ConvTasNet_Libri2Mix_sepclean_8k")
620
+
621
+ separated_results = []
622
+
623
+ for segment in overlap_segments:
624
+ try:
625
+ # โหลดเสียงในช่วงที่ overlap
626
+ start_time = float(segment["start"])
627
+ end_time = float(segment["end"])
628
+
629
+ # โหลดเสียงด้วย torchaudio
630
+ waveform, sample_rate = torchaudio.load(audio_path,
631
+ frame_offset=int(start_time * sample_rate),
632
+ num_frames=int((end_time - start_time) * sample_rate))
633
+
634
+ # แยกเสียง (ConvTasNet คาดหวัง mono input)
635
+ if waveform.shape[0] > 1:
636
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
637
+
638
+ # Separate audio
639
+ with torch.no_grad():
640
+ separated = model(waveform.unsqueeze(0)) # Add batch dimension
641
+
642
+ # บันทึกผลลัพธ์
643
+ segment_result = {
644
+ "original_segment": segment,
645
+ "separated_streams": [],
646
+ "num_streams": separated.shape[1]
647
+ }
648
+
649
+ # บันทึกแต่ละ stream พร้อมหาเวลาจริง
650
+ for i in range(separated.shape[1]):
651
+ stream_audio = separated[0, i, :].cpu().numpy()
652
+ stream_duration = len(stream_audio) / sample_rate
653
+
654
+ # หาเวลาจริงด้วย energy-based detection
655
+ stream_start, stream_end = detect_speech_boundaries(stream_audio, sample_rate, start_time)
656
+
657
+ segment_result["separated_streams"].append({
658
+ "stream_id": i,
659
+ "audio_data": stream_audio,
660
+ "sample_rate": sample_rate,
661
+ "start": stream_start,
662
+ "end": stream_end,
663
+ "duration": stream_end - stream_start
664
+ })
665
+
666
+ separated_results.append(segment_result)
667
+
668
+ except Exception as e:
669
+ print(f"❌ Error separating segment {segment}: {e}")
670
+ # Fallback: ใช้เสียงต้นฉบับ
671
+ separated_results.append({
672
+ "original_segment": segment,
673
+ "separated_streams": [],
674
+ "num_streams": 0,
675
+ "error": str(e)
676
+ })
677
+
678
+ return separated_results
679
+
680
+ except ImportError:
681
+ print("⚠️ Asteroid not installed. Install with: pip install asteroid-filterbanks torch-audio")
682
+ return []
683
+ except Exception as e:
684
+ print(f"❌ Source separation failed: {e}")
685
+ return []
686
+
687
+ def match_streams_to_speakers(separated_results: list, audio_path: str):
688
+ """
689
+ จับคู่ separated streams กับ speakers โดยใช้ speaker embeddings
690
+ """
691
+ try:
692
+ from speechbrain.pretrained import EncoderClassifier
693
+ from sklearn.metrics.pairwise import cosine_similarity
694
+ import numpy as np
695
+
696
+ # โหลด speaker embedding model
697
+ classifier = EncoderClassifier.from_hparams(
698
+ source="speechbrain/spkrec-ecapa-voxceleb",
699
+ savedir="tmp/spkrec-ecapa-voxceleb"
700
+ )
701
+
702
+ matched_results = []
703
+
704
+ # สร้าง speaker profiles จาก clean segments ก่อน
705
+ speaker_profiles = {}
706
+
707
+ for result in separated_results:
708
+ if result.get("error") or not result["separated_streams"]:
709
+ matched_results.append(result)
710
+ continue
711
+
712
+ segment = result["original_segment"]
713
+ streams = result["separated_streams"]
714
+
715
+ # สร้าง embeddings สำหรับแต่ละ stream
716
+ stream_embeddings = []
717
+ for stream in streams:
718
+ try:
719
+ # แปลง audio data เป็น tensor
720
+ audio_tensor = torch.FloatTensor(stream["audio_data"]).unsqueeze(0)
721
+ embedding = classifier.encode_batch(audio_tensor)
722
+ stream_embeddings.append(embedding.squeeze().cpu().numpy())
723
+ except Exception as e:
724
+ print(f"⚠️ Failed to create embedding for stream: {e}")
725
+ stream_embeddings.append(None)
726
+
727
+ # สร้าง speaker profile ถ้ายังไม่มี
728
+ if segment["speaker"] not in speaker_profiles and len(stream_embeddings) > 0:
729
+ valid_embeddings = [emb for emb in stream_embeddings if emb is not None]
730
+ if valid_embeddings:
731
+ speaker_profiles[segment["speaker"]] = np.mean(valid_embeddings, axis=0)
732
+
733
+ # จับคู่กับ speaker (ใช้ cosine similarity)
734
+ matched_streams = []
735
+ for i, (stream, embedding) in enumerate(zip(streams, stream_embeddings)):
736
+ if embedding is not None and len(speaker_profiles) > 0:
737
+ # คำนวณ similarity กับทุก speaker
738
+ similarities = {}
739
+ for speaker_id, profile_embedding in speaker_profiles.items():
740
+ similarity = cosine_similarity([embedding], [profile_embedding])[0][0]
741
+ similarities[speaker_id] = similarity
742
+
743
+ # เลือก speaker ที่มี similarity สูงสุด
744
+ best_match = max(similarities, key=similarities.get)
745
+ confidence = float(similarities[best_match])
746
+
747
+ matched_streams.append({
748
+ **stream,
749
+ "speaker_embedding": embedding,
750
+ "matched_speaker": best_match,
751
+ "confidence": round(confidence, 3)
752
+ })
753
+ else:
754
+ # Fallback ถ้าไม่มี embedding หรือ profile
755
+ matched_streams.append({
756
+ **stream,
757
+ "matched_speaker": segment["speaker"] if embedding is not None else f"unknown_{i}",
758
+ "confidence": 0.5
759
+ })
760
+
761
+ result["matched_streams"] = matched_streams
762
+ matched_results.append(result)
763
+
764
+ return matched_results
765
+
766
+ except ImportError:
767
+ print("⚠️ SpeechBrain not installed. Install with: pip install speechbrain")
768
+ # Fallback: ใช้ speaker เดิม
769
+ for result in separated_results:
770
+ if not result.get("error") and result["separated_streams"]:
771
+ result["matched_streams"] = [
772
+ {**stream, "matched_speaker": result["original_segment"]["speaker"], "confidence": 0.5}
773
+ for stream in result["separated_streams"]
774
+ ]
775
+ return separated_results
776
+ except Exception as e:
777
+ print(f"❌ Speaker matching failed: {e}")
778
+ return separated_results
779
+
780
+ def branch_segments_by_overlap(df_diarization: pd.DataFrame, overlap_threshold: float = 0.1):
781
+ """
782
+ แยก segments เป็น clean และ overlap ตาม overlap_ratio
783
+ """
784
+ clean_segments = []
785
+ overlap_segments = []
786
+
787
+ for _, row in df_diarization.iterrows():
788
+ segment = row.to_dict()
789
+
790
+ # ตรวจสอบว่ามี overlap หรือไม่
791
+ has_overlap = segment.get("has_overlap", False)
792
+ overlap_ratio = segment.get("overlap_ratio", 0.0)
793
+
794
+ if has_overlap and overlap_ratio > overlap_threshold:
795
+ overlap_segments.append(segment)
796
+ else:
797
+ clean_segments.append(segment)
798
+
799
+ return clean_segments, overlap_segments
800
+
801
+ def process_segments_with_branching(audio_path: str, df_diarization: pd.DataFrame):
802
+ """
803
+ ประมวลผล segments แบบแยกเส้นทาง: Clean vs Overlap
804
+ """
805
+ # แยก segments
806
+ clean_segments, overlap_segments = branch_segments_by_overlap(df_diarization)
807
+
808
+ print(f"🔍 Found {len(clean_segments)} clean segments, {len(overlap_segments)} overlap segments")
809
+
810
+ results = {
811
+ "clean_segments": clean_segments,
812
+ "overlap_segments": overlap_segments,
813
+ "clean_transcriptions": [],
814
+ "overlap_transcriptions": []
815
+ }
816
+
817
+ # ประมวลผล clean segments (ใช้วิธีเดิม)
818
+ if clean_segments:
819
+ print("🎯 Processing clean segments...")
820
+ clean_df = pd.DataFrame(clean_segments)
821
+ segment_folder = split_segments(audio_path, clean_df)
822
+ clean_transcriptions = transcribe_segments(segment_folder)
823
+ results["clean_transcriptions"] = clean_transcriptions.to_dict(orient="records")
824
+
825
+ # ประมวลผล overlap segments (ใช้ source separation)
826
+ if overlap_segments:
827
+ print("🔀 Processing overlap segments with source separation...")
828
+ separated_results = separate_overlapping_segments(audio_path, overlap_segments)
829
+ matched_results = match_streams_to_speakers(separated_results, audio_path)
830
+
831
+ # Transcribe แต่ละ separated stream
832
+ overlap_transcriptions = []
833
+ for result in matched_results:
834
+ if result.get("matched_streams"):
835
+ for stream in result["matched_streams"]:
836
+ # บันทึก audio stream เป็นไฟล์ชั่วคราว
837
+ temp_audio_path = save_temp_audio_stream(stream)
838
+ if temp_audio_path:
839
+ # Transcribe stream
840
+ stream_transcription = transcribe_single_audio(temp_audio_path)
841
+ overlap_transcriptions.append({
842
+ "original_segment": result["original_segment"],
843
+ "stream_id": stream["stream_id"],
844
+ "speaker": stream.get("matched_speaker", "unknown"),
845
+ "transcription": stream_transcription,
846
+ "matched_streams": result["matched_streams"] # เพิ่ม matched_streams ทั้งหมด
847
+ })
848
+ # ลบไฟล์ชั่วคราว
849
+ os.remove(temp_audio_path)
850
+
851
+ results["overlap_transcriptions"] = overlap_transcriptions
852
+
853
+ return results
854
+
855
+ def save_temp_audio_stream(stream_data: dict) -> str:
856
+ """บันทึก audio stream เป็นไฟล์ชั่วคราว"""
857
+ try:
858
+ import tempfile
859
+ import soundfile as sf
860
+
861
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
862
+ sf.write(tmp_file.name, stream_data["audio_data"], stream_data["sample_rate"])
863
+ return tmp_file.name
864
+ except Exception as e:
865
+ print(f"❌ Failed to save temp audio: {e}")
866
+ return None
867
+
868
+ def transcribe_single_audio(audio_path: str, num_rounds: int = 3) -> dict:
869
+ """Transcribe ไฟล์เสียงเดี่ยว"""
870
+ try:
871
+ model = models[0]
872
+ text_array = []
873
+ prob_array = []
874
+
875
+ # ทำ transcription หลายรอบ
876
+ for round_num in range(num_rounds):
877
+ segments, _ = model.transcribe(
878
+ audio_path,
879
+ language="th",
880
+ beam_size=5,
881
+ vad_filter=True,
882
+ word_timestamps=True,
883
+ temperature=0.0 if round_num == 0 else 0.2 # รอบแรกใช้ deterministic
884
+ )
885
+
886
+ words = [word for seg in segments if hasattr(seg, "words") for word in seg.words]
887
+
888
+ if words:
889
+ full_text = ''.join([w.word for w in words])
890
+ probs = [w.probability for w in words if w.probability is not None]
891
+ avg_prob = round(np.mean(probs), 4) if probs else 0.0
892
+
893
+ text_array.append(full_text)
894
+ prob_array.append(avg_prob)
895
+ else:
896
+ text_array.append("")
897
+ prob_array.append(0.0)
898
+
899
+ # เลือกผลลัพธ์ที่ดีที่สุด (probability สูงสุด)
900
+ if prob_array and max(prob_array) > 0:
901
+ best_idx = prob_array.index(max(prob_array))
902
+ best_text = text_array[best_idx]
903
+ best_prob = prob_array[best_idx]
904
+ else:
905
+ best_text = text_array[0] if text_array else ""
906
+ best_prob = prob_array[0] if prob_array else 0.0
907
+
908
+ return {
909
+ "text": best_text,
910
+ "text_array": text_array,
911
+ "avg_probability": best_prob,
912
+ "prob_array": prob_array,
913
+ }
914
+ except Exception as e:
915
+ print(f"❌ Transcription failed: {e}")
916
+ return {
917
+ "text": "",
918
+ "text_array": ["", "", ""],
919
+ "avg_probability": 0.0,
920
+ "prob_array": [0.0, 0.0, 0.0],
921
+ "error": str(e)
922
+ }