# mario_controller_eval.py import torch import torch.nn as nn import torch.nn.functional as F import gym_super_mario_bros from gym_super_mario_bros.actions import COMPLEX_MOVEMENT, RIGHT_ONLY, SIMPLE_MOVEMENT from nes_py.wrappers import JoypadSpace import numpy as np import time import sys import pygame import os from pygame.locals import * # Import your wrapper function from wrappers import wrap_mario # Initialize pygame for controller visualization pygame.init() class NintendoControllerVisualizer: def __init__(self): self.screen_width = 400 self.screen_height = 300 self.screen = pygame.display.set_mode((self.screen_width, self.screen_height)) pygame.display.set_caption("Mario Controller - Live Input") # Colors self.bg_color = (40, 40, 40) self.controller_color = (60, 60, 80) self.button_color = (180, 180, 200) self.pressed_color = (255, 255, 0) # Yellow for pressed buttons self.text_color = (255, 255, 255) self.dpad_color = (100, 100, 120) # Font self.font = pygame.font.Font(None, 24) self.small_font = pygame.font.Font(None, 18) # Controller layout coordinates self.controller_rect = pygame.Rect(50, 50, 300, 200) # Button positions (relative to controller_rect) self.buttons = { 'A': (250, 100), 'B': (200, 120), 'START': (175, 150), 'SELECT': (125, 150), 'UP': (80, 80), 'DOWN': (80, 120), 'LEFT': (60, 100), 'RIGHT': (100, 100) } # Button radii self.button_radius = 15 self.dpad_radius = 8 # Current pressed state self.pressed_buttons = set() self.last_action = None self.action_history = [] def update_display(self, action, action_names): """Update the display with current button presses""" self.screen.fill(self.bg_color) # Draw controller background pygame.draw.rect(self.screen, self.controller_color, self.controller_rect, border_radius=20) # Draw buttons for button_name, (x, y) in self.buttons.items(): abs_x = self.controller_rect.x + x abs_y = self.controller_rect.y + y # Determine if button is pressed is_pressed = button_name in self.pressed_buttons # Draw button if button_name in ['UP', 'DOWN', 'LEFT', 'RIGHT']: # D-pad buttons color = self.pressed_color if is_pressed else self.dpad_color pygame.draw.circle(self.screen, color, (abs_x, abs_y), self.dpad_radius) else: # Regular buttons color = self.pressed_color if is_pressed else self.button_color pygame.draw.circle(self.screen, color, (abs_x, abs_y), self.button_radius) # Draw button label text = self.small_font.render(button_name, True, self.text_color) text_rect = text.get_rect(center=(abs_x, abs_y)) self.screen.blit(text, text_rect) # Draw action info action_text = self.font.render(f"Action: {action}", True, self.text_color) action_name_text = self.font.render(f"Input: {action_names[action]}", True, (255, 200, 100)) self.screen.blit(action_text, (20, 20)) self.screen.blit(action_name_text, (20, 260)) pygame.display.flip() def map_action_to_buttons(self, action, action_space_type='COMPLEX'): """Map action number to button presses""" self.pressed_buttons.clear() if action_space_type == 'RIGHT_ONLY': # RIGHT_ONLY has 5 actions actions = [ ['NOOP'], ['RIGHT'], ['RIGHT', 'A'], ['RIGHT', 'B'], ['RIGHT', 'A', 'B'] ] if 0 <= action < len(actions): self.pressed_buttons.update(actions[action]) elif action_space_type == 'SIMPLE': # SIMPLE_MOVEMENT has 7 actions actions = [ ['NOOP'], ['RIGHT'], ['RIGHT', 'A'], ['RIGHT', 'B'], ['RIGHT', 'A', 'B'], ['A'], ['LEFT'] ] if 0 <= action < len(actions): self.pressed_buttons.update(actions[action]) else: # COMPLEX_MOVEMENT (12 actions) actions = [ ['NOOP'], ['RIGHT'], ['RIGHT', 'A'], ['RIGHT', 'B'], ['RIGHT', 'A', 'B'], ['A'], ['LEFT'], ['LEFT', 'A'], ['LEFT', 'B'], ['LEFT', 'A', 'B'], ['DOWN'], ['UP'] ] if 0 <= action < len(actions): self.pressed_buttons.update(actions[action]) # Map button names to display names button_mapping = { 'NOOP': 'NOOP', 'RIGHT': 'RIGHT', 'LEFT': 'LEFT', 'UP': 'UP', 'DOWN': 'DOWN', 'A': 'A', 'B': 'B', 'START': 'START', 'SELECT': 'SELECT' } self.pressed_buttons = {button_mapping.get(btn, btn) for btn in self.pressed_buttons} self.last_action = action self.action_history.append(action) return actions[action] if 0 <= action < len(actions) else ['UNKNOWN'] class CompatibleDQN(nn.Module): def __init__(self, input_shape, num_actions): super(CompatibleDQN, self).__init__() self.input_shape = input_shape self.num_actions = num_actions # Try to match the checkpoint architecture based on the errors # From the errors, it seems to have both regular and dueling components self.layer1 = nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4) self.layer2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) self.layer3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) # Also create sequential layers for compatibility self.seq = nn.Sequential( self.layer1, nn.ReLU(), self.layer2, nn.ReLU(), self.layer3, nn.ReLU() ) # Regular DQN head self.fc = nn.Linear(self._feature_size(), 512) self.q = nn.Linear(512, num_actions) # Dueling DQN heads (will be ignored if not in checkpoint) self.v = nn.Linear(512, 1) self.a = nn.Linear(512, num_actions) def _feature_size(self): # Calculate feature size x = self.seq[0](torch.zeros(1, *self.input_shape)) x = self.seq[1](x) x = self.seq[2](x) x = self.seq[3](x) x = self.seq[4](x) x = self.seq[5](x) return x.view(1, -1).size(1) def forward(self, x): x = self.seq(x) x = x.view(x.size(0), -1) x = F.relu(self.fc(x)) # Try both output methods if hasattr(self, 'q') and self.q.weight.shape[0] == self.num_actions: return self.q(x) else: # Dueling DQN output value = self.v(x) advantage = self.a(x) return value + advantage - advantage.mean(1, keepdim=True) def load_checkpoint_compatible(model, checkpoint_path, device): """Load checkpoint with flexible key mapping""" checkpoint = torch.load(checkpoint_path, map_location=device) state_dict = checkpoint # Create a mapping for different layer naming schemes new_state_dict = {} # Map all possible key combinations key_mapping = { # Regular DQN to our structure 'layer1.weight': 'layer1.weight', 'layer1.bias': 'layer1.bias', 'layer2.weight': 'layer2.weight', 'layer2.bias': 'layer2.bias', 'layer3.weight': 'layer3.weight', 'layer3.bias': 'layer3.bias', 'fc.weight': 'fc.weight', 'fc.bias': 'fc.bias', 'q.weight': 'q.weight', 'q.bias': 'q.bias', # Sequential mappings 'seq.0.weight': 'layer1.weight', 'seq.0.bias': 'layer1.bias', 'seq.1.weight': 'layer2.weight', 'seq.1.bias': 'layer2.bias', 'seq.2.weight': 'layer3.weight', 'seq.2.bias': 'layer3.bias', # Dueling DQN mappings 'v.weight': 'v.weight', 'v.bias': 'v.bias', } # Apply mapping for old_key, new_key in key_mapping.items(): if old_key in state_dict: new_state_dict[new_key] = state_dict[old_key] # Load whatever we can match model.load_state_dict(new_state_dict, strict=False) return model def arange(s): if not isinstance(s, np.ndarray): s = np.array(s) assert len(s.shape) == 3 ret = np.transpose(s, (2, 0, 1)) return np.expand_dims(ret, 0) def get_action_space_type(num_actions): """Determine action space type based on number of actions""" if num_actions == 5: return 'RIGHT_ONLY' elif num_actions == 7: return 'SIMPLE' elif num_actions == 12: return 'COMPLEX' else: return 'CUSTOM' def get_action_names(action_space_type): """Get human-readable action names for display""" if action_space_type == 'RIGHT_ONLY': return [ "NOOP", "RIGHT", "RIGHT + A", "RIGHT + B", "RIGHT + A + B" ] elif action_space_type == 'SIMPLE': return [ "NOOP", "RIGHT", "RIGHT + A", "RIGHT + B", "RIGHT + A + B", "A", "LEFT" ] else: # COMPLEX_MOVEMENT return [ "NOOP", "RIGHT", "RIGHT + A", "RIGHT + B", "RIGHT + A + B", "A", "LEFT", "LEFT + A", "LEFT + B", "LEFT + A + B", "DOWN", "UP" ] def run_evaluation_with_controller(ckpt_path, num_actions=5): """Run evaluation with visual controller overlay""" print(f"šŸŽ® Loading checkpoint: {ckpt_path}") n_frame = 4 env = gym_super_mario_bros.make("SuperMarioBros-v0") # Determine action space based on number of actions if num_actions == 5: env = JoypadSpace(env, RIGHT_ONLY) action_space_type = 'RIGHT_ONLY' elif num_actions == 7: env = JoypadSpace(env, SIMPLE_MOVEMENT) action_space_type = 'SIMPLE' else: env = JoypadSpace(env, COMPLEX_MOVEMENT) action_space_type = 'COMPLEX' env = wrap_mario(env) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"šŸŽÆ Using device: {device}, Action space: {num_actions} ({action_space_type})") # Initialize controller visualizer controller = NintendoControllerVisualizer() action_names = get_action_names(action_space_type) # Use compatible model q = CompatibleDQN(input_shape=(n_frame, 84, 84), num_actions=num_actions).to(device) # Load with flexible loading q = load_checkpoint_compatible(q, ckpt_path, device) q.eval() print("āœ… Checkpoint loaded successfully!") print("šŸŽ® Controller visualization started - Yellow buttons show current input") total_score = 0.0 done = False s = arange(env.reset()) step_count = 0 start_time = time.time() print("šŸš€ Starting evaluation with live controller visualization...") try: while not done: # Handle pygame events (to keep window responsive) for event in pygame.event.get(): if event.type == QUIT: done = True elif event.type == KEYDOWN: if event.key == K_ESCAPE: done = True if done: break # Get action from model s_tensor = torch.from_numpy(s).float().to(device) with torch.no_grad(): q_values = q(s_tensor) action = torch.argmax(q_values).item() # Update controller display button_presses = controller.map_action_to_buttons(action, action_space_type) controller.update_display(action, action_names) # Execute action in environment s_prime, reward, done, info = env.step(action) s_prime = arange(s_prime) total_score += reward s = s_prime step_count += 1 # Print progress occasionally if step_count % 50 == 0: current_time = time.time() - start_time print(f"ā±ļø Step {step_count}, Score: {total_score:.1f}, Time: {current_time:.1f}s") print(f" šŸŽ® Last action: {action} - {action_names[action]}") # Small delay for visualization time.sleep(0.05) # Render the game (optional - can be commented out for better performance) env.render() except KeyboardInterrupt: print("\nā¹ļø Evaluation interrupted by user") finally: # Final results stage = info.get('stage', 1) world = info.get('world', 1) total_time = time.time() - start_time print("\n" + "="*50) print("šŸŽÆ EVALUATION COMPLETE") print(f"šŸ“Š Checkpoint: {ckpt_path}") print(f"šŸ† Total score: {total_score:.1f}") print(f"šŸŒ World: {world}-{stage}") print(f"šŸ‘£ Steps: {step_count}") print(f"ā±ļø Time: {total_time:.1f}s") print(f"⚔ Steps/sec: {step_count/total_time:.1f}") print(f"šŸŽ® Action space: {action_space_type} ({num_actions} actions)") print("="*50) # Save action history history_filename = f"action_history_{os.path.basename(ckpt_path).replace('.pth', '')}.txt" with open(history_filename, 'w') as f: f.write(f"Checkpoint: {ckpt_path}\n") f.write(f"Total steps: {step_count}\n") f.write(f"Final score: {total_score}\n") f.write(f"Action space: {action_space_type}\n\n") f.write("Action History:\n") for i, action in enumerate(controller.action_history): f.write(f"Step {i+1}: Action {action} - {action_names[action]}\n") print(f"šŸ’¾ Action history saved to: {history_filename}") env.close() pygame.quit() def main(): """Main function to run evaluation with controller visualization""" if len(sys.argv) < 2: print("Usage: python mario_controller_eval.py [num_actions]") print("Available checkpoints:") # List available .pth files pth_files = [f for f in os.listdir('.') if f.endswith('.pth')] for i, file in enumerate(pth_files): print(f" {i+1}. {file}") if pth_files: choice = input(f"\nSelect checkpoint (1-{len(pth_files)} or path): ").strip() if choice.isdigit() and 1 <= int(choice) <= len(pth_files): ckpt_path = pth_files[int(choice) - 1] else: ckpt_path = choice else: ckpt_path = input("Enter checkpoint path: ").strip() else: ckpt_path = sys.argv[1] # Get number of actions (default to 5 for RIGHT_ONLY) num_actions = 5 if len(sys.argv) > 2: try: num_actions = int(sys.argv[2]) except ValueError: print("āš ļø Invalid number of actions, using default (5)") # Validate checkpoint exists if not os.path.exists(ckpt_path): print(f"āŒ Checkpoint not found: {ckpt_path}") return print("šŸŽ® Mario AI Evaluation with Live Controller Visualization") print(" - Yellow buttons show current input") print(" - Press ESC or close window to stop") print(" - Action history will be saved to file") print() run_evaluation_with_controller(ckpt_path, num_actions) if __name__ == "__main__": main()