|
|
import logging |
|
|
import os |
|
|
import time |
|
|
|
|
|
import hydra |
|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import wandb |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from torch.func import vmap |
|
|
from tqdm import tqdm |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
from marinegym import init_simulation_app |
|
|
from torchrl.data import CompositeSpec |
|
|
from torchrl.envs.utils import set_exploration_type, ExplorationType |
|
|
from marinegym.utils.torchrl import SyncDataCollector |
|
|
from marinegym.utils.torchrl.transforms import ( |
|
|
FromMultiDiscreteAction, |
|
|
FromDiscreteAction, |
|
|
ravel_composite, |
|
|
AttitudeController, |
|
|
RateController, |
|
|
) |
|
|
from marinegym.utils.wandb import init_wandb |
|
|
from marinegym.utils.torchrl import RenderCallback, EpisodeStats |
|
|
from marinegym.learning import ALGOS |
|
|
|
|
|
from setproctitle import setproctitle |
|
|
from torchrl.envs.transforms import TransformedEnv, InitTracker, Compose |
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
@hydra.main(version_base=None, config_path=".", config_name="train") |
|
|
def main(cfg): |
|
|
OmegaConf.register_new_resolver("eval", eval) |
|
|
OmegaConf.resolve(cfg) |
|
|
OmegaConf.set_struct(cfg, False) |
|
|
simulation_app = init_simulation_app(cfg) |
|
|
run = init_wandb(cfg) |
|
|
setproctitle(run.name) |
|
|
print(OmegaConf.to_yaml(cfg)) |
|
|
|
|
|
from marinegym.envs.isaac_env import IsaacEnv |
|
|
|
|
|
env_class = IsaacEnv.REGISTRY[cfg.task.name] |
|
|
base_env = env_class(cfg, headless=cfg.headless) |
|
|
|
|
|
transforms = [InitTracker()] |
|
|
|
|
|
|
|
|
|
|
|
if cfg.task.get("ravel_obs", False): |
|
|
transform = ravel_composite(base_env.observation_spec, ("agents", "observation")) |
|
|
transforms.append(transform) |
|
|
if cfg.task.get("ravel_obs_central", False): |
|
|
transform = ravel_composite(base_env.observation_spec, ("agents", "observation_central")) |
|
|
transforms.append(transform) |
|
|
|
|
|
|
|
|
action_transform: str = cfg.task.get("action_transform", None) |
|
|
if action_transform is not None: |
|
|
if action_transform.startswith("multidiscrete"): |
|
|
nbins = int(action_transform.split(":")[1]) |
|
|
transform = FromMultiDiscreteAction(nbins=nbins) |
|
|
transforms.append(transform) |
|
|
elif action_transform.startswith("discrete"): |
|
|
nbins = int(action_transform.split(":")[1]) |
|
|
transform = FromDiscreteAction(nbins=nbins) |
|
|
transforms.append(transform) |
|
|
else: |
|
|
raise NotImplementedError(f"Unknown action transform: {action_transform}") |
|
|
|
|
|
env = TransformedEnv(base_env, Compose(*transforms)).train() |
|
|
env.set_seed(cfg.seed) |
|
|
|
|
|
try: |
|
|
policy = ALGOS[cfg.algo.name.lower()]( |
|
|
cfg.algo, |
|
|
env.observation_spec, |
|
|
env.action_spec, |
|
|
env.reward_spec, |
|
|
device=base_env.device |
|
|
) |
|
|
except KeyError: |
|
|
raise NotImplementedError(f"Unknown algorithm: {cfg.algo.name}") |
|
|
|
|
|
frames_per_batch = env.num_envs * int(cfg.algo.train_every) |
|
|
total_frames = cfg.get("total_frames", -1) // frames_per_batch * frames_per_batch |
|
|
max_iters = cfg.get("max_iters", -1) |
|
|
eval_interval = cfg.get("eval_interval", -1) |
|
|
save_interval = cfg.get("save_interval", -1) |
|
|
|
|
|
stats_keys = [ |
|
|
k for k in base_env.observation_spec.keys(True, True) |
|
|
if isinstance(k, tuple) and k[0]=="stats" |
|
|
] |
|
|
episode_stats = EpisodeStats(stats_keys) |
|
|
collector = SyncDataCollector( |
|
|
env, |
|
|
policy=policy, |
|
|
frames_per_batch=frames_per_batch, |
|
|
total_frames=total_frames, |
|
|
device=cfg.sim.device, |
|
|
return_same_td=True, |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate( |
|
|
seed: int=0, |
|
|
exploration_type: ExplorationType=ExplorationType.MODE |
|
|
): |
|
|
|
|
|
base_env.enable_render(True) |
|
|
base_env.eval() |
|
|
env.eval() |
|
|
env.set_seed(seed) |
|
|
|
|
|
render_callback = RenderCallback(interval=2) |
|
|
|
|
|
with set_exploration_type(exploration_type): |
|
|
trajs = env.rollout( |
|
|
max_steps=base_env.max_episode_length, |
|
|
policy=policy, |
|
|
callback=render_callback, |
|
|
auto_reset=True, |
|
|
break_when_any_done=False, |
|
|
return_contiguous=False, |
|
|
) |
|
|
base_env.enable_render(not cfg.headless) |
|
|
env.reset() |
|
|
|
|
|
done = trajs.get(("next", "done")) |
|
|
first_done = torch.argmax(done.long(), dim=1).cpu() |
|
|
|
|
|
def take_first_episode(tensor: torch.Tensor): |
|
|
indices = first_done.reshape(first_done.shape+(1,)*(tensor.ndim-2)) |
|
|
return torch.take_along_dim(tensor, indices, dim=1).reshape(-1) |
|
|
|
|
|
traj_stats = { |
|
|
k: take_first_episode(v) |
|
|
for k, v in trajs[("next", "stats")].cpu().items() |
|
|
} |
|
|
|
|
|
info = { |
|
|
"eval/stats." + k: torch.mean(v.float()).item() |
|
|
for k, v in traj_stats.items() |
|
|
} |
|
|
|
|
|
|
|
|
info["recording"] = wandb.Video( |
|
|
render_callback.get_video_array(axes="t c h w"), |
|
|
fps=0.5 / (cfg.sim.dt * cfg.sim.substeps), |
|
|
format="mp4" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return info |
|
|
|
|
|
pbar = tqdm(collector, total=total_frames//frames_per_batch) |
|
|
env.train() |
|
|
for i, data in enumerate(pbar): |
|
|
info = {"env_frames": collector._frames, "rollout_fps": collector._fps} |
|
|
episode_stats.add(data.to_tensordict()) |
|
|
|
|
|
if len(episode_stats) >= base_env.num_envs: |
|
|
stats = { |
|
|
"train/" + (".".join(k) if isinstance(k, tuple) else k): torch.mean(v.float()).item() |
|
|
for k, v in episode_stats.pop().items(True, True) |
|
|
} |
|
|
info.update(stats) |
|
|
|
|
|
info.update(policy.train_op(data.to_tensordict())) |
|
|
|
|
|
if eval_interval > 0 and i % eval_interval == 0: |
|
|
logging.info(f"Eval at {collector._frames} steps.") |
|
|
info.update(evaluate()) |
|
|
env.train() |
|
|
base_env.train() |
|
|
|
|
|
if save_interval > 0 and i % save_interval == 0: |
|
|
try: |
|
|
ckpt_path = os.path.join(run.dir, f"checkpoint_{collector._frames}.pt") |
|
|
torch.save(policy.state_dict(), ckpt_path) |
|
|
logging.info(f"Saved checkpoint to {str(ckpt_path)}") |
|
|
except AttributeError: |
|
|
logging.warning(f"Policy {policy} does not implement `.state_dict()`") |
|
|
|
|
|
run.log(info) |
|
|
print(OmegaConf.to_yaml({k: v for k, v in info.items() if isinstance(v, float)})) |
|
|
|
|
|
pbar.set_postfix({"rollout_fps": collector._fps, "frames": collector._frames}) |
|
|
|
|
|
if max_iters > 0 and i >= max_iters - 1: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
ckpt_path = os.path.join(run.dir, "checkpoint_final.pt") |
|
|
torch.save(policy.state_dict(), ckpt_path) |
|
|
|
|
|
model_artifact = wandb.Artifact( |
|
|
f"{cfg.task.name}-{cfg.algo.name.lower()}", |
|
|
type="model", |
|
|
description=f"{cfg.task.name}-{cfg.algo.name.lower()}", |
|
|
metadata=dict(cfg)) |
|
|
|
|
|
model_artifact.add_file(ckpt_path) |
|
|
wandb.save(ckpt_path) |
|
|
run.log_artifact(model_artifact) |
|
|
|
|
|
logging.info(f"Saved checkpoint to {str(ckpt_path)}") |
|
|
except AttributeError: |
|
|
logging.warning(f"Policy {policy} does not implement `.state_dict()`") |
|
|
|
|
|
|
|
|
if cfg.get("upload_model", False): |
|
|
from marinegym.utils.huggingface import push_to_hub |
|
|
|
|
|
repo_name = f"{cfg.task.name}-{cfg.algo.name.lower()}-seed{cfg.seed}" |
|
|
logging.info(f"Uploading model to HuggingFace: {repo_name}") |
|
|
logging.info(f"Check: {cfg.task.name}") |
|
|
repo_id = f"{cfg.hf_entity}/{repo_name}" if cfg.get("hf_entity") else repo_name |
|
|
|
|
|
|
|
|
|
|
|
video_folder = f"{run.dir}/videos" |
|
|
if not os.path.exists(video_folder) or not any(Path(video_folder).glob("*.mp4")): |
|
|
logging.warning(f"Warning: No video found in {video_folder}. Skipping video upload.") |
|
|
video_folder = "" |
|
|
|
|
|
push_to_hub( |
|
|
cfg, |
|
|
[0.], |
|
|
repo_id, |
|
|
cfg.algo.name, |
|
|
run.dir, |
|
|
video_folder, |
|
|
create_pr=cfg.get("create_pr", False), |
|
|
private=cfg.get("hf_private", False) |
|
|
) |
|
|
|
|
|
wandb.finish() |
|
|
|
|
|
simulation_app.close() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|