Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from typing import List, Tuple, Dict, Any, Optional | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import base64 | |
| import io | |
| import json | |
| import asyncio | |
| import time | |
| import torch | |
| import os | |
| import logging | |
| from utils import initialize_model, sample_frame | |
| from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler | |
| import concurrent.futures | |
| import aiohttp | |
| import argparse | |
| import uuid | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # GPU settings | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| class GPUWorker: | |
| def __init__(self, gpu_id: int, dispatcher_url: str = "http://localhost:8000"): | |
| self.gpu_id = gpu_id | |
| self.dispatcher_url = dispatcher_url | |
| self.worker_id = f"worker_{gpu_id}_{uuid.uuid4().hex[:8]}" | |
| self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') | |
| self.current_session: Optional[str] = None | |
| self.session_data: Dict[str, Any] = {} | |
| # Model configuration from main.py | |
| self.DEBUG_MODE = False | |
| self.DEBUG_MODE_2 = False | |
| self.NUM_MAX_FRAMES = 1 | |
| self.TIMESTEPS = 1000 | |
| self.SCREEN_WIDTH = 512 | |
| self.SCREEN_HEIGHT = 384 | |
| self.NUM_SAMPLING_STEPS = 32 | |
| self.USE_RNN = False | |
| self.MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-06k" | |
| # Initialize model | |
| self._initialize_model() | |
| # Thread executor for heavy computation | |
| self.thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) | |
| # Load keyboard mappings | |
| self._load_keyboard_mappings() | |
| logger.info(f"GPU Worker {self.worker_id} initialized on GPU {gpu_id}") | |
| def _initialize_model(self): | |
| """Initialize the model on the specified GPU""" | |
| logger.info(f"Initializing model on GPU {self.gpu_id}") | |
| # Load latent stats | |
| with open('latent_stats.json', 'r') as f: | |
| latent_stats = json.load(f) | |
| self.DATA_NORMALIZATION = { | |
| 'mean': torch.tensor(latent_stats['mean']).to(self.device), | |
| 'std': torch.tensor(latent_stats['std']).to(self.device) | |
| } | |
| self.LATENT_DIMS = (16, self.SCREEN_HEIGHT // 8, self.SCREEN_WIDTH // 8) | |
| # Initialize model based on model name | |
| if 'origunet' in self.MODEL_NAME: | |
| if 'x0' in self.MODEL_NAME: | |
| if 'ddpm32' in self.MODEL_NAME: | |
| self.TIMESTEPS = 32 | |
| self.model = initialize_model("config_final_model_origunet_nospatial_x0_ddpm32.yaml", self.MODEL_NAME) | |
| else: | |
| self.model = initialize_model("config_final_model_origunet_nospatial_x0.yaml", self.MODEL_NAME) | |
| else: | |
| if 'ddpm32' in self.MODEL_NAME: | |
| self.TIMESTEPS = 32 | |
| self.model = initialize_model("config_final_model_origunet_nospatial_ddpm32.yaml", self.MODEL_NAME) | |
| else: | |
| self.model = initialize_model("config_final_model_origunet_nospatial.yaml", self.MODEL_NAME) | |
| else: | |
| self.model = initialize_model("config_final_model.yaml", self.MODEL_NAME) | |
| self.model = self.model.to(self.device) | |
| # Create padding image | |
| self.padding_image = torch.zeros(*self.LATENT_DIMS).unsqueeze(0).to(self.device) | |
| self.padding_image = (self.padding_image - self.DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)) / self.DATA_NORMALIZATION['std'].view(1, -1, 1, 1) | |
| logger.info(f"Model initialized successfully on GPU {self.gpu_id}") | |
| def _load_keyboard_mappings(self): | |
| """Load keyboard mappings from main.py""" | |
| self.KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', | |
| ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', | |
| '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', | |
| 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', | |
| 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', | |
| 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', | |
| 'browserback', 'browserfavorites', 'browserforward', 'browserhome', | |
| 'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear', | |
| 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', | |
| 'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10', | |
| 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20', | |
| 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', | |
| 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', | |
| 'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail', | |
| 'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack', | |
| 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', | |
| 'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', | |
| 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn', | |
| 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', | |
| 'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab', | |
| 'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', | |
| 'command', 'option', 'optionleft', 'optionright'] | |
| self.KEYMAPPING = { | |
| 'arrowup': 'up', | |
| 'arrowdown': 'down', | |
| 'arrowleft': 'left', | |
| 'arrowright': 'right', | |
| 'meta': 'command', | |
| 'contextmenu': 'apps', | |
| 'control': 'ctrl', | |
| } | |
| self.INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20', | |
| 'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute'] | |
| self.VALID_KEYS = [key for key in self.KEYS if key not in self.INVALID_KEYS] | |
| self.itos = self.VALID_KEYS | |
| self.stoi = {key: i for i, key in enumerate(self.itos)} | |
| async def register_with_dispatcher(self): | |
| """Register this worker with the dispatcher""" | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| await session.post(f"{self.dispatcher_url}/register_worker", json={ | |
| "worker_id": self.worker_id, | |
| "gpu_id": self.gpu_id, | |
| "endpoint": f"http://localhost:{8001 + self.gpu_id}" | |
| }) | |
| logger.info(f"Successfully registered worker {self.worker_id} with dispatcher") | |
| except Exception as e: | |
| logger.error(f"Failed to register with dispatcher: {e}") | |
| async def ping_dispatcher(self): | |
| """Periodically ping the dispatcher to maintain connection""" | |
| while True: | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| await session.post(f"{self.dispatcher_url}/worker_ping", json={ | |
| "worker_id": self.worker_id, | |
| "is_available": self.current_session is None | |
| }) | |
| await asyncio.sleep(10) # Ping every 10 seconds | |
| except Exception as e: | |
| logger.error(f"Failed to ping dispatcher: {e}") | |
| await asyncio.sleep(5) # Retry after 5 seconds on error | |
| def prepare_model_inputs( | |
| self, | |
| previous_frame: torch.Tensor, | |
| hidden_states: Any, | |
| x: int, | |
| y: int, | |
| right_click: bool, | |
| left_click: bool, | |
| keys_down: List[str], | |
| time_step: int | |
| ) -> Dict[str, torch.Tensor]: | |
| """Prepare inputs for the model (from main.py)""" | |
| # Clamp coordinates to valid ranges | |
| x = min(max(0, x), self.SCREEN_WIDTH - 1) if x is not None else 0 | |
| y = min(max(0, y), self.SCREEN_HEIGHT - 1) if y is not None else 0 | |
| if self.DEBUG_MODE: | |
| logger.info('DEBUG MODE, SETTING TIME STEP TO 0') | |
| time_step = 0 | |
| if self.DEBUG_MODE_2: | |
| if time_step > self.NUM_MAX_FRAMES-1: | |
| logger.info('DEBUG MODE_2, SETTING TIME STEP TO 0') | |
| time_step = 0 | |
| inputs = { | |
| 'image_features': previous_frame.to(self.device), | |
| 'is_padding': torch.BoolTensor([time_step == 0]).to(self.device), | |
| 'x': torch.LongTensor([x]).unsqueeze(0).to(self.device), | |
| 'y': torch.LongTensor([y]).unsqueeze(0).to(self.device), | |
| 'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(self.device), | |
| 'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(self.device), | |
| 'key_events': torch.zeros(len(self.itos), dtype=torch.long).to(self.device) | |
| } | |
| for key in keys_down: | |
| key = key.lower() | |
| if key in self.KEYMAPPING: | |
| key = self.KEYMAPPING[key] | |
| if key in self.stoi: | |
| inputs['key_events'][self.stoi[key]] = 1 | |
| else: | |
| logger.warning(f'Key {key} not found in stoi') | |
| if hidden_states is not None: | |
| inputs['hidden_states'] = hidden_states | |
| if self.DEBUG_MODE: | |
| logger.info('DEBUG MODE, REMOVING INPUTS') | |
| if 'hidden_states' in inputs: | |
| del inputs['hidden_states'] | |
| if self.DEBUG_MODE_2: | |
| if time_step > self.NUM_MAX_FRAMES-1: | |
| logger.info('DEBUG MODE_2, REMOVING HIDDEN STATES') | |
| if 'hidden_states' in inputs: | |
| del inputs['hidden_states'] | |
| logger.info(f'Time step: {time_step}') | |
| return inputs | |
| async def process_frame( | |
| self, | |
| inputs: Dict[str, torch.Tensor], | |
| use_rnn: bool = False, | |
| num_sampling_steps: int = 32 | |
| ) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]: | |
| """Process a single frame through the model""" | |
| # Run the heavy computation in a separate thread | |
| loop = asyncio.get_running_loop() | |
| return await loop.run_in_executor( | |
| self.thread_executor, | |
| lambda: self._process_frame_sync(inputs, use_rnn, num_sampling_steps) | |
| ) | |
| def _process_frame_sync(self, inputs, use_rnn, num_sampling_steps): | |
| """Synchronous version of process_frame that runs in a thread""" | |
| timing = {} | |
| # Temporal encoding | |
| start = time.perf_counter() | |
| output_from_rnn, hidden_states = self.model.temporal_encoder.forward_step(inputs) | |
| timing['temporal_encoder'] = time.perf_counter() - start | |
| # UNet sampling | |
| start = time.perf_counter() | |
| logger.info(f"model.clip_denoised: {self.model.clip_denoised}") | |
| self.model.clip_denoised = False | |
| logger.info(f"USE_RNN: {use_rnn}, NUM_SAMPLING_STEPS: {num_sampling_steps}") | |
| if use_rnn: | |
| sample_latent = output_from_rnn[:, :16] | |
| else: | |
| if num_sampling_steps >= self.TIMESTEPS: | |
| sample_latent = self.model.p_sample_loop( | |
| cond={'c_concat': output_from_rnn}, | |
| shape=[1, *self.LATENT_DIMS], | |
| return_intermediates=False, | |
| verbose=True | |
| ) | |
| else: | |
| if num_sampling_steps == 1: | |
| x = torch.randn([1, *self.LATENT_DIMS], device=self.device) | |
| t = torch.full((1,), self.TIMESTEPS-1, device=self.device, dtype=torch.long) | |
| sample_latent = self.model.apply_model(x, t, {'c_concat': output_from_rnn}) | |
| else: | |
| sampler = DDIMSampler(self.model) | |
| sample_latent, _ = sampler.sample( | |
| S=num_sampling_steps, | |
| conditioning={'c_concat': output_from_rnn}, | |
| batch_size=1, | |
| shape=self.LATENT_DIMS, | |
| verbose=False | |
| ) | |
| timing['unet'] = time.perf_counter() - start | |
| # Decoding | |
| start = time.perf_counter() | |
| sample = sample_latent * self.DATA_NORMALIZATION['std'].view(1, -1, 1, 1) + self.DATA_NORMALIZATION['mean'].view(1, -1, 1, 1) | |
| sample = self.model.decode_first_stage(sample) | |
| sample = sample.squeeze(0).clamp(-1, 1) | |
| timing['decode'] = time.perf_counter() - start | |
| # Convert to image | |
| sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8) | |
| timing['total'] = sum(timing.values()) | |
| return sample_latent, sample_img, hidden_states, timing | |
| def initialize_session(self, session_id: str, client_id: str = None): | |
| """Initialize a new session""" | |
| self.current_session = session_id | |
| # Use client_id from dispatcher if provided, otherwise create one | |
| if client_id: | |
| log_session_id = client_id | |
| else: | |
| # Fallback: create a time-prefixed session identifier for logging | |
| session_start_time = int(time.time()) | |
| log_session_id = f"{session_start_time}_{session_id}" | |
| self.session_data[session_id] = { | |
| 'previous_frame': self.padding_image, | |
| 'hidden_states': None, | |
| 'keys_down': set(), | |
| 'frame_num': -1, | |
| 'client_settings': { | |
| 'use_rnn': self.USE_RNN, | |
| 'sampling_steps': self.NUM_SAMPLING_STEPS | |
| }, | |
| 'input_queue': asyncio.Queue(), | |
| 'is_processing': False, | |
| 'log_session_id': log_session_id # Store the time-prefixed ID for logging | |
| } | |
| logger.info(f"Initialized session {session_id} with log ID {log_session_id}") | |
| # Start processing task for this session | |
| asyncio.create_task(self._process_session_queue(session_id)) | |
| def end_session(self, session_id: str): | |
| """End a session and clean up""" | |
| if session_id in self.session_data: | |
| # Log session end using the stored log_session_id | |
| session = self.session_data[session_id] | |
| log_session_id = session.get('log_session_id', session_id) # Fallback to session_id if not found | |
| log_interaction(log_session_id, {}, is_end_of_session=True) | |
| # Clear any remaining items in the queue | |
| while not session['input_queue'].empty(): | |
| try: | |
| session['input_queue'].get_nowait() | |
| session['input_queue'].task_done() | |
| except asyncio.QueueEmpty: | |
| break | |
| del self.session_data[session_id] | |
| if self.current_session == session_id: | |
| self.current_session = None | |
| logger.info(f"Ended session {session_id}") | |
| async def _process_session_queue(self, session_id: str): | |
| """Process the input queue for a specific session with interesting input filtering""" | |
| while session_id in self.session_data: | |
| try: | |
| session = self.session_data[session_id] | |
| input_queue = session['input_queue'] | |
| # Wait for input to be available | |
| if input_queue.empty(): | |
| await asyncio.sleep(0.01) # Small delay to prevent busy waiting | |
| continue | |
| # If already processing, skip | |
| if session['is_processing']: | |
| await asyncio.sleep(0.01) | |
| continue | |
| # Set processing flag | |
| session['is_processing'] = True | |
| try: | |
| # Process queue with interesting input filtering | |
| await self._process_next_input(session_id) | |
| finally: | |
| session['is_processing'] = False | |
| except Exception as e: | |
| logger.error(f"Error in session queue processing for {session_id}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| await asyncio.sleep(1) # Prevent tight error loop | |
| logger.info(f"Session queue processor ended for {session_id}") | |
| async def _process_next_input(self, session_id: str): | |
| """Process next input with interesting input filtering (from main.py logic)""" | |
| session = self.session_data[session_id] | |
| input_queue = session['input_queue'] | |
| if input_queue.empty(): | |
| return | |
| queue_size = input_queue.qsize() | |
| logger.info(f"Processing next input for session {session_id}. Queue size: {queue_size}") | |
| try: | |
| # Initialize variables to track progress | |
| skipped = 0 | |
| latest_input = None | |
| # Process the queue one item at a time | |
| while not input_queue.empty(): | |
| current_input = await input_queue.get() | |
| input_queue.task_done() | |
| # Always update the latest input | |
| latest_input = current_input | |
| # Check if this is an interesting event | |
| is_interesting = (current_input.get("is_left_click") or | |
| current_input.get("is_right_click") or | |
| (current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or | |
| (current_input.get("keys_up") and len(current_input.get("keys_up")) > 0) or | |
| current_input.get("wheel_delta_x", 0) != 0 or | |
| current_input.get("wheel_delta_y", 0) != 0) | |
| # Process immediately if interesting | |
| if is_interesting: | |
| logger.info(f"Found interesting input for session {session_id} (skipped {skipped} events)") | |
| await self._process_single_input(session_id, current_input) | |
| return | |
| # Otherwise, continue to the next item | |
| skipped += 1 | |
| # If this is the last item and no interesting inputs were found | |
| if input_queue.empty(): | |
| logger.info(f"No interesting inputs for session {session_id}, processing latest movement (skipped {skipped-1} events)") | |
| await self._process_single_input(session_id, latest_input) | |
| return | |
| except Exception as e: | |
| logger.error(f"Error in _process_next_input for session {session_id}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| async def process_input(self, session_id: str, data: dict) -> dict: | |
| """Process input for a session - adds to queue or handles control messages""" | |
| if session_id not in self.session_data: | |
| self.initialize_session(session_id) # Fallback initialization without client_id | |
| session = self.session_data[session_id] | |
| # Handle control messages immediately (don't queue these) | |
| if data.get("type") == "reset": | |
| logger.info(f"Received reset command for session {session_id}") | |
| # Log the reset action using the stored log_session_id | |
| log_session_id = session.get('log_session_id', session_id) # Fallback to session_id if not found | |
| log_interaction(log_session_id, data, is_reset=True) | |
| # Clear the queue | |
| while not session['input_queue'].empty(): | |
| try: | |
| session['input_queue'].get_nowait() | |
| session['input_queue'].task_done() | |
| except asyncio.QueueEmpty: | |
| break | |
| session['previous_frame'] = self.padding_image | |
| session['hidden_states'] = None | |
| session['keys_down'] = set() | |
| session['frame_num'] = -1 | |
| return {"type": "reset_confirmed"} | |
| elif data.get("type") == "update_sampling_steps": | |
| steps = data.get("steps", 32) | |
| if steps < 1: | |
| return {"type": "error", "message": "Invalid sampling steps value"} | |
| session['client_settings']['sampling_steps'] = steps | |
| logger.info(f"Updated sampling steps to {steps} for session {session_id}") | |
| return {"type": "steps_updated", "steps": steps} | |
| elif data.get("type") == "update_use_rnn": | |
| use_rnn = data.get("use_rnn", False) | |
| session['client_settings']['use_rnn'] = use_rnn | |
| logger.info(f"Updated USE_RNN to {use_rnn} for session {session_id}") | |
| return {"type": "rnn_updated", "use_rnn": use_rnn} | |
| elif data.get("type") == "get_settings": | |
| return { | |
| "type": "settings", | |
| "sampling_steps": session['client_settings']['sampling_steps'], | |
| "use_rnn": session['client_settings']['use_rnn'] | |
| } | |
| elif data.get("type") == "heartbeat": | |
| return {"type": "heartbeat_response"} | |
| # For regular input data, add to queue and return immediately | |
| # The actual processing will happen asynchronously in the queue processor | |
| await session['input_queue'].put(data) | |
| queue_size = session['input_queue'].qsize() | |
| logger.info(f"Added input to queue for session {session_id}. Queue size: {queue_size}") | |
| # Return a placeholder response - the real response will be sent via WebSocket | |
| return {"type": "queued", "queue_size": queue_size} | |
| async def _process_single_input(self, session_id: str, data: dict): | |
| """Process a single input for a session (the actual processing logic)""" | |
| session = self.session_data[session_id] | |
| # Process regular input | |
| try: | |
| session['frame_num'] += 1 | |
| # Extract input data | |
| x = max(0, min(data.get("x", 0), self.SCREEN_WIDTH - 1)) | |
| y = max(0, min(data.get("y", 0), self.SCREEN_HEIGHT - 1)) | |
| is_left_click = data.get("is_left_click", False) | |
| is_right_click = data.get("is_right_click", False) | |
| keys_down_list = data.get("keys_down", []) | |
| keys_up_list = data.get("keys_up", []) | |
| wheel_delta_x = data.get("wheel_delta_x", 0) | |
| wheel_delta_y = data.get("wheel_delta_y", 0) | |
| # Update keys_down set | |
| for key in keys_down_list: | |
| key = key.lower() | |
| if key in self.KEYMAPPING: | |
| key = self.KEYMAPPING[key] | |
| session['keys_down'].add(key) | |
| for key in keys_up_list: | |
| key = key.lower() | |
| if key in self.KEYMAPPING: | |
| key = self.KEYMAPPING[key] | |
| session['keys_down'].discard(key) | |
| # Handle debug modes | |
| if self.DEBUG_MODE: | |
| logger.info("DEBUG MODE, REMOVING HIDDEN STATES") | |
| session['previous_frame'] = self.padding_image | |
| if self.DEBUG_MODE_2: | |
| if session['frame_num'] > self.NUM_MAX_FRAMES-1: | |
| logger.info("DEBUG MODE_2, REMOVING HIDDEN STATES") | |
| session['previous_frame'] = self.padding_image | |
| session['frame_num'] = 0 | |
| # Prepare model inputs | |
| inputs = self.prepare_model_inputs( | |
| session['previous_frame'], | |
| session['hidden_states'], | |
| x, y, is_right_click, is_left_click, | |
| list(session['keys_down']), | |
| session['frame_num'] | |
| ) | |
| # Log the input data being processed | |
| logger.info(f"Processing frame {session['frame_num']} for session {session_id}: " | |
| f"pos=({x},{y}), clicks=(L:{is_left_click},R:{is_right_click}), " | |
| f"keys_down={keys_down_list}, keys_up={keys_up_list}, " | |
| f"wheel=({wheel_delta_x},{wheel_delta_y})") | |
| # Process frame | |
| sample_latent, sample_img, hidden_states, timing_info = await self.process_frame( | |
| inputs, | |
| use_rnn=session['client_settings']['use_rnn'], | |
| num_sampling_steps=session['client_settings']['sampling_steps'] | |
| ) | |
| # Update session state | |
| session['previous_frame'] = sample_latent | |
| session['hidden_states'] = hidden_states | |
| # Convert image to base64 | |
| img = Image.fromarray(sample_img) | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| # Log timing | |
| logger.info(f"Frame {session['frame_num']} processed in {timing_info['total']:.4f}s (FPS: {1.0/timing_info['total']:.2f})") | |
| # Log the interaction using the stored log_session_id | |
| log_session_id = session.get('log_session_id', session_id) # Fallback to session_id if not found | |
| log_interaction(log_session_id, data, generated_frame=sample_img) | |
| # Send result back to dispatcher | |
| await self._send_result_to_dispatcher(session_id, {"image": img_str}) | |
| except Exception as e: | |
| logger.error(f"Error processing input for session {session_id}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| await self._send_result_to_dispatcher(session_id, {"type": "error", "message": str(e)}) | |
| async def _send_result_to_dispatcher(self, session_id: str, result: dict): | |
| """Send processing result back to dispatcher""" | |
| try: | |
| async with aiohttp.ClientSession() as client_session: | |
| await client_session.post(f"{self.dispatcher_url}/worker_result", json={ | |
| "session_id": session_id, | |
| "worker_id": self.worker_id, | |
| "result": result | |
| }) | |
| except Exception as e: | |
| logger.error(f"Failed to send result to dispatcher: {e}") | |
| # FastAPI app for the worker | |
| app = FastAPI() | |
| # Global worker instance | |
| worker: Optional[GPUWorker] = None | |
| def log_interaction(log_session_id, data, generated_frame=None, is_end_of_session=False, is_reset=False): | |
| """Log user interaction and optionally the generated frame.""" | |
| timestamp = time.time() | |
| # Create directory structure if it doesn't exist | |
| os.makedirs("interaction_logs", exist_ok=True) | |
| # Structure the log entry | |
| log_entry = { | |
| "timestamp": timestamp, | |
| "session_id": log_session_id, # Use the time-prefixed session ID | |
| "is_eos": is_end_of_session, | |
| "is_reset": is_reset | |
| } | |
| # Include type if present (for reset, etc.) | |
| if data.get("type"): | |
| log_entry["type"] = data.get("type") | |
| # Only include input data if this isn't just a control message | |
| if not is_end_of_session and not is_reset: | |
| log_entry["inputs"] = { | |
| "x": data.get("x"), | |
| "y": data.get("y"), | |
| "is_left_click": data.get("is_left_click"), | |
| "is_right_click": data.get("is_right_click"), | |
| "keys_down": data.get("keys_down", []), | |
| "keys_up": data.get("keys_up", []), | |
| "wheel_delta_x": data.get("wheel_delta_x", 0), | |
| "wheel_delta_y": data.get("wheel_delta_y", 0), | |
| "is_auto_input": data.get("is_auto_input", False) | |
| } | |
| else: | |
| # For EOS/reset records, just include minimal info | |
| log_entry["inputs"] = None | |
| # Use the time-prefixed session ID for the filename (already includes timestamp) | |
| session_file = f"interaction_logs/session_{log_session_id}.jsonl" | |
| with open(session_file, "a") as f: | |
| f.write(json.dumps(log_entry) + "\n") | |
| # Optionally save the frame if provided | |
| if generated_frame is not None and not is_end_of_session and not is_reset: | |
| frame_dir = f"interaction_logs/frames_{log_session_id}" | |
| os.makedirs(frame_dir, exist_ok=True) | |
| frame_file = f"{frame_dir}/{timestamp:.6f}.png" | |
| # Save the frame as PNG | |
| Image.fromarray(generated_frame).save(frame_file) | |
| async def process_input_endpoint(request: dict): | |
| """Process input from dispatcher""" | |
| if not worker: | |
| raise HTTPException(status_code=500, detail="Worker not initialized") | |
| session_id = request.get("session_id") | |
| data = request.get("data") | |
| if not session_id or not data: | |
| raise HTTPException(status_code=400, detail="Missing session_id or data") | |
| result = await worker.process_input(session_id, data) | |
| return result | |
| async def init_session_endpoint(request: dict): | |
| """Initialize session from dispatcher with client_id""" | |
| if not worker: | |
| raise HTTPException(status_code=500, detail="Worker not initialized") | |
| session_id = request.get("session_id") | |
| client_id = request.get("client_id") | |
| if not session_id: | |
| raise HTTPException(status_code=400, detail="Missing session_id") | |
| worker.initialize_session(session_id, client_id) | |
| return {"status": "session_initialized"} | |
| async def end_session_endpoint(request: dict): | |
| """End session from dispatcher""" | |
| if not worker: | |
| raise HTTPException(status_code=500, detail="Worker not initialized") | |
| session_id = request.get("session_id") | |
| if not session_id: | |
| raise HTTPException(status_code=400, detail="Missing session_id") | |
| worker.end_session(session_id) | |
| return {"status": "session_ended"} | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "worker_id": worker.worker_id if worker else None, | |
| "gpu_id": worker.gpu_id if worker else None, | |
| "current_session": worker.current_session if worker else None | |
| } | |
| async def startup_worker(gpu_id: int, dispatcher_url: str): | |
| """Initialize the worker""" | |
| global worker | |
| worker = GPUWorker(gpu_id, dispatcher_url) | |
| # Register with dispatcher | |
| await worker.register_with_dispatcher() | |
| # Start ping task | |
| asyncio.create_task(worker.ping_dispatcher()) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Parse command line arguments | |
| parser = argparse.ArgumentParser(description="GPU Worker for Neural OS") | |
| parser.add_argument("--gpu-id", type=int, required=True, help="GPU ID to use") | |
| parser.add_argument("--dispatcher-url", type=str, default="http://localhost:8000", help="Dispatcher URL") | |
| args = parser.parse_args() | |
| # Calculate port based on GPU ID | |
| port = 8001 + args.gpu_id | |
| async def startup_event(): | |
| await startup_worker(args.gpu_id, args.dispatcher_url) | |
| logger.info(f"Starting worker on GPU {args.gpu_id}, port {port}") | |
| uvicorn.run(app, host="0.0.0.0", port=port) |