Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import tempfile | |
| import logging | |
| import json | |
| import requests # For Gemini API calls | |
| # Import your dispatcher class from the local summarizer_tool.py file | |
| from summarizer_tool import AllInOneDispatcher | |
| # --- Gemini API Configuration --- | |
| # The API key will be automatically provided by the Canvas environment at runtime | |
| # if left as an empty string. DO NOT hardcode your API key here. | |
| GEMINI_API_KEY = "" # Leave as empty string for Canvas environment | |
| GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent" | |
| # Configure logging for the Gradio app | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Initialize the dispatcher globally. | |
| # This ensures models are loaded only once when the Gradio app starts up. | |
| # This can take time, especially on CPU. | |
| try: | |
| dispatcher = AllInOneDispatcher() | |
| logging.info("AllInOneDispatcher initialized successfully for Gradio app.") | |
| except Exception as e: | |
| logging.error(f"Failed to initialize AllInOneDispatcher: {e}") | |
| raise RuntimeError(f"Failed to initialize AI models. Check logs for details: {e}") from e | |
| # --- Helper Function for Gemini API Call --- | |
| def call_gemini_api(prompt: str) -> str: | |
| """ | |
| Calls the Gemini API with the given prompt and returns the generated text. | |
| """ | |
| headers = { | |
| 'Content-Type': 'application/json', | |
| } | |
| payload = { | |
| "contents": [{"role": "user", "parts": [{"text": prompt}]}], | |
| } | |
| full_api_url = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}" if GEMINI_API_KEY else GEMINI_API_URL | |
| try: | |
| response = requests.post(full_api_url, headers=headers, data=json.dumps(payload)) | |
| response.raise_for_status() # Raise an exception for HTTP errors | |
| result = response.json() | |
| if result.get("candidates") and len(result["candidates"]) > 0 and \ | |
| result["candidates"][0].get("content") and \ | |
| result["candidates"][0]["content"].get("parts") and \ | |
| len(result["candidates"][0]["content"]["parts"]) > 0: | |
| return result["candidates"][0]["content"]["parts"][0]["text"] | |
| else: | |
| return "I couldn't generate a response for that." | |
| except requests.exceptions.RequestException as e: | |
| logging.error(f"Gemini API Call Error: {e}") | |
| return f"An error occurred while connecting to the AI: {e}" | |
| except json.JSONDecodeError: | |
| logging.error(f"Gemini API Response Error: Could not decode JSON. Response: {response.text}") | |
| return "An error occurred while processing the AI's response." | |
| except Exception as e: | |
| logging.error(f"An unexpected error occurred during Gemini API call: {e}") | |
| return f"An unexpected error occurred: {e}" | |
| # --- Main Chat Function for Gradio --- | |
| async def chat_with_ai(message: str, history: list, selected_task: str, uploaded_file): | |
| """ | |
| Processes user messages, selected tasks, and uploaded files. | |
| """ | |
| response_text = "" | |
| file_path = None | |
| # Handle file upload first, if any | |
| if uploaded_file is not None: | |
| file_path = uploaded_file # Gradio passes the path directly for type="filepath" | |
| logging.info(f"Received file: {file_path} for task: {selected_task}") | |
| # Determine file type for task mapping | |
| file_extension = os.path.splitext(file_path)[1].lower() | |
| if file_extension in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']: | |
| if selected_task not in ["Image Classification", "Object Detection"]: | |
| return "Please select 'Image Classification' or 'Object Detection' for image files." | |
| elif file_extension in ['.mp3', '.wav', '.ogg', '.flac', '.m4a']: | |
| if selected_task != "Automatic Speech Recognition": | |
| return "Please select 'Automatic Speech Recognition' for audio files." | |
| elif file_extension in ['.mp4', '.mov', '.avi', '.mkv']: | |
| if selected_task != "Video Analysis": | |
| return "Please select 'Video Analysis' for video files." | |
| elif file_extension == '.pdf': | |
| if selected_task != "PDF Summarization (RAG)": | |
| return "Please select 'PDF Summarization (RAG)' for PDF files." | |
| else: | |
| return f"Unsupported file type: {file_extension}. Please upload a supported file or select 'General Chat'." | |
| try: | |
| if selected_task == "General Chat": | |
| # Use Gemini for general chat | |
| prompt = f"User: {message}\nAI:" | |
| response_text = call_gemini_api(prompt) | |
| return response_text | |
| elif selected_task == "Summarize Text": | |
| if not message.strip(): return "Please provide text to summarize." | |
| result = dispatcher.process(message, task="summarization", max_length=150, min_length=30) | |
| response_text = f"Here's a summary of your text:\n\n{json.dumps(result, indent=2)}" | |
| return response_text | |
| elif selected_task == "Sentiment Analysis": | |
| if not message.strip(): return "Please provide text for sentiment analysis." | |
| result = dispatcher.process(message, task="sentiment-analysis") | |
| response_text = f"The sentiment of your text is: {json.dumps(result, indent=2)}" | |
| return response_text | |
| elif selected_task == "Text Generation": | |
| if not message.strip(): return "Please provide a prompt for text generation." | |
| result = dispatcher.process(message, task="text-generation", max_new_tokens=100, num_return_sequences=1) | |
| generated_text = result[0]['generated_text'] if result and isinstance(result, list) and result[0].get('generated_text') else str(result) | |
| response_text = f"Here's the generated text:\n\n{generated_text}" | |
| return response_text | |
| elif selected_task == "Text-to-Speech (TTS)": | |
| if not message.strip(): return "Please provide text for speech generation." | |
| audio_path = dispatcher.process(message, task="tts", lang="en") # Default to English | |
| if os.path.exists(audio_path): | |
| # Gradio ChatInterface can return audio directly | |
| return (f"Here's the audio for your text:", gr.Audio(audio_path, label="Generated Speech", autoplay=True)) | |
| else: | |
| return "Failed to generate speech." | |
| elif selected_task == "Translation (EN to FR)": | |
| if not message.strip(): return "Please provide text to translate." | |
| result = dispatcher.process(message, task="translation_en_to_fr") | |
| translated_text = result[0]['translation_text'] if result and isinstance(result, list) and result[0].get('translation_text') else str(result) | |
| response_text = f"Here's the English to French translation:\n\n{translated_text}" | |
| return response_text | |
| elif selected_task == "Image Classification": | |
| if not file_path: return "Please upload an image file for classification." | |
| result = dispatcher.process(file_path, task="image-classification") | |
| response_text = f"Image Classification Result:\n\n{json.dumps(result, indent=2)}" | |
| return response_text | |
| elif selected_task == "Object Detection": | |
| if not file_path: return "Please upload an image file for object detection." | |
| result = dispatcher.process(file_path, task="object-detection") | |
| response_text = f"Object Detection Result:\n\n{json.dumps(result, indent=2)}" | |
| return response_text | |
| elif selected_task == "Automatic Speech Recognition": | |
| if not file_path: return "Please upload an audio file for transcription." | |
| result = dispatcher.process(file_path, task="automatic-speech-recognition") | |
| transcription = result.get('text', 'No transcription found.') | |
| response_text = f"Audio Transcription:\n\n{transcription}" | |
| return response_text | |
| elif selected_task == "Video Analysis": | |
| if not file_path: return "Please upload a video file for analysis." | |
| result = dispatcher.process(file_path, task="video") | |
| image_analysis = json.dumps(result.get('image_analysis'), indent=2) | |
| audio_analysis = json.dumps(result.get('audio_analysis'), indent=2) | |
| response_text = f"Video Analysis Result:\n\nImage Analysis:\n{image_analysis}\n\nAudio Analysis:\n{audio_analysis}" | |
| return response_text | |
| elif selected_task == "PDF Summarization (RAG)": | |
| if not file_path: return "Please upload a PDF file for summarization." | |
| result = dispatcher.process(file_path, task="pdf") | |
| response_text = f"PDF Summary:\n\n{result}" | |
| return response_text | |
| elif selected_task == "Process Dataset": | |
| # This task requires more specific parameters (dataset name, column, etc.) | |
| # It's not directly compatible with a single chat message input. | |
| # We'll guide the user to a separate interface for this, or simplify. | |
| # For now, let's keep it simple: user provides dataset_name, subset, split, column in message. | |
| # A more robust solution would involve a separate Gradio component for this. | |
| return "For 'Process Dataset', please use the dedicated 'Dataset Analyzer' tab if it were available, or provide all parameters in your message like: 'dataset: glue, subset: sst2, split: train, column: sentence, task: sentiment-analysis, samples: 2'." | |
| # Example of parsing: | |
| # parts = message.split(',') | |
| # params = {p.split(':')[0].strip(): p.split(':')[1].strip() for p in parts if ':' in p} | |
| # dataset_name = params.get('dataset') | |
| # subset_name = params.get('subset', '') | |
| # split = params.get('split', 'train') | |
| # column = params.get('column') | |
| # task_for_dataset = params.get('task') | |
| # num_samples = int(params.get('samples', 2)) | |
| # if not all([dataset_name, column, task_for_dataset]): | |
| # return "Please provide dataset name, column, and task for dataset processing." | |
| # result = dispatcher.process_dataset_from_hub(dataset_name, subset_name, split, column, task_for_dataset, num_samples) | |
| # return f"Dataset Processing Results:\n\n{json.dumps(result, indent=2)}" | |
| else: | |
| return "Please select a valid task from the dropdown." | |
| except Exception as e: | |
| logging.error(f"An error occurred in chat_with_ai: {e}") | |
| return f"An unexpected error occurred during processing: {e}" | |
| finally: | |
| # Clean up temporary file if it was uploaded and processed | |
| if file_path and os.path.exists(file_path): | |
| # Gradio handles temp file cleanup for gr.File(type="filepath") | |
| # However, if you manually copy/save, ensure cleanup. | |
| # For this setup, Gradio should handle it. | |
| pass | |
| # --- Gradio Interface Definition --- | |
| # Define the choices for the task dropdown | |
| task_choices = [ | |
| "General Chat", | |
| "Summarize Text", | |
| "Sentiment Analysis", | |
| "Text Generation", | |
| "Text-to-Speech (TTS)", | |
| "Translation (EN to FR)", | |
| "Image Classification", | |
| "Object Detection", | |
| "Automatic Speech Recognition", | |
| "Video Analysis", | |
| "PDF Summarization (RAG)", | |
| # "Process Dataset" - Removed for now as it needs more complex input than a simple chat | |
| ] | |
| # Create the ChatInterface | |
| demo = gr.ChatInterface( | |
| fn=chat_with_ai, | |
| textbox=gr.Textbox(placeholder="Ask me anything or provide text/files for analysis...", container=False, scale=7), | |
| chatbot=gr.Chatbot(height=500), | |
| # Add a file upload component | |
| additional_inputs=[ | |
| gr.Dropdown(task_choices, label="Select Task", value="General Chat", container=True), | |
| gr.File(label="Upload File (Optional)", type="filepath", file_types=[ | |
| ".pdf", ".mp3", ".wav", ".jpg", ".jpeg", ".png", ".mov", ".mp4", ".avi", ".mkv" | |
| ]) | |
| ], | |
| title="💬 Multimodal AI Assistant (Chat Interface)", | |
| description="Interact with various AI models. Select a task and provide your input (text or file)." | |
| ) | |
| # --- Launch the Gradio App --- | |
| if __name__ == "__main__": | |
| # For local testing, use demo.launch() | |
| # For Hugging Face Spaces, ensure all dependencies are in requirements.txt | |
| demo.launch(share=True) # share=True creates a public link for easy sharing (temporary) | |