Y Phung Nguyen commited on
Commit
5096447
·
1 Parent(s): 0cd2df1

Upd MCP ASR&TTS

Browse files
Files changed (2) hide show
  1. agent.py +161 -13
  2. voice.py +63 -18
agent.py CHANGED
@@ -16,10 +16,14 @@ from pathlib import Path
16
 
17
  # MCP imports
18
  try:
 
 
 
 
19
  from mcp import types as mcp_types
20
- from mcp.server import Server, NotificationOptions
21
- from mcp.types import Tool, TextContent, ImageContent, EmbeddedResource
22
  from mcp.server.models import InitializationOptions
 
23
  except ImportError:
24
  print("Error: MCP SDK not installed. Install with: pip install mcp", file=sys.stderr)
25
  sys.exit(1)
@@ -60,8 +64,8 @@ GEMINI_MAX_FILES = int(os.environ.get("GEMINI_MAX_FILES", "10"))
60
  GEMINI_MAX_TOTAL_FILE_SIZE = int(os.environ.get("GEMINI_MAX_TOTAL_FILE_SIZE", "50")) # MB
61
  GEMINI_TEMPERATURE = float(os.environ.get("GEMINI_TEMPERATURE", "0.2"))
62
 
63
- # Create MCP server
64
- app = Server("gemini-mcp-server")
65
 
66
  def decode_base64_file(content: str, mime_type: str = None) -> bytes:
67
  """Decode base64 encoded file content"""
@@ -117,7 +121,7 @@ def prepare_gemini_files(files: list) -> list:
117
 
118
  return gemini_parts
119
 
120
- @app.list_tools()
121
  async def list_tools() -> list[Tool]:
122
  """List available tools"""
123
  try:
@@ -159,6 +163,46 @@ async def list_tools() -> list[Tool]:
159
  },
160
  "required": ["user_prompt"]
161
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  )
163
  ]
164
  return tools
@@ -166,7 +210,7 @@ async def list_tools() -> list[Tool]:
166
  logger.error(f"Error in list_tools(): {e}")
167
  raise
168
 
169
- @app.call_tool()
170
  async def call_tool(name: str, arguments: dict) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
171
  """Handle tool calls"""
172
  logger.info(f"🔵 MCP tool call received: {name}")
@@ -277,6 +321,111 @@ async def call_tool(name: str, arguments: dict) -> Sequence[TextContent | ImageC
277
  except Exception as e:
278
  logger.error(f"Error in generate_content: {e}")
279
  return [TextContent(type="text", text=f"Error: {str(e)}")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  else:
281
  return [TextContent(type="text", text=f"Unknown tool: {name}")]
282
 
@@ -289,21 +438,20 @@ async def main():
289
  logger.info(f"Default Lite Model: {GEMINI_MODEL_LITE}")
290
  logger.info("=" * 60)
291
 
292
- # Use stdio_server from mcp.server.stdio
293
- from mcp.server.stdio import stdio_server
294
-
295
  # Keep logging enabled for debugging
296
  original_root_level = logging.getLogger("root").level
297
  logging.getLogger("root").setLevel(logging.INFO)
298
 
299
  try:
 
 
300
  async with stdio_server() as streams:
301
  # Prepare server capabilities for initialization
302
  try:
303
- if hasattr(app, "get_capabilities"):
304
  notification_options = NotificationOptions()
305
  experimental_capabilities: dict[str, dict[str, Any]] = {}
306
- server_capabilities = app.get_capabilities(
307
  notification_options=notification_options,
308
  experimental_capabilities=experimental_capabilities,
309
  )
@@ -322,13 +470,13 @@ async def main():
322
  logger.info("MCP server ready")
323
  try:
324
  # Run the server - it will automatically handle the initialization handshake
325
- await app.run(
326
  read_stream=streams[0],
327
  write_stream=streams[1],
328
  initialization_options=init_options,
329
  )
330
  except Exception as run_error:
331
- logger.error(f"Error in app.run(): {run_error}")
332
  raise
333
  except Exception as e:
334
  logging.getLogger("root").setLevel(original_root_level)
 
16
 
17
  # MCP imports
18
  try:
19
+ from mcp.server import Server
20
+ from mcp.types import Tool, TextContent
21
+ import mcp.server.stdio
22
+ # Additional imports needed for server functionality
23
  from mcp import types as mcp_types
24
+ from mcp.types import ImageContent, EmbeddedResource
 
25
  from mcp.server.models import InitializationOptions
26
+ from mcp.server import NotificationOptions
27
  except ImportError:
28
  print("Error: MCP SDK not installed. Install with: pip install mcp", file=sys.stderr)
29
  sys.exit(1)
 
64
  GEMINI_MAX_TOTAL_FILE_SIZE = int(os.environ.get("GEMINI_MAX_TOTAL_FILE_SIZE", "50")) # MB
65
  GEMINI_TEMPERATURE = float(os.environ.get("GEMINI_TEMPERATURE", "0.2"))
66
 
67
+ # Initialize MCP server
68
+ server = Server("gemini-mcp-server")
69
 
70
  def decode_base64_file(content: str, mime_type: str = None) -> bytes:
71
  """Decode base64 encoded file content"""
 
121
 
122
  return gemini_parts
123
 
124
+ @server.list_tools()
125
  async def list_tools() -> list[Tool]:
126
  """List available tools"""
127
  try:
 
163
  },
164
  "required": ["user_prompt"]
165
  }
166
+ ),
167
+ Tool(
168
+ name="transcribe_audio",
169
+ description="Transcribe audio file to text using Gemini AI. Supports various audio formats (WAV, MP3, M4A, etc.).",
170
+ inputSchema={
171
+ "type": "object",
172
+ "properties": {
173
+ "audio_path": {
174
+ "type": "string",
175
+ "description": "Path to audio file to transcribe (required)"
176
+ },
177
+ "language": {
178
+ "type": "string",
179
+ "description": "Language code (optional, defaults to auto-detect)"
180
+ }
181
+ },
182
+ "required": ["audio_path"]
183
+ }
184
+ ),
185
+ Tool(
186
+ name="text_to_speech",
187
+ description="Convert text to speech audio using Gemini AI. Returns path to generated audio file.",
188
+ inputSchema={
189
+ "type": "object",
190
+ "properties": {
191
+ "text": {
192
+ "type": "string",
193
+ "description": "Text to convert to speech (required)"
194
+ },
195
+ "language": {
196
+ "type": "string",
197
+ "description": "Language code (optional, defaults to 'en')"
198
+ },
199
+ "voice": {
200
+ "type": "string",
201
+ "description": "Voice selection (optional)"
202
+ }
203
+ },
204
+ "required": ["text"]
205
+ }
206
  )
207
  ]
208
  return tools
 
210
  logger.error(f"Error in list_tools(): {e}")
211
  raise
212
 
213
+ @server.call_tool()
214
  async def call_tool(name: str, arguments: dict) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
215
  """Handle tool calls"""
216
  logger.info(f"🔵 MCP tool call received: {name}")
 
321
  except Exception as e:
322
  logger.error(f"Error in generate_content: {e}")
323
  return [TextContent(type="text", text=f"Error: {str(e)}")]
324
+ elif name == "transcribe_audio":
325
+ try:
326
+ audio_path = arguments.get("audio_path")
327
+ if not audio_path:
328
+ logger.error("❌ audio_path is required but missing")
329
+ return [TextContent(type="text", text="Error: audio_path is required")]
330
+
331
+ language = arguments.get("language", "auto")
332
+
333
+ # Check if file exists
334
+ if not os.path.exists(audio_path):
335
+ logger.error(f"❌ Audio file not found: {audio_path}")
336
+ return [TextContent(type="text", text=f"Error: Audio file not found: {audio_path}")]
337
+
338
+ # Use Gemini to transcribe audio
339
+ system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts."
340
+ user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises."
341
+
342
+ files = [{"path": os.path.abspath(audio_path)}]
343
+
344
+ try:
345
+ generation_config = {
346
+ "temperature": 0.2,
347
+ "max_output_tokens": GEMINI_MAX_OUTPUT_TOKENS
348
+ }
349
+
350
+ timeout_seconds = min(GEMINI_TIMEOUT / 1000.0, 20.0)
351
+ logger.info(f"🔵 Transcribing audio with Gemini API, timeout={timeout_seconds}s...")
352
+
353
+ gemini_contents = [f"{system_prompt}\n\n{user_prompt}"]
354
+ file_parts = prepare_gemini_files(files)
355
+ for file_part in file_parts:
356
+ gemini_contents.append({
357
+ "inline_data": {
358
+ "mime_type": file_part["mime_type"],
359
+ "data": base64.b64encode(file_part["data"]).decode('utf-8')
360
+ }
361
+ })
362
+
363
+ def transcribe_sync():
364
+ return gemini_client.models.generate_content(
365
+ model=GEMINI_MODEL_LITE,
366
+ contents=gemini_contents,
367
+ config=generation_config,
368
+ )
369
+
370
+ response = await asyncio.wait_for(
371
+ asyncio.to_thread(transcribe_sync),
372
+ timeout=timeout_seconds
373
+ )
374
+
375
+ logger.info(f"✅ Audio transcription completed successfully")
376
+
377
+ if response and hasattr(response, 'text') and response.text:
378
+ return [TextContent(type="text", text=response.text.strip())]
379
+ elif response and hasattr(response, 'candidates') and response.candidates:
380
+ text_parts = []
381
+ for candidate in response.candidates:
382
+ if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
383
+ for part in candidate.content.parts:
384
+ if hasattr(part, 'text'):
385
+ text_parts.append(part.text)
386
+ if text_parts:
387
+ text = ''.join(text_parts).strip()
388
+ return [TextContent(type="text", text=text)]
389
+ else:
390
+ return [TextContent(type="text", text="Error: No text in transcription response")]
391
+ else:
392
+ return [TextContent(type="text", text="Error: No response from transcription")]
393
+
394
+ except asyncio.TimeoutError:
395
+ error_msg = f"Audio transcription timed out"
396
+ logger.error(f"❌ {error_msg}")
397
+ return [TextContent(type="text", text=f"Error: {error_msg}")]
398
+ except Exception as e:
399
+ logger.error(f"❌ Error transcribing audio: {type(e).__name__}: {e}")
400
+ import traceback
401
+ logger.debug(f"Full traceback: {traceback.format_exc()}")
402
+ return [TextContent(type="text", text=f"Error: {str(e)}")]
403
+
404
+ except Exception as e:
405
+ logger.error(f"Error in transcribe_audio: {e}")
406
+ return [TextContent(type="text", text=f"Error: {str(e)}")]
407
+ elif name == "text_to_speech":
408
+ try:
409
+ text = arguments.get("text")
410
+ if not text:
411
+ logger.error("❌ text is required but missing")
412
+ return [TextContent(type="text", text="Error: text is required")]
413
+
414
+ language = arguments.get("language", "en")
415
+
416
+ # Note: Gemini API doesn't directly support TTS audio generation
417
+ # This tool is provided for MCP protocol compliance, but the client
418
+ # should use local TTS models (like maya1) for actual audio generation
419
+ logger.info(f"🔵 TTS request received for text: {text[:50]}...")
420
+ logger.info("ℹ️ Gemini API doesn't support direct TTS. Client should use local TTS model.")
421
+
422
+ # Return a signal that client should handle TTS locally
423
+ # The client will interpret this and use its local TTS model
424
+ return [TextContent(type="text", text="USE_LOCAL_TTS")]
425
+
426
+ except Exception as e:
427
+ logger.error(f"Error in text_to_speech: {e}")
428
+ return [TextContent(type="text", text=f"Error: {str(e)}")]
429
  else:
430
  return [TextContent(type="text", text=f"Unknown tool: {name}")]
431
 
 
438
  logger.info(f"Default Lite Model: {GEMINI_MODEL_LITE}")
439
  logger.info("=" * 60)
440
 
 
 
 
441
  # Keep logging enabled for debugging
442
  original_root_level = logging.getLogger("root").level
443
  logging.getLogger("root").setLevel(logging.INFO)
444
 
445
  try:
446
+ # Use stdio_server from mcp.server.stdio
447
+ from mcp.server.stdio import stdio_server
448
  async with stdio_server() as streams:
449
  # Prepare server capabilities for initialization
450
  try:
451
+ if hasattr(server, "get_capabilities"):
452
  notification_options = NotificationOptions()
453
  experimental_capabilities: dict[str, dict[str, Any]] = {}
454
+ server_capabilities = server.get_capabilities(
455
  notification_options=notification_options,
456
  experimental_capabilities=experimental_capabilities,
457
  )
 
470
  logger.info("MCP server ready")
471
  try:
472
  # Run the server - it will automatically handle the initialization handshake
473
+ await server.run(
474
  read_stream=streams[0],
475
  write_stream=streams[1],
476
  initialization_options=init_options,
477
  )
478
  except Exception as run_error:
479
+ logger.error(f"Error in server.run(): {run_error}")
480
  raise
481
  except Exception as e:
482
  logging.getLogger("root").setLevel(original_root_level)
voice.py CHANGED
@@ -15,26 +15,57 @@ except ImportError:
15
 
16
 
17
  async def transcribe_audio_gemini(audio_path: str) -> str:
18
- """Transcribe audio using Gemini MCP"""
19
  if not MCP_AVAILABLE:
20
  return ""
21
 
22
  try:
23
- audio_path_abs = os.path.abspath(audio_path)
24
- files = [{"path": audio_path_abs}]
 
 
 
 
 
 
 
 
 
 
25
 
26
- system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts."
27
- user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises."
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- result = await call_agent(
30
- user_prompt=user_prompt,
31
- system_prompt=system_prompt,
32
- files=files,
33
- model=config.GEMINI_MODEL_LITE,
34
- temperature=0.2
35
  )
36
 
37
- return result.strip()
 
 
 
 
 
 
 
 
 
38
  except Exception as e:
39
  logger.error(f"Gemini transcription error: {e}")
40
  return ""
@@ -83,24 +114,33 @@ def transcribe_audio(audio):
83
 
84
 
85
  async def generate_speech_mcp(text: str) -> str:
86
- """Generate speech using MCP TTS tool"""
87
  if not MCP_AVAILABLE:
88
  return None
89
 
90
  try:
91
  session = await get_mcp_session()
92
  if session is None:
 
93
  return None
94
 
95
  tools = await get_cached_mcp_tools()
96
  tts_tool = None
97
  for tool in tools:
98
- tool_name_lower = tool.name.lower()
99
- if "tts" in tool_name_lower or "speech" in tool_name_lower or "synthesize" in tool_name_lower:
100
  tts_tool = tool
101
- logger.info(f"Found MCP TTS tool: {tool.name}")
102
  break
103
 
 
 
 
 
 
 
 
 
 
104
  if tts_tool:
105
  result = await session.call_tool(
106
  tts_tool.name,
@@ -110,8 +150,13 @@ async def generate_speech_mcp(text: str) -> str:
110
  if hasattr(result, 'content') and result.content:
111
  for item in result.content:
112
  if hasattr(item, 'text'):
113
- if os.path.exists(item.text):
114
- return item.text
 
 
 
 
 
115
  elif hasattr(item, 'data') and item.data:
116
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
117
  tmp_file.write(item.data)
 
15
 
16
 
17
  async def transcribe_audio_gemini(audio_path: str) -> str:
18
+ """Transcribe audio using Gemini MCP transcribe_audio tool"""
19
  if not MCP_AVAILABLE:
20
  return ""
21
 
22
  try:
23
+ session = await get_mcp_session()
24
+ if session is None:
25
+ logger.warning("MCP session not available for transcription")
26
+ return ""
27
+
28
+ tools = await get_cached_mcp_tools()
29
+ transcribe_tool = None
30
+ for tool in tools:
31
+ if tool.name == "transcribe_audio":
32
+ transcribe_tool = tool
33
+ logger.info(f"Found MCP transcribe_audio tool: {tool.name}")
34
+ break
35
 
36
+ if not transcribe_tool:
37
+ logger.warning("transcribe_audio MCP tool not found, falling back to generate_content")
38
+ # Fallback to using generate_content
39
+ audio_path_abs = os.path.abspath(audio_path)
40
+ files = [{"path": audio_path_abs}]
41
+ system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts."
42
+ user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises."
43
+ result = await call_agent(
44
+ user_prompt=user_prompt,
45
+ system_prompt=system_prompt,
46
+ files=files,
47
+ model=config.GEMINI_MODEL_LITE,
48
+ temperature=0.2
49
+ )
50
+ return result.strip()
51
 
52
+ # Use the transcribe_audio tool
53
+ audio_path_abs = os.path.abspath(audio_path)
54
+ result = await session.call_tool(
55
+ transcribe_tool.name,
56
+ arguments={"audio_path": audio_path_abs}
 
57
  )
58
 
59
+ if hasattr(result, 'content') and result.content:
60
+ for item in result.content:
61
+ if hasattr(item, 'text'):
62
+ transcribed_text = item.text.strip()
63
+ if transcribed_text:
64
+ logger.info(f"✅ Transcribed via MCP transcribe_audio tool: {transcribed_text[:50]}...")
65
+ return transcribed_text
66
+
67
+ logger.warning("MCP transcribe_audio returned empty result")
68
+ return ""
69
  except Exception as e:
70
  logger.error(f"Gemini transcription error: {e}")
71
  return ""
 
114
 
115
 
116
  async def generate_speech_mcp(text: str) -> str:
117
+ """Generate speech using MCP text_to_speech tool"""
118
  if not MCP_AVAILABLE:
119
  return None
120
 
121
  try:
122
  session = await get_mcp_session()
123
  if session is None:
124
+ logger.warning("MCP session not available for TTS")
125
  return None
126
 
127
  tools = await get_cached_mcp_tools()
128
  tts_tool = None
129
  for tool in tools:
130
+ if tool.name == "text_to_speech":
 
131
  tts_tool = tool
132
+ logger.info(f"Found MCP text_to_speech tool: {tool.name}")
133
  break
134
 
135
+ if not tts_tool:
136
+ # Fallback: search for any TTS-related tool
137
+ for tool in tools:
138
+ tool_name_lower = tool.name.lower()
139
+ if "tts" in tool_name_lower or "speech" in tool_name_lower or "synthesize" in tool_name_lower:
140
+ tts_tool = tool
141
+ logger.info(f"Found MCP TTS tool (fallback): {tool.name}")
142
+ break
143
+
144
  if tts_tool:
145
  result = await session.call_tool(
146
  tts_tool.name,
 
150
  if hasattr(result, 'content') and result.content:
151
  for item in result.content:
152
  if hasattr(item, 'text'):
153
+ text_result = item.text
154
+ # Check if it's a signal to use local TTS
155
+ if text_result == "USE_LOCAL_TTS":
156
+ logger.info("MCP TTS tool indicates client-side TTS should be used")
157
+ return None # Return None to trigger client-side TTS
158
+ elif os.path.exists(text_result):
159
+ return text_result
160
  elif hasattr(item, 'data') and item.data:
161
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
162
  tmp_file.write(item.data)