File size: 12,522 Bytes
83d28ce
 
 
 
 
9c8e6cc
83d28ce
 
 
 
9c8e6cc
 
 
 
 
 
 
83d28ce
 
 
 
 
 
 
 
 
 
 
 
 
9c8e6cc
 
 
 
 
 
 
 
 
 
 
83d28ce
9c8e6cc
83d28ce
 
9c8e6cc
 
83d28ce
9c8e6cc
83d28ce
9c8e6cc
 
 
 
 
83d28ce
9c8e6cc
 
 
 
 
 
 
83d28ce
9c8e6cc
83d28ce
 
9c8e6cc
 
83d28ce
9c8e6cc
83d28ce
9c8e6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83d28ce
 
9c8e6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83d28ce
9c8e6cc
 
 
 
 
 
 
 
 
83d28ce
 
 
 
9c8e6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
83d28ce
 
9c8e6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
83d28ce
 
 
 
 
 
 
9c8e6cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import gradio as gr
import os
import tempfile
import logging
import json
import requests # For Gemini API calls

# Import your dispatcher class from the local summarizer_tool.py file
from summarizer_tool import AllInOneDispatcher

# --- Gemini API Configuration ---
# The API key will be automatically provided by the Canvas environment at runtime
# if left as an empty string. DO NOT hardcode your API key here.
GEMINI_API_KEY = "" # Leave as empty string for Canvas environment
GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"


# Configure logging for the Gradio app
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Initialize the dispatcher globally.
# This ensures models are loaded only once when the Gradio app starts up.
# This can take time, especially on CPU.
try:
    dispatcher = AllInOneDispatcher()
    logging.info("AllInOneDispatcher initialized successfully for Gradio app.")
except Exception as e:
    logging.error(f"Failed to initialize AllInOneDispatcher: {e}")
    raise RuntimeError(f"Failed to initialize AI models. Check logs for details: {e}") from e

# --- Helper Function for Gemini API Call ---
def call_gemini_api(prompt: str) -> str:
    """
    Calls the Gemini API with the given prompt and returns the generated text.
    """
    headers = {
        'Content-Type': 'application/json',
    }
    payload = {
        "contents": [{"role": "user", "parts": [{"text": prompt}]}],
    }

    full_api_url = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}" if GEMINI_API_KEY else GEMINI_API_URL

    try:
        response = requests.post(full_api_url, headers=headers, data=json.dumps(payload))
        response.raise_for_status() # Raise an exception for HTTP errors
        
        result = response.json()

        if result.get("candidates") and len(result["candidates"]) > 0 and \
           result["candidates"][0].get("content") and \
           result["candidates"][0]["content"].get("parts") and \
           len(result["candidates"][0]["content"]["parts"]) > 0:
            return result["candidates"][0]["content"]["parts"][0]["text"]
        else:
            return "I couldn't generate a response for that."
    except requests.exceptions.RequestException as e:
        logging.error(f"Gemini API Call Error: {e}")
        return f"An error occurred while connecting to the AI: {e}"
    except json.JSONDecodeError:
        logging.error(f"Gemini API Response Error: Could not decode JSON. Response: {response.text}")
        return "An error occurred while processing the AI's response."
    except Exception as e:
        logging.error(f"An unexpected error occurred during Gemini API call: {e}")
        return f"An unexpected error occurred: {e}"

# --- Main Chat Function for Gradio ---
async def chat_with_ai(message: str, history: list, selected_task: str, uploaded_file):
    """
    Processes user messages, selected tasks, and uploaded files.
    """
    response_text = ""
    file_path = None

    # Handle file upload first, if any
    if uploaded_file is not None:
        file_path = uploaded_file # Gradio passes the path directly for type="filepath"
        logging.info(f"Received file: {file_path} for task: {selected_task}")
        
        # Determine file type for task mapping
        file_extension = os.path.splitext(file_path)[1].lower()
        
        if file_extension in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']:
            if selected_task not in ["Image Classification", "Object Detection"]:
                return "Please select 'Image Classification' or 'Object Detection' for image files."
        elif file_extension in ['.mp3', '.wav', '.ogg', '.flac', '.m4a']:
            if selected_task != "Automatic Speech Recognition":
                return "Please select 'Automatic Speech Recognition' for audio files."
        elif file_extension in ['.mp4', '.mov', '.avi', '.mkv']:
            if selected_task != "Video Analysis":
                return "Please select 'Video Analysis' for video files."
        elif file_extension == '.pdf':
            if selected_task != "PDF Summarization (RAG)":
                return "Please select 'PDF Summarization (RAG)' for PDF files."
        else:
            return f"Unsupported file type: {file_extension}. Please upload a supported file or select 'General Chat'."


    try:
        if selected_task == "General Chat":
            # Use Gemini for general chat
            prompt = f"User: {message}\nAI:"
            response_text = call_gemini_api(prompt)
            return response_text

        elif selected_task == "Summarize Text":
            if not message.strip(): return "Please provide text to summarize."
            result = dispatcher.process(message, task="summarization", max_length=150, min_length=30)
            response_text = f"Here's a summary of your text:\n\n{json.dumps(result, indent=2)}"
            return response_text

        elif selected_task == "Sentiment Analysis":
            if not message.strip(): return "Please provide text for sentiment analysis."
            result = dispatcher.process(message, task="sentiment-analysis")
            response_text = f"The sentiment of your text is: {json.dumps(result, indent=2)}"
            return response_text
        
        elif selected_task == "Text Generation":
            if not message.strip(): return "Please provide a prompt for text generation."
            result = dispatcher.process(message, task="text-generation", max_new_tokens=100, num_return_sequences=1)
            generated_text = result[0]['generated_text'] if result and isinstance(result, list) and result[0].get('generated_text') else str(result)
            response_text = f"Here's the generated text:\n\n{generated_text}"
            return response_text

        elif selected_task == "Text-to-Speech (TTS)":
            if not message.strip(): return "Please provide text for speech generation."
            audio_path = dispatcher.process(message, task="tts", lang="en") # Default to English
            if os.path.exists(audio_path):
                # Gradio ChatInterface can return audio directly
                return (f"Here's the audio for your text:", gr.Audio(audio_path, label="Generated Speech", autoplay=True))
            else:
                return "Failed to generate speech."

        elif selected_task == "Translation (EN to FR)":
            if not message.strip(): return "Please provide text to translate."
            result = dispatcher.process(message, task="translation_en_to_fr")
            translated_text = result[0]['translation_text'] if result and isinstance(result, list) and result[0].get('translation_text') else str(result)
            response_text = f"Here's the English to French translation:\n\n{translated_text}"
            return response_text

        elif selected_task == "Image Classification":
            if not file_path: return "Please upload an image file for classification."
            result = dispatcher.process(file_path, task="image-classification")
            response_text = f"Image Classification Result:\n\n{json.dumps(result, indent=2)}"
            return response_text

        elif selected_task == "Object Detection":
            if not file_path: return "Please upload an image file for object detection."
            result = dispatcher.process(file_path, task="object-detection")
            response_text = f"Object Detection Result:\n\n{json.dumps(result, indent=2)}"
            return response_text

        elif selected_task == "Automatic Speech Recognition":
            if not file_path: return "Please upload an audio file for transcription."
            result = dispatcher.process(file_path, task="automatic-speech-recognition")
            transcription = result.get('text', 'No transcription found.')
            response_text = f"Audio Transcription:\n\n{transcription}"
            return response_text

        elif selected_task == "Video Analysis":
            if not file_path: return "Please upload a video file for analysis."
            result = dispatcher.process(file_path, task="video")
            image_analysis = json.dumps(result.get('image_analysis'), indent=2)
            audio_analysis = json.dumps(result.get('audio_analysis'), indent=2)
            response_text = f"Video Analysis Result:\n\nImage Analysis:\n{image_analysis}\n\nAudio Analysis:\n{audio_analysis}"
            return response_text

        elif selected_task == "PDF Summarization (RAG)":
            if not file_path: return "Please upload a PDF file for summarization."
            result = dispatcher.process(file_path, task="pdf")
            response_text = f"PDF Summary:\n\n{result}"
            return response_text

        elif selected_task == "Process Dataset":
            # This task requires more specific parameters (dataset name, column, etc.)
            # It's not directly compatible with a single chat message input.
            # We'll guide the user to a separate interface for this, or simplify.
            # For now, let's keep it simple: user provides dataset_name, subset, split, column in message.
            # A more robust solution would involve a separate Gradio component for this.
            return "For 'Process Dataset', please use the dedicated 'Dataset Analyzer' tab if it were available, or provide all parameters in your message like: 'dataset: glue, subset: sst2, split: train, column: sentence, task: sentiment-analysis, samples: 2'."
            # Example of parsing:
            # parts = message.split(',')
            # params = {p.split(':')[0].strip(): p.split(':')[1].strip() for p in parts if ':' in p}
            # dataset_name = params.get('dataset')
            # subset_name = params.get('subset', '')
            # split = params.get('split', 'train')
            # column = params.get('column')
            # task_for_dataset = params.get('task')
            # num_samples = int(params.get('samples', 2))
            # if not all([dataset_name, column, task_for_dataset]):
            #     return "Please provide dataset name, column, and task for dataset processing."
            # result = dispatcher.process_dataset_from_hub(dataset_name, subset_name, split, column, task_for_dataset, num_samples)
            # return f"Dataset Processing Results:\n\n{json.dumps(result, indent=2)}"
        
        else:
            return "Please select a valid task from the dropdown."

    except Exception as e:
        logging.error(f"An error occurred in chat_with_ai: {e}")
        return f"An unexpected error occurred during processing: {e}"
    finally:
        # Clean up temporary file if it was uploaded and processed
        if file_path and os.path.exists(file_path):
            # Gradio handles temp file cleanup for gr.File(type="filepath")
            # However, if you manually copy/save, ensure cleanup.
            # For this setup, Gradio should handle it.
            pass


# --- Gradio Interface Definition ---

# Define the choices for the task dropdown
task_choices = [
    "General Chat",
    "Summarize Text",
    "Sentiment Analysis",
    "Text Generation",
    "Text-to-Speech (TTS)",
    "Translation (EN to FR)",
    "Image Classification",
    "Object Detection",
    "Automatic Speech Recognition",
    "Video Analysis",
    "PDF Summarization (RAG)",
    # "Process Dataset" - Removed for now as it needs more complex input than a simple chat
]

# Create the ChatInterface
demo = gr.ChatInterface(
    fn=chat_with_ai,
    textbox=gr.Textbox(placeholder="Ask me anything or provide text/files for analysis...", container=False, scale=7),
    chatbot=gr.Chatbot(height=500),
    # Add a file upload component
    additional_inputs=[
        gr.Dropdown(task_choices, label="Select Task", value="General Chat", container=True),
        gr.File(label="Upload File (Optional)", type="filepath", file_types=[
            ".pdf", ".mp3", ".wav", ".jpg", ".jpeg", ".png", ".mov", ".mp4", ".avi", ".mkv"
        ])
    ],
    title="💬 Multimodal AI Assistant (Chat Interface)",
    description="Interact with various AI models. Select a task and provide your input (text or file)."
)

# --- Launch the Gradio App ---
if __name__ == "__main__":
    # For local testing, use demo.launch()
    # For Hugging Face Spaces, ensure all dependencies are in requirements.txt
    demo.launch(share=True) # share=True creates a public link for easy sharing (temporary)