File size: 1,680 Bytes
c682ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""
Example usage of the Pi-0 Bolt Nut Sort model
"""

from openpi.policies import policy_config
from openpi.training import config
import numpy as np

def load_model(checkpoint_path: str):
    """Load the Pi-0 bolt nut sort model."""
    train_config = config.get_config("pi0_bns")
    
    policy = policy_config.create_trained_policy(
        train_config,
        checkpoint_path,
        default_prompt="sort the bolts and the nuts into separate baskets"
    )
    
    return policy

def create_observation(images, joint_positions):
    """Create observation dict for the model."""
    return {
        "images": {
            "cam_high": images["high"],  # [224, 224, 3] uint8
            "cam_left_wrist": images["left_wrist"],  # [224, 224, 3] uint8  
            "cam_right_wrist": images["right_wrist"],  # [224, 224, 3] uint8
        },
        "state": joint_positions,  # [14] float32 
        "prompt": "sort the bolts and the nuts into separate baskets"
    }

# Example usage
if __name__ == "__main__":
    # Load model
    policy = load_model("./checkpoint")
    
    # Create dummy observation
    images = {
        "high": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8),
        "left_wrist": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8),
        "right_wrist": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8),
    }
    joint_positions = np.random.randn(14).astype(np.float32)
    
    obs = create_observation(images, joint_positions)
    
    # Get actions
    result = policy.infer(obs)
    actions = result["actions"]  # [50, 14] - 50 timesteps of 14-DoF actions
    
    print(f"Generated actions shape: {actions.shape}")