|
|
from transformers import pipeline |
|
|
import json |
|
|
|
|
|
class ToolCallingAgent: |
|
|
def __init__(self): |
|
|
self.model = pipeline( |
|
|
"text-generation", |
|
|
model="cognitivecomputations/dolphin-2.9-llama3-8b", |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
def generate(self, prompt, tools): |
|
|
tools_json = json.dumps(tools, ensure_ascii=False) |
|
|
system_msg = f"""You are an AI assistant that can call tools. |
|
|
Available tools: {tools_json} |
|
|
Respond ONLY with a valid JSON containing keys 'tool_name' and 'parameters'.""" |
|
|
|
|
|
|
|
|
full_prompt = f"<|system|>{system_msg}</s><|user|>{prompt}</s>" |
|
|
|
|
|
response = self.model( |
|
|
full_prompt, |
|
|
max_new_tokens=200, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
text = response[0]['generated_text'] |
|
|
|
|
|
|
|
|
json_start = text.find("{") |
|
|
json_end = text.rfind("}") + 1 |
|
|
if json_start == -1 or json_end == -1: |
|
|
return {"error": "No JSON found in model output", "raw_output": text} |
|
|
|
|
|
json_text = text[json_start:json_end] |
|
|
|
|
|
try: |
|
|
return json.loads(json_text) |
|
|
except json.JSONDecodeError as e: |
|
|
return { |
|
|
"error": "Failed to parse JSON", |
|
|
"message": str(e), |
|
|
"raw_output": text, |
|
|
"extracted_json": json_text |
|
|
} |
|
|
|