API-Bridge / app.py
Rafs-an09002's picture
Update app.py
f8af6f0 verified
"""
GambitFlow Bridge API - HuggingFace Space
Unified API gateway with Firebase analytics and rate limiting
"""
from flask import Flask, request, jsonify, Response
from flask_cors import CORS
import requests
import time
import os
from functools import wraps
import firebase_admin
from firebase_admin import credentials, db
import json
app = Flask(__name__)
CORS(app)
# ==================== FIREBASE SETUP ====================
def initialize_firebase():
"""Initialize Firebase Admin SDK"""
try:
# Load credentials from environment variable
firebase_creds = os.getenv('FIREBASE_CREDENTIALS')
if firebase_creds:
cred_dict = json.loads(firebase_creds)
cred = credentials.Certificate(cred_dict)
else:
# Fallback to service account file
cred = credentials.Certificate('firebase-credentials.json')
firebase_admin.initialize_app(cred, {
'databaseURL': os.getenv('FIREBASE_DATABASE_URL', 'https://chess-web-78351-default-rtdb.asia-southeast1.firebasedatabase.app')
})
print("✅ Firebase initialized successfully")
except Exception as e:
print(f"⚠️ Firebase initialization failed: {e}")
# Initialize Firebase
initialize_firebase()
# ==================== MODEL CONFIGURATION ====================
MODELS = {
'nano': {
'name': 'Nexus-Nano',
'endpoint': os.getenv('NANO_ENDPOINT', 'https://gambitflow-nexus-nano-inference-api.hf.space'),
'timeout': 30
},
'core': {
'name': 'Nexus-Core',
'endpoint': os.getenv('CORE_ENDPOINT', 'https://gambitflow-nexus-core-inference-api.hf.space'),
'timeout': 40
},
'base': {
'name': 'Synapse-Base',
'endpoint': os.getenv('BASE_ENDPOINT', 'https://gambitflow-synapse-base-inference-api.hf.space'),
'timeout': 60
}
}
# ==================== FIREBASE ANALYTICS ====================
def increment_stats(model_name, stat_type='moves'):
"""
Increment statistics in Firebase
stat_type: 'moves' or 'matches'
"""
try:
ref = db.reference('stats')
# Increment total stats
total_ref = ref.child('total').child(stat_type)
current = total_ref.get() or 0
total_ref.set(current + 1)
# Increment model-specific stats
model_ref = ref.child('models').child(model_name).child(stat_type)
current = model_ref.get() or 0
model_ref.set(current + 1)
# Update last_updated timestamp
ref.child('last_updated').set(int(time.time()))
except Exception as e:
print(f"Firebase stats update error: {e}")
def get_all_stats():
"""Get all statistics from Firebase"""
try:
ref = db.reference('stats')
stats = ref.get() or {}
if not stats:
# Initialize default structure
stats = {
'total': {'moves': 0, 'matches': 0},
'models': {
'nano': {'moves': 0, 'matches': 0},
'core': {'moves': 0, 'matches': 0},
'base': {'moves': 0, 'matches': 0}
},
'last_updated': int(time.time())
}
ref.set(stats)
return stats
except Exception as e:
print(f"Firebase stats fetch error: {e}")
return {
'total': {'moves': 0, 'matches': 0},
'models': {
'nano': {'moves': 0, 'matches': 0},
'core': {'moves': 0, 'matches': 0},
'base': {'moves': 0, 'matches': 0}
},
'last_updated': int(time.time())
}
# ==================== CACHE ====================
class SimpleCache:
def __init__(self, ttl=300):
self.cache = {}
self.ttl = ttl
def get(self, key):
if key in self.cache:
value, timestamp = self.cache[key]
if time.time() - timestamp < self.ttl:
return value
del self.cache[key]
return None
def set(self, key, value):
self.cache[key] = (value, time.time())
def clear_old(self):
current_time = time.time()
expired = [k for k, (_, t) in self.cache.items() if current_time - t >= self.ttl]
for k in expired:
del self.cache[k]
cache = SimpleCache(ttl=300)
# ==================== ROUTES ====================
@app.route('/')
def index():
"""API documentation"""
return jsonify({
'name': 'GambitFlow Bridge API',
'version': '1.0.0',
'description': 'Unified gateway for all GambitFlow chess engines',
'endpoints': {
'/predict': 'POST - Get best move prediction',
'/health': 'GET - Health check',
'/stats': 'GET - Get usage statistics',
'/models': 'GET - List available models'
},
'models': list(MODELS.keys())
})
@app.route('/health')
def health():
"""Health check endpoint"""
return jsonify({
'status': 'healthy',
'timestamp': int(time.time()),
'models': len(MODELS),
'cache_size': len(cache.cache)
})
@app.route('/stats')
def get_stats():
"""Get usage statistics from Firebase"""
stats = get_all_stats()
return jsonify(stats)
@app.route('/models')
def list_models():
"""List all available models"""
models_info = {}
for key, config in MODELS.items():
models_info[key] = {
'name': config['name'],
'endpoint': config['endpoint'],
'timeout': config['timeout']
}
return jsonify({'models': models_info})
@app.route('/predict', methods=['POST'])
def predict():
"""
Main prediction endpoint
Forwards request to appropriate model and tracks statistics
"""
try:
data = request.get_json()
if not data:
return jsonify({'error': 'No data provided'}), 400
# Extract parameters
fen = data.get('fen')
model = data.get('model', 'core')
depth = data.get('depth', 5)
time_limit = data.get('time_limit', 3000)
track_stats = data.get('track_stats', True) # Allow disabling stats tracking
if not fen:
return jsonify({'error': 'FEN position required'}), 400
if model not in MODELS:
return jsonify({'error': f'Invalid model: {model}'}), 400
# Check cache
cache_key = f"{model}:{fen}:{depth}:{time_limit}"
cached = cache.get(cache_key)
if cached:
cached['from_cache'] = True
if track_stats:
increment_stats(model, 'moves')
return jsonify(cached)
# Forward to model API
model_config = MODELS[model]
endpoint = f"{model_config['endpoint']}/predict"
response = requests.post(
endpoint,
json={
'fen': fen,
'depth': depth,
'time_limit': time_limit
},
timeout=model_config['timeout']
)
if response.status_code == 200:
result = response.json()
# Cache the result
cache.set(cache_key, result)
# Track statistics in Firebase
if track_stats:
increment_stats(model, 'moves')
result['from_cache'] = False
result['model'] = model
return jsonify(result)
else:
return jsonify({
'error': 'Model API error',
'status_code': response.status_code,
'details': response.text
}), response.status_code
except requests.Timeout:
return jsonify({'error': 'Request timeout'}), 504
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/match/start', methods=['POST'])
def start_match():
"""Track match start"""
try:
data = request.get_json()
model = data.get('model', 'core')
if model not in MODELS:
return jsonify({'error': 'Invalid model'}), 400
increment_stats(model, 'matches')
return jsonify({
'success': True,
'model': model,
'message': 'Match started'
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/batch', methods=['POST'])
def batch_predict():
"""
Batch prediction endpoint for multiple positions
"""
try:
data = request.get_json()
positions = data.get('positions', [])
model = data.get('model', 'core')
if not positions:
return jsonify({'error': 'No positions provided'}), 400
if len(positions) > 10:
return jsonify({'error': 'Maximum 10 positions per batch'}), 400
results = []
for pos in positions:
fen = pos.get('fen')
depth = pos.get('depth', 5)
time_limit = pos.get('time_limit', 3000)
# Make individual request
pred_data = {
'fen': fen,
'model': model,
'depth': depth,
'time_limit': time_limit,
'track_stats': False # Don't track for batch
}
result = predict_single(pred_data)
results.append(result)
# Track batch as single operation
increment_stats(model, 'moves')
return jsonify({
'success': True,
'count': len(results),
'results': results
})
except Exception as e:
return jsonify({'error': str(e)}), 500
def predict_single(data):
"""Helper function for single prediction"""
try:
fen = data.get('fen')
model = data.get('model', 'core')
depth = data.get('depth', 5)
time_limit = data.get('time_limit', 3000)
model_config = MODELS[model]
endpoint = f"{model_config['endpoint']}/predict"
response = requests.post(
endpoint,
json={
'fen': fen,
'depth': depth,
'time_limit': time_limit
},
timeout=model_config['timeout']
)
if response.status_code == 200:
return response.json()
else:
return {'error': 'Prediction failed'}
except:
return {'error': 'Request failed'}
# ==================== CLEANUP ====================
@app.before_request
def before_request():
"""Clean old cache entries before each request"""
cache.clear_old()
# ==================== RUN ====================
if __name__ == '__main__':
port = int(os.getenv('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False)