Spaces:
Running
on
Zero
Running
on
Zero
| # Related third-party imports | |
| from hydra.utils import instantiate, get_original_cwd | |
| import numpy as np | |
| import torch | |
| import tqdm | |
| from easydict import EasyDict as edict | |
| from models.SpaTrackV2.metric import ( | |
| camera_to_rel_deg, calculate_auc, calculate_auc_np, rotation_angle, translation_angle #,camera_to_rel_deg_pair | |
| ) | |
| from models.SpaTrackV2.train_utils import * | |
| # from inference import run_inference | |
| from models.SpaTrackV2.models.utils import vis_result, procrustes_analysis | |
| from models.SpaTrackV2.metric import ( | |
| rotation_angle, camera_to_rel_deg | |
| ) | |
| from models.SpaTrackV2.models.utils import ( | |
| loss_fn, pose_encoding_to_camera, AverageMeter | |
| ) | |
| from torch.utils.checkpoint import checkpoint | |
| from functools import partial | |
| from models.SpaTrackV2.datasets.tartan_utils.evaluate import ATEEvaluator, RPEEvaluator | |
| from models.SpaTrackV2.datasets.tartan_utils.eval_ate_scale import align, plot_traj | |
| from models.vggt.vggt.utils.load_fn import preprocess_image | |
| import wandb | |
| from pathlib import Path | |
| from einops import rearrange | |
| eval_ate = ATEEvaluator() | |
| eval_rpe = RPEEvaluator() | |
| def has_nan_gradients(model): | |
| for param in model.parameters(): | |
| if param.grad is not None and torch.isnan(param.grad).any(): | |
| return True | |
| return False | |
| def adjust_size(height, width, macro_block_size=16): | |
| """ | |
| 调整尺寸以满足宏块大小的要求。 | |
| :param height: 原始高度 | |
| :param width: 原始宽度 | |
| :param macro_block_size: 宏块大小(通常是16) | |
| :return: 调整后的高度和宽度 | |
| """ | |
| new_height = (height + macro_block_size - 1) // macro_block_size * macro_block_size | |
| new_width = (width + macro_block_size - 1) // macro_block_size * macro_block_size | |
| return new_height, new_width | |
| def vis_video_depth(video_depth: torch.tensor, output_path: str = None, | |
| frame_rate: int=60, output_width: int=518, frame_height: int=518, | |
| wandb_hw: int=224): | |
| """ | |
| visualize the video depth. | |
| video_depth: T H W | |
| """ | |
| output_width, frame_height = video_depth.shape[1], video_depth.shape[0] | |
| filt_max = video_depth.max() | |
| video_depth = (video_depth - video_depth.min())/(filt_max-video_depth.min()) * 255.0 | |
| new_height, new_width = None, None | |
| if output_path is None: | |
| vid_rgb = [] | |
| for i in range(video_depth.shape[0]): | |
| depth = video_depth[i].squeeze().detach().cpu().numpy().astype(np.uint8) | |
| depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO) | |
| depth_color = cv2.resize(depth_color, (wandb_hw, wandb_hw)) | |
| depth_color = cv2.cvtColor(depth_color, cv2.COLOR_BGR2RGB) | |
| # adjust the size | |
| if new_height is None or new_width is None: | |
| new_height, new_width = adjust_size(depth_color.shape[0], depth_color.shape[1]) | |
| # resize the frame | |
| resized_depth_color = cv2.resize(depth_color, (new_width, new_height), interpolation=cv2.INTER_LINEAR) | |
| vid_rgb.append(resized_depth_color) | |
| vid_rgb = np.stack(vid_rgb, axis=0) | |
| return vid_rgb | |
| else: | |
| # out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), | |
| # frame_rate, (output_width, frame_height)) | |
| video_depth_np = [] | |
| for i in range(video_depth.shape[0]): | |
| depth = video_depth[i].squeeze().detach().cpu().numpy().astype(np.uint8) | |
| depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO) | |
| # adjust the size | |
| if new_height is None or new_width is None: | |
| new_height, new_width = adjust_size(depth_color.shape[0], depth_color.shape[1]) | |
| # resize the frame | |
| resized_depth_color = cv2.resize(depth_color, (new_width, new_height), interpolation=cv2.INTER_LINEAR) | |
| resized_depth_color = cv2.cvtColor(resized_depth_color, cv2.COLOR_BGR2RGB) | |
| video_depth_np.append(resized_depth_color) | |
| video_depth_np = np.stack(video_depth_np, axis=0) | |
| import imageio | |
| imageio.mimwrite(output_path, video_depth_np, fps=8) | |
| # out.release() | |
| return None | |
| def train_fn( | |
| model, dataloader, cfg, optimizer, lite, scaler, | |
| lr_scheduler, training=True, viz=None, epoch=-1, logger_tf=None, logger=None, | |
| logger_wb=None | |
| ): | |
| ln_depth_glob = AverageMeter("ln_depth_glob", ":.4f") | |
| ln_edge = AverageMeter("ln_edge", ":.4f") | |
| ln_normal = AverageMeter("ln_normal", ":.4f") | |
| ln_cons = AverageMeter("ln_cons", ":.4f") | |
| ln_scale_shift = AverageMeter("ln_scale_shift", ":.4f") | |
| ln_pose = AverageMeter("ln_pose", ":.4f") | |
| ln_msk_l2 = AverageMeter("ln_msk_l2", ":.4f") | |
| rot_err = AverageMeter("rot_error", ":.4f") | |
| trans_err = AverageMeter("tran_error", ":.4f") | |
| if training: | |
| model.train() | |
| # model.base_model.to(model.device) | |
| for step, batch_raw in enumerate(tqdm.tqdm(dataloader, disable=not lite.is_global_zero)): | |
| # get the global step | |
| global_step = epoch * len(dataloader) + step | |
| if training: | |
| optimizer.zero_grad() | |
| # move data to devices | |
| batch = {k: v.to(model.device) for k, v in batch_raw.items() if isinstance(v, torch.Tensor)} | |
| dtype = torch.bfloat16 | |
| # model = model.to(dtype) | |
| annots = { | |
| "poses_gt": batch["pose_enc"].to(dtype), | |
| "depth_gt": batch["depths"].to(dtype), | |
| "metric": True, | |
| "traj_3d": batch["traj_3d"].to(dtype), | |
| "vis": batch["vis"], | |
| "syn_real": batch["syn_real"], | |
| "metric_rel": batch["metric_rel"], | |
| "traj_mat": batch["traj_mat"].to(dtype), | |
| "iters": global_step, | |
| "data_dir": batch_raw["data_dir"], | |
| "intrs": batch["intrs"] | |
| } | |
| if batch["depths"].to(dtype).sum() == 0: | |
| annots.pop("depth_gt"), annots.pop("poses_gt"), annots.pop("traj_mat"), annots.pop("intrs") | |
| annots.update({"custom_vid": True}) | |
| batch["rgbs"] = batch["rgbs"].to(dtype) | |
| with lite.autocast(): | |
| kwargs = {"est_depth": cfg.train.est_depth, | |
| "est_pose": cfg.train.est_pose, | |
| "stage": cfg.train.stage} | |
| B, T, C, H, W = batch["rgbs"].shape | |
| video_tensor = batch["rgbs"] | |
| if hasattr(model, 'module'): | |
| ret = model.module(video_tensor/255.0, annots, **kwargs) | |
| else: | |
| ret = model(video_tensor/255.0, annots, **kwargs) | |
| loss = ret["loss"]["loss"] | |
| # calculate the camear pose loss | |
| c2w_traj_est = ret["poses_pred"] | |
| try: | |
| # _, _, est_traj_aligned = eval_ate.evaluate(c2w_traj_est[0].detach().cpu().numpy(), | |
| # batch["traj_mat"][0].cpu().numpy(), scale=True) | |
| # align the traj_mat with gt | |
| est_traj_aligned = c2w_traj_est[0].clone().detach().cpu().numpy() | |
| est_traj_aligned[:, :3, 3] *= ret["loss"]["norm_scale"].item() | |
| rpe_error = eval_rpe.evaluate(batch["traj_mat"][0].cpu().numpy(), est_traj_aligned) | |
| traj_pts = batch["traj_mat"][0][:,:3, 3] | |
| traj_pts = ((traj_pts[1:] - traj_pts[:-1]).norm(dim=-1)).sum() | |
| rot_error = rpe_error[0]/np.pi*180 # degrees | |
| if traj_pts.abs() < 5e-2: | |
| traj_pts = 1 | |
| tran_error = rpe_error[1]/traj_pts*100 # percentage | |
| except: | |
| rot_error = 0 | |
| tran_error = 0 | |
| if "custom_vid" in annots: | |
| rot_err.update(0, batch["rgbs"].shape[0]) | |
| trans_err.update(0, batch["rgbs"].shape[0]) | |
| else: | |
| rot_err.update(rot_error, batch["rgbs"].shape[0]) | |
| trans_err.update(tran_error, batch["rgbs"].shape[0]) | |
| if ret["loss"] is not None: | |
| if "ln_depth_glob" in ret["loss"]: | |
| ln_depth_glob.update(ret["loss"]["ln_depth_glob"].item(), batch["rgbs"].shape[0]) | |
| if "ln_edge" in ret["loss"]: | |
| ln_edge.update(ret["loss"]["ln_edge"].item(), batch["rgbs"].shape[0]) | |
| if "ln_normal" in ret["loss"]: | |
| ln_normal.update(ret["loss"]["ln_normal"].item(), batch["rgbs"].shape[0]) | |
| if "ln_cons" in ret["loss"]: | |
| ln_cons.update(ret["loss"]["ln_cons"].item(), batch["rgbs"].shape[0]) | |
| if "ln_scale_shift" in ret["loss"]: | |
| ln_scale_shift.update(ret["loss"]["ln_scale_shift"].item(), batch["rgbs"].shape[0]) | |
| if "ln_pose" in ret["loss"]: | |
| ln_pose.update(ret["loss"]["ln_pose"].item(), batch["rgbs"].shape[0]) | |
| if "ln_msk_l2" in ret["loss"]: | |
| ln_msk_l2.update(ret["loss"]["ln_msk_l2"].item(), batch["rgbs"].shape[0]) | |
| lite.barrier() | |
| lite.backward(loss) | |
| # for name, param in model.named_parameters(): | |
| # if param.grad is not None: | |
| # print(f"the norm of gradient of {name}: {param.grad.norm()}") | |
| # import pdb; pdb.set_trace() | |
| if (cfg.train.clip_grad > 0): | |
| if cfg.clip_type == "norm": | |
| #NOTE: clip the grad norm must be done after the unscale the optimizer | |
| #NOTE: clip the grad norm by their groups | |
| for group in optimizer.param_groups: | |
| torch.nn.utils.clip_grad_norm_(group['params'], cfg.train.clip_grad) | |
| elif cfg.clip_type == "value": | |
| torch.nn.utils.clip_grad_value_(model.parameters(), cfg.train.clip_grad) | |
| # print the gradient norm | |
| # for name, param in model.named_parameters(): | |
| # if param.grad is not None: | |
| # print(f"the norm of gradient of {name}: {param.grad.norm()}") | |
| # import pdb; pdb.set_trace() | |
| if torch.isnan(next(p for p in model.parameters() if p.grad is not None).grad).any(): | |
| for name, param in model.named_parameters(): | |
| if param.grad is not None and torch.isnan(param.grad).any(): | |
| print(f"NaN gradient found in parameter {name}") | |
| if lite.is_global_zero: | |
| save_path = Path( | |
| f"{cfg.exp_dir}/nan_model.pth" | |
| ) | |
| save_dict = { | |
| "model": model.module.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "scheduler": lr_scheduler.state_dict(), | |
| "total_steps": len(dataloader) * epoch, | |
| } | |
| logging.info(f"Saving file {save_path}") | |
| torch.save(save_dict, save_path) | |
| torch.cuda.empty_cache() | |
| # else: | |
| optimizer.step() | |
| lr_scheduler.step() | |
| # visualize the results | |
| if (lite.is_global_zero)&(global_step%cfg.train.print_interval == 0): | |
| with torch.no_grad(): | |
| # log the loss and the depth info | |
| logger_tf.add_scalar("train/loss", loss, global_step) | |
| logger_tf.add_scalar("train/learning_rate", | |
| lr_scheduler.get_lr()[0], global_step) | |
| logger_tf.add_scalar("train/ln_depth_glob", ln_depth_glob.avg, global_step) | |
| logger_tf.add_scalar("train/ln_edge", ln_edge.avg, global_step) | |
| logger_tf.add_scalar("train/ln_normal", ln_normal.avg, global_step) | |
| logger_tf.add_scalar("train/ln_cons", ln_cons.avg, global_step) | |
| logger_tf.add_scalar("train/ln_scale_shift", ln_scale_shift.avg, global_step) | |
| logger_tf.add_scalar("train/ln_pose", ln_pose.avg, global_step) | |
| logger_tf.add_scalar("train/ln_msk_l2", ln_msk_l2.avg, global_step) | |
| logger_tf.add_scalar("train/rot_error", rot_err.avg, global_step) | |
| logger_tf.add_scalar("train/tran_error", trans_err.avg, global_step) | |
| logger.info(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}, " | |
| f"Depth Glob Loss: {ln_depth_glob.avg:.4f}, " | |
| f"Edge Loss: {ln_edge.avg:.4f}, Normal Loss: {ln_normal.avg:.4f}, " | |
| f"Cons Loss: {ln_cons.avg:.4f}, Scale Shift Loss: {ln_scale_shift.avg:.4f}," | |
| f"Pose Loss: {ln_pose.avg:.4f}, Mask L2 Loss: {ln_msk_l2.avg:.4f}," | |
| f"Rot Error: {rot_err.avg:.4f}," | |
| f"Trans Error: {trans_err.avg:.4f}" | |
| ) | |
| # log with wandb | |
| if logger_wb is not None: | |
| logger_wb.log({ | |
| "train/loss": loss, | |
| "train/learning_rate": lr_scheduler.get_lr()[0], | |
| "train/ln_depth_glob": ln_depth_glob.avg, | |
| "train/ln_edge": ln_edge.avg, | |
| "train/ln_normal": ln_normal.avg, | |
| "train/ln_cons": ln_cons.avg, | |
| "train/ln_scale_shift": ln_scale_shift.avg, | |
| "train/ln_pose": ln_pose.avg, | |
| "train/ln_msk_l2": ln_msk_l2.avg, | |
| "train/rot_error": rot_err.avg, | |
| "train/tran_error": trans_err.avg, | |
| }, step=global_step) | |
| # reset the loss | |
| ln_depth_glob.reset() | |
| ln_edge.reset() | |
| ln_normal.reset() | |
| ln_cons.reset() | |
| ln_scale_shift.reset() | |
| ln_pose.reset() | |
| ln_msk_l2.reset() | |
| rot_err.reset() | |
| trans_err.reset() | |
| exception_case = (tran_error > 20) and (cfg.train.stage == 1) | |
| exception_case = False | |
| exception_name = ["norm", "abnorm"] | |
| if ((lite.is_global_zero)&(global_step%200 == 0))|(exception_case) and (annots.get("custom_vid", False)==False): | |
| with torch.no_grad(): | |
| # visualize the results | |
| os.makedirs(f"vis_depth", exist_ok=True) | |
| gt_rel_dp_vis = 1/(annots["depth_gt"][0].squeeze().float().clamp(min=1e-5)) | |
| # gt_rel_dp_vis = (gt_rel_dp_vis-gt_rel_dp_vis.min())/(gt_rel_dp_vis.max()-gt_rel_dp_vis.min()) | |
| gt_rel_dp_vis[gt_rel_dp_vis>30] = 0.0 | |
| gt_rel_dp_vis = gt_rel_dp_vis.clamp(min=0.0, max=20.0) | |
| gt_reldepth = vis_video_depth(gt_rel_dp_vis) | |
| if logger_wb is not None: | |
| logger_wb.log({ | |
| f"gt_relvideo_{exception_name[int(exception_case)]}": wandb.Video(gt_reldepth.transpose(0,3,1,2), fps=4, caption="Training Progress"), | |
| }, step=global_step) | |
| if logger_tf is not None: | |
| logger_tf.add_video(f"train/gt_relvideo_{exception_name[int(exception_case)]}", gt_reldepth.transpose(0,3,1,2)[None], global_step) | |
| if ((cfg.train.save_4d)&(global_step%200 == 0))|(exception_case): | |
| viser4d_dir = os.path.join(cfg.exp_dir,f"viser_4d_{global_step}_{exception_name[int(exception_case)]}") | |
| os.makedirs(viser4d_dir, exist_ok=True) | |
| # save the 4d results | |
| # depth_est = ret["depth"] | |
| unc_metric = ret["unc_metric"] | |
| mask = (unc_metric>0.5).squeeze(1) | |
| # pose_est = batch["traj_mat"].squeeze(0).float() | |
| # pose_est[:, :3, 3] /= ret["loss"]["norm_scale"].item() | |
| pose_est = ret["poses_pred"][0] | |
| intrinsics = ret["intrs"].squeeze(0).float() | |
| for i in range(ret["points_map"].shape[0]): | |
| img_i = ret["images"][0][i].permute(1,2,0).float().cpu().numpy() | |
| img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB) | |
| cv2.imwrite(osp.join(viser4d_dir, f'frame_{i:04d}.png'), img_i) | |
| point_map = ret["points_map"][i].float().permute(2,0,1).cpu().numpy() | |
| np.save(osp.join(viser4d_dir, f'point_{i:04d}.npy'), point_map) | |
| np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy()) | |
| np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy()) | |
| np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy()) | |
| return True | |
| def eval_fn( | |
| model, dataloader, cfg, optimizer, lite, scaler, | |
| lr_scheduler, training=True, viz=None, epoch=-1, logger_tf=None, logger=None, | |
| logger_wb=None | |
| ): | |
| model.eval() | |
| # model = model.float() | |
| # define the metrics | |
| ATE = AverageMeter("ATE", ":.4f") | |
| RPE0 = AverageMeter("Relative Rot error", ":.4f") | |
| RPE1 = AverageMeter("Relative Trans error", ":.4f") | |
| DMSE = AverageMeter("depth_loss", ":.4f") | |
| FDE = AverageMeter("focal_loss", ":.4f") | |
| # calculate the metrics | |
| if lite.is_global_zero: | |
| for step, batch in enumerate(tqdm.tqdm(dataloader, disable=not lite.is_global_zero)): | |
| batch = {k: v.to(model.device) for k, v in batch.items() if isinstance(v, torch.Tensor)} | |
| annots = { | |
| "poses_gt": batch["pose_enc"].float(), | |
| "depth_gt": batch["depths"].float(), | |
| "traj_mat": batch["traj_mat"].float(), | |
| "metric": True | |
| } | |
| batch["rgbs"] = batch["rgbs"].to(torch.bfloat16) | |
| with torch.no_grad(): | |
| ret = model(batch["rgbs"]) | |
| pred_SEs = ret["poses_pred"][0] | |
| pred_focal = ret["focal"] | |
| gt_SEs = annots["traj_mat"][0] | |
| # ate loss | |
| pred_SEs_r = pred_SEs | |
| error, gt_traj, est_traj_aligned = eval_ate.evaluate(gt_SEs.float().cpu().numpy(), | |
| pred_SEs_r.float().cpu().numpy(), scale=True) | |
| # rpe loss | |
| rpe_error = eval_rpe.evaluate(gt_SEs.float().cpu().numpy(), | |
| est_traj_aligned) | |
| # focal loss | |
| H, W = batch["rgbs"].shape[-2:] | |
| focal_loss = torch.nn.functional.l1_loss(pred_focal, | |
| annots["poses_gt"][:,:,-1])*H | |
| # depth loss | |
| depth_loss = torch.nn.functional.l1_loss(ret["inv_depth"].squeeze(), | |
| annots["depth_gt"].squeeze()) | |
| # calculate the metrics | |
| ATE.update(error, batch["rgbs"].shape[0]) | |
| RPE0.update(rpe_error[0]/np.pi*180, batch["rgbs"].shape[0]) | |
| RPE1.update(rpe_error[1], batch["rgbs"].shape[0]) | |
| FDE.update(focal_loss, batch["rgbs"].shape[0]) | |
| DMSE.update(depth_loss, batch["rgbs"].shape[0]) | |
| logger.info(f"Epoch {epoch}, ATE: {ATE.avg:.4f} mm, RPE0: {RPE0.avg:.4f} degrees, RPE1: {RPE1.avg:.4f} mm, FDE: {FDE.avg:.4f} pix, DMSE: {DMSE.avg:.4f}") | |
| logger_tf.add_scalar("eval/ATE", ATE.avg, epoch) | |
| logger_tf.add_scalar("eval/RPE0", RPE0.avg, epoch) | |
| logger_tf.add_scalar("eval/RPE1", RPE1.avg, epoch) | |
| logger_tf.add_scalar("eval/FDE", FDE.avg, epoch) | |
| logger_tf.add_scalar("eval/DMSE", DMSE.avg, epoch) | |
| if logger_wb is not None: | |
| # log with wandb | |
| logger_wb.log({ | |
| "eval/ATE": ATE.avg, | |
| "eval/RPE0": RPE0.avg, | |
| "eval/RPE1": RPE1.avg, | |
| "eval/FDE": FDE.avg, | |
| "eval/DMSE": DMSE.avg, | |
| }, step=epoch) | |
| return True |