import logging import os import time import hydra import torch import numpy as np import pandas as pd import wandb import matplotlib.pyplot as plt from torch.func import vmap from tqdm import tqdm from omegaconf import OmegaConf from marinegym import init_simulation_app from torchrl.data import CompositeSpec from torchrl.envs.utils import set_exploration_type, ExplorationType from marinegym.utils.torchrl import SyncDataCollector from marinegym.utils.torchrl.transforms import ( FromMultiDiscreteAction, FromDiscreteAction, ravel_composite, AttitudeController, RateController, ) from marinegym.utils.wandb import init_wandb from marinegym.utils.torchrl import RenderCallback, EpisodeStats from marinegym.learning import ALGOS from setproctitle import setproctitle from torchrl.envs.transforms import TransformedEnv, InitTracker, Compose from pathlib import Path @hydra.main(version_base=None, config_path=".", config_name="train") def main(cfg): OmegaConf.register_new_resolver("eval", eval) OmegaConf.resolve(cfg) OmegaConf.set_struct(cfg, False) simulation_app = init_simulation_app(cfg) run = init_wandb(cfg) setproctitle(run.name) print(OmegaConf.to_yaml(cfg)) from marinegym.envs.isaac_env import IsaacEnv env_class = IsaacEnv.REGISTRY[cfg.task.name] base_env = env_class(cfg, headless=cfg.headless) transforms = [InitTracker()] # a CompositeSpec is by default processed by a entity-based encoder # ravel it to use a MLP encoder instead if cfg.task.get("ravel_obs", False): transform = ravel_composite(base_env.observation_spec, ("agents", "observation")) transforms.append(transform) if cfg.task.get("ravel_obs_central", False): transform = ravel_composite(base_env.observation_spec, ("agents", "observation_central")) transforms.append(transform) # optionally discretize the action space or use a controller action_transform: str = cfg.task.get("action_transform", None) if action_transform is not None: if action_transform.startswith("multidiscrete"): nbins = int(action_transform.split(":")[1]) transform = FromMultiDiscreteAction(nbins=nbins) transforms.append(transform) elif action_transform.startswith("discrete"): nbins = int(action_transform.split(":")[1]) transform = FromDiscreteAction(nbins=nbins) transforms.append(transform) else: raise NotImplementedError(f"Unknown action transform: {action_transform}") env = TransformedEnv(base_env, Compose(*transforms)).train() env.set_seed(cfg.seed) try: policy = ALGOS[cfg.algo.name.lower()]( cfg.algo, env.observation_spec, env.action_spec, env.reward_spec, device=base_env.device ) except KeyError: raise NotImplementedError(f"Unknown algorithm: {cfg.algo.name}") frames_per_batch = env.num_envs * int(cfg.algo.train_every) total_frames = cfg.get("total_frames", -1) // frames_per_batch * frames_per_batch max_iters = cfg.get("max_iters", -1) eval_interval = cfg.get("eval_interval", -1) save_interval = cfg.get("save_interval", -1) stats_keys = [ k for k in base_env.observation_spec.keys(True, True) if isinstance(k, tuple) and k[0]=="stats" ] episode_stats = EpisodeStats(stats_keys) collector = SyncDataCollector( env, policy=policy, frames_per_batch=frames_per_batch, total_frames=total_frames, device=cfg.sim.device, return_same_td=True, ) @torch.no_grad() def evaluate( seed: int=0, exploration_type: ExplorationType=ExplorationType.MODE ): base_env.enable_render(True) base_env.eval() env.eval() env.set_seed(seed) render_callback = RenderCallback(interval=2) with set_exploration_type(exploration_type): trajs = env.rollout( max_steps=base_env.max_episode_length, policy=policy, callback=render_callback, auto_reset=True, break_when_any_done=False, return_contiguous=False, ) base_env.enable_render(not cfg.headless) env.reset() done = trajs.get(("next", "done")) first_done = torch.argmax(done.long(), dim=1).cpu() def take_first_episode(tensor: torch.Tensor): indices = first_done.reshape(first_done.shape+(1,)*(tensor.ndim-2)) return torch.take_along_dim(tensor, indices, dim=1).reshape(-1) traj_stats = { k: take_first_episode(v) for k, v in trajs[("next", "stats")].cpu().items() } info = { "eval/stats." + k: torch.mean(v.float()).item() for k, v in traj_stats.items() } # log video info["recording"] = wandb.Video( render_callback.get_video_array(axes="t c h w"), fps=0.5 / (cfg.sim.dt * cfg.sim.substeps), format="mp4" ) # log distributions # df = pd.DataFrame(traj_stats) # table = wandb.Table(dataframe=df) # info["eval/return"] = wandb.plot.histogram(table, "return") # info["eval/episode_len"] = wandb.plot.histogram(table, "episode_len") return info pbar = tqdm(collector, total=total_frames//frames_per_batch) env.train() for i, data in enumerate(pbar): info = {"env_frames": collector._frames, "rollout_fps": collector._fps} episode_stats.add(data.to_tensordict()) if len(episode_stats) >= base_env.num_envs: stats = { "train/" + (".".join(k) if isinstance(k, tuple) else k): torch.mean(v.float()).item() for k, v in episode_stats.pop().items(True, True) } info.update(stats) info.update(policy.train_op(data.to_tensordict())) if eval_interval > 0 and i % eval_interval == 0: logging.info(f"Eval at {collector._frames} steps.") info.update(evaluate()) env.train() base_env.train() if save_interval > 0 and i % save_interval == 0: try: ckpt_path = os.path.join(run.dir, f"checkpoint_{collector._frames}.pt") torch.save(policy.state_dict(), ckpt_path) logging.info(f"Saved checkpoint to {str(ckpt_path)}") except AttributeError: logging.warning(f"Policy {policy} does not implement `.state_dict()`") run.log(info) print(OmegaConf.to_yaml({k: v for k, v in info.items() if isinstance(v, float)})) pbar.set_postfix({"rollout_fps": collector._fps, "frames": collector._frames}) if max_iters > 0 and i >= max_iters - 1: break # logging.info(f"Final Eval at {collector._frames} steps.") # info = {"env_frames": collector._frames} # info.update(evaluate()) # run.log(info) try: ckpt_path = os.path.join(run.dir, "checkpoint_final.pt") torch.save(policy.state_dict(), ckpt_path) model_artifact = wandb.Artifact( f"{cfg.task.name}-{cfg.algo.name.lower()}", type="model", description=f"{cfg.task.name}-{cfg.algo.name.lower()}", metadata=dict(cfg)) model_artifact.add_file(ckpt_path) wandb.save(ckpt_path) run.log_artifact(model_artifact) logging.info(f"Saved checkpoint to {str(ckpt_path)}") except AttributeError: logging.warning(f"Policy {policy} does not implement `.state_dict()`") # upload model to huggingface if cfg.get("upload_model", False): from marinegym.utils.huggingface import push_to_hub repo_name = f"{cfg.task.name}-{cfg.algo.name.lower()}-seed{cfg.seed}" logging.info(f"Uploading model to HuggingFace: {repo_name}") logging.info(f"Check: {cfg.task.name}") repo_id = f"{cfg.hf_entity}/{repo_name}" if cfg.get("hf_entity") else repo_name # episodic_returns = episode_stats.get("stats/return", []) video_folder = f"{run.dir}/videos" if not os.path.exists(video_folder) or not any(Path(video_folder).glob("*.mp4")): logging.warning(f"Warning: No video found in {video_folder}. Skipping video upload.") video_folder = "" push_to_hub( cfg, [0.], # TODO: fix this repo_id, cfg.algo.name, run.dir, video_folder, create_pr=cfg.get("create_pr", False), private=cfg.get("hf_private", False) ) wandb.finish() simulation_app.close() if __name__ == "__main__": main()