""" Anthropic-Compatible API Endpoint Lightweight CPU-based implementation for Hugging Face Spaces """ import os import time import uuid import logging from datetime import datetime from logging.handlers import RotatingFileHandler from typing import List, Optional, Union from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, Header, Request from fastapi.responses import StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread import json # ============== Logging Configuration ============== LOG_DIR = "/tmp/logs" os.makedirs(LOG_DIR, exist_ok=True) LOG_FILE = os.path.join(LOG_DIR, "api.log") # Create formatters log_format = logging.Formatter( '%(asctime)s | %(levelname)-8s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) # File handler with rotation (10MB max, keep 5 backups) file_handler = RotatingFileHandler( LOG_FILE, maxBytes=10*1024*1024, backupCount=5, encoding='utf-8' ) file_handler.setFormatter(log_format) file_handler.setLevel(logging.DEBUG) # Console handler console_handler = logging.StreamHandler() console_handler.setFormatter(log_format) console_handler.setLevel(logging.INFO) # Root logger logging.basicConfig(level=logging.DEBUG, handlers=[file_handler, console_handler]) logger = logging.getLogger("anthropic-api") # Also capture uvicorn logs for uvicorn_logger in ["uvicorn", "uvicorn.error", "uvicorn.access"]: uv_log = logging.getLogger(uvicorn_logger) uv_log.handlers = [file_handler, console_handler] logger.info("=" * 60) logger.info(f"Application Startup at {datetime.now().isoformat()}") logger.info(f"Log file: {LOG_FILE}") logger.info("=" * 60) # ============== Configuration ============== MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct" # Ultra-lightweight 135M model MAX_TOKENS_DEFAULT = 1024 DEVICE = "cpu" # Global model and tokenizer model = None tokenizer = None @asynccontextmanager async def lifespan(app: FastAPI): """Load model on startup""" global model, tokenizer logger.info(f"Loading model: {MODEL_ID}") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) logger.info("Tokenizer loaded successfully") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float32, device_map=DEVICE, low_cpu_mem_usage=True ) model.eval() logger.info("Model loaded successfully!") logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") except Exception as e: logger.error(f"Failed to load model: {e}", exc_info=True) raise yield # Cleanup logger.info("Shutting down, cleaning up model...") del model, tokenizer app = FastAPI( title="Anthropic-Compatible API", description="Lightweight CPU-based API with Anthropic Messages API compatibility", version="1.0.0", lifespan=lifespan ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Request logging middleware @app.middleware("http") async def log_requests(request: Request, call_next): request_id = str(uuid.uuid4())[:8] start_time = time.time() logger.info(f"[{request_id}] {request.method} {request.url.path} - Started") try: response = await call_next(request) duration = (time.time() - start_time) * 1000 logger.info(f"[{request_id}] {request.method} {request.url.path} - {response.status_code} ({duration:.2f}ms)") return response except Exception as e: duration = (time.time() - start_time) * 1000 logger.error(f"[{request_id}] {request.method} {request.url.path} - Error: {e} ({duration:.2f}ms)") raise # ============== Pydantic Models (Anthropic-Compatible) ============== class ContentBlock(BaseModel): type: str = "text" text: str class Message(BaseModel): role: str content: Union[str, List[ContentBlock]] class MessageRequest(BaseModel): model: str messages: List[Message] max_tokens: int = MAX_TOKENS_DEFAULT temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.9 top_k: Optional[int] = 50 stream: Optional[bool] = False system: Optional[str] = None stop_sequences: Optional[List[str]] = None class Usage(BaseModel): input_tokens: int output_tokens: int class MessageResponse(BaseModel): id: str type: str = "message" role: str = "assistant" content: List[ContentBlock] model: str stop_reason: str = "end_turn" stop_sequence: Optional[str] = None usage: Usage class ErrorResponse(BaseModel): type: str = "error" error: dict # ============== Helper Functions ============== def format_messages(messages: List[Message], system: Optional[str] = None) -> str: """Format messages into a prompt string""" formatted_messages = [] if system: formatted_messages.append({"role": "system", "content": system}) for msg in messages: content = msg.content if isinstance(content, list): content = " ".join([block.text for block in content if block.type == "text"]) formatted_messages.append({"role": msg.role, "content": content}) # Use chat template if available if tokenizer.chat_template: return tokenizer.apply_chat_template( formatted_messages, tokenize=False, add_generation_prompt=True ) # Fallback simple format prompt = "" for msg in formatted_messages: role = msg["role"].capitalize() prompt += f"{role}: {msg['content']}\n" prompt += "Assistant: " return prompt def generate_id() -> str: """Generate a unique message ID""" return f"msg_{uuid.uuid4().hex[:24]}" # ============== API Endpoints ============== @app.get("/") async def root(): """Health check endpoint""" logger.debug("Root endpoint accessed") return { "status": "healthy", "model": MODEL_ID, "api_version": "2023-06-01", "compatibility": "anthropic-messages-api", "log_file": LOG_FILE } @app.get("/v1/models") async def list_models(): """List available models (Anthropic-compatible)""" logger.debug("Models list requested") return { "object": "list", "data": [ { "id": "smollm2-135m", "object": "model", "created": int(time.time()), "owned_by": "huggingface", "display_name": "SmolLM2 135M Instruct" } ] } @app.get("/logs") async def get_logs(lines: int = 100): """Get recent log entries""" try: with open(LOG_FILE, 'r') as f: all_lines = f.readlines() recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines return { "log_file": LOG_FILE, "total_lines": len(all_lines), "returned_lines": len(recent_lines), "logs": "".join(recent_lines) } except FileNotFoundError: return {"error": "Log file not found", "log_file": LOG_FILE} @app.post("/v1/messages") async def create_message( request: MessageRequest, x_api_key: Optional[str] = Header(None, alias="x-api-key"), anthropic_version: Optional[str] = Header(None, alias="anthropic-version") ): """ Create a message (Anthropic Messages API compatible) """ message_id = generate_id() logger.info(f"[{message_id}] Creating message - model: {request.model}, max_tokens: {request.max_tokens}, stream: {request.stream}") try: # Format the prompt prompt = format_messages(request.messages, request.system) logger.debug(f"[{message_id}] Prompt length: {len(prompt)} chars") # Tokenize inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) input_token_count = inputs.input_ids.shape[1] logger.info(f"[{message_id}] Input tokens: {input_token_count}") if request.stream: logger.info(f"[{message_id}] Starting streaming response") return await stream_response(request, inputs, input_token_count, message_id) # Generate gen_start = time.time() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=request.max_tokens, temperature=request.temperature if request.temperature > 0 else 1.0, top_p=request.top_p, top_k=request.top_k, do_sample=request.temperature > 0, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) gen_time = time.time() - gen_start # Decode only new tokens generated_tokens = outputs[0][input_token_count:] generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) output_token_count = len(generated_tokens) tokens_per_sec = output_token_count / gen_time if gen_time > 0 else 0 logger.info(f"[{message_id}] Generated {output_token_count} tokens in {gen_time:.2f}s ({tokens_per_sec:.1f} tok/s)") # Build response response = MessageResponse( id=message_id, content=[ContentBlock(type="text", text=generated_text.strip())], model=request.model, stop_reason="end_turn", usage=Usage( input_tokens=input_token_count, output_tokens=output_token_count ) ) return response except Exception as e: logger.error(f"[{message_id}] Error creating message: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) async def stream_response(request: MessageRequest, inputs, input_token_count: int, message_id: str): """Stream response using SSE (Server-Sent Events)""" async def generate(): # Send message_start event start_event = { "type": "message_start", "message": { "id": message_id, "type": "message", "role": "assistant", "content": [], "model": request.model, "stop_reason": None, "stop_sequence": None, "usage": {"input_tokens": input_token_count, "output_tokens": 0} } } yield f"event: message_start\ndata: {json.dumps(start_event)}\n\n" # Send content_block_start block_start = { "type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""} } yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" # Setup streamer streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "max_new_tokens": request.max_tokens, "temperature": request.temperature if request.temperature > 0 else 1.0, "top_p": request.top_p, "top_k": request.top_k, "do_sample": request.temperature > 0, "pad_token_id": tokenizer.eos_token_id, "eos_token_id": tokenizer.eos_token_id, "streamer": streamer, } # Run generation in a thread gen_start = time.time() thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() output_tokens = 0 for text in streamer: if text: output_tokens += len(tokenizer.encode(text, add_special_tokens=False)) delta_event = { "type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": text} } yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n" thread.join() gen_time = time.time() - gen_start tokens_per_sec = output_tokens / gen_time if gen_time > 0 else 0 logger.info(f"[{message_id}] Stream completed: {output_tokens} tokens in {gen_time:.2f}s ({tokens_per_sec:.1f} tok/s)") # Send content_block_stop block_stop = {"type": "content_block_stop", "index": 0} yield f"event: content_block_stop\ndata: {json.dumps(block_stop)}\n\n" # Send message_delta delta = { "type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence": None}, "usage": {"output_tokens": output_tokens} } yield f"event: message_delta\ndata: {json.dumps(delta)}\n\n" # Send message_stop yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" } ) # Token counting endpoint @app.post("/v1/messages/count_tokens") async def count_tokens(request: MessageRequest): """Count tokens for a message request""" prompt = format_messages(request.messages, request.system) tokens = tokenizer.encode(prompt) logger.debug(f"Token count request: {len(tokens)} tokens") return {"input_tokens": len(tokens)} # Health check @app.get("/health") async def health(): return {"status": "ok", "model_loaded": model is not None, "log_file": LOG_FILE} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860, log_config=None)