Spaces:
Running
on
Zero
Running
on
Zero
First commit
Browse files- .gitignore +2 -0
- README.md +45 -14
- agent.py +82 -0
- app.py +85 -0
- configs.py +26 -0
- embeder.py +39 -0
- prompt.py +13 -0
- rag.py +273 -0
- requirements.txt +8 -0
- tools.py +122 -0
- transcriber.py +134 -0
- utils.py +181 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
data/*
|
README.md
CHANGED
|
@@ -1,14 +1,45 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Chatbot for Video Question Answering Demo
|
| 2 |
+
|
| 3 |
+
AI chatbot that can answer questions about video content. This project leverages multi-modal LLM, multi-modal RAG pipeline to process video frames, transcribe audio, and retrieval information to provide accurate answers to questions about video content.
|
| 4 |
+
|
| 5 |
+
## Requirements
|
| 6 |
+
|
| 7 |
+
- Python 3.12+
|
| 8 |
+
- [uv](https://docs.astral.sh/uv/) for package and project manager
|
| 9 |
+
- [FFmpeg](https://ffmpeg.org/) installed and available in PATH
|
| 10 |
+
- [Google Gemini API key](https://aistudio.google.com/apikey) for the LLM functionality
|
| 11 |
+
|
| 12 |
+
## Installation
|
| 13 |
+
|
| 14 |
+
1. Clone this repository
|
| 15 |
+
```bash
|
| 16 |
+
git clone [repository-url]
|
| 17 |
+
cd VideoChatbot
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
2. Install dependencies using uv
|
| 21 |
+
```bash
|
| 22 |
+
uv sync
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
3. Create a `.env` file in the project root with your API key
|
| 26 |
+
```
|
| 27 |
+
GEMINI_API_KEY=your_api_key_here
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## Usage
|
| 31 |
+
|
| 32 |
+
1. Start the application
|
| 33 |
+
```bash
|
| 34 |
+
python -m app.main
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
2. Access the UI through your browser (typically at http://127.0.0.1:7860)
|
| 38 |
+
|
| 39 |
+
3. Upload a video file or provide a YouTube URL and ask questions about it
|
| 40 |
+
|
| 41 |
+
4. The system will process the video (extract frames, transcribe audio), index the content, and then answer your questions
|
| 42 |
+
|
| 43 |
+
## Notes
|
| 44 |
+
|
| 45 |
+
This project is designed to be a demo and may require additional configuration for production use. The video processing and indexing can take time depending on the video length and complexity. Use a larger LLMs, embeddings, transcription models, and vector databases for better performance and accuracy.
|
agent.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Generator
|
| 3 |
+
|
| 4 |
+
from smolagents import ToolCallingAgent, OpenAIServerModel, ActionStep
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
import tools
|
| 8 |
+
from configs import settings
|
| 9 |
+
from prompt import video_to_text_prompt
|
| 10 |
+
from rag import VideoRAG
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class VideoChatbot:
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
model: str = 'gemini-2.0-flash',
|
| 17 |
+
api_base: str = None,
|
| 18 |
+
api_key: str = None
|
| 19 |
+
):
|
| 20 |
+
self.video_rag = VideoRAG(
|
| 21 |
+
video_frame_rate=settings.VIDEO_EXTRACTION_FRAME_RATE,
|
| 22 |
+
audio_segment_length=settings.AUDIO_SEGMENT_LENGTH,
|
| 23 |
+
)
|
| 24 |
+
self.agent = ToolCallingAgent(
|
| 25 |
+
tools=[
|
| 26 |
+
tools.download_video,
|
| 27 |
+
*tools.create_video_rag_tools(self.video_rag)
|
| 28 |
+
],
|
| 29 |
+
model=OpenAIServerModel(
|
| 30 |
+
model_id=model,
|
| 31 |
+
api_base=api_base,
|
| 32 |
+
api_key=api_key
|
| 33 |
+
),
|
| 34 |
+
step_callbacks=[self._step_callback],
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def chat(self, message: str, attachments: list[str] = None) -> Generator:
|
| 38 |
+
"""Chats with the bot, including handling attachments (images and videos).
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
message: The text message to send to the bot.
|
| 42 |
+
attachments: A list of file paths for images or videos to include in the chat.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
A generator yielding step objects representing the bot's responses and actions.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
images = []
|
| 49 |
+
for filepath in attachments or []:
|
| 50 |
+
if filepath.endswith(('.jpg', '.jpeg', '.png')):
|
| 51 |
+
images.append(Image.open(filepath))
|
| 52 |
+
if filepath.endswith('.mp4'):
|
| 53 |
+
message = video_to_text_prompt(filepath) + message
|
| 54 |
+
|
| 55 |
+
for step in self.agent.run(
|
| 56 |
+
message,
|
| 57 |
+
stream=True,
|
| 58 |
+
reset=False,
|
| 59 |
+
images=images,
|
| 60 |
+
):
|
| 61 |
+
yield step
|
| 62 |
+
|
| 63 |
+
def clear(self):
|
| 64 |
+
"""Clears the chatbot message history and context."""
|
| 65 |
+
self.agent.state.clear()
|
| 66 |
+
self.agent.memory.reset()
|
| 67 |
+
self.agent.monitor.reset()
|
| 68 |
+
self.video_rag.clear()
|
| 69 |
+
|
| 70 |
+
def _step_callback(self, step: ActionStep, agent: ToolCallingAgent):
|
| 71 |
+
if step.observations:
|
| 72 |
+
image_index = 0
|
| 73 |
+
for image_path in re.findall(r'<observation_image>(.*?)</observation_image>', step.observations):
|
| 74 |
+
try:
|
| 75 |
+
image = Image.open(image_path)
|
| 76 |
+
step.observations_images.append(image)
|
| 77 |
+
step.observations = step.observations.replace(image_path, str(image_index))
|
| 78 |
+
image_index += 1
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f'Error loading image {image_path}: {e}')
|
| 81 |
+
|
| 82 |
+
|
app.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from smolagents import ChatMessageToolCall, ActionStep, FinalAnswerStep
|
| 6 |
+
|
| 7 |
+
from agent import VideoChatbot
|
| 8 |
+
from configs import settings
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
bot = VideoChatbot(
|
| 12 |
+
model=settings.CHATBOT_MODEL,
|
| 13 |
+
api_base=settings.MODEL_BASE_API,
|
| 14 |
+
api_key=os.environ['GEMINI_API_KEY']
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def chat(message: dict, history: list[dict]):
|
| 19 |
+
|
| 20 |
+
# move the file to the data directory
|
| 21 |
+
message['files'] = [shutil.copy(file, settings.DATA_DIR) for file in message['files']]
|
| 22 |
+
|
| 23 |
+
# add the input message to the history
|
| 24 |
+
history.extend([{'role': 'user', 'content': {'path': file}} for file in message['files']])
|
| 25 |
+
history.append({'role': 'user', 'content': message['text']})
|
| 26 |
+
yield history, ''
|
| 27 |
+
|
| 28 |
+
for step in bot.chat(message['text'], message['files']):
|
| 29 |
+
match step:
|
| 30 |
+
case ChatMessageToolCall():
|
| 31 |
+
if step.function.name == 'download_video':
|
| 32 |
+
history.append({
|
| 33 |
+
'role': 'assistant',
|
| 34 |
+
'content': f'📥 Downloading video from {step.function.arguments['url']}'
|
| 35 |
+
})
|
| 36 |
+
elif step.function.name == 'add_video':
|
| 37 |
+
history.append({
|
| 38 |
+
'role': 'assistant',
|
| 39 |
+
'content': f'🎥 Processing and adding video `{step.function.arguments["filename"]}` '
|
| 40 |
+
f'to the knowledge base. This may take a while...'
|
| 41 |
+
})
|
| 42 |
+
elif step.function.name == 'search_in_video':
|
| 43 |
+
filename = os.path.basename(bot.video_rag.videos[step.function.arguments["video_id"]]['video_path'])
|
| 44 |
+
history.append({
|
| 45 |
+
'role': 'assistant',
|
| 46 |
+
'content': f'🔍 Searching in video `{filename}` '
|
| 47 |
+
f'for query: *{step.function.arguments.get("text_query", step.function.arguments.get("image_query", ""))}*'
|
| 48 |
+
})
|
| 49 |
+
elif step.function.name == 'final_answer':
|
| 50 |
+
continue
|
| 51 |
+
yield history, ''
|
| 52 |
+
case ActionStep():
|
| 53 |
+
yield history, ''
|
| 54 |
+
case FinalAnswerStep():
|
| 55 |
+
history.append({'role': 'assistant', 'content': step.output})
|
| 56 |
+
yield history, ''
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def clear_chat(chatbot):
|
| 60 |
+
chatbot.clear()
|
| 61 |
+
return chatbot, gr.update(value='')
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def main():
|
| 65 |
+
with gr.Blocks() as demo:
|
| 66 |
+
gr.Markdown('# Video Chatbot Demo')
|
| 67 |
+
gr.Markdown('This demo showcases a video chatbot that can process and search videos using '
|
| 68 |
+
'RAG (Retrieval-Augmented Generation). You can upload videos/images or link to YouTube videos, '
|
| 69 |
+
'ask questions, and get answers based on the video content.')
|
| 70 |
+
chatbot = gr.Chatbot(type='messages', label='Video Chatbot', height=800, resizable=True)
|
| 71 |
+
textbox = gr.MultimodalTextbox(
|
| 72 |
+
sources=['upload'],
|
| 73 |
+
file_types=['image', '.mp4'],
|
| 74 |
+
show_label=False,
|
| 75 |
+
placeholder='Type a message or upload an image/video...',
|
| 76 |
+
|
| 77 |
+
)
|
| 78 |
+
textbox.submit(chat, [textbox, chatbot], [chatbot, textbox])
|
| 79 |
+
clear = gr.Button('Clear Chat')
|
| 80 |
+
clear.click(clear_chat, [chatbot], [chatbot, textbox])
|
| 81 |
+
|
| 82 |
+
demo.launch(debug=True)
|
| 83 |
+
|
| 84 |
+
if __name__ == '__main__':
|
| 85 |
+
main()
|
configs.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
load_dotenv()
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class Settings:
|
| 10 |
+
DATA_DIR: str = 'data'
|
| 11 |
+
FFMPEG_PATH: str = 'ffmpeg'
|
| 12 |
+
MAX_VIDEO_RESOLUTION: int = 360
|
| 13 |
+
MAX_VIDEO_FPS: float = 30
|
| 14 |
+
VIDEO_EXTENSION: str = 'mp4'
|
| 15 |
+
VIDEO_EXTRACTION_FRAME_RATE: float = 1.0
|
| 16 |
+
AUDIO_SEGMENT_LENGTH: int = 300
|
| 17 |
+
CHATBOT_MODEL: str = 'gemini-2.0-flash'
|
| 18 |
+
MODEL_BASE_API: str = 'https://generativelanguage.googleapis.com/v1beta/'
|
| 19 |
+
TEXT_EMBEDDING_MODEL: str = 'sentence-transformers/all-MiniLM-L6-v2'
|
| 20 |
+
IMAGE_EMBEDDING_MODEL: str = 'facebook/dinov2-small'
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
settings = Settings()
|
| 26 |
+
|
embeder.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
from transformers import pipeline
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MultimodalEmbedder:
|
| 9 |
+
"""A multimodal embedder that supports text and image embeddings."""
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
text_model: str = 'sentence-transformers/all-MiniLM-L6-v2',
|
| 13 |
+
image_model: str = 'facebook/dinov2-small'
|
| 14 |
+
):
|
| 15 |
+
self.text_model = SentenceTransformer(text_model)
|
| 16 |
+
self.image_model = pipeline(
|
| 17 |
+
'image-feature-extraction',
|
| 18 |
+
model=image_model,
|
| 19 |
+
device=0 if torch.cuda.is_available() else -1,
|
| 20 |
+
pool=True
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
| 24 |
+
"""Embed a list of texts"""
|
| 25 |
+
return self.text_model.encode(
|
| 26 |
+
texts,
|
| 27 |
+
device='cuda' if torch.cuda.is_available() else 'cpu',
|
| 28 |
+
show_progress_bar=True
|
| 29 |
+
).tolist()
|
| 30 |
+
|
| 31 |
+
def embed_images(self, images: list[str | Image.Image]) -> list[list[float]]:
|
| 32 |
+
"""Embed a list of images, which can be file paths or PIL Image objects."""
|
| 33 |
+
images = [Image.open(img) if isinstance(img, str) else img for img in images]
|
| 34 |
+
images = [img.convert('RGB') for img in images]
|
| 35 |
+
|
| 36 |
+
embeddings = self.image_model(images)
|
| 37 |
+
|
| 38 |
+
return [emb[0] for emb in embeddings]
|
| 39 |
+
|
prompt.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def video_to_text_prompt(video_path: str, metadata: dict = None) -> str:
|
| 5 |
+
"""Generate a text prompt to represent a video file with its metadata."""
|
| 6 |
+
metadata = metadata or {}
|
| 7 |
+
return f'''<video>
|
| 8 |
+
Filename: {os.path.basename(video_path)}
|
| 9 |
+
Metadata:
|
| 10 |
+
{'\n'.join(f'- {key}: {value}' for key, value in metadata.items())}
|
| 11 |
+
</video>
|
| 12 |
+
'''
|
| 13 |
+
|
rag.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
import uuid
|
| 3 |
+
|
| 4 |
+
import lancedb
|
| 5 |
+
import pyarrow as pa
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from scipy.spatial import distance
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
import utils
|
| 11 |
+
from configs import settings
|
| 12 |
+
from embeder import MultimodalEmbedder
|
| 13 |
+
from transcriber import AudioTranscriber
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class VideoRAG:
|
| 17 |
+
"""Video RAG (Retrieval-Augmented Generation) system for processing and searching video content."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, video_frame_rate: float = 1, audio_segment_length: int = 300):
|
| 20 |
+
self.video_frame_rate = video_frame_rate
|
| 21 |
+
self.audio_segment_length = audio_segment_length
|
| 22 |
+
|
| 23 |
+
print('Loading embedding and audio transcription models...')
|
| 24 |
+
self.embedder = MultimodalEmbedder(
|
| 25 |
+
text_model=settings.TEXT_EMBEDDING_MODEL,
|
| 26 |
+
image_model=settings.IMAGE_EMBEDDING_MODEL,
|
| 27 |
+
)
|
| 28 |
+
self.transcriber = AudioTranscriber()
|
| 29 |
+
|
| 30 |
+
# init DB and tables
|
| 31 |
+
self._init_db()
|
| 32 |
+
|
| 33 |
+
def _init_db(self):
|
| 34 |
+
print('Initializing LanceDB...')
|
| 35 |
+
self.db = lancedb.connect(f'{settings.DATA_DIR}/vectordb')
|
| 36 |
+
self.frames_table = self.db.create_table('frames', mode='overwrite', schema=pa.schema([
|
| 37 |
+
pa.field('vector', pa.list_(pa.float32(), 384)),
|
| 38 |
+
pa.field('video_id', pa.string()),
|
| 39 |
+
pa.field('frame_index', pa.int32()),
|
| 40 |
+
pa.field('frame_path', pa.string()),
|
| 41 |
+
]))
|
| 42 |
+
self.transcripts_table = self.db.create_table('transcripts', mode='overwrite', schema=pa.schema([
|
| 43 |
+
pa.field('vector', pa.list_(pa.float32(), 384)),
|
| 44 |
+
pa.field('video_id', pa.string()),
|
| 45 |
+
pa.field('segment_index', pa.int32()),
|
| 46 |
+
pa.field('start', pa.float64()),
|
| 47 |
+
pa.field('end', pa.float64()),
|
| 48 |
+
pa.field('text', pa.string()),
|
| 49 |
+
]))
|
| 50 |
+
|
| 51 |
+
# save video metadata
|
| 52 |
+
self.videos = {}
|
| 53 |
+
|
| 54 |
+
def is_video_exists(self, video_id: str) -> bool:
|
| 55 |
+
"""Check if a video exists in the RAG system by video ID.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
video_id (str): The ID of the video to check.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
bool: True if the video exists, False otherwise.
|
| 62 |
+
"""
|
| 63 |
+
return video_id in self.videos
|
| 64 |
+
|
| 65 |
+
def get_video(self, video_id: str) -> dict:
|
| 66 |
+
"""Retrieve video metadata by video ID.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
video_id (str): The ID of the video to retrieve.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
dict: A dictionary containing video metadata, including video path, frame directory, frame rate, and transcript segments.
|
| 73 |
+
"""
|
| 74 |
+
if video_id not in self.videos:
|
| 75 |
+
raise ValueError(f'Video with ID {video_id} not found.')
|
| 76 |
+
return self.videos[video_id]
|
| 77 |
+
|
| 78 |
+
def add_video(self, video_path: str) -> str:
|
| 79 |
+
"""Add a video to the RAG system by processing its frames and transcripts.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
video_path (str): The path to the video file to be added.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
str: A unique video ID generated for the added video.
|
| 86 |
+
"""
|
| 87 |
+
# create a unique video ID
|
| 88 |
+
video_id = uuid.uuid4().hex[:8]
|
| 89 |
+
|
| 90 |
+
print(f'Adding video "{video_path}" with ID {video_id} to the RAG system...')
|
| 91 |
+
|
| 92 |
+
print('Extracting video frames')
|
| 93 |
+
# process video frames
|
| 94 |
+
frame_paths = utils.extract_video_frames(video_path, output_dir=f'{video_path}_frames',
|
| 95 |
+
frame_rate=self.video_frame_rate)
|
| 96 |
+
print(f'Computing embeddings for {len(frame_paths)} frames...')
|
| 97 |
+
# calculate embeddings for frames
|
| 98 |
+
frame_embeddings = self.embedder.embed_images(frame_paths)
|
| 99 |
+
# get significant frames to reduce the number of frames
|
| 100 |
+
frame_indexes = get_significant_frames(frame_embeddings, threshold=0.6)
|
| 101 |
+
# add frames to the database
|
| 102 |
+
self.frames_table.add(
|
| 103 |
+
[{
|
| 104 |
+
'vector': frame_embeddings[i],
|
| 105 |
+
'video_id': video_id,
|
| 106 |
+
'frame_index': i,
|
| 107 |
+
'frame_path': frame_paths[i],
|
| 108 |
+
} for i in frame_indexes]
|
| 109 |
+
)
|
| 110 |
+
print(f'Added {len(frame_indexes)} significant frames to the database.')
|
| 111 |
+
|
| 112 |
+
print('Extracting audio from video')
|
| 113 |
+
# transcribe video to text
|
| 114 |
+
audio_path = utils.extract_audio(video_path)
|
| 115 |
+
print(f'Splitting and transcribing audio...')
|
| 116 |
+
segments = []
|
| 117 |
+
for i, segment_path in tqdm(enumerate(utils.split_media_file(
|
| 118 |
+
audio_path,
|
| 119 |
+
output_dir=f'{video_path}_audio_segments',
|
| 120 |
+
segment_length=self.audio_segment_length
|
| 121 |
+
)), desc='Transcribing audio'):
|
| 122 |
+
for segment in self.transcriber.transcribe(segment_path)['segments']:
|
| 123 |
+
segment['start'] += i * self.audio_segment_length
|
| 124 |
+
segment['end'] += i * self.audio_segment_length
|
| 125 |
+
segments.append(segment)
|
| 126 |
+
segments = sorted(segments, key=lambda s: s['start'])
|
| 127 |
+
|
| 128 |
+
print(f'Computing embeddings for {len(segments)} transcript segments...')
|
| 129 |
+
# calculate embeddings for transcripts
|
| 130 |
+
transcript_embeddings = self.embedder.embed_texts([s['text'] for s in segments])
|
| 131 |
+
# add transcripts to the database
|
| 132 |
+
self.transcripts_table.add(
|
| 133 |
+
[{
|
| 134 |
+
'vector': transcript_embeddings[i],
|
| 135 |
+
'video_id': video_id,
|
| 136 |
+
'segment_index': i,
|
| 137 |
+
'start': segment['start'],
|
| 138 |
+
'end': segment['end'],
|
| 139 |
+
'text': segment['text'],
|
| 140 |
+
} for i, segment in enumerate(segments)],
|
| 141 |
+
)
|
| 142 |
+
print(f'Added {len(segments)} transcript segments to the database.')
|
| 143 |
+
|
| 144 |
+
# add video metadata to the database
|
| 145 |
+
self.videos[video_id] = {
|
| 146 |
+
'video_path': video_path,
|
| 147 |
+
'frame_dir': f'{video_path}_frames',
|
| 148 |
+
'video_frame_rate': self.video_frame_rate,
|
| 149 |
+
'transcript_segments': segments,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
print(f'Video "{video_path}" added with ID {video_id}.')
|
| 153 |
+
return video_id
|
| 154 |
+
|
| 155 |
+
def search(self, video_id: str, text: str = None, image: str | Image.Image = None, limit: int = 10) -> list[dict]:
|
| 156 |
+
"""Search for relevant video frames or transcripts based on text or image input.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
video_id (str): The ID of the video to search in.
|
| 160 |
+
text (str, optional): The text query to search for in the video transcripts.
|
| 161 |
+
image (str | Image.Image, optional): The image query to search for in the video frames. If a string is provided, it should be the path to the image file.
|
| 162 |
+
limit (int, optional): The maximum number of results to return. Defaults to 10.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
list[dict]: A list of dictionaries containing the search results, each with start and end times, distance, frame paths, and transcript segments.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
video_metadata = self.get_video(video_id)
|
| 169 |
+
|
| 170 |
+
# search for transcripts based on text
|
| 171 |
+
timespans = []
|
| 172 |
+
if text is not None:
|
| 173 |
+
text_embedding = self.embedder.embed_texts([text])[0]
|
| 174 |
+
query = (self.transcripts_table
|
| 175 |
+
.search(text_embedding)
|
| 176 |
+
.where(f'video_id = \'{video_id}\'')
|
| 177 |
+
.limit(limit))
|
| 178 |
+
for result in query.to_list():
|
| 179 |
+
timespans.append({
|
| 180 |
+
'start': result['start'],
|
| 181 |
+
'end': result['end'],
|
| 182 |
+
'distance': distance.cosine(text_embedding, result['vector']),
|
| 183 |
+
})
|
| 184 |
+
|
| 185 |
+
# search for frames based on image
|
| 186 |
+
if image is not None:
|
| 187 |
+
image_embedding = self.embedder.embed_images([image])[0]
|
| 188 |
+
query = (self.frames_table
|
| 189 |
+
.search(image_embedding)
|
| 190 |
+
.where(f'video_id = \'{video_id}\'')
|
| 191 |
+
.limit(limit))
|
| 192 |
+
for result in query.to_list():
|
| 193 |
+
start = result['frame_index'] / self.video_frame_rate
|
| 194 |
+
timespans.append({
|
| 195 |
+
'start': start,
|
| 196 |
+
'end': start + 1,
|
| 197 |
+
'distance': distance.cosine(image_embedding, result['vector']), # Fix lancedb return large distance
|
| 198 |
+
})
|
| 199 |
+
|
| 200 |
+
# merge nearby timespans
|
| 201 |
+
timespans = merge_searched_timespans(timespans, threshold=5)
|
| 202 |
+
# sort timespans by distance
|
| 203 |
+
timespans = sorted(timespans, key=lambda x: x['distance'])
|
| 204 |
+
# limit to k results
|
| 205 |
+
timespans = timespans[:limit]
|
| 206 |
+
|
| 207 |
+
for timespan in timespans:
|
| 208 |
+
# extend timespans to at least 5 seconds
|
| 209 |
+
duration = timespan['end'] - timespan['start']
|
| 210 |
+
if duration < 5:
|
| 211 |
+
timespan['start'] = max(0, timespan['start'] - (5 - duration) / 2)
|
| 212 |
+
timespan['end'] = timespan['start'] + 5
|
| 213 |
+
# add frame paths
|
| 214 |
+
timespan['frame_paths'] = []
|
| 215 |
+
for frame_index in range(
|
| 216 |
+
int(timespan['start'] * self.video_frame_rate),
|
| 217 |
+
int(timespan['end'] * self.video_frame_rate)
|
| 218 |
+
):
|
| 219 |
+
timespan['frame_paths'].append(os.path.join(video_metadata['frame_dir'], f'{frame_index + 1}.jpg'))
|
| 220 |
+
# add transcript segments
|
| 221 |
+
timespan['transcript_segments'] = []
|
| 222 |
+
for segment in video_metadata['transcript_segments']:
|
| 223 |
+
if utils.span_iou((segment['start'], segment['end']),
|
| 224 |
+
(timespan['start'], timespan['end'])) > 0:
|
| 225 |
+
timespan['transcript_segments'].append(segment)
|
| 226 |
+
|
| 227 |
+
return timespans
|
| 228 |
+
|
| 229 |
+
def clear(self):
|
| 230 |
+
"""Clear the RAG system by dropping all tables and resetting video metadata."""
|
| 231 |
+
self._init_db()
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_significant_frames(frame_embeddings: list[list[float]], threshold: float = 0.8) -> list[int]:
|
| 235 |
+
"""Select significant frames by comparing embeddings."""
|
| 236 |
+
selected_frames = []
|
| 237 |
+
current_frame = 0
|
| 238 |
+
for i, embedding in enumerate(frame_embeddings):
|
| 239 |
+
similarity = 1 - distance.cosine(frame_embeddings[current_frame], embedding)
|
| 240 |
+
if similarity < threshold:
|
| 241 |
+
selected_frames.append(current_frame)
|
| 242 |
+
current_frame = i
|
| 243 |
+
|
| 244 |
+
selected_frames.append(current_frame)
|
| 245 |
+
|
| 246 |
+
return selected_frames
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def merge_searched_timespans(timespans: list[dict], threshold: float) -> list[dict]:
|
| 250 |
+
"""Merge timespans if the gap between them is less than or equal to threshold."""
|
| 251 |
+
if not timespans:
|
| 252 |
+
return []
|
| 253 |
+
|
| 254 |
+
# Sort spans by start time
|
| 255 |
+
sorted_spans = sorted(timespans, key=lambda s: s['start'])
|
| 256 |
+
|
| 257 |
+
merged_spans = []
|
| 258 |
+
current_span = sorted_spans[0].copy()
|
| 259 |
+
|
| 260 |
+
for next_span in sorted_spans[1:]:
|
| 261 |
+
gap = next_span['start'] - current_span['end']
|
| 262 |
+
if gap <= threshold:
|
| 263 |
+
# Extend the current span’s end if needed
|
| 264 |
+
current_span['end'] = max(current_span['end'], next_span['end'])
|
| 265 |
+
current_span['distance'] = min(current_span['distance'], next_span['distance'])
|
| 266 |
+
else:
|
| 267 |
+
# No merge push current and start a new one
|
| 268 |
+
merged_spans.append(current_span)
|
| 269 |
+
current_span = next_span.copy()
|
| 270 |
+
|
| 271 |
+
# Add the last span
|
| 272 |
+
merged_spans.append(current_span)
|
| 273 |
+
return merged_spans
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
google-genai>=1.22.0
|
| 2 |
+
lancedb>=0.24.0
|
| 3 |
+
pillow>=10.4.0
|
| 4 |
+
sentence-transformers>=4.1.0
|
| 5 |
+
smolagents[openai]>=1.19.0
|
| 6 |
+
tqdm>=4.67.1
|
| 7 |
+
transformers>=4.53.0
|
| 8 |
+
yt-dlp>=2025.6.25
|
tools.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from smolagents import tool, Tool
|
| 4 |
+
|
| 5 |
+
import utils
|
| 6 |
+
from configs import settings
|
| 7 |
+
from prompt import video_to_text_prompt
|
| 8 |
+
from rag import VideoRAG
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@tool
|
| 12 |
+
def download_video(url: str) -> str:
|
| 13 |
+
"""
|
| 14 |
+
Download a video from YouTube or other supported platforms.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
url (str): The URL of the video.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
str: The video information, including the filename.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
filepath, info = utils.download_video(
|
| 25 |
+
url,
|
| 26 |
+
output_dir=settings.DATA_DIR,
|
| 27 |
+
max_resolution=settings.MAX_VIDEO_RESOLUTION,
|
| 28 |
+
max_fps=settings.MAX_VIDEO_FPS,
|
| 29 |
+
extension=settings.VIDEO_EXTENSION
|
| 30 |
+
)
|
| 31 |
+
except Exception as e:
|
| 32 |
+
return f'Error downloading video: {e.__class__.__name__}: {e}'
|
| 33 |
+
|
| 34 |
+
return video_to_text_prompt(
|
| 35 |
+
filepath,
|
| 36 |
+
metadata={
|
| 37 |
+
'URL': url,
|
| 38 |
+
'Title': info.get('title', 'N/A'),
|
| 39 |
+
'Channel': info.get('channel', 'N/A'),
|
| 40 |
+
'Duration': info.get('duration', 'N/A'),
|
| 41 |
+
}
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def create_video_rag_tools(video_rag: VideoRAG) -> list[Tool]:
|
| 46 |
+
|
| 47 |
+
@tool
|
| 48 |
+
def add_video(filename: str) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Add a video file to the RAG knowledge-base for further search and analysis.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
filename (str): The video filename to add.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
str: The video ID if added successfully, or an error message.
|
| 57 |
+
"""
|
| 58 |
+
try:
|
| 59 |
+
video_id = video_rag.add_video(os.path.join(settings.DATA_DIR, filename))
|
| 60 |
+
return f'Video added with ID: {video_id}'
|
| 61 |
+
except Exception as e:
|
| 62 |
+
return f'Error adding video: {e.__class__.__name__}: {e}'
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@tool
|
| 66 |
+
def search_in_video(video_id: str, text_query: str = None, image_query: str = None) -> str:
|
| 67 |
+
"""
|
| 68 |
+
Search for relevant video frames and transcripts based on text or image query. Allows searching within a specific video added to the RAG knowledge-base.
|
| 69 |
+
At least one of `text_query` or `image_query` must be provided.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
video_id (str): The ID of the video to search in. This should be the ID returned by `add_video`.
|
| 73 |
+
text_query (str, optional): The text query to search for in the video transcripts.
|
| 74 |
+
image_query (str, optional): The image query to search for in the video frames. This is the filename of the image.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
str: A message indicating the search results or an error message if the video is not found.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
if not video_rag.is_video_exists(video_id):
|
| 81 |
+
return f'Video with ID "{video_id}" not found in the knowledge-base. Please add the video first using `add_video` tool.'
|
| 82 |
+
if not text_query and not image_query:
|
| 83 |
+
return 'Please provide at least one of `text_query` or `image_query` to search in the video.'
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
results = video_rag.search(
|
| 87 |
+
video_id=video_id,
|
| 88 |
+
text=text_query,
|
| 89 |
+
image=image_query,
|
| 90 |
+
limit=5
|
| 91 |
+
)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
return f'Error searching in video: {e.__class__.__name__}: {e}'
|
| 94 |
+
|
| 95 |
+
if not results:
|
| 96 |
+
return f'No results found for the given query in video ID {video_id}.'
|
| 97 |
+
|
| 98 |
+
# build the output message
|
| 99 |
+
output = f'Search results for video ID {video_id}:\n'
|
| 100 |
+
for result in results:
|
| 101 |
+
# include timespans, transcript segments, and frame paths in the output
|
| 102 |
+
timespan_text = f'{utils.seconds_to_hms(int(result['start']))} - {utils.seconds_to_hms(int(result['end']))}'
|
| 103 |
+
transcript_texts = []
|
| 104 |
+
for segment in result['transcript_segments']:
|
| 105 |
+
transcript_texts.append(
|
| 106 |
+
f'- {utils.seconds_to_hms(int(segment['start']), drop_hours=True)}'
|
| 107 |
+
f'-{utils.seconds_to_hms(int(segment['end']), drop_hours=True)}: {segment['text']}')
|
| 108 |
+
observation_image_texts = []
|
| 109 |
+
for frame_path in result['frame_paths'][::5]: # take every 5th frame for brevity
|
| 110 |
+
observation_image_texts.append(f'<observation_image>{frame_path}</observation_image>')
|
| 111 |
+
|
| 112 |
+
output += f'''<video_segment>
|
| 113 |
+
Timespan: {timespan_text}
|
| 114 |
+
Transcript:
|
| 115 |
+
{'\n'.join(transcript_texts)}
|
| 116 |
+
Frame images: {' '.join(observation_image_texts)}
|
| 117 |
+
</video_segment>\n'''
|
| 118 |
+
|
| 119 |
+
return output
|
| 120 |
+
|
| 121 |
+
return [add_video, search_in_video]
|
| 122 |
+
|
transcriber.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from google import genai
|
| 5 |
+
from google.genai import types
|
| 6 |
+
|
| 7 |
+
class AudioTranscriber:
|
| 8 |
+
"""A class to transcribe audio files"""
|
| 9 |
+
|
| 10 |
+
SYSTEM_INSTRUCTION = '''You are an advanced audio transcription model. Your task is to accurately transcribe provided audio input into a structured JSON format.
|
| 11 |
+
|
| 12 |
+
**Output Format Specification:**
|
| 13 |
+
|
| 14 |
+
Your response MUST be a valid JSON object with the following structure:
|
| 15 |
+
|
| 16 |
+
```json
|
| 17 |
+
{
|
| 18 |
+
"segments": [
|
| 19 |
+
{
|
| 20 |
+
"text": "The transcribed text for the segment.",
|
| 21 |
+
"start": "The start time of the segment in seconds.",
|
| 22 |
+
"end": "The end time of the segment in seconds.",
|
| 23 |
+
"speaker": "The speaker ID for the segment."
|
| 24 |
+
}
|
| 25 |
+
],
|
| 26 |
+
"language": "The language of the transcribed text in ISO 639-1 format."
|
| 27 |
+
}
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
**Detailed Instructions and Rules:**
|
| 31 |
+
|
| 32 |
+
1. Segments:
|
| 33 |
+
- A "segment" is defined as a continuous section of speech from a single speaker include multiple sentences or phrases.
|
| 34 |
+
- Each segment object MUST contain `text`, `start`, `end`, and `speaker` fields.
|
| 35 |
+
- `text`: The verbatim transcription of the speech within that segment.
|
| 36 |
+
- `start`: The precise start time of the segment in seconds, represented as a floating-point number (e.g., 0.0, 5.25).
|
| 37 |
+
- `end`: The precise end time of the segment in seconds, represented as a floating-point number (e.g., 4.9, 10.12).
|
| 38 |
+
- `speaker`: An integer representing the speaker ID.
|
| 39 |
+
+ Speaker IDs start at `0` for the first detected speaker.
|
| 40 |
+
+ The speaker ID MUST increment by 1 each time a new, distinct speaker is identified in the audio. Do not reuse speaker IDs within the same transcription.
|
| 41 |
+
+ If the same speaker talks again after another speaker, they retain their original speaker ID.
|
| 42 |
+
+ **Segment Splitting Rule**: A segment for the same speaker should only be split if there is a period of silence lasting more than 5 seconds. Otherwise, continuous speech from the same speaker, even with short pauses, should remain within a single segment.
|
| 43 |
+
|
| 44 |
+
2. Language:
|
| 45 |
+
- `language`: A two-letter ISO 639-1 code representing the primary language of the transcribed text (e.g., "en" for English, "es" for Spanish, "fr" for French).
|
| 46 |
+
- If multiple languages are detected in the audio, you MUST select and output only the ISO 639-1 code for the primary language used throughout the audio.
|
| 47 |
+
'''
|
| 48 |
+
|
| 49 |
+
RESPONSE_SCHEMA = {
|
| 50 |
+
'type': 'object',
|
| 51 |
+
'properties': {
|
| 52 |
+
'segments': {
|
| 53 |
+
'type': 'array',
|
| 54 |
+
"description": 'A list of transcribed segments from the audio file.',
|
| 55 |
+
'items': {
|
| 56 |
+
'type': 'object',
|
| 57 |
+
'properties': {
|
| 58 |
+
'text': {
|
| 59 |
+
'type': 'string',
|
| 60 |
+
'description': 'The transcribed text for the segment.'
|
| 61 |
+
},
|
| 62 |
+
'start': {
|
| 63 |
+
'type': 'number',
|
| 64 |
+
'description': 'The start time of the segment in seconds.'
|
| 65 |
+
},
|
| 66 |
+
'end': {
|
| 67 |
+
'type': 'number',
|
| 68 |
+
'description': 'The end time of the segment in seconds.'
|
| 69 |
+
},
|
| 70 |
+
'speaker': {
|
| 71 |
+
'type': 'integer',
|
| 72 |
+
'description': 'The speaker ID for the segment.'
|
| 73 |
+
}
|
| 74 |
+
},
|
| 75 |
+
'required': ['text', 'start', 'end', 'speaker'],
|
| 76 |
+
'propertyOrdering': ['text', 'start', 'end', 'speaker']
|
| 77 |
+
},
|
| 78 |
+
},
|
| 79 |
+
'language': {
|
| 80 |
+
'type': 'string',
|
| 81 |
+
'description': 'The language of the transcribed text in ISO 639-1 format.',
|
| 82 |
+
}
|
| 83 |
+
},
|
| 84 |
+
'required': ['segments', 'language'],
|
| 85 |
+
'propertyOrdering': ['segments', 'language']
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
def __init__(self, model: str = 'gemini-2.0-flash', api_key: str = None):
|
| 89 |
+
self.model = model
|
| 90 |
+
self.client = genai.Client(api_key=api_key)
|
| 91 |
+
|
| 92 |
+
def transcribe(self, audio_path: str) -> dict[str, Any]:
|
| 93 |
+
"""Transcribe an audio file from the given path.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
audio_path (str): The path to the audio file to be transcribed.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
dict[str, Any]: The transcription result.
|
| 100 |
+
```{
|
| 101 |
+
"segments": [
|
| 102 |
+
{
|
| 103 |
+
"text": "Transcribed text",
|
| 104 |
+
"start": 0.0,
|
| 105 |
+
"end": 5.0,
|
| 106 |
+
"speaker": 0
|
| 107 |
+
}
|
| 108 |
+
],
|
| 109 |
+
"language": "en"
|
| 110 |
+
}```
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
uploaded_file = self.client.files.upload(file=audio_path)
|
| 114 |
+
while uploaded_file.state != 'ACTIVE':
|
| 115 |
+
time.sleep(1)
|
| 116 |
+
uploaded_file = self.client.files.get(name=uploaded_file.name)
|
| 117 |
+
if uploaded_file.state == 'FAILED':
|
| 118 |
+
raise ValueError('Failed to upload the audio file')
|
| 119 |
+
|
| 120 |
+
response = self.client.models.generate_content(
|
| 121 |
+
model=self.model,
|
| 122 |
+
contents=uploaded_file,
|
| 123 |
+
config=types.GenerateContentConfig(
|
| 124 |
+
system_instruction=self.SYSTEM_INSTRUCTION,
|
| 125 |
+
temperature=0.2,
|
| 126 |
+
response_mime_type='application/json',
|
| 127 |
+
response_schema=self.RESPONSE_SCHEMA,
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if response.parsed is None:
|
| 132 |
+
raise ValueError('Failed to transcribe the audio file')
|
| 133 |
+
|
| 134 |
+
return response.parsed # type: ignore
|
utils.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os.path
|
| 3 |
+
import subprocess
|
| 4 |
+
|
| 5 |
+
from yt_dlp import YoutubeDL
|
| 6 |
+
|
| 7 |
+
from configs import settings
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def download_video(
|
| 11 |
+
url: str,
|
| 12 |
+
output_dir: str = None,
|
| 13 |
+
max_resolution: int = 1080,
|
| 14 |
+
max_fps: float = 60,
|
| 15 |
+
extension: str = 'mp4'
|
| 16 |
+
) -> tuple[str, dict]:
|
| 17 |
+
"""Download a video from YouTube or other supported sites. Returns the file path and video metadata.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
url (str): The URL of the video.
|
| 21 |
+
output_dir (str, optional): Directory to save the downloaded video. Defaults to current directory.
|
| 22 |
+
max_resolution (int, optional): Maximum resolution of the video to download. Defaults to 1080.
|
| 23 |
+
max_fps (float, optional): Maximum frames per second of the video to download. Defaults to 60.
|
| 24 |
+
extension (str, optional): File extension for the downloaded video. Defaults to 'mp4'.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
tuple[str, dict]: A tuple containing the path to the downloaded video file and its metadata.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
ydl_opts = {
|
| 31 |
+
'format': f'bestvideo[height<={max_resolution}][fps<={max_fps}][ext={extension}]+'
|
| 32 |
+
f'bestaudio/best[height<={max_resolution}][fps<={max_fps}][ext={extension}]/best',
|
| 33 |
+
'merge_output_format': extension,
|
| 34 |
+
'outtmpl': f'{output_dir or "."}/%(title)s.%(ext)s',
|
| 35 |
+
'noplaylist': True,
|
| 36 |
+
}
|
| 37 |
+
with YoutubeDL(ydl_opts) as ydl:
|
| 38 |
+
info = ydl.extract_info(url, download=True)
|
| 39 |
+
ydl.download([url])
|
| 40 |
+
if output_dir:
|
| 41 |
+
output_path = os.path.join(output_dir, ydl.prepare_filename(info))
|
| 42 |
+
else:
|
| 43 |
+
output_path = ydl.prepare_filename(info)
|
| 44 |
+
|
| 45 |
+
return output_path, info
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def extract_video_frames(video_path: str, output_dir: str, frame_rate: float = 1, extension: str = 'jpg') -> list[str]:
|
| 49 |
+
"""Extract frames from a video file at a specified frame rate.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
video_path (str): Path to the video file.
|
| 53 |
+
output_dir (str): Directory to save the extracted frames.
|
| 54 |
+
frame_rate (float, optional): Frame rate for extraction. Defaults to 1 frame per second.
|
| 55 |
+
extension (str, optional): File extension for the extracted frames. Defaults to 'jpg'.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
list[str]: A sorted list of paths to the extracted frame images.
|
| 59 |
+
"""
|
| 60 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
subprocess.run(
|
| 63 |
+
[
|
| 64 |
+
settings.FFMPEG_PATH,
|
| 65 |
+
# '-v', 'quiet',
|
| 66 |
+
'-i', video_path,
|
| 67 |
+
'-vf', f'fps={frame_rate}',
|
| 68 |
+
'-y',
|
| 69 |
+
f'{output_dir or "."}/%d.{extension}'
|
| 70 |
+
],
|
| 71 |
+
stdout=subprocess.DEVNULL,
|
| 72 |
+
stderr=subprocess.DEVNULL,
|
| 73 |
+
)
|
| 74 |
+
# Get all extracted frames
|
| 75 |
+
results = sorted(glob.glob(f'{output_dir or "."}/*.{extension}'),
|
| 76 |
+
key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
| 77 |
+
if not results:
|
| 78 |
+
raise FileNotFoundError(f'No frames found in "{output_dir}" for video "{video_path}"')
|
| 79 |
+
|
| 80 |
+
return results
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def extract_audio(video_path: str, output_dir: str = None, extension: str = 'm4a') -> str:
|
| 84 |
+
"""Extract audio from a video file and save it as an M4A file.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
video_path (str): Path to the video file.
|
| 88 |
+
output_dir (str, optional): Directory to save the extracted audio. Defaults to the same directory as the video.
|
| 89 |
+
extension (str, optional): File extension for the extracted audio. Defaults to 'm4a'.
|
| 90 |
+
Returns:
|
| 91 |
+
str: Path to the extracted audio file.
|
| 92 |
+
"""
|
| 93 |
+
if output_dir is None:
|
| 94 |
+
output_dir = os.path.dirname(video_path)
|
| 95 |
+
|
| 96 |
+
audio_path = os.path.join(output_dir, f'{os.path.splitext(os.path.basename(video_path))[0]}.{extension}')
|
| 97 |
+
|
| 98 |
+
subprocess.run(
|
| 99 |
+
[
|
| 100 |
+
settings.FFMPEG_PATH,
|
| 101 |
+
'-i', video_path,
|
| 102 |
+
'-q:a', '0',
|
| 103 |
+
'-map', 'a',
|
| 104 |
+
'-y',
|
| 105 |
+
audio_path
|
| 106 |
+
],
|
| 107 |
+
stdout=subprocess.DEVNULL,
|
| 108 |
+
stderr=subprocess.DEVNULL,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if not os.path.exists(audio_path):
|
| 112 |
+
raise FileNotFoundError(f'Audio extraction failed: "{audio_path}" does not exist.')
|
| 113 |
+
|
| 114 |
+
return audio_path
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def split_media_file(file_path: str, output_dir: str, segment_length: int = 60) -> list[str]:
|
| 118 |
+
"""Split a media file into segments of specified length in seconds.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
file_path (str): Path to the media file to be split.
|
| 122 |
+
output_dir (str): Directory to save the split segments.
|
| 123 |
+
segment_length (int, optional): Length of each segment in seconds. Defaults to 60 seconds.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
list[str]: A sorted list of paths to the split media segments.
|
| 127 |
+
"""
|
| 128 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 129 |
+
|
| 130 |
+
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
| 131 |
+
extension = os.path.splitext(file_path)[1]
|
| 132 |
+
segment_pattern = os.path.join(output_dir, f'{base_name}_%03d.{extension}')
|
| 133 |
+
|
| 134 |
+
subprocess.run(
|
| 135 |
+
[
|
| 136 |
+
settings.FFMPEG_PATH,
|
| 137 |
+
'-i', file_path,
|
| 138 |
+
'-c', 'copy',
|
| 139 |
+
'-map', '0',
|
| 140 |
+
'-segment_time', str(segment_length),
|
| 141 |
+
'-f', 'segment',
|
| 142 |
+
'-y',
|
| 143 |
+
segment_pattern
|
| 144 |
+
],
|
| 145 |
+
stdout=subprocess.DEVNULL,
|
| 146 |
+
stderr=subprocess.DEVNULL,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return sorted(glob.glob(f'{output_dir}/*{base_name}_*.{extension}'))
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def span_iou(span1: tuple[float, float], span2: tuple[float, float]) -> float:
|
| 153 |
+
"""Calculate the Intersection over Union (IoU) of two spans."""
|
| 154 |
+
start1, end1 = span1
|
| 155 |
+
start2, end2 = span2
|
| 156 |
+
|
| 157 |
+
intersection_start = max(start1, start2)
|
| 158 |
+
intersection_end = min(end1, end2)
|
| 159 |
+
|
| 160 |
+
if intersection_start >= intersection_end:
|
| 161 |
+
return 0.0 # No overlap
|
| 162 |
+
|
| 163 |
+
intersection_length = intersection_end - intersection_start
|
| 164 |
+
union_length = (end1 - start1) + (end2 - start2) - intersection_length
|
| 165 |
+
|
| 166 |
+
return intersection_length / union_length if union_length > 0 else 0.0
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def seconds_to_hms(total_seconds: int, drop_hours: bool = False) -> str:
|
| 170 |
+
"""Convert a number of seconds to a string formatted as HH:MM:SS."""
|
| 171 |
+
# Ensure we’re working with non-negative integers
|
| 172 |
+
if total_seconds < 0:
|
| 173 |
+
raise ValueError('total_seconds must be non-negative')
|
| 174 |
+
|
| 175 |
+
hours, remainder = divmod(total_seconds, 3600)
|
| 176 |
+
minutes, seconds = divmod(remainder, 60)
|
| 177 |
+
|
| 178 |
+
if drop_hours and hours == 0:
|
| 179 |
+
return f'{minutes:02d}:{seconds:02d}'
|
| 180 |
+
|
| 181 |
+
return f'{hours:02d}:{minutes:02d}:{seconds:02d}'
|