import torch import torch.nn as nn import torch.nn.functional as F import gym_super_mario_bros from gym_super_mario_bros.actions import COMPLEX_MOVEMENT from nes_py.wrappers import JoypadSpace import numpy as np import time import sys # Import your wrapper function from wrappers import wrap_mario class CompatibleDQN(nn.Module): def __init__(self, input_shape, num_actions): super(CompatibleDQN, self).__init__() self.input_shape = input_shape self.num_actions = num_actions # Try to match the checkpoint architecture based on the errors # From the errors, it seems to have both regular and dueling components self.layer1 = nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4) self.layer2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) self.layer3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) # Also create sequential layers for compatibility self.seq = nn.Sequential( self.layer1, nn.ReLU(), self.layer2, nn.ReLU(), self.layer3, nn.ReLU() ) # Regular DQN head self.fc = nn.Linear(self._feature_size(), 512) self.q = nn.Linear(512, num_actions) # Dueling DQN heads (will be ignored if not in checkpoint) self.v = nn.Linear(512, 1) self.a = nn.Linear(512, num_actions) def _feature_size(self): # Calculate feature size x = self.seq[0](torch.zeros(1, *self.input_shape)) x = self.seq[1](x) x = self.seq[2](x) x = self.seq[3](x) x = self.seq[4](x) x = self.seq[5](x) return x.view(1, -1).size(1) def forward(self, x): x = self.seq(x) x = x.view(x.size(0), -1) x = F.relu(self.fc(x)) # Try both output methods if hasattr(self, 'q') and self.q.weight.shape[0] == self.num_actions: return self.q(x) else: # Dueling DQN output value = self.v(x) advantage = self.a(x) return value + advantage - advantage.mean(1, keepdim=True) def load_checkpoint_compatible(model, checkpoint_path, device): """Load checkpoint with flexible key mapping""" checkpoint = torch.load(checkpoint_path, map_location=device) state_dict = checkpoint # Create a mapping for different layer naming schemes new_state_dict = {} # Map all possible key combinations key_mapping = { # Regular DQN to our structure 'layer1.weight': 'layer1.weight', 'layer1.bias': 'layer1.bias', 'layer2.weight': 'layer2.weight', 'layer2.bias': 'layer2.bias', 'layer3.weight': 'layer3.weight', 'layer3.bias': 'layer3.bias', 'fc.weight': 'fc.weight', 'fc.bias': 'fc.bias', 'q.weight': 'q.weight', 'q.bias': 'q.bias', # Sequential mappings 'seq.0.weight': 'layer1.weight', 'seq.0.bias': 'layer1.bias', 'seq.1.weight': 'layer2.weight', 'seq.1.bias': 'layer2.bias', 'seq.2.weight': 'layer3.weight', 'seq.2.bias': 'layer3.bias', # Dueling DQN mappings 'v.weight': 'v.weight', 'v.bias': 'v.bias', } # Apply mapping for old_key, new_key in key_mapping.items(): if old_key in state_dict: new_state_dict[new_key] = state_dict[old_key] # Load whatever we can match model.load_state_dict(new_state_dict, strict=False) return model def arange(s): if not isinstance(s, np.ndarray): s = np.array(s) assert len(s.shape) == 3 ret = np.transpose(s, (2, 0, 1)) return np.expand_dims(ret, 0) if __name__ == "__main__": ckpt_path = sys.argv[1] if len(sys.argv) > 1 else "mario_q_target.pth" print(f"Load ckpt from {ckpt_path}") n_frame = 4 env = gym_super_mario_bros.make("SuperMarioBros-v0") env = JoypadSpace(env, COMPLEX_MOVEMENT) env = wrap_mario(env) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Use compatible model num_actions = 5 q = CompatibleDQN(input_shape=(n_frame, 84, 84), num_actions=num_actions).to(device) # Load with flexible loading q = load_checkpoint_compatible(q, ckpt_path, device) q.eval() print("Checkpoint loaded successfully!") total_score = 0.0 done = False s = arange(env.reset()) i = 0 print("Starting evaluation...") while not done: env.render() # Convert to tensor and get action s_tensor = torch.from_numpy(s).float().to(device) with torch.no_grad(): q_values = q(s_tensor) a = torch.argmax(q_values).item() s_prime, r, done, info = env.step(a) s_prime = arange(s_prime) total_score += r s = s_prime i += 1 if i % 100 == 0: print(f"Step {i}, Score: {total_score}") time.sleep(0.001) stage = info.get('stage', 1) world = info.get('world', 1) print(f"Total score: {total_score} | World: {world}-{stage} | Steps: {i}") env.close()