Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
8b76adf
1
Parent(s):
4036b4b
- dispatcher.py +282 -4
- main.py +10 -2
- start_system.sh +3 -0
- static/index.html +30 -3
- worker.py +103 -8
dispatcher.py
CHANGED
|
@@ -11,11 +11,157 @@ from enum import Enum
|
|
| 11 |
import uuid
|
| 12 |
import aiohttp
|
| 13 |
import logging
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Configure logging
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class SessionStatus(Enum):
|
| 20 |
QUEUED = "queued"
|
| 21 |
ACTIVE = "active"
|
|
@@ -33,6 +179,9 @@ class UserSession:
|
|
| 33 |
last_activity: Optional[float] = None
|
| 34 |
max_session_time: Optional[float] = None
|
| 35 |
user_has_interacted: bool = False
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
@dataclass
|
| 38 |
class WorkerInfo:
|
|
@@ -67,6 +216,15 @@ class SessionManager:
|
|
| 67 |
last_ping=time.time()
|
| 68 |
)
|
| 69 |
logger.info(f"Registered worker {worker_id} on GPU {gpu_id} at {endpoint}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
async def get_available_worker(self) -> Optional[WorkerInfo]:
|
| 72 |
"""Get an available worker"""
|
|
@@ -80,6 +238,7 @@ class SessionManager:
|
|
| 80 |
self.sessions[session.session_id] = session
|
| 81 |
self.session_queue.append(session.session_id)
|
| 82 |
session.status = SessionStatus.QUEUED
|
|
|
|
| 83 |
logger.info(f"Added session {session.session_id} to queue. Queue size: {len(self.session_queue)}")
|
| 84 |
|
| 85 |
async def process_queue(self):
|
|
@@ -94,8 +253,15 @@ class SessionManager:
|
|
| 94 |
|
| 95 |
worker = await self.get_available_worker()
|
| 96 |
if not worker:
|
|
|
|
|
|
|
|
|
|
| 97 |
break # No available workers
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
# Assign session to worker
|
| 100 |
self.session_queue.pop(0)
|
| 101 |
session.status = SessionStatus.ACTIVE
|
|
@@ -112,6 +278,28 @@ class SessionManager:
|
|
| 112 |
|
| 113 |
logger.info(f"Assigned session {session_id} to worker {worker.worker_id}")
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
# Notify user that their session is starting
|
| 116 |
await self.notify_session_start(session)
|
| 117 |
|
|
@@ -199,12 +387,25 @@ class SessionManager:
|
|
| 199 |
|
| 200 |
session.status = status
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
# Free up the worker
|
| 203 |
if session.worker_id and session.worker_id in self.workers:
|
| 204 |
worker = self.workers[session.worker_id]
|
| 205 |
worker.is_available = True
|
| 206 |
worker.current_session = None
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
# Notify worker to clean up
|
| 209 |
try:
|
| 210 |
async with aiohttp.ClientSession() as client_session:
|
|
@@ -241,6 +442,11 @@ class SessionManager:
|
|
| 241 |
})
|
| 242 |
except Exception as e:
|
| 243 |
logger.error(f"Failed to send queue update to session {session_id}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
def _calculate_dynamic_wait_time(self, position_in_queue: int) -> float:
|
| 246 |
"""Calculate dynamic estimated wait time based on current session progress"""
|
|
@@ -308,6 +514,7 @@ class SessionManager:
|
|
| 308 |
session = self.sessions.get(session_id)
|
| 309 |
if session:
|
| 310 |
session.last_activity = time.time()
|
|
|
|
| 311 |
if not session.user_has_interacted:
|
| 312 |
session.user_has_interacted = True
|
| 313 |
logger.info(f"User started interacting in session {session_id}")
|
|
@@ -335,6 +542,9 @@ session_manager = SessionManager()
|
|
| 335 |
app = FastAPI()
|
| 336 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 337 |
|
|
|
|
|
|
|
|
|
|
| 338 |
@app.get("/")
|
| 339 |
async def get():
|
| 340 |
return HTMLResponse(open("static/index.html").read())
|
|
@@ -383,21 +593,39 @@ async def worker_result(result_data: dict):
|
|
| 383 |
|
| 384 |
@app.websocket("/ws")
|
| 385 |
async def websocket_endpoint(websocket: WebSocket):
|
|
|
|
| 386 |
await websocket.accept()
|
| 387 |
|
| 388 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
session_id = str(uuid.uuid4())
|
| 390 |
-
client_id = f"{int(time.time())}_{
|
| 391 |
|
| 392 |
session = UserSession(
|
| 393 |
session_id=session_id,
|
| 394 |
client_id=client_id,
|
| 395 |
websocket=websocket,
|
| 396 |
created_at=time.time(),
|
| 397 |
-
status=SessionStatus.QUEUED
|
|
|
|
| 398 |
)
|
| 399 |
|
| 400 |
-
logger.info(f"New WebSocket connection: {client_id}")
|
|
|
|
|
|
|
|
|
|
| 401 |
|
| 402 |
try:
|
| 403 |
# Add to queue
|
|
@@ -492,10 +720,60 @@ async def periodic_queue_update():
|
|
| 492 |
except Exception as e:
|
| 493 |
logger.error(f"Error in periodic queue update: {e}")
|
| 494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
@app.on_event("startup")
|
| 496 |
async def startup_event():
|
| 497 |
# Start background tasks
|
| 498 |
asyncio.create_task(periodic_queue_update())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
if __name__ == "__main__":
|
| 501 |
import uvicorn
|
|
|
|
| 11 |
import uuid
|
| 12 |
import aiohttp
|
| 13 |
import logging
|
| 14 |
+
from collections import defaultdict, deque
|
| 15 |
+
from datetime import datetime
|
| 16 |
|
| 17 |
# Configure logging
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
+
# Analytics and monitoring
|
| 22 |
+
class SystemAnalytics:
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.start_time = time.time()
|
| 25 |
+
self.total_connections = 0
|
| 26 |
+
self.active_connections = 0
|
| 27 |
+
self.total_interactions = 0
|
| 28 |
+
self.ip_addresses = defaultdict(int) # IP -> connection count
|
| 29 |
+
self.session_durations = deque(maxlen=100) # Last 100 session durations
|
| 30 |
+
self.waiting_times = deque(maxlen=100) # Last 100 waiting times
|
| 31 |
+
self.users_bypassed_queue = 0 # Users who got GPU immediately
|
| 32 |
+
self.users_waited_in_queue = 0 # Users who had to wait
|
| 33 |
+
self.gpu_utilization_samples = deque(maxlen=100) # GPU utilization over time
|
| 34 |
+
self.queue_size_samples = deque(maxlen=100) # Queue size over time
|
| 35 |
+
self.log_file = None
|
| 36 |
+
self._init_log_file()
|
| 37 |
+
|
| 38 |
+
def _init_log_file(self):
|
| 39 |
+
"""Initialize the system log file"""
|
| 40 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 41 |
+
log_filename = f"system_analytics_{timestamp}.log"
|
| 42 |
+
self.log_file = log_filename
|
| 43 |
+
self._write_log("="*80)
|
| 44 |
+
self._write_log("NEURAL OS MULTI-GPU SYSTEM ANALYTICS")
|
| 45 |
+
self._write_log("="*80)
|
| 46 |
+
self._write_log(f"System started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 47 |
+
self._write_log("")
|
| 48 |
+
|
| 49 |
+
def _write_log(self, message):
|
| 50 |
+
"""Write message to log file and console"""
|
| 51 |
+
timestamp = datetime.now().strftime("%H:%M:%S")
|
| 52 |
+
log_message = f"[{timestamp}] {message}"
|
| 53 |
+
print(log_message)
|
| 54 |
+
with open(self.log_file, "a") as f:
|
| 55 |
+
f.write(log_message + "\n")
|
| 56 |
+
|
| 57 |
+
def log_new_connection(self, client_id: str, ip: str):
|
| 58 |
+
"""Log new connection"""
|
| 59 |
+
self.total_connections += 1
|
| 60 |
+
self.active_connections += 1
|
| 61 |
+
self.ip_addresses[ip] += 1
|
| 62 |
+
|
| 63 |
+
unique_ips = len(self.ip_addresses)
|
| 64 |
+
self._write_log(f"🔗 NEW CONNECTION: {client_id} from {ip}")
|
| 65 |
+
self._write_log(f" 📊 Total connections: {self.total_connections} | Active: {self.active_connections} | Unique IPs: {unique_ips}")
|
| 66 |
+
|
| 67 |
+
def log_connection_closed(self, client_id: str, duration: float, interactions: int, reason: str = "normal"):
|
| 68 |
+
"""Log connection closed"""
|
| 69 |
+
self.active_connections -= 1
|
| 70 |
+
self.total_interactions += interactions
|
| 71 |
+
self.session_durations.append(duration)
|
| 72 |
+
|
| 73 |
+
avg_duration = sum(self.session_durations) / len(self.session_durations) if self.session_durations else 0
|
| 74 |
+
|
| 75 |
+
self._write_log(f"🚪 CONNECTION CLOSED: {client_id}")
|
| 76 |
+
self._write_log(f" ⏱️ Duration: {duration:.1f}s | Interactions: {interactions} | Reason: {reason}")
|
| 77 |
+
self._write_log(f" 📊 Active connections: {self.active_connections} | Avg session duration: {avg_duration:.1f}s")
|
| 78 |
+
|
| 79 |
+
def log_queue_bypass(self, client_id: str):
|
| 80 |
+
"""Log when user bypasses queue (gets GPU immediately)"""
|
| 81 |
+
self.users_bypassed_queue += 1
|
| 82 |
+
bypass_rate = (self.users_bypassed_queue / self.total_connections) * 100 if self.total_connections > 0 else 0
|
| 83 |
+
self._write_log(f"⚡ QUEUE BYPASS: {client_id} got GPU immediately")
|
| 84 |
+
self._write_log(f" 📊 Bypass rate: {bypass_rate:.1f}% ({self.users_bypassed_queue}/{self.total_connections})")
|
| 85 |
+
|
| 86 |
+
def log_queue_wait(self, client_id: str, wait_time: float, queue_position: int):
|
| 87 |
+
"""Log when user had to wait in queue"""
|
| 88 |
+
self.users_waited_in_queue += 1
|
| 89 |
+
self.waiting_times.append(wait_time)
|
| 90 |
+
|
| 91 |
+
avg_wait = sum(self.waiting_times) / len(self.waiting_times) if self.waiting_times else 0
|
| 92 |
+
wait_rate = (self.users_waited_in_queue / self.total_connections) * 100 if self.total_connections > 0 else 0
|
| 93 |
+
|
| 94 |
+
self._write_log(f"⏳ QUEUE WAIT: {client_id} waited {wait_time:.1f}s (was #{queue_position})")
|
| 95 |
+
self._write_log(f" 📊 Wait rate: {wait_rate:.1f}% | Avg wait time: {avg_wait:.1f}s")
|
| 96 |
+
|
| 97 |
+
def log_gpu_status(self, total_gpus: int, active_gpus: int, available_gpus: int):
|
| 98 |
+
"""Log GPU utilization"""
|
| 99 |
+
utilization = (active_gpus / total_gpus) * 100 if total_gpus > 0 else 0
|
| 100 |
+
self.gpu_utilization_samples.append(utilization)
|
| 101 |
+
|
| 102 |
+
avg_utilization = sum(self.gpu_utilization_samples) / len(self.gpu_utilization_samples) if self.gpu_utilization_samples else 0
|
| 103 |
+
|
| 104 |
+
self._write_log(f"🖥️ GPU STATUS: {active_gpus}/{total_gpus} in use ({utilization:.1f}% utilization)")
|
| 105 |
+
self._write_log(f" 📊 Available: {available_gpus} | Avg utilization: {avg_utilization:.1f}%")
|
| 106 |
+
|
| 107 |
+
def log_worker_registered(self, worker_id: str, gpu_id: int, endpoint: str):
|
| 108 |
+
"""Log when a worker registers"""
|
| 109 |
+
self._write_log(f"⚙️ WORKER REGISTERED: {worker_id} (GPU {gpu_id}) at {endpoint}")
|
| 110 |
+
|
| 111 |
+
def log_worker_disconnected(self, worker_id: str, gpu_id: int):
|
| 112 |
+
"""Log when a worker disconnects"""
|
| 113 |
+
self._write_log(f"⚙️ WORKER DISCONNECTED: {worker_id} (GPU {gpu_id})")
|
| 114 |
+
|
| 115 |
+
def log_no_workers_available(self, queue_size: int):
|
| 116 |
+
"""Log critical situation when no workers are available"""
|
| 117 |
+
self._write_log(f"⚠️ CRITICAL: No GPU workers available! {queue_size} users waiting")
|
| 118 |
+
self._write_log(" Please check worker processes and GPU availability")
|
| 119 |
+
|
| 120 |
+
def log_queue_status(self, queue_size: int, estimated_wait: float):
|
| 121 |
+
"""Log queue status"""
|
| 122 |
+
self.queue_size_samples.append(queue_size)
|
| 123 |
+
|
| 124 |
+
avg_queue_size = sum(self.queue_size_samples) / len(self.queue_size_samples) if self.queue_size_samples else 0
|
| 125 |
+
|
| 126 |
+
if queue_size > 0:
|
| 127 |
+
self._write_log(f"📝 QUEUE STATUS: {queue_size} users waiting | Est. wait: {estimated_wait:.1f}s")
|
| 128 |
+
self._write_log(f" 📊 Avg queue size: {avg_queue_size:.1f}")
|
| 129 |
+
|
| 130 |
+
def log_periodic_summary(self):
|
| 131 |
+
"""Log periodic system summary"""
|
| 132 |
+
uptime = time.time() - self.start_time
|
| 133 |
+
uptime_hours = uptime / 3600
|
| 134 |
+
|
| 135 |
+
unique_ips = len(self.ip_addresses)
|
| 136 |
+
avg_duration = sum(self.session_durations) / len(self.session_durations) if self.session_durations else 0
|
| 137 |
+
avg_wait = sum(self.waiting_times) / len(self.waiting_times) if self.waiting_times else 0
|
| 138 |
+
avg_utilization = sum(self.gpu_utilization_samples) / len(self.gpu_utilization_samples) if self.gpu_utilization_samples else 0
|
| 139 |
+
avg_queue_size = sum(self.queue_size_samples) / len(self.queue_size_samples) if self.queue_size_samples else 0
|
| 140 |
+
|
| 141 |
+
bypass_rate = (self.users_bypassed_queue / self.total_connections) * 100 if self.total_connections > 0 else 0
|
| 142 |
+
|
| 143 |
+
self._write_log("")
|
| 144 |
+
self._write_log("="*60)
|
| 145 |
+
self._write_log("📊 SYSTEM SUMMARY")
|
| 146 |
+
self._write_log("="*60)
|
| 147 |
+
self._write_log(f"⏱️ Uptime: {uptime_hours:.1f} hours")
|
| 148 |
+
self._write_log(f"🔗 Connections: {self.total_connections} total | {self.active_connections} active | {unique_ips} unique IPs")
|
| 149 |
+
self._write_log(f"💬 Total interactions: {self.total_interactions}")
|
| 150 |
+
self._write_log(f"⚡ Queue bypass rate: {bypass_rate:.1f}% ({self.users_bypassed_queue}/{self.total_connections})")
|
| 151 |
+
self._write_log(f"⏳ Avg waiting time: {avg_wait:.1f}s")
|
| 152 |
+
self._write_log(f"📝 Avg queue size: {avg_queue_size:.1f}")
|
| 153 |
+
self._write_log(f"🖥️ Avg GPU utilization: {avg_utilization:.1f}%")
|
| 154 |
+
self._write_log(f"⏱️ Avg session duration: {avg_duration:.1f}s")
|
| 155 |
+
self._write_log("")
|
| 156 |
+
self._write_log("🌍 TOP IP ADDRESSES:")
|
| 157 |
+
for ip, count in sorted(self.ip_addresses.items(), key=lambda x: x[1], reverse=True)[:10]:
|
| 158 |
+
self._write_log(f" {ip}: {count} connections")
|
| 159 |
+
self._write_log("="*60)
|
| 160 |
+
self._write_log("")
|
| 161 |
+
|
| 162 |
+
# Initialize analytics
|
| 163 |
+
analytics = SystemAnalytics()
|
| 164 |
+
|
| 165 |
class SessionStatus(Enum):
|
| 166 |
QUEUED = "queued"
|
| 167 |
ACTIVE = "active"
|
|
|
|
| 179 |
last_activity: Optional[float] = None
|
| 180 |
max_session_time: Optional[float] = None
|
| 181 |
user_has_interacted: bool = False
|
| 182 |
+
ip_address: Optional[str] = None
|
| 183 |
+
interaction_count: int = 0
|
| 184 |
+
queue_start_time: Optional[float] = None
|
| 185 |
|
| 186 |
@dataclass
|
| 187 |
class WorkerInfo:
|
|
|
|
| 216 |
last_ping=time.time()
|
| 217 |
)
|
| 218 |
logger.info(f"Registered worker {worker_id} on GPU {gpu_id} at {endpoint}")
|
| 219 |
+
|
| 220 |
+
# Log worker registration
|
| 221 |
+
analytics.log_worker_registered(worker_id, gpu_id, endpoint)
|
| 222 |
+
|
| 223 |
+
# Log GPU status
|
| 224 |
+
total_gpus = len(self.workers)
|
| 225 |
+
active_gpus = len([w for w in self.workers.values() if not w.is_available])
|
| 226 |
+
available_gpus = total_gpus - active_gpus
|
| 227 |
+
analytics.log_gpu_status(total_gpus, active_gpus, available_gpus)
|
| 228 |
|
| 229 |
async def get_available_worker(self) -> Optional[WorkerInfo]:
|
| 230 |
"""Get an available worker"""
|
|
|
|
| 238 |
self.sessions[session.session_id] = session
|
| 239 |
self.session_queue.append(session.session_id)
|
| 240 |
session.status = SessionStatus.QUEUED
|
| 241 |
+
session.queue_start_time = time.time()
|
| 242 |
logger.info(f"Added session {session.session_id} to queue. Queue size: {len(self.session_queue)}")
|
| 243 |
|
| 244 |
async def process_queue(self):
|
|
|
|
| 253 |
|
| 254 |
worker = await self.get_available_worker()
|
| 255 |
if not worker:
|
| 256 |
+
# Log critical situation if no workers are available
|
| 257 |
+
if len(self.workers) == 0:
|
| 258 |
+
analytics.log_no_workers_available(len(self.session_queue))
|
| 259 |
break # No available workers
|
| 260 |
|
| 261 |
+
# Calculate wait time
|
| 262 |
+
wait_time = time.time() - session.queue_start_time if session.queue_start_time else 0
|
| 263 |
+
queue_position = self.session_queue.index(session_id) + 1
|
| 264 |
+
|
| 265 |
# Assign session to worker
|
| 266 |
self.session_queue.pop(0)
|
| 267 |
session.status = SessionStatus.ACTIVE
|
|
|
|
| 278 |
|
| 279 |
logger.info(f"Assigned session {session_id} to worker {worker.worker_id}")
|
| 280 |
|
| 281 |
+
# Log analytics
|
| 282 |
+
if wait_time > 0:
|
| 283 |
+
analytics.log_queue_wait(session.client_id, wait_time, queue_position)
|
| 284 |
+
else:
|
| 285 |
+
analytics.log_queue_bypass(session.client_id)
|
| 286 |
+
|
| 287 |
+
# Log GPU status
|
| 288 |
+
total_gpus = len(self.workers)
|
| 289 |
+
active_gpus = len([w for w in self.workers.values() if not w.is_available])
|
| 290 |
+
available_gpus = total_gpus - active_gpus
|
| 291 |
+
analytics.log_gpu_status(total_gpus, active_gpus, available_gpus)
|
| 292 |
+
|
| 293 |
+
# Initialize session on worker with client_id for logging
|
| 294 |
+
try:
|
| 295 |
+
async with aiohttp.ClientSession() as client_session:
|
| 296 |
+
await client_session.post(f"{worker.endpoint}/init_session", json={
|
| 297 |
+
"session_id": session_id,
|
| 298 |
+
"client_id": session.client_id
|
| 299 |
+
})
|
| 300 |
+
except Exception as e:
|
| 301 |
+
logger.error(f"Failed to initialize session on worker {worker.worker_id}: {e}")
|
| 302 |
+
|
| 303 |
# Notify user that their session is starting
|
| 304 |
await self.notify_session_start(session)
|
| 305 |
|
|
|
|
| 387 |
|
| 388 |
session.status = status
|
| 389 |
|
| 390 |
+
# Calculate session duration
|
| 391 |
+
duration = time.time() - session.created_at
|
| 392 |
+
|
| 393 |
+
# Log analytics
|
| 394 |
+
reason = "timeout" if status == SessionStatus.TIMEOUT else "normal"
|
| 395 |
+
analytics.log_connection_closed(session.client_id, duration, session.interaction_count, reason)
|
| 396 |
+
|
| 397 |
# Free up the worker
|
| 398 |
if session.worker_id and session.worker_id in self.workers:
|
| 399 |
worker = self.workers[session.worker_id]
|
| 400 |
worker.is_available = True
|
| 401 |
worker.current_session = None
|
| 402 |
|
| 403 |
+
# Log GPU status
|
| 404 |
+
total_gpus = len(self.workers)
|
| 405 |
+
active_gpus = len([w for w in self.workers.values() if not w.is_available])
|
| 406 |
+
available_gpus = total_gpus - active_gpus
|
| 407 |
+
analytics.log_gpu_status(total_gpus, active_gpus, available_gpus)
|
| 408 |
+
|
| 409 |
# Notify worker to clean up
|
| 410 |
try:
|
| 411 |
async with aiohttp.ClientSession() as client_session:
|
|
|
|
| 442 |
})
|
| 443 |
except Exception as e:
|
| 444 |
logger.error(f"Failed to send queue update to session {session_id}: {e}")
|
| 445 |
+
|
| 446 |
+
# Log queue status if there's a queue
|
| 447 |
+
if self.session_queue:
|
| 448 |
+
estimated_wait = self._calculate_dynamic_wait_time(1)
|
| 449 |
+
analytics.log_queue_status(len(self.session_queue), estimated_wait)
|
| 450 |
|
| 451 |
def _calculate_dynamic_wait_time(self, position_in_queue: int) -> float:
|
| 452 |
"""Calculate dynamic estimated wait time based on current session progress"""
|
|
|
|
| 514 |
session = self.sessions.get(session_id)
|
| 515 |
if session:
|
| 516 |
session.last_activity = time.time()
|
| 517 |
+
session.interaction_count += 1
|
| 518 |
if not session.user_has_interacted:
|
| 519 |
session.user_has_interacted = True
|
| 520 |
logger.info(f"User started interacting in session {session_id}")
|
|
|
|
| 542 |
app = FastAPI()
|
| 543 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 544 |
|
| 545 |
+
# Global connection counter like in main.py
|
| 546 |
+
connection_counter = 0
|
| 547 |
+
|
| 548 |
@app.get("/")
|
| 549 |
async def get():
|
| 550 |
return HTMLResponse(open("static/index.html").read())
|
|
|
|
| 593 |
|
| 594 |
@app.websocket("/ws")
|
| 595 |
async def websocket_endpoint(websocket: WebSocket):
|
| 596 |
+
global connection_counter
|
| 597 |
await websocket.accept()
|
| 598 |
|
| 599 |
+
# Extract client IP address
|
| 600 |
+
client_ip = "unknown"
|
| 601 |
+
if websocket.client and hasattr(websocket.client, 'host'):
|
| 602 |
+
client_ip = websocket.client.host
|
| 603 |
+
elif hasattr(websocket, 'headers'):
|
| 604 |
+
# Try to get real IP from headers (for proxy setups)
|
| 605 |
+
client_ip = websocket.headers.get('x-forwarded-for',
|
| 606 |
+
websocket.headers.get('x-real-ip',
|
| 607 |
+
websocket.headers.get('cf-connecting-ip', 'unknown')))
|
| 608 |
+
if ',' in client_ip:
|
| 609 |
+
client_ip = client_ip.split(',')[0].strip()
|
| 610 |
+
|
| 611 |
+
# Create session with connection counter like in main.py
|
| 612 |
+
connection_counter += 1
|
| 613 |
session_id = str(uuid.uuid4())
|
| 614 |
+
client_id = f"{int(time.time())}_{connection_counter}"
|
| 615 |
|
| 616 |
session = UserSession(
|
| 617 |
session_id=session_id,
|
| 618 |
client_id=client_id,
|
| 619 |
websocket=websocket,
|
| 620 |
created_at=time.time(),
|
| 621 |
+
status=SessionStatus.QUEUED,
|
| 622 |
+
ip_address=client_ip
|
| 623 |
)
|
| 624 |
|
| 625 |
+
logger.info(f"New WebSocket connection: {client_id} from {client_ip}")
|
| 626 |
+
|
| 627 |
+
# Log new connection analytics
|
| 628 |
+
analytics.log_new_connection(client_id, client_ip)
|
| 629 |
|
| 630 |
try:
|
| 631 |
# Add to queue
|
|
|
|
| 720 |
except Exception as e:
|
| 721 |
logger.error(f"Error in periodic queue update: {e}")
|
| 722 |
|
| 723 |
+
# Background task to periodically log analytics summary
|
| 724 |
+
async def periodic_analytics_summary():
|
| 725 |
+
while True:
|
| 726 |
+
try:
|
| 727 |
+
await asyncio.sleep(300) # Log summary every 5 minutes
|
| 728 |
+
analytics.log_periodic_summary()
|
| 729 |
+
except Exception as e:
|
| 730 |
+
logger.error(f"Error in periodic analytics summary: {e}")
|
| 731 |
+
|
| 732 |
+
# Background task to check worker health
|
| 733 |
+
async def periodic_worker_health_check():
|
| 734 |
+
while True:
|
| 735 |
+
try:
|
| 736 |
+
await asyncio.sleep(60) # Check every minute
|
| 737 |
+
current_time = time.time()
|
| 738 |
+
disconnected_workers = []
|
| 739 |
+
|
| 740 |
+
for worker_id, worker in list(session_manager.workers.items()):
|
| 741 |
+
if current_time - worker.last_ping > 30: # 30 second timeout
|
| 742 |
+
disconnected_workers.append((worker_id, worker.gpu_id))
|
| 743 |
+
|
| 744 |
+
for worker_id, gpu_id in disconnected_workers:
|
| 745 |
+
analytics.log_worker_disconnected(worker_id, gpu_id)
|
| 746 |
+
del session_manager.workers[worker_id]
|
| 747 |
+
logger.warning(f"Removed disconnected worker {worker_id} (GPU {gpu_id})")
|
| 748 |
+
|
| 749 |
+
if disconnected_workers:
|
| 750 |
+
# Log updated GPU status
|
| 751 |
+
total_gpus = len(session_manager.workers)
|
| 752 |
+
active_gpus = len([w for w in session_manager.workers.values() if not w.is_available])
|
| 753 |
+
available_gpus = total_gpus - active_gpus
|
| 754 |
+
analytics.log_gpu_status(total_gpus, active_gpus, available_gpus)
|
| 755 |
+
|
| 756 |
+
except Exception as e:
|
| 757 |
+
logger.error(f"Error in periodic worker health check: {e}")
|
| 758 |
+
|
| 759 |
@app.on_event("startup")
|
| 760 |
async def startup_event():
|
| 761 |
# Start background tasks
|
| 762 |
asyncio.create_task(periodic_queue_update())
|
| 763 |
+
asyncio.create_task(periodic_analytics_summary())
|
| 764 |
+
asyncio.create_task(periodic_worker_health_check())
|
| 765 |
+
|
| 766 |
+
# Log initial system status
|
| 767 |
+
analytics._write_log("🚀 System initialized and ready to accept connections")
|
| 768 |
+
analytics._write_log(" Waiting for GPU workers to register...")
|
| 769 |
+
|
| 770 |
+
@app.on_event("shutdown")
|
| 771 |
+
async def shutdown_event():
|
| 772 |
+
# Log final system summary
|
| 773 |
+
analytics._write_log("")
|
| 774 |
+
analytics._write_log("🛑 System shutting down...")
|
| 775 |
+
analytics.log_periodic_summary()
|
| 776 |
+
analytics._write_log("System shutdown complete.")
|
| 777 |
|
| 778 |
if __name__ == "__main__":
|
| 779 |
import uvicorn
|
main.py
CHANGED
|
@@ -526,7 +526,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 526 |
if not user_has_interacted:
|
| 527 |
user_has_interacted = True
|
| 528 |
print(f"[{time.perf_counter():.3f}] User has started interacting with canvas for client {client_id}")
|
| 529 |
-
|
|
|
|
|
|
|
| 530 |
|
| 531 |
# Update the set based on the received data
|
| 532 |
for key in keys_down_list:
|
|
@@ -649,7 +651,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 649 |
is_interesting = (current_input.get("is_left_click") or
|
| 650 |
current_input.get("is_right_click") or
|
| 651 |
(current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or
|
| 652 |
-
(current_input.get("keys_up") and len(current_input.get("keys_up")) > 0)
|
|
|
|
|
|
|
| 653 |
|
| 654 |
# Process immediately if interesting
|
| 655 |
if is_interesting:
|
|
@@ -802,6 +806,8 @@ def log_interaction(client_id, data, generated_frame=None, is_end_of_session=Fal
|
|
| 802 |
"is_right_click": data.get("is_right_click"),
|
| 803 |
"keys_down": data.get("keys_down", []),
|
| 804 |
"keys_up": data.get("keys_up", []),
|
|
|
|
|
|
|
| 805 |
"is_auto_input": data.get("is_auto_input", False)
|
| 806 |
}
|
| 807 |
else:
|
|
@@ -809,6 +815,8 @@ def log_interaction(client_id, data, generated_frame=None, is_end_of_session=Fal
|
|
| 809 |
log_entry["inputs"] = None
|
| 810 |
|
| 811 |
# Save to a file (one file per session)
|
|
|
|
|
|
|
| 812 |
session_file = f"interaction_logs/session_{client_id}.jsonl"
|
| 813 |
with open(session_file, "a") as f:
|
| 814 |
f.write(json.dumps(log_entry) + "\n")
|
|
|
|
| 526 |
if not user_has_interacted:
|
| 527 |
user_has_interacted = True
|
| 528 |
print(f"[{time.perf_counter():.3f}] User has started interacting with canvas for client {client_id}")
|
| 529 |
+
wheel_delta_x = data.get("wheel_delta_x", 0)
|
| 530 |
+
wheel_delta_y = data.get("wheel_delta_y", 0)
|
| 531 |
+
print(f'[{time.perf_counter():.3f}] Processing: x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}, wheel: ({wheel_delta_x},{wheel_delta_y}), time_since_activity: {time.perf_counter() - last_user_activity_time:.3f}')
|
| 532 |
|
| 533 |
# Update the set based on the received data
|
| 534 |
for key in keys_down_list:
|
|
|
|
| 651 |
is_interesting = (current_input.get("is_left_click") or
|
| 652 |
current_input.get("is_right_click") or
|
| 653 |
(current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or
|
| 654 |
+
(current_input.get("keys_up") and len(current_input.get("keys_up")) > 0) or
|
| 655 |
+
current_input.get("wheel_delta_x", 0) != 0 or
|
| 656 |
+
current_input.get("wheel_delta_y", 0) != 0)
|
| 657 |
|
| 658 |
# Process immediately if interesting
|
| 659 |
if is_interesting:
|
|
|
|
| 806 |
"is_right_click": data.get("is_right_click"),
|
| 807 |
"keys_down": data.get("keys_down", []),
|
| 808 |
"keys_up": data.get("keys_up", []),
|
| 809 |
+
"wheel_delta_x": data.get("wheel_delta_x", 0),
|
| 810 |
+
"wheel_delta_y": data.get("wheel_delta_y", 0),
|
| 811 |
"is_auto_input": data.get("is_auto_input", False)
|
| 812 |
}
|
| 813 |
else:
|
|
|
|
| 815 |
log_entry["inputs"] = None
|
| 816 |
|
| 817 |
# Save to a file (one file per session)
|
| 818 |
+
if not os.path.exists("interaction_logs"):
|
| 819 |
+
os.makedirs("interaction_logs", exist_ok=True)
|
| 820 |
session_file = f"interaction_logs/session_{client_id}.jsonl"
|
| 821 |
with open(session_file, "a") as f:
|
| 822 |
f.write(json.dumps(log_entry) + "\n")
|
start_system.sh
CHANGED
|
@@ -61,6 +61,7 @@ echo "========================================"
|
|
| 61 |
echo "📊 Number of GPUs: $NUM_GPUS"
|
| 62 |
echo "🌐 Dispatcher port: $DISPATCHER_PORT"
|
| 63 |
echo "💻 Worker ports: $(seq -s', ' 8001 $((8000 + NUM_GPUS)))"
|
|
|
|
| 64 |
echo ""
|
| 65 |
|
| 66 |
# Check if required files exist
|
|
@@ -130,12 +131,14 @@ for ((i=0; i<NUM_GPUS; i++)); do
|
|
| 130 |
done
|
| 131 |
echo ""
|
| 132 |
echo "📋 Log files:"
|
|
|
|
| 133 |
echo " Dispatcher: dispatcher.log"
|
| 134 |
echo " Workers summary: workers.log"
|
| 135 |
for ((i=0; i<NUM_GPUS; i++)); do
|
| 136 |
echo " GPU $i worker: worker_gpu_$i.log"
|
| 137 |
done
|
| 138 |
echo ""
|
|
|
|
| 139 |
echo "Press Ctrl+C to stop the system"
|
| 140 |
echo "================================"
|
| 141 |
|
|
|
|
| 61 |
echo "📊 Number of GPUs: $NUM_GPUS"
|
| 62 |
echo "🌐 Dispatcher port: $DISPATCHER_PORT"
|
| 63 |
echo "💻 Worker ports: $(seq -s', ' 8001 $((8000 + NUM_GPUS)))"
|
| 64 |
+
echo "📈 Analytics logging: system_analytics_$(date +%Y%m%d_%H%M%S).log"
|
| 65 |
echo ""
|
| 66 |
|
| 67 |
# Check if required files exist
|
|
|
|
| 131 |
done
|
| 132 |
echo ""
|
| 133 |
echo "📋 Log files:"
|
| 134 |
+
echo " System analytics: system_analytics_*.log (real-time monitoring)"
|
| 135 |
echo " Dispatcher: dispatcher.log"
|
| 136 |
echo " Workers summary: workers.log"
|
| 137 |
for ((i=0; i<NUM_GPUS; i++)); do
|
| 138 |
echo " GPU $i worker: worker_gpu_$i.log"
|
| 139 |
done
|
| 140 |
echo ""
|
| 141 |
+
echo "💡 Monitor system in real-time: tail -f system_analytics_*.log"
|
| 142 |
echo "Press Ctrl+C to stop the system"
|
| 143 |
echo "================================"
|
| 144 |
|
static/index.html
CHANGED
|
@@ -414,6 +414,8 @@
|
|
| 414 |
"is_right_click": false,
|
| 415 |
"keys_down": [],
|
| 416 |
"keys_up": [],
|
|
|
|
|
|
|
| 417 |
"is_auto_input": true // Flag to identify auto-generated inputs
|
| 418 |
}));
|
| 419 |
lastAutoInputTime = currentTime;
|
|
@@ -531,7 +533,9 @@
|
|
| 531 |
"is_left_click": false,
|
| 532 |
"is_right_click": false,
|
| 533 |
"keys_down": [],
|
| 534 |
-
"keys_up": []
|
|
|
|
|
|
|
| 535 |
}));
|
| 536 |
updateLastUserInputTime(); // Update for auto-input mechanism
|
| 537 |
} catch (error) {
|
|
@@ -541,9 +545,9 @@
|
|
| 541 |
stopTimeoutCountdown();
|
| 542 |
}
|
| 543 |
|
| 544 |
-
function sendInputState(x, y, isLeftClick = false, isRightClick = false, keysDownArr = [], keysUpArr = []) {
|
| 545 |
const currentTime = Date.now();
|
| 546 |
-
if (isConnected && socket.readyState === WebSocket.OPEN && (isLeftClick || isRightClick || keysDownArr.length > 0 || keysUpArr.length > 0 || !lastSentPosition || currentTime - lastSentTime >= SEND_INTERVAL)) {
|
| 547 |
try {
|
| 548 |
socket.send(JSON.stringify({
|
| 549 |
"x": x,
|
|
@@ -552,6 +556,8 @@
|
|
| 552 |
"is_right_click": isRightClick,
|
| 553 |
"keys_down": keysDownArr,
|
| 554 |
"keys_up": keysUpArr,
|
|
|
|
|
|
|
| 555 |
}));
|
| 556 |
lastSentPosition = { x, y };
|
| 557 |
lastSentTime = currentTime;
|
|
@@ -638,6 +644,27 @@
|
|
| 638 |
sendInputState(x, y, false, true);
|
| 639 |
});
|
| 640 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
// Track keyboard events
|
| 642 |
const TROUBLESOME = new Set([
|
| 643 |
"Tab", // focus change
|
|
|
|
| 414 |
"is_right_click": false,
|
| 415 |
"keys_down": [],
|
| 416 |
"keys_up": [],
|
| 417 |
+
"wheel_delta_x": 0,
|
| 418 |
+
"wheel_delta_y": 0,
|
| 419 |
"is_auto_input": true // Flag to identify auto-generated inputs
|
| 420 |
}));
|
| 421 |
lastAutoInputTime = currentTime;
|
|
|
|
| 533 |
"is_left_click": false,
|
| 534 |
"is_right_click": false,
|
| 535 |
"keys_down": [],
|
| 536 |
+
"keys_up": [],
|
| 537 |
+
"wheel_delta_x": 0,
|
| 538 |
+
"wheel_delta_y": 0
|
| 539 |
}));
|
| 540 |
updateLastUserInputTime(); // Update for auto-input mechanism
|
| 541 |
} catch (error) {
|
|
|
|
| 545 |
stopTimeoutCountdown();
|
| 546 |
}
|
| 547 |
|
| 548 |
+
function sendInputState(x, y, isLeftClick = false, isRightClick = false, keysDownArr = [], keysUpArr = [], wheelDeltaX = 0, wheelDeltaY = 0) {
|
| 549 |
const currentTime = Date.now();
|
| 550 |
+
if (isConnected && socket.readyState === WebSocket.OPEN && (isLeftClick || isRightClick || keysDownArr.length > 0 || keysUpArr.length > 0 || wheelDeltaX !== 0 || wheelDeltaY !== 0 || !lastSentPosition || currentTime - lastSentTime >= SEND_INTERVAL)) {
|
| 551 |
try {
|
| 552 |
socket.send(JSON.stringify({
|
| 553 |
"x": x,
|
|
|
|
| 556 |
"is_right_click": isRightClick,
|
| 557 |
"keys_down": keysDownArr,
|
| 558 |
"keys_up": keysUpArr,
|
| 559 |
+
"wheel_delta_x": wheelDeltaX,
|
| 560 |
+
"wheel_delta_y": wheelDeltaY,
|
| 561 |
}));
|
| 562 |
lastSentPosition = { x, y };
|
| 563 |
lastSentTime = currentTime;
|
|
|
|
| 644 |
sendInputState(x, y, false, true);
|
| 645 |
});
|
| 646 |
|
| 647 |
+
// Handle mouse wheel events
|
| 648 |
+
canvas.addEventListener("wheel", function (event) {
|
| 649 |
+
event.preventDefault(); // Prevent page scrolling
|
| 650 |
+
if (!isConnected || isProcessing) return;
|
| 651 |
+
|
| 652 |
+
let rect = canvas.getBoundingClientRect();
|
| 653 |
+
let x = event.clientX - rect.left;
|
| 654 |
+
let y = event.clientY - rect.top;
|
| 655 |
+
|
| 656 |
+
// Normalize wheel delta values (different browsers handle this differently)
|
| 657 |
+
let deltaX = event.deltaX;
|
| 658 |
+
let deltaY = event.deltaY;
|
| 659 |
+
|
| 660 |
+
// Clamp values to reasonable range
|
| 661 |
+
//deltaX = Math.max(-10, Math.min(10, deltaX));
|
| 662 |
+
//deltaY = Math.max(-10, Math.min(10, deltaY));
|
| 663 |
+
|
| 664 |
+
console.log(`Wheel event: deltaX=${deltaX}, deltaY=${deltaY} at (${x}, ${y})`);
|
| 665 |
+
sendInputState(x, y, false, false, [], [], deltaX, deltaY);
|
| 666 |
+
});
|
| 667 |
+
|
| 668 |
// Track keyboard events
|
| 669 |
const TROUBLESOME = new Set([
|
| 670 |
"Tab", // focus change
|
worker.py
CHANGED
|
@@ -293,9 +293,17 @@ class GPUWorker:
|
|
| 293 |
|
| 294 |
return sample_latent, sample_img, hidden_states, timing
|
| 295 |
|
| 296 |
-
def initialize_session(self, session_id: str):
|
| 297 |
"""Initialize a new session"""
|
| 298 |
self.current_session = session_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
self.session_data[session_id] = {
|
| 300 |
'previous_frame': self.padding_image,
|
| 301 |
'hidden_states': None,
|
|
@@ -306,9 +314,10 @@ class GPUWorker:
|
|
| 306 |
'sampling_steps': self.NUM_SAMPLING_STEPS
|
| 307 |
},
|
| 308 |
'input_queue': asyncio.Queue(),
|
| 309 |
-
'is_processing': False
|
|
|
|
| 310 |
}
|
| 311 |
-
logger.info(f"Initialized session {session_id}")
|
| 312 |
|
| 313 |
# Start processing task for this session
|
| 314 |
asyncio.create_task(self._process_session_queue(session_id))
|
|
@@ -316,8 +325,12 @@ class GPUWorker:
|
|
| 316 |
def end_session(self, session_id: str):
|
| 317 |
"""End a session and clean up"""
|
| 318 |
if session_id in self.session_data:
|
| 319 |
-
#
|
| 320 |
session = self.session_data[session_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
while not session['input_queue'].empty():
|
| 322 |
try:
|
| 323 |
session['input_queue'].get_nowait()
|
|
@@ -391,7 +404,9 @@ class GPUWorker:
|
|
| 391 |
is_interesting = (current_input.get("is_left_click") or
|
| 392 |
current_input.get("is_right_click") or
|
| 393 |
(current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or
|
| 394 |
-
(current_input.get("keys_up") and len(current_input.get("keys_up")) > 0)
|
|
|
|
|
|
|
| 395 |
|
| 396 |
# Process immediately if interesting
|
| 397 |
if is_interesting:
|
|
@@ -416,13 +431,17 @@ class GPUWorker:
|
|
| 416 |
async def process_input(self, session_id: str, data: dict) -> dict:
|
| 417 |
"""Process input for a session - adds to queue or handles control messages"""
|
| 418 |
if session_id not in self.session_data:
|
| 419 |
-
self.initialize_session(session_id)
|
| 420 |
|
| 421 |
session = self.session_data[session_id]
|
| 422 |
|
| 423 |
# Handle control messages immediately (don't queue these)
|
| 424 |
if data.get("type") == "reset":
|
| 425 |
logger.info(f"Received reset command for session {session_id}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
# Clear the queue
|
| 427 |
while not session['input_queue'].empty():
|
| 428 |
try:
|
|
@@ -484,6 +503,8 @@ class GPUWorker:
|
|
| 484 |
is_right_click = data.get("is_right_click", False)
|
| 485 |
keys_down_list = data.get("keys_down", [])
|
| 486 |
keys_up_list = data.get("keys_up", [])
|
|
|
|
|
|
|
| 487 |
|
| 488 |
# Update keys_down set
|
| 489 |
for key in keys_down_list:
|
|
@@ -518,8 +539,13 @@ class GPUWorker:
|
|
| 518 |
session['frame_num']
|
| 519 |
)
|
| 520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
# Process frame
|
| 522 |
-
logger.info(f"Processing frame {session['frame_num']} for session {session_id}")
|
| 523 |
sample_latent, sample_img, hidden_states, timing_info = await self.process_frame(
|
| 524 |
inputs,
|
| 525 |
use_rnn=session['client_settings']['use_rnn'],
|
|
@@ -539,6 +565,10 @@ class GPUWorker:
|
|
| 539 |
# Log timing
|
| 540 |
logger.info(f"Frame {session['frame_num']} processed in {timing_info['total']:.4f}s (FPS: {1.0/timing_info['total']:.2f})")
|
| 541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
# Send result back to dispatcher
|
| 543 |
await self._send_result_to_dispatcher(session_id, {"image": img_str})
|
| 544 |
|
|
@@ -566,6 +596,55 @@ app = FastAPI()
|
|
| 566 |
# Global worker instance
|
| 567 |
worker: Optional[GPUWorker] = None
|
| 568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
@app.post("/process_input")
|
| 570 |
async def process_input_endpoint(request: dict):
|
| 571 |
"""Process input from dispatcher"""
|
|
@@ -581,13 +660,29 @@ async def process_input_endpoint(request: dict):
|
|
| 581 |
result = await worker.process_input(session_id, data)
|
| 582 |
return result
|
| 583 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
@app.post("/end_session")
|
| 585 |
async def end_session_endpoint(request: dict):
|
| 586 |
-
"""End
|
| 587 |
if not worker:
|
| 588 |
raise HTTPException(status_code=500, detail="Worker not initialized")
|
| 589 |
|
| 590 |
session_id = request.get("session_id")
|
|
|
|
| 591 |
if not session_id:
|
| 592 |
raise HTTPException(status_code=400, detail="Missing session_id")
|
| 593 |
|
|
|
|
| 293 |
|
| 294 |
return sample_latent, sample_img, hidden_states, timing
|
| 295 |
|
| 296 |
+
def initialize_session(self, session_id: str, client_id: str = None):
|
| 297 |
"""Initialize a new session"""
|
| 298 |
self.current_session = session_id
|
| 299 |
+
# Use client_id from dispatcher if provided, otherwise create one
|
| 300 |
+
if client_id:
|
| 301 |
+
log_session_id = client_id
|
| 302 |
+
else:
|
| 303 |
+
# Fallback: create a time-prefixed session identifier for logging
|
| 304 |
+
session_start_time = int(time.time())
|
| 305 |
+
log_session_id = f"{session_start_time}_{session_id}"
|
| 306 |
+
|
| 307 |
self.session_data[session_id] = {
|
| 308 |
'previous_frame': self.padding_image,
|
| 309 |
'hidden_states': None,
|
|
|
|
| 314 |
'sampling_steps': self.NUM_SAMPLING_STEPS
|
| 315 |
},
|
| 316 |
'input_queue': asyncio.Queue(),
|
| 317 |
+
'is_processing': False,
|
| 318 |
+
'log_session_id': log_session_id # Store the time-prefixed ID for logging
|
| 319 |
}
|
| 320 |
+
logger.info(f"Initialized session {session_id} with log ID {log_session_id}")
|
| 321 |
|
| 322 |
# Start processing task for this session
|
| 323 |
asyncio.create_task(self._process_session_queue(session_id))
|
|
|
|
| 325 |
def end_session(self, session_id: str):
|
| 326 |
"""End a session and clean up"""
|
| 327 |
if session_id in self.session_data:
|
| 328 |
+
# Log session end using the stored log_session_id
|
| 329 |
session = self.session_data[session_id]
|
| 330 |
+
log_session_id = session.get('log_session_id', session_id) # Fallback to session_id if not found
|
| 331 |
+
log_interaction(log_session_id, {}, is_end_of_session=True)
|
| 332 |
+
|
| 333 |
+
# Clear any remaining items in the queue
|
| 334 |
while not session['input_queue'].empty():
|
| 335 |
try:
|
| 336 |
session['input_queue'].get_nowait()
|
|
|
|
| 404 |
is_interesting = (current_input.get("is_left_click") or
|
| 405 |
current_input.get("is_right_click") or
|
| 406 |
(current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or
|
| 407 |
+
(current_input.get("keys_up") and len(current_input.get("keys_up")) > 0) or
|
| 408 |
+
current_input.get("wheel_delta_x", 0) != 0 or
|
| 409 |
+
current_input.get("wheel_delta_y", 0) != 0)
|
| 410 |
|
| 411 |
# Process immediately if interesting
|
| 412 |
if is_interesting:
|
|
|
|
| 431 |
async def process_input(self, session_id: str, data: dict) -> dict:
|
| 432 |
"""Process input for a session - adds to queue or handles control messages"""
|
| 433 |
if session_id not in self.session_data:
|
| 434 |
+
self.initialize_session(session_id) # Fallback initialization without client_id
|
| 435 |
|
| 436 |
session = self.session_data[session_id]
|
| 437 |
|
| 438 |
# Handle control messages immediately (don't queue these)
|
| 439 |
if data.get("type") == "reset":
|
| 440 |
logger.info(f"Received reset command for session {session_id}")
|
| 441 |
+
# Log the reset action using the stored log_session_id
|
| 442 |
+
log_session_id = session.get('log_session_id', session_id) # Fallback to session_id if not found
|
| 443 |
+
log_interaction(log_session_id, data, is_reset=True)
|
| 444 |
+
|
| 445 |
# Clear the queue
|
| 446 |
while not session['input_queue'].empty():
|
| 447 |
try:
|
|
|
|
| 503 |
is_right_click = data.get("is_right_click", False)
|
| 504 |
keys_down_list = data.get("keys_down", [])
|
| 505 |
keys_up_list = data.get("keys_up", [])
|
| 506 |
+
wheel_delta_x = data.get("wheel_delta_x", 0)
|
| 507 |
+
wheel_delta_y = data.get("wheel_delta_y", 0)
|
| 508 |
|
| 509 |
# Update keys_down set
|
| 510 |
for key in keys_down_list:
|
|
|
|
| 539 |
session['frame_num']
|
| 540 |
)
|
| 541 |
|
| 542 |
+
# Log the input data being processed
|
| 543 |
+
logger.info(f"Processing frame {session['frame_num']} for session {session_id}: "
|
| 544 |
+
f"pos=({x},{y}), clicks=(L:{is_left_click},R:{is_right_click}), "
|
| 545 |
+
f"keys_down={keys_down_list}, keys_up={keys_up_list}, "
|
| 546 |
+
f"wheel=({wheel_delta_x},{wheel_delta_y})")
|
| 547 |
+
|
| 548 |
# Process frame
|
|
|
|
| 549 |
sample_latent, sample_img, hidden_states, timing_info = await self.process_frame(
|
| 550 |
inputs,
|
| 551 |
use_rnn=session['client_settings']['use_rnn'],
|
|
|
|
| 565 |
# Log timing
|
| 566 |
logger.info(f"Frame {session['frame_num']} processed in {timing_info['total']:.4f}s (FPS: {1.0/timing_info['total']:.2f})")
|
| 567 |
|
| 568 |
+
# Log the interaction using the stored log_session_id
|
| 569 |
+
log_session_id = session.get('log_session_id', session_id) # Fallback to session_id if not found
|
| 570 |
+
log_interaction(log_session_id, data, generated_frame=sample_img)
|
| 571 |
+
|
| 572 |
# Send result back to dispatcher
|
| 573 |
await self._send_result_to_dispatcher(session_id, {"image": img_str})
|
| 574 |
|
|
|
|
| 596 |
# Global worker instance
|
| 597 |
worker: Optional[GPUWorker] = None
|
| 598 |
|
| 599 |
+
def log_interaction(log_session_id, data, generated_frame=None, is_end_of_session=False, is_reset=False):
|
| 600 |
+
"""Log user interaction and optionally the generated frame."""
|
| 601 |
+
timestamp = time.time()
|
| 602 |
+
|
| 603 |
+
# Create directory structure if it doesn't exist
|
| 604 |
+
os.makedirs("interaction_logs", exist_ok=True)
|
| 605 |
+
|
| 606 |
+
# Structure the log entry
|
| 607 |
+
log_entry = {
|
| 608 |
+
"timestamp": timestamp,
|
| 609 |
+
"session_id": log_session_id, # Use the time-prefixed session ID
|
| 610 |
+
"is_eos": is_end_of_session,
|
| 611 |
+
"is_reset": is_reset
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
# Include type if present (for reset, etc.)
|
| 615 |
+
if data.get("type"):
|
| 616 |
+
log_entry["type"] = data.get("type")
|
| 617 |
+
|
| 618 |
+
# Only include input data if this isn't just a control message
|
| 619 |
+
if not is_end_of_session and not is_reset:
|
| 620 |
+
log_entry["inputs"] = {
|
| 621 |
+
"x": data.get("x"),
|
| 622 |
+
"y": data.get("y"),
|
| 623 |
+
"is_left_click": data.get("is_left_click"),
|
| 624 |
+
"is_right_click": data.get("is_right_click"),
|
| 625 |
+
"keys_down": data.get("keys_down", []),
|
| 626 |
+
"keys_up": data.get("keys_up", []),
|
| 627 |
+
"wheel_delta_x": data.get("wheel_delta_x", 0),
|
| 628 |
+
"wheel_delta_y": data.get("wheel_delta_y", 0),
|
| 629 |
+
"is_auto_input": data.get("is_auto_input", False)
|
| 630 |
+
}
|
| 631 |
+
else:
|
| 632 |
+
# For EOS/reset records, just include minimal info
|
| 633 |
+
log_entry["inputs"] = None
|
| 634 |
+
|
| 635 |
+
# Use the time-prefixed session ID for the filename (already includes timestamp)
|
| 636 |
+
session_file = f"interaction_logs/session_{log_session_id}.jsonl"
|
| 637 |
+
with open(session_file, "a") as f:
|
| 638 |
+
f.write(json.dumps(log_entry) + "\n")
|
| 639 |
+
|
| 640 |
+
# Optionally save the frame if provided
|
| 641 |
+
if generated_frame is not None and not is_end_of_session and not is_reset:
|
| 642 |
+
frame_dir = f"interaction_logs/frames_{log_session_id}"
|
| 643 |
+
os.makedirs(frame_dir, exist_ok=True)
|
| 644 |
+
frame_file = f"{frame_dir}/{timestamp:.6f}.png"
|
| 645 |
+
# Save the frame as PNG
|
| 646 |
+
Image.fromarray(generated_frame).save(frame_file)
|
| 647 |
+
|
| 648 |
@app.post("/process_input")
|
| 649 |
async def process_input_endpoint(request: dict):
|
| 650 |
"""Process input from dispatcher"""
|
|
|
|
| 660 |
result = await worker.process_input(session_id, data)
|
| 661 |
return result
|
| 662 |
|
| 663 |
+
@app.post("/init_session")
|
| 664 |
+
async def init_session_endpoint(request: dict):
|
| 665 |
+
"""Initialize session from dispatcher with client_id"""
|
| 666 |
+
if not worker:
|
| 667 |
+
raise HTTPException(status_code=500, detail="Worker not initialized")
|
| 668 |
+
|
| 669 |
+
session_id = request.get("session_id")
|
| 670 |
+
client_id = request.get("client_id")
|
| 671 |
+
|
| 672 |
+
if not session_id:
|
| 673 |
+
raise HTTPException(status_code=400, detail="Missing session_id")
|
| 674 |
+
|
| 675 |
+
worker.initialize_session(session_id, client_id)
|
| 676 |
+
return {"status": "session_initialized"}
|
| 677 |
+
|
| 678 |
@app.post("/end_session")
|
| 679 |
async def end_session_endpoint(request: dict):
|
| 680 |
+
"""End session from dispatcher"""
|
| 681 |
if not worker:
|
| 682 |
raise HTTPException(status_code=500, detail="Worker not initialized")
|
| 683 |
|
| 684 |
session_id = request.get("session_id")
|
| 685 |
+
|
| 686 |
if not session_id:
|
| 687 |
raise HTTPException(status_code=400, detail="Missing session_id")
|
| 688 |
|