GAIA_agent / tools.py
ItzRoBeerT's picture
Added describe image tool
86c6428
raw
history blame
3.26 kB
from smolagents import Tool, DuckDuckGoSearchTool, PythonInterpreterTool, VisitWebpageTool, WikipediaSearchTool
from openai import OpenAI
import whisper
import base64
import os
class read_file(Tool):
name="read_file"
description="Read a file and return the content."
inputs={
"file_path": {
"type": "string",
"description": "The path to the file to read."
}
}
output_type = "string"
def forward(self, file_path: str) -> str:
"""
Read the content of a file and return it as a string.
"""
try:
with open(file_path, 'r') as file:
content = file.read()
return content
except Exception as e:
return f"Error reading file: {str(e)}"
class transcribe_audio(Tool):
name="transcribe_audio"
description="Transcribe an audio file and return the text."
inputs={
"audio_path": {
"type": "string",
"description": "The path to the audio file to transcribe."
}
}
output_type = "string"
def forward(self, audio_path: str) -> str:
try:
# Load the Whisper model
model = whisper.load_model("small")
# Transcribe the audio file
result = model.transcribe(audio_path)
return result['text']
except Exception as e:
return f"Error transcribing audio: {str(e)}"
def get_data_uri(image_path: str, base64_image: str):
_, file_extension = os.path.splitext(image_path)
file_extension = file_extension.lower().lstrip(".")
mime_type = f"image/{file_extension}"
data_uri = f"data:{mime_type};base64,{base64_image}"
return data_uri
class describe_image(Tool):
name="describe_image"
description="Describe an image and return the description."
inputs={
"image_path": {
"type": "string",
"description": "The path to the image file to describe."
}
}
output_type = "string"
def forward(self, image_path: str) -> str:
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise ValueError("OpenAI API key not provided and OPENAI_API_KEY environment variable not set")
base_url = os.getenv("OPENROUTER_BASE_URL")
client = OpenAI(api_key=api_key, base_url=base_url)
try:
with open(image_path, 'rb') as image_file:
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
data_uri = get_data_uri(image_path, base64_image)
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image in detail. Include information about the main subject, setting, colors, and any notable elements."},
{
"type": "image_url",
"image_url": {"url": data_uri}
}
]
}
],
max_tokens=500
)
return response.choices[0].message.content
except Exception as e:
return f"Error describing image: {str(e)}"
def return_tools() -> list[Tool]:
"""
Returns a list of tools to be used by the agent.
"""
return [
read_file(),
transcribe_audio(),
describe_image(),
DuckDuckGoSearchTool(),
PythonInterpreterTool(),
VisitWebpageTool(),
WikipediaSearchTool(),
]