radius / tools /tool_agent.py
Batnini's picture
Update tools/tool_agent.py
f3ad633 verified
from transformers import pipeline
import json
class ToolCallingAgent:
def __init__(self):
self.model = pipeline(
"text-generation",
model="cognitivecomputations/dolphin-2.9-llama3-8b",
device_map="auto"
)
def generate(self, prompt, tools):
tools_json = json.dumps(tools, ensure_ascii=False)
system_msg = f"""You are an AI assistant that can call tools.
Available tools: {tools_json}
Respond ONLY with a valid JSON containing keys 'tool_name' and 'parameters'."""
# Construct prompt with system and user tokens (assuming model supports these)
full_prompt = f"<|system|>{system_msg}</s><|user|>{prompt}</s>"
response = self.model(
full_prompt,
max_new_tokens=200,
do_sample=False # deterministic output for better JSON consistency
)
text = response[0]['generated_text']
# Extract JSON substring between first '{' and last '}'
json_start = text.find("{")
json_end = text.rfind("}") + 1
if json_start == -1 or json_end == -1:
return {"error": "No JSON found in model output", "raw_output": text}
json_text = text[json_start:json_end]
try:
return json.loads(json_text)
except json.JSONDecodeError as e:
return {
"error": "Failed to parse JSON",
"message": str(e),
"raw_output": text,
"extracted_json": json_text
}