""" GambitFlow Bridge API - HuggingFace Space Unified API gateway with Firebase analytics and rate limiting """ from flask import Flask, request, jsonify, Response from flask_cors import CORS import requests import time import os from functools import wraps import firebase_admin from firebase_admin import credentials, db import json app = Flask(__name__) CORS(app) # ==================== FIREBASE SETUP ==================== def initialize_firebase(): """Initialize Firebase Admin SDK""" try: # Load credentials from environment variable firebase_creds = os.getenv('FIREBASE_CREDENTIALS') if firebase_creds: cred_dict = json.loads(firebase_creds) cred = credentials.Certificate(cred_dict) else: # Fallback to service account file cred = credentials.Certificate('firebase-credentials.json') firebase_admin.initialize_app(cred, { 'databaseURL': os.getenv('FIREBASE_DATABASE_URL', 'https://chess-web-78351-default-rtdb.asia-southeast1.firebasedatabase.app') }) print("✅ Firebase initialized successfully") except Exception as e: print(f"⚠️ Firebase initialization failed: {e}") # Initialize Firebase initialize_firebase() # ==================== MODEL CONFIGURATION ==================== MODELS = { 'nano': { 'name': 'Nexus-Nano', 'endpoint': os.getenv('NANO_ENDPOINT', 'https://gambitflow-nexus-nano-inference-api.hf.space'), 'timeout': 30 }, 'core': { 'name': 'Nexus-Core', 'endpoint': os.getenv('CORE_ENDPOINT', 'https://gambitflow-nexus-core-inference-api.hf.space'), 'timeout': 40 }, 'base': { 'name': 'Synapse-Base', 'endpoint': os.getenv('BASE_ENDPOINT', 'https://gambitflow-synapse-base-inference-api.hf.space'), 'timeout': 60 } } # ==================== FIREBASE ANALYTICS ==================== def increment_stats(model_name, stat_type='moves'): """ Increment statistics in Firebase stat_type: 'moves' or 'matches' """ try: ref = db.reference('stats') # Increment total stats total_ref = ref.child('total').child(stat_type) current = total_ref.get() or 0 total_ref.set(current + 1) # Increment model-specific stats model_ref = ref.child('models').child(model_name).child(stat_type) current = model_ref.get() or 0 model_ref.set(current + 1) # Update last_updated timestamp ref.child('last_updated').set(int(time.time())) except Exception as e: print(f"Firebase stats update error: {e}") def get_all_stats(): """Get all statistics from Firebase""" try: ref = db.reference('stats') stats = ref.get() or {} if not stats: # Initialize default structure stats = { 'total': {'moves': 0, 'matches': 0}, 'models': { 'nano': {'moves': 0, 'matches': 0}, 'core': {'moves': 0, 'matches': 0}, 'base': {'moves': 0, 'matches': 0} }, 'last_updated': int(time.time()) } ref.set(stats) return stats except Exception as e: print(f"Firebase stats fetch error: {e}") return { 'total': {'moves': 0, 'matches': 0}, 'models': { 'nano': {'moves': 0, 'matches': 0}, 'core': {'moves': 0, 'matches': 0}, 'base': {'moves': 0, 'matches': 0} }, 'last_updated': int(time.time()) } # ==================== CACHE ==================== class SimpleCache: def __init__(self, ttl=300): self.cache = {} self.ttl = ttl def get(self, key): if key in self.cache: value, timestamp = self.cache[key] if time.time() - timestamp < self.ttl: return value del self.cache[key] return None def set(self, key, value): self.cache[key] = (value, time.time()) def clear_old(self): current_time = time.time() expired = [k for k, (_, t) in self.cache.items() if current_time - t >= self.ttl] for k in expired: del self.cache[k] cache = SimpleCache(ttl=300) # ==================== ROUTES ==================== @app.route('/') def index(): """API documentation""" return jsonify({ 'name': 'GambitFlow Bridge API', 'version': '1.0.0', 'description': 'Unified gateway for all GambitFlow chess engines', 'endpoints': { '/predict': 'POST - Get best move prediction', '/health': 'GET - Health check', '/stats': 'GET - Get usage statistics', '/models': 'GET - List available models' }, 'models': list(MODELS.keys()) }) @app.route('/health') def health(): """Health check endpoint""" return jsonify({ 'status': 'healthy', 'timestamp': int(time.time()), 'models': len(MODELS), 'cache_size': len(cache.cache) }) @app.route('/stats') def get_stats(): """Get usage statistics from Firebase""" stats = get_all_stats() return jsonify(stats) @app.route('/models') def list_models(): """List all available models""" models_info = {} for key, config in MODELS.items(): models_info[key] = { 'name': config['name'], 'endpoint': config['endpoint'], 'timeout': config['timeout'] } return jsonify({'models': models_info}) @app.route('/predict', methods=['POST']) def predict(): """ Main prediction endpoint Forwards request to appropriate model and tracks statistics """ try: data = request.get_json() if not data: return jsonify({'error': 'No data provided'}), 400 # Extract parameters fen = data.get('fen') model = data.get('model', 'core') depth = data.get('depth', 5) time_limit = data.get('time_limit', 3000) track_stats = data.get('track_stats', True) # Allow disabling stats tracking if not fen: return jsonify({'error': 'FEN position required'}), 400 if model not in MODELS: return jsonify({'error': f'Invalid model: {model}'}), 400 # Check cache cache_key = f"{model}:{fen}:{depth}:{time_limit}" cached = cache.get(cache_key) if cached: cached['from_cache'] = True if track_stats: increment_stats(model, 'moves') return jsonify(cached) # Forward to model API model_config = MODELS[model] endpoint = f"{model_config['endpoint']}/predict" response = requests.post( endpoint, json={ 'fen': fen, 'depth': depth, 'time_limit': time_limit }, timeout=model_config['timeout'] ) if response.status_code == 200: result = response.json() # Cache the result cache.set(cache_key, result) # Track statistics in Firebase if track_stats: increment_stats(model, 'moves') result['from_cache'] = False result['model'] = model return jsonify(result) else: return jsonify({ 'error': 'Model API error', 'status_code': response.status_code, 'details': response.text }), response.status_code except requests.Timeout: return jsonify({'error': 'Request timeout'}), 504 except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/match/start', methods=['POST']) def start_match(): """Track match start""" try: data = request.get_json() model = data.get('model', 'core') if model not in MODELS: return jsonify({'error': 'Invalid model'}), 400 increment_stats(model, 'matches') return jsonify({ 'success': True, 'model': model, 'message': 'Match started' }) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/batch', methods=['POST']) def batch_predict(): """ Batch prediction endpoint for multiple positions """ try: data = request.get_json() positions = data.get('positions', []) model = data.get('model', 'core') if not positions: return jsonify({'error': 'No positions provided'}), 400 if len(positions) > 10: return jsonify({'error': 'Maximum 10 positions per batch'}), 400 results = [] for pos in positions: fen = pos.get('fen') depth = pos.get('depth', 5) time_limit = pos.get('time_limit', 3000) # Make individual request pred_data = { 'fen': fen, 'model': model, 'depth': depth, 'time_limit': time_limit, 'track_stats': False # Don't track for batch } result = predict_single(pred_data) results.append(result) # Track batch as single operation increment_stats(model, 'moves') return jsonify({ 'success': True, 'count': len(results), 'results': results }) except Exception as e: return jsonify({'error': str(e)}), 500 def predict_single(data): """Helper function for single prediction""" try: fen = data.get('fen') model = data.get('model', 'core') depth = data.get('depth', 5) time_limit = data.get('time_limit', 3000) model_config = MODELS[model] endpoint = f"{model_config['endpoint']}/predict" response = requests.post( endpoint, json={ 'fen': fen, 'depth': depth, 'time_limit': time_limit }, timeout=model_config['timeout'] ) if response.status_code == 200: return response.json() else: return {'error': 'Prediction failed'} except: return {'error': 'Request failed'} # ==================== CLEANUP ==================== @app.before_request def before_request(): """Clean old cache entries before each request""" cache.clear_old() # ==================== RUN ==================== if __name__ == '__main__': port = int(os.getenv('PORT', 7860)) app.run(host='0.0.0.0', port=port, debug=False)