|
|
""" |
|
|
Anthropic-Compatible API Endpoint |
|
|
Lightweight CPU-based implementation for Hugging Face Spaces |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import uuid |
|
|
import logging |
|
|
from datetime import datetime |
|
|
from logging.handlers import RotatingFileHandler |
|
|
from typing import List, Optional, Union |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Header, Request |
|
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
from threading import Thread |
|
|
import json |
|
|
|
|
|
|
|
|
LOG_DIR = "/tmp/logs" |
|
|
os.makedirs(LOG_DIR, exist_ok=True) |
|
|
LOG_FILE = os.path.join(LOG_DIR, "api.log") |
|
|
|
|
|
|
|
|
log_format = logging.Formatter( |
|
|
'%(asctime)s | %(levelname)-8s | %(name)s | %(message)s', |
|
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
|
) |
|
|
|
|
|
|
|
|
file_handler = RotatingFileHandler( |
|
|
LOG_FILE, |
|
|
maxBytes=10*1024*1024, |
|
|
backupCount=5, |
|
|
encoding='utf-8' |
|
|
) |
|
|
file_handler.setFormatter(log_format) |
|
|
file_handler.setLevel(logging.DEBUG) |
|
|
|
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
|
console_handler.setFormatter(log_format) |
|
|
console_handler.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG, handlers=[file_handler, console_handler]) |
|
|
logger = logging.getLogger("anthropic-api") |
|
|
|
|
|
|
|
|
for uvicorn_logger in ["uvicorn", "uvicorn.error", "uvicorn.access"]: |
|
|
uv_log = logging.getLogger(uvicorn_logger) |
|
|
uv_log.handlers = [file_handler, console_handler] |
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info(f"Application Startup at {datetime.now().isoformat()}") |
|
|
logger.info(f"Log file: {LOG_FILE}") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
|
|
|
MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct" |
|
|
MAX_TOKENS_DEFAULT = 1024 |
|
|
DEVICE = "cpu" |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Load model on startup""" |
|
|
global model, tokenizer |
|
|
logger.info(f"Loading model: {MODEL_ID}") |
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
logger.info("Tokenizer loaded successfully") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.float32, |
|
|
device_map=DEVICE, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
model.eval() |
|
|
logger.info("Model loaded successfully!") |
|
|
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
yield |
|
|
|
|
|
|
|
|
logger.info("Shutting down, cleaning up model...") |
|
|
del model, tokenizer |
|
|
|
|
|
app = FastAPI( |
|
|
title="Anthropic-Compatible API", |
|
|
description="Lightweight CPU-based API with Anthropic Messages API compatibility", |
|
|
version="1.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
@app.middleware("http") |
|
|
async def log_requests(request: Request, call_next): |
|
|
request_id = str(uuid.uuid4())[:8] |
|
|
start_time = time.time() |
|
|
|
|
|
logger.info(f"[{request_id}] {request.method} {request.url.path} - Started") |
|
|
|
|
|
try: |
|
|
response = await call_next(request) |
|
|
duration = (time.time() - start_time) * 1000 |
|
|
logger.info(f"[{request_id}] {request.method} {request.url.path} - {response.status_code} ({duration:.2f}ms)") |
|
|
return response |
|
|
except Exception as e: |
|
|
duration = (time.time() - start_time) * 1000 |
|
|
logger.error(f"[{request_id}] {request.method} {request.url.path} - Error: {e} ({duration:.2f}ms)") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
class ContentBlock(BaseModel): |
|
|
type: str = "text" |
|
|
text: str |
|
|
|
|
|
class Message(BaseModel): |
|
|
role: str |
|
|
content: Union[str, List[ContentBlock]] |
|
|
|
|
|
class MessageRequest(BaseModel): |
|
|
model: str |
|
|
messages: List[Message] |
|
|
max_tokens: int = MAX_TOKENS_DEFAULT |
|
|
temperature: Optional[float] = 0.7 |
|
|
top_p: Optional[float] = 0.9 |
|
|
top_k: Optional[int] = 50 |
|
|
stream: Optional[bool] = False |
|
|
system: Optional[str] = None |
|
|
stop_sequences: Optional[List[str]] = None |
|
|
|
|
|
class Usage(BaseModel): |
|
|
input_tokens: int |
|
|
output_tokens: int |
|
|
|
|
|
class MessageResponse(BaseModel): |
|
|
id: str |
|
|
type: str = "message" |
|
|
role: str = "assistant" |
|
|
content: List[ContentBlock] |
|
|
model: str |
|
|
stop_reason: str = "end_turn" |
|
|
stop_sequence: Optional[str] = None |
|
|
usage: Usage |
|
|
|
|
|
class ErrorResponse(BaseModel): |
|
|
type: str = "error" |
|
|
error: dict |
|
|
|
|
|
|
|
|
|
|
|
def format_messages(messages: List[Message], system: Optional[str] = None) -> str: |
|
|
"""Format messages into a prompt string""" |
|
|
formatted_messages = [] |
|
|
|
|
|
if system: |
|
|
formatted_messages.append({"role": "system", "content": system}) |
|
|
|
|
|
for msg in messages: |
|
|
content = msg.content |
|
|
if isinstance(content, list): |
|
|
content = " ".join([block.text for block in content if block.type == "text"]) |
|
|
formatted_messages.append({"role": msg.role, "content": content}) |
|
|
|
|
|
|
|
|
if tokenizer.chat_template: |
|
|
return tokenizer.apply_chat_template( |
|
|
formatted_messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
prompt = "" |
|
|
for msg in formatted_messages: |
|
|
role = msg["role"].capitalize() |
|
|
prompt += f"{role}: {msg['content']}\n" |
|
|
prompt += "Assistant: " |
|
|
return prompt |
|
|
|
|
|
def generate_id() -> str: |
|
|
"""Generate a unique message ID""" |
|
|
return f"msg_{uuid.uuid4().hex[:24]}" |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Health check endpoint""" |
|
|
logger.debug("Root endpoint accessed") |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model": MODEL_ID, |
|
|
"api_version": "2023-06-01", |
|
|
"compatibility": "anthropic-messages-api", |
|
|
"log_file": LOG_FILE |
|
|
} |
|
|
|
|
|
@app.get("/v1/models") |
|
|
async def list_models(): |
|
|
"""List available models (Anthropic-compatible)""" |
|
|
logger.debug("Models list requested") |
|
|
return { |
|
|
"object": "list", |
|
|
"data": [ |
|
|
{ |
|
|
"id": "smollm2-135m", |
|
|
"object": "model", |
|
|
"created": int(time.time()), |
|
|
"owned_by": "huggingface", |
|
|
"display_name": "SmolLM2 135M Instruct" |
|
|
} |
|
|
] |
|
|
} |
|
|
|
|
|
@app.get("/logs") |
|
|
async def get_logs(lines: int = 100): |
|
|
"""Get recent log entries""" |
|
|
try: |
|
|
with open(LOG_FILE, 'r') as f: |
|
|
all_lines = f.readlines() |
|
|
recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines |
|
|
return { |
|
|
"log_file": LOG_FILE, |
|
|
"total_lines": len(all_lines), |
|
|
"returned_lines": len(recent_lines), |
|
|
"logs": "".join(recent_lines) |
|
|
} |
|
|
except FileNotFoundError: |
|
|
return {"error": "Log file not found", "log_file": LOG_FILE} |
|
|
|
|
|
@app.post("/v1/messages") |
|
|
async def create_message( |
|
|
request: MessageRequest, |
|
|
x_api_key: Optional[str] = Header(None, alias="x-api-key"), |
|
|
anthropic_version: Optional[str] = Header(None, alias="anthropic-version") |
|
|
): |
|
|
""" |
|
|
Create a message (Anthropic Messages API compatible) |
|
|
""" |
|
|
message_id = generate_id() |
|
|
logger.info(f"[{message_id}] Creating message - model: {request.model}, max_tokens: {request.max_tokens}, stream: {request.stream}") |
|
|
|
|
|
try: |
|
|
|
|
|
prompt = format_messages(request.messages, request.system) |
|
|
logger.debug(f"[{message_id}] Prompt length: {len(prompt)} chars") |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
|
|
input_token_count = inputs.input_ids.shape[1] |
|
|
logger.info(f"[{message_id}] Input tokens: {input_token_count}") |
|
|
|
|
|
if request.stream: |
|
|
logger.info(f"[{message_id}] Starting streaming response") |
|
|
return await stream_response(request, inputs, input_token_count, message_id) |
|
|
|
|
|
|
|
|
gen_start = time.time() |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=request.max_tokens, |
|
|
temperature=request.temperature if request.temperature > 0 else 1.0, |
|
|
top_p=request.top_p, |
|
|
top_k=request.top_k, |
|
|
do_sample=request.temperature > 0, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
gen_time = time.time() - gen_start |
|
|
|
|
|
|
|
|
generated_tokens = outputs[0][input_token_count:] |
|
|
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
output_token_count = len(generated_tokens) |
|
|
|
|
|
tokens_per_sec = output_token_count / gen_time if gen_time > 0 else 0 |
|
|
logger.info(f"[{message_id}] Generated {output_token_count} tokens in {gen_time:.2f}s ({tokens_per_sec:.1f} tok/s)") |
|
|
|
|
|
|
|
|
response = MessageResponse( |
|
|
id=message_id, |
|
|
content=[ContentBlock(type="text", text=generated_text.strip())], |
|
|
model=request.model, |
|
|
stop_reason="end_turn", |
|
|
usage=Usage( |
|
|
input_tokens=input_token_count, |
|
|
output_tokens=output_token_count |
|
|
) |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[{message_id}] Error creating message: {e}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
async def stream_response(request: MessageRequest, inputs, input_token_count: int, message_id: str): |
|
|
"""Stream response using SSE (Server-Sent Events)""" |
|
|
|
|
|
async def generate(): |
|
|
|
|
|
start_event = { |
|
|
"type": "message_start", |
|
|
"message": { |
|
|
"id": message_id, |
|
|
"type": "message", |
|
|
"role": "assistant", |
|
|
"content": [], |
|
|
"model": request.model, |
|
|
"stop_reason": None, |
|
|
"stop_sequence": None, |
|
|
"usage": {"input_tokens": input_token_count, "output_tokens": 0} |
|
|
} |
|
|
} |
|
|
yield f"event: message_start\ndata: {json.dumps(start_event)}\n\n" |
|
|
|
|
|
|
|
|
block_start = { |
|
|
"type": "content_block_start", |
|
|
"index": 0, |
|
|
"content_block": {"type": "text", "text": ""} |
|
|
} |
|
|
yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" |
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
generation_kwargs = { |
|
|
**inputs, |
|
|
"max_new_tokens": request.max_tokens, |
|
|
"temperature": request.temperature if request.temperature > 0 else 1.0, |
|
|
"top_p": request.top_p, |
|
|
"top_k": request.top_k, |
|
|
"do_sample": request.temperature > 0, |
|
|
"pad_token_id": tokenizer.eos_token_id, |
|
|
"eos_token_id": tokenizer.eos_token_id, |
|
|
"streamer": streamer, |
|
|
} |
|
|
|
|
|
|
|
|
gen_start = time.time() |
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
output_tokens = 0 |
|
|
for text in streamer: |
|
|
if text: |
|
|
output_tokens += len(tokenizer.encode(text, add_special_tokens=False)) |
|
|
delta_event = { |
|
|
"type": "content_block_delta", |
|
|
"index": 0, |
|
|
"delta": {"type": "text_delta", "text": text} |
|
|
} |
|
|
yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n" |
|
|
|
|
|
thread.join() |
|
|
gen_time = time.time() - gen_start |
|
|
tokens_per_sec = output_tokens / gen_time if gen_time > 0 else 0 |
|
|
logger.info(f"[{message_id}] Stream completed: {output_tokens} tokens in {gen_time:.2f}s ({tokens_per_sec:.1f} tok/s)") |
|
|
|
|
|
|
|
|
block_stop = {"type": "content_block_stop", "index": 0} |
|
|
yield f"event: content_block_stop\ndata: {json.dumps(block_stop)}\n\n" |
|
|
|
|
|
|
|
|
delta = { |
|
|
"type": "message_delta", |
|
|
"delta": {"stop_reason": "end_turn", "stop_sequence": None}, |
|
|
"usage": {"output_tokens": output_tokens} |
|
|
} |
|
|
yield f"event: message_delta\ndata: {json.dumps(delta)}\n\n" |
|
|
|
|
|
|
|
|
yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n" |
|
|
|
|
|
return StreamingResponse( |
|
|
generate(), |
|
|
media_type="text/event-stream", |
|
|
headers={ |
|
|
"Cache-Control": "no-cache", |
|
|
"Connection": "keep-alive", |
|
|
"X-Accel-Buffering": "no" |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/v1/messages/count_tokens") |
|
|
async def count_tokens(request: MessageRequest): |
|
|
"""Count tokens for a message request""" |
|
|
prompt = format_messages(request.messages, request.system) |
|
|
tokens = tokenizer.encode(prompt) |
|
|
logger.debug(f"Token count request: {len(tokens)} tokens") |
|
|
return {"input_tokens": len(tokens)} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
return {"status": "ok", "model_loaded": model is not None, "log_file": LOG_FILE} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860, log_config=None) |
|
|
|