Spaces:
Sleeping
Sleeping
File size: 12,522 Bytes
83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc 83d28ce 9c8e6cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import gradio as gr
import os
import tempfile
import logging
import json
import requests # For Gemini API calls
# Import your dispatcher class from the local summarizer_tool.py file
from summarizer_tool import AllInOneDispatcher
# --- Gemini API Configuration ---
# The API key will be automatically provided by the Canvas environment at runtime
# if left as an empty string. DO NOT hardcode your API key here.
GEMINI_API_KEY = "" # Leave as empty string for Canvas environment
GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
# Configure logging for the Gradio app
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Initialize the dispatcher globally.
# This ensures models are loaded only once when the Gradio app starts up.
# This can take time, especially on CPU.
try:
dispatcher = AllInOneDispatcher()
logging.info("AllInOneDispatcher initialized successfully for Gradio app.")
except Exception as e:
logging.error(f"Failed to initialize AllInOneDispatcher: {e}")
raise RuntimeError(f"Failed to initialize AI models. Check logs for details: {e}") from e
# --- Helper Function for Gemini API Call ---
def call_gemini_api(prompt: str) -> str:
"""
Calls the Gemini API with the given prompt and returns the generated text.
"""
headers = {
'Content-Type': 'application/json',
}
payload = {
"contents": [{"role": "user", "parts": [{"text": prompt}]}],
}
full_api_url = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}" if GEMINI_API_KEY else GEMINI_API_URL
try:
response = requests.post(full_api_url, headers=headers, data=json.dumps(payload))
response.raise_for_status() # Raise an exception for HTTP errors
result = response.json()
if result.get("candidates") and len(result["candidates"]) > 0 and \
result["candidates"][0].get("content") and \
result["candidates"][0]["content"].get("parts") and \
len(result["candidates"][0]["content"]["parts"]) > 0:
return result["candidates"][0]["content"]["parts"][0]["text"]
else:
return "I couldn't generate a response for that."
except requests.exceptions.RequestException as e:
logging.error(f"Gemini API Call Error: {e}")
return f"An error occurred while connecting to the AI: {e}"
except json.JSONDecodeError:
logging.error(f"Gemini API Response Error: Could not decode JSON. Response: {response.text}")
return "An error occurred while processing the AI's response."
except Exception as e:
logging.error(f"An unexpected error occurred during Gemini API call: {e}")
return f"An unexpected error occurred: {e}"
# --- Main Chat Function for Gradio ---
async def chat_with_ai(message: str, history: list, selected_task: str, uploaded_file):
"""
Processes user messages, selected tasks, and uploaded files.
"""
response_text = ""
file_path = None
# Handle file upload first, if any
if uploaded_file is not None:
file_path = uploaded_file # Gradio passes the path directly for type="filepath"
logging.info(f"Received file: {file_path} for task: {selected_task}")
# Determine file type for task mapping
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']:
if selected_task not in ["Image Classification", "Object Detection"]:
return "Please select 'Image Classification' or 'Object Detection' for image files."
elif file_extension in ['.mp3', '.wav', '.ogg', '.flac', '.m4a']:
if selected_task != "Automatic Speech Recognition":
return "Please select 'Automatic Speech Recognition' for audio files."
elif file_extension in ['.mp4', '.mov', '.avi', '.mkv']:
if selected_task != "Video Analysis":
return "Please select 'Video Analysis' for video files."
elif file_extension == '.pdf':
if selected_task != "PDF Summarization (RAG)":
return "Please select 'PDF Summarization (RAG)' for PDF files."
else:
return f"Unsupported file type: {file_extension}. Please upload a supported file or select 'General Chat'."
try:
if selected_task == "General Chat":
# Use Gemini for general chat
prompt = f"User: {message}\nAI:"
response_text = call_gemini_api(prompt)
return response_text
elif selected_task == "Summarize Text":
if not message.strip(): return "Please provide text to summarize."
result = dispatcher.process(message, task="summarization", max_length=150, min_length=30)
response_text = f"Here's a summary of your text:\n\n{json.dumps(result, indent=2)}"
return response_text
elif selected_task == "Sentiment Analysis":
if not message.strip(): return "Please provide text for sentiment analysis."
result = dispatcher.process(message, task="sentiment-analysis")
response_text = f"The sentiment of your text is: {json.dumps(result, indent=2)}"
return response_text
elif selected_task == "Text Generation":
if not message.strip(): return "Please provide a prompt for text generation."
result = dispatcher.process(message, task="text-generation", max_new_tokens=100, num_return_sequences=1)
generated_text = result[0]['generated_text'] if result and isinstance(result, list) and result[0].get('generated_text') else str(result)
response_text = f"Here's the generated text:\n\n{generated_text}"
return response_text
elif selected_task == "Text-to-Speech (TTS)":
if not message.strip(): return "Please provide text for speech generation."
audio_path = dispatcher.process(message, task="tts", lang="en") # Default to English
if os.path.exists(audio_path):
# Gradio ChatInterface can return audio directly
return (f"Here's the audio for your text:", gr.Audio(audio_path, label="Generated Speech", autoplay=True))
else:
return "Failed to generate speech."
elif selected_task == "Translation (EN to FR)":
if not message.strip(): return "Please provide text to translate."
result = dispatcher.process(message, task="translation_en_to_fr")
translated_text = result[0]['translation_text'] if result and isinstance(result, list) and result[0].get('translation_text') else str(result)
response_text = f"Here's the English to French translation:\n\n{translated_text}"
return response_text
elif selected_task == "Image Classification":
if not file_path: return "Please upload an image file for classification."
result = dispatcher.process(file_path, task="image-classification")
response_text = f"Image Classification Result:\n\n{json.dumps(result, indent=2)}"
return response_text
elif selected_task == "Object Detection":
if not file_path: return "Please upload an image file for object detection."
result = dispatcher.process(file_path, task="object-detection")
response_text = f"Object Detection Result:\n\n{json.dumps(result, indent=2)}"
return response_text
elif selected_task == "Automatic Speech Recognition":
if not file_path: return "Please upload an audio file for transcription."
result = dispatcher.process(file_path, task="automatic-speech-recognition")
transcription = result.get('text', 'No transcription found.')
response_text = f"Audio Transcription:\n\n{transcription}"
return response_text
elif selected_task == "Video Analysis":
if not file_path: return "Please upload a video file for analysis."
result = dispatcher.process(file_path, task="video")
image_analysis = json.dumps(result.get('image_analysis'), indent=2)
audio_analysis = json.dumps(result.get('audio_analysis'), indent=2)
response_text = f"Video Analysis Result:\n\nImage Analysis:\n{image_analysis}\n\nAudio Analysis:\n{audio_analysis}"
return response_text
elif selected_task == "PDF Summarization (RAG)":
if not file_path: return "Please upload a PDF file for summarization."
result = dispatcher.process(file_path, task="pdf")
response_text = f"PDF Summary:\n\n{result}"
return response_text
elif selected_task == "Process Dataset":
# This task requires more specific parameters (dataset name, column, etc.)
# It's not directly compatible with a single chat message input.
# We'll guide the user to a separate interface for this, or simplify.
# For now, let's keep it simple: user provides dataset_name, subset, split, column in message.
# A more robust solution would involve a separate Gradio component for this.
return "For 'Process Dataset', please use the dedicated 'Dataset Analyzer' tab if it were available, or provide all parameters in your message like: 'dataset: glue, subset: sst2, split: train, column: sentence, task: sentiment-analysis, samples: 2'."
# Example of parsing:
# parts = message.split(',')
# params = {p.split(':')[0].strip(): p.split(':')[1].strip() for p in parts if ':' in p}
# dataset_name = params.get('dataset')
# subset_name = params.get('subset', '')
# split = params.get('split', 'train')
# column = params.get('column')
# task_for_dataset = params.get('task')
# num_samples = int(params.get('samples', 2))
# if not all([dataset_name, column, task_for_dataset]):
# return "Please provide dataset name, column, and task for dataset processing."
# result = dispatcher.process_dataset_from_hub(dataset_name, subset_name, split, column, task_for_dataset, num_samples)
# return f"Dataset Processing Results:\n\n{json.dumps(result, indent=2)}"
else:
return "Please select a valid task from the dropdown."
except Exception as e:
logging.error(f"An error occurred in chat_with_ai: {e}")
return f"An unexpected error occurred during processing: {e}"
finally:
# Clean up temporary file if it was uploaded and processed
if file_path and os.path.exists(file_path):
# Gradio handles temp file cleanup for gr.File(type="filepath")
# However, if you manually copy/save, ensure cleanup.
# For this setup, Gradio should handle it.
pass
# --- Gradio Interface Definition ---
# Define the choices for the task dropdown
task_choices = [
"General Chat",
"Summarize Text",
"Sentiment Analysis",
"Text Generation",
"Text-to-Speech (TTS)",
"Translation (EN to FR)",
"Image Classification",
"Object Detection",
"Automatic Speech Recognition",
"Video Analysis",
"PDF Summarization (RAG)",
# "Process Dataset" - Removed for now as it needs more complex input than a simple chat
]
# Create the ChatInterface
demo = gr.ChatInterface(
fn=chat_with_ai,
textbox=gr.Textbox(placeholder="Ask me anything or provide text/files for analysis...", container=False, scale=7),
chatbot=gr.Chatbot(height=500),
# Add a file upload component
additional_inputs=[
gr.Dropdown(task_choices, label="Select Task", value="General Chat", container=True),
gr.File(label="Upload File (Optional)", type="filepath", file_types=[
".pdf", ".mp3", ".wav", ".jpg", ".jpeg", ".png", ".mov", ".mp4", ".avi", ".mkv"
])
],
title="💬 Multimodal AI Assistant (Chat Interface)",
description="Interact with various AI models. Select a task and provide your input (text or file)."
)
# --- Launch the Gradio App ---
if __name__ == "__main__":
# For local testing, use demo.launch()
# For Hugging Face Spaces, ensure all dependencies are in requirements.txt
demo.launch(share=True) # share=True creates a public link for easy sharing (temporary)
|