neyugncol commited on
Commit
d994d22
·
verified ·
1 Parent(s): bbc095d

First commit

Browse files
Files changed (12) hide show
  1. .gitignore +2 -0
  2. README.md +45 -14
  3. agent.py +82 -0
  4. app.py +85 -0
  5. configs.py +26 -0
  6. embeder.py +39 -0
  7. prompt.py +13 -0
  8. rag.py +273 -0
  9. requirements.txt +8 -0
  10. tools.py +122 -0
  11. transcriber.py +134 -0
  12. utils.py +181 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .env
2
+ data/*
README.md CHANGED
@@ -1,14 +1,45 @@
1
- ---
2
- title: Video Chatbot
3
- emoji: 🐢
4
- colorFrom: indigo
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.35.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: A chatbot that can answer questions about a video.
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chatbot for Video Question Answering Demo
2
+
3
+ AI chatbot that can answer questions about video content. This project leverages multi-modal LLM, multi-modal RAG pipeline to process video frames, transcribe audio, and retrieval information to provide accurate answers to questions about video content.
4
+
5
+ ## Requirements
6
+
7
+ - Python 3.12+
8
+ - [uv](https://docs.astral.sh/uv/) for package and project manager
9
+ - [FFmpeg](https://ffmpeg.org/) installed and available in PATH
10
+ - [Google Gemini API key](https://aistudio.google.com/apikey) for the LLM functionality
11
+
12
+ ## Installation
13
+
14
+ 1. Clone this repository
15
+ ```bash
16
+ git clone [repository-url]
17
+ cd VideoChatbot
18
+ ```
19
+
20
+ 2. Install dependencies using uv
21
+ ```bash
22
+ uv sync
23
+ ```
24
+
25
+ 3. Create a `.env` file in the project root with your API key
26
+ ```
27
+ GEMINI_API_KEY=your_api_key_here
28
+ ```
29
+
30
+ ## Usage
31
+
32
+ 1. Start the application
33
+ ```bash
34
+ python -m app.main
35
+ ```
36
+
37
+ 2. Access the UI through your browser (typically at http://127.0.0.1:7860)
38
+
39
+ 3. Upload a video file or provide a YouTube URL and ask questions about it
40
+
41
+ 4. The system will process the video (extract frames, transcribe audio), index the content, and then answer your questions
42
+
43
+ ## Notes
44
+
45
+ This project is designed to be a demo and may require additional configuration for production use. The video processing and indexing can take time depending on the video length and complexity. Use a larger LLMs, embeddings, transcription models, and vector databases for better performance and accuracy.
agent.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Generator
3
+
4
+ from smolagents import ToolCallingAgent, OpenAIServerModel, ActionStep
5
+ from PIL import Image
6
+
7
+ import tools
8
+ from configs import settings
9
+ from prompt import video_to_text_prompt
10
+ from rag import VideoRAG
11
+
12
+
13
+ class VideoChatbot:
14
+ def __init__(
15
+ self,
16
+ model: str = 'gemini-2.0-flash',
17
+ api_base: str = None,
18
+ api_key: str = None
19
+ ):
20
+ self.video_rag = VideoRAG(
21
+ video_frame_rate=settings.VIDEO_EXTRACTION_FRAME_RATE,
22
+ audio_segment_length=settings.AUDIO_SEGMENT_LENGTH,
23
+ )
24
+ self.agent = ToolCallingAgent(
25
+ tools=[
26
+ tools.download_video,
27
+ *tools.create_video_rag_tools(self.video_rag)
28
+ ],
29
+ model=OpenAIServerModel(
30
+ model_id=model,
31
+ api_base=api_base,
32
+ api_key=api_key
33
+ ),
34
+ step_callbacks=[self._step_callback],
35
+ )
36
+
37
+ def chat(self, message: str, attachments: list[str] = None) -> Generator:
38
+ """Chats with the bot, including handling attachments (images and videos).
39
+
40
+ Args:
41
+ message: The text message to send to the bot.
42
+ attachments: A list of file paths for images or videos to include in the chat.
43
+
44
+ Returns:
45
+ A generator yielding step objects representing the bot's responses and actions.
46
+ """
47
+
48
+ images = []
49
+ for filepath in attachments or []:
50
+ if filepath.endswith(('.jpg', '.jpeg', '.png')):
51
+ images.append(Image.open(filepath))
52
+ if filepath.endswith('.mp4'):
53
+ message = video_to_text_prompt(filepath) + message
54
+
55
+ for step in self.agent.run(
56
+ message,
57
+ stream=True,
58
+ reset=False,
59
+ images=images,
60
+ ):
61
+ yield step
62
+
63
+ def clear(self):
64
+ """Clears the chatbot message history and context."""
65
+ self.agent.state.clear()
66
+ self.agent.memory.reset()
67
+ self.agent.monitor.reset()
68
+ self.video_rag.clear()
69
+
70
+ def _step_callback(self, step: ActionStep, agent: ToolCallingAgent):
71
+ if step.observations:
72
+ image_index = 0
73
+ for image_path in re.findall(r'<observation_image>(.*?)</observation_image>', step.observations):
74
+ try:
75
+ image = Image.open(image_path)
76
+ step.observations_images.append(image)
77
+ step.observations = step.observations.replace(image_path, str(image_index))
78
+ image_index += 1
79
+ except Exception as e:
80
+ print(f'Error loading image {image_path}: {e}')
81
+
82
+
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ import gradio as gr
5
+ from smolagents import ChatMessageToolCall, ActionStep, FinalAnswerStep
6
+
7
+ from agent import VideoChatbot
8
+ from configs import settings
9
+
10
+
11
+ bot = VideoChatbot(
12
+ model=settings.CHATBOT_MODEL,
13
+ api_base=settings.MODEL_BASE_API,
14
+ api_key=os.environ['GEMINI_API_KEY']
15
+ )
16
+
17
+
18
+ def chat(message: dict, history: list[dict]):
19
+
20
+ # move the file to the data directory
21
+ message['files'] = [shutil.copy(file, settings.DATA_DIR) for file in message['files']]
22
+
23
+ # add the input message to the history
24
+ history.extend([{'role': 'user', 'content': {'path': file}} for file in message['files']])
25
+ history.append({'role': 'user', 'content': message['text']})
26
+ yield history, ''
27
+
28
+ for step in bot.chat(message['text'], message['files']):
29
+ match step:
30
+ case ChatMessageToolCall():
31
+ if step.function.name == 'download_video':
32
+ history.append({
33
+ 'role': 'assistant',
34
+ 'content': f'📥 Downloading video from {step.function.arguments['url']}'
35
+ })
36
+ elif step.function.name == 'add_video':
37
+ history.append({
38
+ 'role': 'assistant',
39
+ 'content': f'🎥 Processing and adding video `{step.function.arguments["filename"]}` '
40
+ f'to the knowledge base. This may take a while...'
41
+ })
42
+ elif step.function.name == 'search_in_video':
43
+ filename = os.path.basename(bot.video_rag.videos[step.function.arguments["video_id"]]['video_path'])
44
+ history.append({
45
+ 'role': 'assistant',
46
+ 'content': f'🔍 Searching in video `{filename}` '
47
+ f'for query: *{step.function.arguments.get("text_query", step.function.arguments.get("image_query", ""))}*'
48
+ })
49
+ elif step.function.name == 'final_answer':
50
+ continue
51
+ yield history, ''
52
+ case ActionStep():
53
+ yield history, ''
54
+ case FinalAnswerStep():
55
+ history.append({'role': 'assistant', 'content': step.output})
56
+ yield history, ''
57
+
58
+
59
+ def clear_chat(chatbot):
60
+ chatbot.clear()
61
+ return chatbot, gr.update(value='')
62
+
63
+
64
+ def main():
65
+ with gr.Blocks() as demo:
66
+ gr.Markdown('# Video Chatbot Demo')
67
+ gr.Markdown('This demo showcases a video chatbot that can process and search videos using '
68
+ 'RAG (Retrieval-Augmented Generation). You can upload videos/images or link to YouTube videos, '
69
+ 'ask questions, and get answers based on the video content.')
70
+ chatbot = gr.Chatbot(type='messages', label='Video Chatbot', height=800, resizable=True)
71
+ textbox = gr.MultimodalTextbox(
72
+ sources=['upload'],
73
+ file_types=['image', '.mp4'],
74
+ show_label=False,
75
+ placeholder='Type a message or upload an image/video...',
76
+
77
+ )
78
+ textbox.submit(chat, [textbox, chatbot], [chatbot, textbox])
79
+ clear = gr.Button('Clear Chat')
80
+ clear.click(clear_chat, [chatbot], [chatbot, textbox])
81
+
82
+ demo.launch(debug=True)
83
+
84
+ if __name__ == '__main__':
85
+ main()
configs.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+
8
+ @dataclass
9
+ class Settings:
10
+ DATA_DIR: str = 'data'
11
+ FFMPEG_PATH: str = 'ffmpeg'
12
+ MAX_VIDEO_RESOLUTION: int = 360
13
+ MAX_VIDEO_FPS: float = 30
14
+ VIDEO_EXTENSION: str = 'mp4'
15
+ VIDEO_EXTRACTION_FRAME_RATE: float = 1.0
16
+ AUDIO_SEGMENT_LENGTH: int = 300
17
+ CHATBOT_MODEL: str = 'gemini-2.0-flash'
18
+ MODEL_BASE_API: str = 'https://generativelanguage.googleapis.com/v1beta/'
19
+ TEXT_EMBEDDING_MODEL: str = 'sentence-transformers/all-MiniLM-L6-v2'
20
+ IMAGE_EMBEDDING_MODEL: str = 'facebook/dinov2-small'
21
+
22
+
23
+
24
+
25
+ settings = Settings()
26
+
embeder.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from sentence_transformers import SentenceTransformer
4
+ from transformers import pipeline
5
+ from PIL import Image
6
+
7
+
8
+ class MultimodalEmbedder:
9
+ """A multimodal embedder that supports text and image embeddings."""
10
+ def __init__(
11
+ self,
12
+ text_model: str = 'sentence-transformers/all-MiniLM-L6-v2',
13
+ image_model: str = 'facebook/dinov2-small'
14
+ ):
15
+ self.text_model = SentenceTransformer(text_model)
16
+ self.image_model = pipeline(
17
+ 'image-feature-extraction',
18
+ model=image_model,
19
+ device=0 if torch.cuda.is_available() else -1,
20
+ pool=True
21
+ )
22
+
23
+ def embed_texts(self, texts: list[str]) -> list[list[float]]:
24
+ """Embed a list of texts"""
25
+ return self.text_model.encode(
26
+ texts,
27
+ device='cuda' if torch.cuda.is_available() else 'cpu',
28
+ show_progress_bar=True
29
+ ).tolist()
30
+
31
+ def embed_images(self, images: list[str | Image.Image]) -> list[list[float]]:
32
+ """Embed a list of images, which can be file paths or PIL Image objects."""
33
+ images = [Image.open(img) if isinstance(img, str) else img for img in images]
34
+ images = [img.convert('RGB') for img in images]
35
+
36
+ embeddings = self.image_model(images)
37
+
38
+ return [emb[0] for emb in embeddings]
39
+
prompt.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def video_to_text_prompt(video_path: str, metadata: dict = None) -> str:
5
+ """Generate a text prompt to represent a video file with its metadata."""
6
+ metadata = metadata or {}
7
+ return f'''<video>
8
+ Filename: {os.path.basename(video_path)}
9
+ Metadata:
10
+ {'\n'.join(f'- {key}: {value}' for key, value in metadata.items())}
11
+ </video>
12
+ '''
13
+
rag.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import uuid
3
+
4
+ import lancedb
5
+ import pyarrow as pa
6
+ from PIL import Image
7
+ from scipy.spatial import distance
8
+ from tqdm import tqdm
9
+
10
+ import utils
11
+ from configs import settings
12
+ from embeder import MultimodalEmbedder
13
+ from transcriber import AudioTranscriber
14
+
15
+
16
+ class VideoRAG:
17
+ """Video RAG (Retrieval-Augmented Generation) system for processing and searching video content."""
18
+
19
+ def __init__(self, video_frame_rate: float = 1, audio_segment_length: int = 300):
20
+ self.video_frame_rate = video_frame_rate
21
+ self.audio_segment_length = audio_segment_length
22
+
23
+ print('Loading embedding and audio transcription models...')
24
+ self.embedder = MultimodalEmbedder(
25
+ text_model=settings.TEXT_EMBEDDING_MODEL,
26
+ image_model=settings.IMAGE_EMBEDDING_MODEL,
27
+ )
28
+ self.transcriber = AudioTranscriber()
29
+
30
+ # init DB and tables
31
+ self._init_db()
32
+
33
+ def _init_db(self):
34
+ print('Initializing LanceDB...')
35
+ self.db = lancedb.connect(f'{settings.DATA_DIR}/vectordb')
36
+ self.frames_table = self.db.create_table('frames', mode='overwrite', schema=pa.schema([
37
+ pa.field('vector', pa.list_(pa.float32(), 384)),
38
+ pa.field('video_id', pa.string()),
39
+ pa.field('frame_index', pa.int32()),
40
+ pa.field('frame_path', pa.string()),
41
+ ]))
42
+ self.transcripts_table = self.db.create_table('transcripts', mode='overwrite', schema=pa.schema([
43
+ pa.field('vector', pa.list_(pa.float32(), 384)),
44
+ pa.field('video_id', pa.string()),
45
+ pa.field('segment_index', pa.int32()),
46
+ pa.field('start', pa.float64()),
47
+ pa.field('end', pa.float64()),
48
+ pa.field('text', pa.string()),
49
+ ]))
50
+
51
+ # save video metadata
52
+ self.videos = {}
53
+
54
+ def is_video_exists(self, video_id: str) -> bool:
55
+ """Check if a video exists in the RAG system by video ID.
56
+
57
+ Args:
58
+ video_id (str): The ID of the video to check.
59
+
60
+ Returns:
61
+ bool: True if the video exists, False otherwise.
62
+ """
63
+ return video_id in self.videos
64
+
65
+ def get_video(self, video_id: str) -> dict:
66
+ """Retrieve video metadata by video ID.
67
+
68
+ Args:
69
+ video_id (str): The ID of the video to retrieve.
70
+
71
+ Returns:
72
+ dict: A dictionary containing video metadata, including video path, frame directory, frame rate, and transcript segments.
73
+ """
74
+ if video_id not in self.videos:
75
+ raise ValueError(f'Video with ID {video_id} not found.')
76
+ return self.videos[video_id]
77
+
78
+ def add_video(self, video_path: str) -> str:
79
+ """Add a video to the RAG system by processing its frames and transcripts.
80
+
81
+ Args:
82
+ video_path (str): The path to the video file to be added.
83
+
84
+ Returns:
85
+ str: A unique video ID generated for the added video.
86
+ """
87
+ # create a unique video ID
88
+ video_id = uuid.uuid4().hex[:8]
89
+
90
+ print(f'Adding video "{video_path}" with ID {video_id} to the RAG system...')
91
+
92
+ print('Extracting video frames')
93
+ # process video frames
94
+ frame_paths = utils.extract_video_frames(video_path, output_dir=f'{video_path}_frames',
95
+ frame_rate=self.video_frame_rate)
96
+ print(f'Computing embeddings for {len(frame_paths)} frames...')
97
+ # calculate embeddings for frames
98
+ frame_embeddings = self.embedder.embed_images(frame_paths)
99
+ # get significant frames to reduce the number of frames
100
+ frame_indexes = get_significant_frames(frame_embeddings, threshold=0.6)
101
+ # add frames to the database
102
+ self.frames_table.add(
103
+ [{
104
+ 'vector': frame_embeddings[i],
105
+ 'video_id': video_id,
106
+ 'frame_index': i,
107
+ 'frame_path': frame_paths[i],
108
+ } for i in frame_indexes]
109
+ )
110
+ print(f'Added {len(frame_indexes)} significant frames to the database.')
111
+
112
+ print('Extracting audio from video')
113
+ # transcribe video to text
114
+ audio_path = utils.extract_audio(video_path)
115
+ print(f'Splitting and transcribing audio...')
116
+ segments = []
117
+ for i, segment_path in tqdm(enumerate(utils.split_media_file(
118
+ audio_path,
119
+ output_dir=f'{video_path}_audio_segments',
120
+ segment_length=self.audio_segment_length
121
+ )), desc='Transcribing audio'):
122
+ for segment in self.transcriber.transcribe(segment_path)['segments']:
123
+ segment['start'] += i * self.audio_segment_length
124
+ segment['end'] += i * self.audio_segment_length
125
+ segments.append(segment)
126
+ segments = sorted(segments, key=lambda s: s['start'])
127
+
128
+ print(f'Computing embeddings for {len(segments)} transcript segments...')
129
+ # calculate embeddings for transcripts
130
+ transcript_embeddings = self.embedder.embed_texts([s['text'] for s in segments])
131
+ # add transcripts to the database
132
+ self.transcripts_table.add(
133
+ [{
134
+ 'vector': transcript_embeddings[i],
135
+ 'video_id': video_id,
136
+ 'segment_index': i,
137
+ 'start': segment['start'],
138
+ 'end': segment['end'],
139
+ 'text': segment['text'],
140
+ } for i, segment in enumerate(segments)],
141
+ )
142
+ print(f'Added {len(segments)} transcript segments to the database.')
143
+
144
+ # add video metadata to the database
145
+ self.videos[video_id] = {
146
+ 'video_path': video_path,
147
+ 'frame_dir': f'{video_path}_frames',
148
+ 'video_frame_rate': self.video_frame_rate,
149
+ 'transcript_segments': segments,
150
+ }
151
+
152
+ print(f'Video "{video_path}" added with ID {video_id}.')
153
+ return video_id
154
+
155
+ def search(self, video_id: str, text: str = None, image: str | Image.Image = None, limit: int = 10) -> list[dict]:
156
+ """Search for relevant video frames or transcripts based on text or image input.
157
+
158
+ Args:
159
+ video_id (str): The ID of the video to search in.
160
+ text (str, optional): The text query to search for in the video transcripts.
161
+ image (str | Image.Image, optional): The image query to search for in the video frames. If a string is provided, it should be the path to the image file.
162
+ limit (int, optional): The maximum number of results to return. Defaults to 10.
163
+
164
+ Returns:
165
+ list[dict]: A list of dictionaries containing the search results, each with start and end times, distance, frame paths, and transcript segments.
166
+ """
167
+
168
+ video_metadata = self.get_video(video_id)
169
+
170
+ # search for transcripts based on text
171
+ timespans = []
172
+ if text is not None:
173
+ text_embedding = self.embedder.embed_texts([text])[0]
174
+ query = (self.transcripts_table
175
+ .search(text_embedding)
176
+ .where(f'video_id = \'{video_id}\'')
177
+ .limit(limit))
178
+ for result in query.to_list():
179
+ timespans.append({
180
+ 'start': result['start'],
181
+ 'end': result['end'],
182
+ 'distance': distance.cosine(text_embedding, result['vector']),
183
+ })
184
+
185
+ # search for frames based on image
186
+ if image is not None:
187
+ image_embedding = self.embedder.embed_images([image])[0]
188
+ query = (self.frames_table
189
+ .search(image_embedding)
190
+ .where(f'video_id = \'{video_id}\'')
191
+ .limit(limit))
192
+ for result in query.to_list():
193
+ start = result['frame_index'] / self.video_frame_rate
194
+ timespans.append({
195
+ 'start': start,
196
+ 'end': start + 1,
197
+ 'distance': distance.cosine(image_embedding, result['vector']), # Fix lancedb return large distance
198
+ })
199
+
200
+ # merge nearby timespans
201
+ timespans = merge_searched_timespans(timespans, threshold=5)
202
+ # sort timespans by distance
203
+ timespans = sorted(timespans, key=lambda x: x['distance'])
204
+ # limit to k results
205
+ timespans = timespans[:limit]
206
+
207
+ for timespan in timespans:
208
+ # extend timespans to at least 5 seconds
209
+ duration = timespan['end'] - timespan['start']
210
+ if duration < 5:
211
+ timespan['start'] = max(0, timespan['start'] - (5 - duration) / 2)
212
+ timespan['end'] = timespan['start'] + 5
213
+ # add frame paths
214
+ timespan['frame_paths'] = []
215
+ for frame_index in range(
216
+ int(timespan['start'] * self.video_frame_rate),
217
+ int(timespan['end'] * self.video_frame_rate)
218
+ ):
219
+ timespan['frame_paths'].append(os.path.join(video_metadata['frame_dir'], f'{frame_index + 1}.jpg'))
220
+ # add transcript segments
221
+ timespan['transcript_segments'] = []
222
+ for segment in video_metadata['transcript_segments']:
223
+ if utils.span_iou((segment['start'], segment['end']),
224
+ (timespan['start'], timespan['end'])) > 0:
225
+ timespan['transcript_segments'].append(segment)
226
+
227
+ return timespans
228
+
229
+ def clear(self):
230
+ """Clear the RAG system by dropping all tables and resetting video metadata."""
231
+ self._init_db()
232
+
233
+
234
+ def get_significant_frames(frame_embeddings: list[list[float]], threshold: float = 0.8) -> list[int]:
235
+ """Select significant frames by comparing embeddings."""
236
+ selected_frames = []
237
+ current_frame = 0
238
+ for i, embedding in enumerate(frame_embeddings):
239
+ similarity = 1 - distance.cosine(frame_embeddings[current_frame], embedding)
240
+ if similarity < threshold:
241
+ selected_frames.append(current_frame)
242
+ current_frame = i
243
+
244
+ selected_frames.append(current_frame)
245
+
246
+ return selected_frames
247
+
248
+
249
+ def merge_searched_timespans(timespans: list[dict], threshold: float) -> list[dict]:
250
+ """Merge timespans if the gap between them is less than or equal to threshold."""
251
+ if not timespans:
252
+ return []
253
+
254
+ # Sort spans by start time
255
+ sorted_spans = sorted(timespans, key=lambda s: s['start'])
256
+
257
+ merged_spans = []
258
+ current_span = sorted_spans[0].copy()
259
+
260
+ for next_span in sorted_spans[1:]:
261
+ gap = next_span['start'] - current_span['end']
262
+ if gap <= threshold:
263
+ # Extend the current span’s end if needed
264
+ current_span['end'] = max(current_span['end'], next_span['end'])
265
+ current_span['distance'] = min(current_span['distance'], next_span['distance'])
266
+ else:
267
+ # No merge push current and start a new one
268
+ merged_spans.append(current_span)
269
+ current_span = next_span.copy()
270
+
271
+ # Add the last span
272
+ merged_spans.append(current_span)
273
+ return merged_spans
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ google-genai>=1.22.0
2
+ lancedb>=0.24.0
3
+ pillow>=10.4.0
4
+ sentence-transformers>=4.1.0
5
+ smolagents[openai]>=1.19.0
6
+ tqdm>=4.67.1
7
+ transformers>=4.53.0
8
+ yt-dlp>=2025.6.25
tools.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from smolagents import tool, Tool
4
+
5
+ import utils
6
+ from configs import settings
7
+ from prompt import video_to_text_prompt
8
+ from rag import VideoRAG
9
+
10
+
11
+ @tool
12
+ def download_video(url: str) -> str:
13
+ """
14
+ Download a video from YouTube or other supported platforms.
15
+
16
+ Args:
17
+ url (str): The URL of the video.
18
+
19
+ Returns:
20
+ str: The video information, including the filename.
21
+ """
22
+
23
+ try:
24
+ filepath, info = utils.download_video(
25
+ url,
26
+ output_dir=settings.DATA_DIR,
27
+ max_resolution=settings.MAX_VIDEO_RESOLUTION,
28
+ max_fps=settings.MAX_VIDEO_FPS,
29
+ extension=settings.VIDEO_EXTENSION
30
+ )
31
+ except Exception as e:
32
+ return f'Error downloading video: {e.__class__.__name__}: {e}'
33
+
34
+ return video_to_text_prompt(
35
+ filepath,
36
+ metadata={
37
+ 'URL': url,
38
+ 'Title': info.get('title', 'N/A'),
39
+ 'Channel': info.get('channel', 'N/A'),
40
+ 'Duration': info.get('duration', 'N/A'),
41
+ }
42
+ )
43
+
44
+
45
+ def create_video_rag_tools(video_rag: VideoRAG) -> list[Tool]:
46
+
47
+ @tool
48
+ def add_video(filename: str) -> str:
49
+ """
50
+ Add a video file to the RAG knowledge-base for further search and analysis.
51
+
52
+ Args:
53
+ filename (str): The video filename to add.
54
+
55
+ Returns:
56
+ str: The video ID if added successfully, or an error message.
57
+ """
58
+ try:
59
+ video_id = video_rag.add_video(os.path.join(settings.DATA_DIR, filename))
60
+ return f'Video added with ID: {video_id}'
61
+ except Exception as e:
62
+ return f'Error adding video: {e.__class__.__name__}: {e}'
63
+
64
+
65
+ @tool
66
+ def search_in_video(video_id: str, text_query: str = None, image_query: str = None) -> str:
67
+ """
68
+ Search for relevant video frames and transcripts based on text or image query. Allows searching within a specific video added to the RAG knowledge-base.
69
+ At least one of `text_query` or `image_query` must be provided.
70
+
71
+ Args:
72
+ video_id (str): The ID of the video to search in. This should be the ID returned by `add_video`.
73
+ text_query (str, optional): The text query to search for in the video transcripts.
74
+ image_query (str, optional): The image query to search for in the video frames. This is the filename of the image.
75
+
76
+ Returns:
77
+ str: A message indicating the search results or an error message if the video is not found.
78
+ """
79
+
80
+ if not video_rag.is_video_exists(video_id):
81
+ return f'Video with ID "{video_id}" not found in the knowledge-base. Please add the video first using `add_video` tool.'
82
+ if not text_query and not image_query:
83
+ return 'Please provide at least one of `text_query` or `image_query` to search in the video.'
84
+
85
+ try:
86
+ results = video_rag.search(
87
+ video_id=video_id,
88
+ text=text_query,
89
+ image=image_query,
90
+ limit=5
91
+ )
92
+ except Exception as e:
93
+ return f'Error searching in video: {e.__class__.__name__}: {e}'
94
+
95
+ if not results:
96
+ return f'No results found for the given query in video ID {video_id}.'
97
+
98
+ # build the output message
99
+ output = f'Search results for video ID {video_id}:\n'
100
+ for result in results:
101
+ # include timespans, transcript segments, and frame paths in the output
102
+ timespan_text = f'{utils.seconds_to_hms(int(result['start']))} - {utils.seconds_to_hms(int(result['end']))}'
103
+ transcript_texts = []
104
+ for segment in result['transcript_segments']:
105
+ transcript_texts.append(
106
+ f'- {utils.seconds_to_hms(int(segment['start']), drop_hours=True)}'
107
+ f'-{utils.seconds_to_hms(int(segment['end']), drop_hours=True)}: {segment['text']}')
108
+ observation_image_texts = []
109
+ for frame_path in result['frame_paths'][::5]: # take every 5th frame for brevity
110
+ observation_image_texts.append(f'<observation_image>{frame_path}</observation_image>')
111
+
112
+ output += f'''<video_segment>
113
+ Timespan: {timespan_text}
114
+ Transcript:
115
+ {'\n'.join(transcript_texts)}
116
+ Frame images: {' '.join(observation_image_texts)}
117
+ </video_segment>\n'''
118
+
119
+ return output
120
+
121
+ return [add_video, search_in_video]
122
+
transcriber.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Any
3
+
4
+ from google import genai
5
+ from google.genai import types
6
+
7
+ class AudioTranscriber:
8
+ """A class to transcribe audio files"""
9
+
10
+ SYSTEM_INSTRUCTION = '''You are an advanced audio transcription model. Your task is to accurately transcribe provided audio input into a structured JSON format.
11
+
12
+ **Output Format Specification:**
13
+
14
+ Your response MUST be a valid JSON object with the following structure:
15
+
16
+ ```json
17
+ {
18
+ "segments": [
19
+ {
20
+ "text": "The transcribed text for the segment.",
21
+ "start": "The start time of the segment in seconds.",
22
+ "end": "The end time of the segment in seconds.",
23
+ "speaker": "The speaker ID for the segment."
24
+ }
25
+ ],
26
+ "language": "The language of the transcribed text in ISO 639-1 format."
27
+ }
28
+ ```
29
+
30
+ **Detailed Instructions and Rules:**
31
+
32
+ 1. Segments:
33
+ - A "segment" is defined as a continuous section of speech from a single speaker include multiple sentences or phrases.
34
+ - Each segment object MUST contain `text`, `start`, `end`, and `speaker` fields.
35
+ - `text`: The verbatim transcription of the speech within that segment.
36
+ - `start`: The precise start time of the segment in seconds, represented as a floating-point number (e.g., 0.0, 5.25).
37
+ - `end`: The precise end time of the segment in seconds, represented as a floating-point number (e.g., 4.9, 10.12).
38
+ - `speaker`: An integer representing the speaker ID.
39
+ + Speaker IDs start at `0` for the first detected speaker.
40
+ + The speaker ID MUST increment by 1 each time a new, distinct speaker is identified in the audio. Do not reuse speaker IDs within the same transcription.
41
+ + If the same speaker talks again after another speaker, they retain their original speaker ID.
42
+ + **Segment Splitting Rule**: A segment for the same speaker should only be split if there is a period of silence lasting more than 5 seconds. Otherwise, continuous speech from the same speaker, even with short pauses, should remain within a single segment.
43
+
44
+ 2. Language:
45
+ - `language`: A two-letter ISO 639-1 code representing the primary language of the transcribed text (e.g., "en" for English, "es" for Spanish, "fr" for French).
46
+ - If multiple languages are detected in the audio, you MUST select and output only the ISO 639-1 code for the primary language used throughout the audio.
47
+ '''
48
+
49
+ RESPONSE_SCHEMA = {
50
+ 'type': 'object',
51
+ 'properties': {
52
+ 'segments': {
53
+ 'type': 'array',
54
+ "description": 'A list of transcribed segments from the audio file.',
55
+ 'items': {
56
+ 'type': 'object',
57
+ 'properties': {
58
+ 'text': {
59
+ 'type': 'string',
60
+ 'description': 'The transcribed text for the segment.'
61
+ },
62
+ 'start': {
63
+ 'type': 'number',
64
+ 'description': 'The start time of the segment in seconds.'
65
+ },
66
+ 'end': {
67
+ 'type': 'number',
68
+ 'description': 'The end time of the segment in seconds.'
69
+ },
70
+ 'speaker': {
71
+ 'type': 'integer',
72
+ 'description': 'The speaker ID for the segment.'
73
+ }
74
+ },
75
+ 'required': ['text', 'start', 'end', 'speaker'],
76
+ 'propertyOrdering': ['text', 'start', 'end', 'speaker']
77
+ },
78
+ },
79
+ 'language': {
80
+ 'type': 'string',
81
+ 'description': 'The language of the transcribed text in ISO 639-1 format.',
82
+ }
83
+ },
84
+ 'required': ['segments', 'language'],
85
+ 'propertyOrdering': ['segments', 'language']
86
+ }
87
+
88
+ def __init__(self, model: str = 'gemini-2.0-flash', api_key: str = None):
89
+ self.model = model
90
+ self.client = genai.Client(api_key=api_key)
91
+
92
+ def transcribe(self, audio_path: str) -> dict[str, Any]:
93
+ """Transcribe an audio file from the given path.
94
+
95
+ Args:
96
+ audio_path (str): The path to the audio file to be transcribed.
97
+
98
+ Returns:
99
+ dict[str, Any]: The transcription result.
100
+ ```{
101
+ "segments": [
102
+ {
103
+ "text": "Transcribed text",
104
+ "start": 0.0,
105
+ "end": 5.0,
106
+ "speaker": 0
107
+ }
108
+ ],
109
+ "language": "en"
110
+ }```
111
+ """
112
+
113
+ uploaded_file = self.client.files.upload(file=audio_path)
114
+ while uploaded_file.state != 'ACTIVE':
115
+ time.sleep(1)
116
+ uploaded_file = self.client.files.get(name=uploaded_file.name)
117
+ if uploaded_file.state == 'FAILED':
118
+ raise ValueError('Failed to upload the audio file')
119
+
120
+ response = self.client.models.generate_content(
121
+ model=self.model,
122
+ contents=uploaded_file,
123
+ config=types.GenerateContentConfig(
124
+ system_instruction=self.SYSTEM_INSTRUCTION,
125
+ temperature=0.2,
126
+ response_mime_type='application/json',
127
+ response_schema=self.RESPONSE_SCHEMA,
128
+ )
129
+ )
130
+
131
+ if response.parsed is None:
132
+ raise ValueError('Failed to transcribe the audio file')
133
+
134
+ return response.parsed # type: ignore
utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os.path
3
+ import subprocess
4
+
5
+ from yt_dlp import YoutubeDL
6
+
7
+ from configs import settings
8
+
9
+
10
+ def download_video(
11
+ url: str,
12
+ output_dir: str = None,
13
+ max_resolution: int = 1080,
14
+ max_fps: float = 60,
15
+ extension: str = 'mp4'
16
+ ) -> tuple[str, dict]:
17
+ """Download a video from YouTube or other supported sites. Returns the file path and video metadata.
18
+
19
+ Args:
20
+ url (str): The URL of the video.
21
+ output_dir (str, optional): Directory to save the downloaded video. Defaults to current directory.
22
+ max_resolution (int, optional): Maximum resolution of the video to download. Defaults to 1080.
23
+ max_fps (float, optional): Maximum frames per second of the video to download. Defaults to 60.
24
+ extension (str, optional): File extension for the downloaded video. Defaults to 'mp4'.
25
+
26
+ Returns:
27
+ tuple[str, dict]: A tuple containing the path to the downloaded video file and its metadata.
28
+ """
29
+
30
+ ydl_opts = {
31
+ 'format': f'bestvideo[height<={max_resolution}][fps<={max_fps}][ext={extension}]+'
32
+ f'bestaudio/best[height<={max_resolution}][fps<={max_fps}][ext={extension}]/best',
33
+ 'merge_output_format': extension,
34
+ 'outtmpl': f'{output_dir or "."}/%(title)s.%(ext)s',
35
+ 'noplaylist': True,
36
+ }
37
+ with YoutubeDL(ydl_opts) as ydl:
38
+ info = ydl.extract_info(url, download=True)
39
+ ydl.download([url])
40
+ if output_dir:
41
+ output_path = os.path.join(output_dir, ydl.prepare_filename(info))
42
+ else:
43
+ output_path = ydl.prepare_filename(info)
44
+
45
+ return output_path, info
46
+
47
+
48
+ def extract_video_frames(video_path: str, output_dir: str, frame_rate: float = 1, extension: str = 'jpg') -> list[str]:
49
+ """Extract frames from a video file at a specified frame rate.
50
+
51
+ Args:
52
+ video_path (str): Path to the video file.
53
+ output_dir (str): Directory to save the extracted frames.
54
+ frame_rate (float, optional): Frame rate for extraction. Defaults to 1 frame per second.
55
+ extension (str, optional): File extension for the extracted frames. Defaults to 'jpg'.
56
+
57
+ Returns:
58
+ list[str]: A sorted list of paths to the extracted frame images.
59
+ """
60
+ os.makedirs(output_dir, exist_ok=True)
61
+
62
+ subprocess.run(
63
+ [
64
+ settings.FFMPEG_PATH,
65
+ # '-v', 'quiet',
66
+ '-i', video_path,
67
+ '-vf', f'fps={frame_rate}',
68
+ '-y',
69
+ f'{output_dir or "."}/%d.{extension}'
70
+ ],
71
+ stdout=subprocess.DEVNULL,
72
+ stderr=subprocess.DEVNULL,
73
+ )
74
+ # Get all extracted frames
75
+ results = sorted(glob.glob(f'{output_dir or "."}/*.{extension}'),
76
+ key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
77
+ if not results:
78
+ raise FileNotFoundError(f'No frames found in "{output_dir}" for video "{video_path}"')
79
+
80
+ return results
81
+
82
+
83
+ def extract_audio(video_path: str, output_dir: str = None, extension: str = 'm4a') -> str:
84
+ """Extract audio from a video file and save it as an M4A file.
85
+
86
+ Args:
87
+ video_path (str): Path to the video file.
88
+ output_dir (str, optional): Directory to save the extracted audio. Defaults to the same directory as the video.
89
+ extension (str, optional): File extension for the extracted audio. Defaults to 'm4a'.
90
+ Returns:
91
+ str: Path to the extracted audio file.
92
+ """
93
+ if output_dir is None:
94
+ output_dir = os.path.dirname(video_path)
95
+
96
+ audio_path = os.path.join(output_dir, f'{os.path.splitext(os.path.basename(video_path))[0]}.{extension}')
97
+
98
+ subprocess.run(
99
+ [
100
+ settings.FFMPEG_PATH,
101
+ '-i', video_path,
102
+ '-q:a', '0',
103
+ '-map', 'a',
104
+ '-y',
105
+ audio_path
106
+ ],
107
+ stdout=subprocess.DEVNULL,
108
+ stderr=subprocess.DEVNULL,
109
+ )
110
+
111
+ if not os.path.exists(audio_path):
112
+ raise FileNotFoundError(f'Audio extraction failed: "{audio_path}" does not exist.')
113
+
114
+ return audio_path
115
+
116
+
117
+ def split_media_file(file_path: str, output_dir: str, segment_length: int = 60) -> list[str]:
118
+ """Split a media file into segments of specified length in seconds.
119
+
120
+ Args:
121
+ file_path (str): Path to the media file to be split.
122
+ output_dir (str): Directory to save the split segments.
123
+ segment_length (int, optional): Length of each segment in seconds. Defaults to 60 seconds.
124
+
125
+ Returns:
126
+ list[str]: A sorted list of paths to the split media segments.
127
+ """
128
+ os.makedirs(output_dir, exist_ok=True)
129
+
130
+ base_name = os.path.splitext(os.path.basename(file_path))[0]
131
+ extension = os.path.splitext(file_path)[1]
132
+ segment_pattern = os.path.join(output_dir, f'{base_name}_%03d.{extension}')
133
+
134
+ subprocess.run(
135
+ [
136
+ settings.FFMPEG_PATH,
137
+ '-i', file_path,
138
+ '-c', 'copy',
139
+ '-map', '0',
140
+ '-segment_time', str(segment_length),
141
+ '-f', 'segment',
142
+ '-y',
143
+ segment_pattern
144
+ ],
145
+ stdout=subprocess.DEVNULL,
146
+ stderr=subprocess.DEVNULL,
147
+ )
148
+
149
+ return sorted(glob.glob(f'{output_dir}/*{base_name}_*.{extension}'))
150
+
151
+
152
+ def span_iou(span1: tuple[float, float], span2: tuple[float, float]) -> float:
153
+ """Calculate the Intersection over Union (IoU) of two spans."""
154
+ start1, end1 = span1
155
+ start2, end2 = span2
156
+
157
+ intersection_start = max(start1, start2)
158
+ intersection_end = min(end1, end2)
159
+
160
+ if intersection_start >= intersection_end:
161
+ return 0.0 # No overlap
162
+
163
+ intersection_length = intersection_end - intersection_start
164
+ union_length = (end1 - start1) + (end2 - start2) - intersection_length
165
+
166
+ return intersection_length / union_length if union_length > 0 else 0.0
167
+
168
+
169
+ def seconds_to_hms(total_seconds: int, drop_hours: bool = False) -> str:
170
+ """Convert a number of seconds to a string formatted as HH:MM:SS."""
171
+ # Ensure we’re working with non-negative integers
172
+ if total_seconds < 0:
173
+ raise ValueError('total_seconds must be non-negative')
174
+
175
+ hours, remainder = divmod(total_seconds, 3600)
176
+ minutes, seconds = divmod(remainder, 60)
177
+
178
+ if drop_hours and hours == 0:
179
+ return f'{minutes:02d}:{seconds:02d}'
180
+
181
+ return f'{hours:02d}:{minutes:02d}:{seconds:02d}'