Spaces:
Running
Running
| """ | |
| GambitFlow Bridge API - HuggingFace Space | |
| Unified API gateway with Firebase analytics and rate limiting | |
| """ | |
| from flask import Flask, request, jsonify, Response | |
| from flask_cors import CORS | |
| import requests | |
| import time | |
| import os | |
| from functools import wraps | |
| import firebase_admin | |
| from firebase_admin import credentials, db | |
| import json | |
| app = Flask(__name__) | |
| CORS(app) | |
| # ==================== FIREBASE SETUP ==================== | |
| def initialize_firebase(): | |
| """Initialize Firebase Admin SDK""" | |
| try: | |
| # Load credentials from environment variable | |
| firebase_creds = os.getenv('FIREBASE_CREDENTIALS') | |
| if firebase_creds: | |
| cred_dict = json.loads(firebase_creds) | |
| cred = credentials.Certificate(cred_dict) | |
| else: | |
| # Fallback to service account file | |
| cred = credentials.Certificate('firebase-credentials.json') | |
| firebase_admin.initialize_app(cred, { | |
| 'databaseURL': os.getenv('FIREBASE_DATABASE_URL', 'https://chess-web-78351-default-rtdb.asia-southeast1.firebasedatabase.app') | |
| }) | |
| print("✅ Firebase initialized successfully") | |
| except Exception as e: | |
| print(f"⚠️ Firebase initialization failed: {e}") | |
| # Initialize Firebase | |
| initialize_firebase() | |
| # ==================== MODEL CONFIGURATION ==================== | |
| MODELS = { | |
| 'nano': { | |
| 'name': 'Nexus-Nano', | |
| 'endpoint': os.getenv('NANO_ENDPOINT', 'https://gambitflow-nexus-nano-inference-api.hf.space'), | |
| 'timeout': 30 | |
| }, | |
| 'core': { | |
| 'name': 'Nexus-Core', | |
| 'endpoint': os.getenv('CORE_ENDPOINT', 'https://gambitflow-nexus-core-inference-api.hf.space'), | |
| 'timeout': 40 | |
| }, | |
| 'base': { | |
| 'name': 'Synapse-Base', | |
| 'endpoint': os.getenv('BASE_ENDPOINT', 'https://gambitflow-synapse-base-inference-api.hf.space'), | |
| 'timeout': 60 | |
| } | |
| } | |
| # ==================== FIREBASE ANALYTICS ==================== | |
| def increment_stats(model_name, stat_type='moves'): | |
| """ | |
| Increment statistics in Firebase | |
| stat_type: 'moves' or 'matches' | |
| """ | |
| try: | |
| ref = db.reference('stats') | |
| # Increment total stats | |
| total_ref = ref.child('total').child(stat_type) | |
| current = total_ref.get() or 0 | |
| total_ref.set(current + 1) | |
| # Increment model-specific stats | |
| model_ref = ref.child('models').child(model_name).child(stat_type) | |
| current = model_ref.get() or 0 | |
| model_ref.set(current + 1) | |
| # Update last_updated timestamp | |
| ref.child('last_updated').set(int(time.time())) | |
| except Exception as e: | |
| print(f"Firebase stats update error: {e}") | |
| def get_all_stats(): | |
| """Get all statistics from Firebase""" | |
| try: | |
| ref = db.reference('stats') | |
| stats = ref.get() or {} | |
| if not stats: | |
| # Initialize default structure | |
| stats = { | |
| 'total': {'moves': 0, 'matches': 0}, | |
| 'models': { | |
| 'nano': {'moves': 0, 'matches': 0}, | |
| 'core': {'moves': 0, 'matches': 0}, | |
| 'base': {'moves': 0, 'matches': 0} | |
| }, | |
| 'last_updated': int(time.time()) | |
| } | |
| ref.set(stats) | |
| return stats | |
| except Exception as e: | |
| print(f"Firebase stats fetch error: {e}") | |
| return { | |
| 'total': {'moves': 0, 'matches': 0}, | |
| 'models': { | |
| 'nano': {'moves': 0, 'matches': 0}, | |
| 'core': {'moves': 0, 'matches': 0}, | |
| 'base': {'moves': 0, 'matches': 0} | |
| }, | |
| 'last_updated': int(time.time()) | |
| } | |
| # ==================== CACHE ==================== | |
| class SimpleCache: | |
| def __init__(self, ttl=300): | |
| self.cache = {} | |
| self.ttl = ttl | |
| def get(self, key): | |
| if key in self.cache: | |
| value, timestamp = self.cache[key] | |
| if time.time() - timestamp < self.ttl: | |
| return value | |
| del self.cache[key] | |
| return None | |
| def set(self, key, value): | |
| self.cache[key] = (value, time.time()) | |
| def clear_old(self): | |
| current_time = time.time() | |
| expired = [k for k, (_, t) in self.cache.items() if current_time - t >= self.ttl] | |
| for k in expired: | |
| del self.cache[k] | |
| cache = SimpleCache(ttl=300) | |
| # ==================== ROUTES ==================== | |
| def index(): | |
| """API documentation""" | |
| return jsonify({ | |
| 'name': 'GambitFlow Bridge API', | |
| 'version': '1.0.0', | |
| 'description': 'Unified gateway for all GambitFlow chess engines', | |
| 'endpoints': { | |
| '/predict': 'POST - Get best move prediction', | |
| '/health': 'GET - Health check', | |
| '/stats': 'GET - Get usage statistics', | |
| '/models': 'GET - List available models' | |
| }, | |
| 'models': list(MODELS.keys()) | |
| }) | |
| def health(): | |
| """Health check endpoint""" | |
| return jsonify({ | |
| 'status': 'healthy', | |
| 'timestamp': int(time.time()), | |
| 'models': len(MODELS), | |
| 'cache_size': len(cache.cache) | |
| }) | |
| def get_stats(): | |
| """Get usage statistics from Firebase""" | |
| stats = get_all_stats() | |
| return jsonify(stats) | |
| def list_models(): | |
| """List all available models""" | |
| models_info = {} | |
| for key, config in MODELS.items(): | |
| models_info[key] = { | |
| 'name': config['name'], | |
| 'endpoint': config['endpoint'], | |
| 'timeout': config['timeout'] | |
| } | |
| return jsonify({'models': models_info}) | |
| def predict(): | |
| """ | |
| Main prediction endpoint | |
| Forwards request to appropriate model and tracks statistics | |
| """ | |
| try: | |
| data = request.get_json() | |
| if not data: | |
| return jsonify({'error': 'No data provided'}), 400 | |
| # Extract parameters | |
| fen = data.get('fen') | |
| model = data.get('model', 'core') | |
| depth = data.get('depth', 5) | |
| time_limit = data.get('time_limit', 3000) | |
| track_stats = data.get('track_stats', True) # Allow disabling stats tracking | |
| if not fen: | |
| return jsonify({'error': 'FEN position required'}), 400 | |
| if model not in MODELS: | |
| return jsonify({'error': f'Invalid model: {model}'}), 400 | |
| # Check cache | |
| cache_key = f"{model}:{fen}:{depth}:{time_limit}" | |
| cached = cache.get(cache_key) | |
| if cached: | |
| cached['from_cache'] = True | |
| if track_stats: | |
| increment_stats(model, 'moves') | |
| return jsonify(cached) | |
| # Forward to model API | |
| model_config = MODELS[model] | |
| endpoint = f"{model_config['endpoint']}/predict" | |
| response = requests.post( | |
| endpoint, | |
| json={ | |
| 'fen': fen, | |
| 'depth': depth, | |
| 'time_limit': time_limit | |
| }, | |
| timeout=model_config['timeout'] | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| # Cache the result | |
| cache.set(cache_key, result) | |
| # Track statistics in Firebase | |
| if track_stats: | |
| increment_stats(model, 'moves') | |
| result['from_cache'] = False | |
| result['model'] = model | |
| return jsonify(result) | |
| else: | |
| return jsonify({ | |
| 'error': 'Model API error', | |
| 'status_code': response.status_code, | |
| 'details': response.text | |
| }), response.status_code | |
| except requests.Timeout: | |
| return jsonify({'error': 'Request timeout'}), 504 | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def start_match(): | |
| """Track match start""" | |
| try: | |
| data = request.get_json() | |
| model = data.get('model', 'core') | |
| if model not in MODELS: | |
| return jsonify({'error': 'Invalid model'}), 400 | |
| increment_stats(model, 'matches') | |
| return jsonify({ | |
| 'success': True, | |
| 'model': model, | |
| 'message': 'Match started' | |
| }) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def batch_predict(): | |
| """ | |
| Batch prediction endpoint for multiple positions | |
| """ | |
| try: | |
| data = request.get_json() | |
| positions = data.get('positions', []) | |
| model = data.get('model', 'core') | |
| if not positions: | |
| return jsonify({'error': 'No positions provided'}), 400 | |
| if len(positions) > 10: | |
| return jsonify({'error': 'Maximum 10 positions per batch'}), 400 | |
| results = [] | |
| for pos in positions: | |
| fen = pos.get('fen') | |
| depth = pos.get('depth', 5) | |
| time_limit = pos.get('time_limit', 3000) | |
| # Make individual request | |
| pred_data = { | |
| 'fen': fen, | |
| 'model': model, | |
| 'depth': depth, | |
| 'time_limit': time_limit, | |
| 'track_stats': False # Don't track for batch | |
| } | |
| result = predict_single(pred_data) | |
| results.append(result) | |
| # Track batch as single operation | |
| increment_stats(model, 'moves') | |
| return jsonify({ | |
| 'success': True, | |
| 'count': len(results), | |
| 'results': results | |
| }) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def predict_single(data): | |
| """Helper function for single prediction""" | |
| try: | |
| fen = data.get('fen') | |
| model = data.get('model', 'core') | |
| depth = data.get('depth', 5) | |
| time_limit = data.get('time_limit', 3000) | |
| model_config = MODELS[model] | |
| endpoint = f"{model_config['endpoint']}/predict" | |
| response = requests.post( | |
| endpoint, | |
| json={ | |
| 'fen': fen, | |
| 'depth': depth, | |
| 'time_limit': time_limit | |
| }, | |
| timeout=model_config['timeout'] | |
| ) | |
| if response.status_code == 200: | |
| return response.json() | |
| else: | |
| return {'error': 'Prediction failed'} | |
| except: | |
| return {'error': 'Request failed'} | |
| # ==================== CLEANUP ==================== | |
| def before_request(): | |
| """Clean old cache entries before each request""" | |
| cache.clear_old() | |
| # ==================== RUN ==================== | |
| if __name__ == '__main__': | |
| port = int(os.getenv('PORT', 7860)) | |
| app.run(host='0.0.0.0', port=port, debug=False) |