Spaces:
Build error
Build error
| from typing import overload, Tuple, Optional | |
| import os | |
| import cv2 | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| import numpy as np | |
| from glob import glob | |
| from PIL import Image | |
| from einops import rearrange | |
| from model.cldm import ControlLDM | |
| from model.gaussian_diffusion import Diffusion | |
| from model.bsrnet import RRDBNet | |
| from model.swinir import SwinIR | |
| from model.scunet import SCUNet | |
| from utils.sampler import SpacedSampler | |
| from utils.cond_fn import Guidance | |
| from utils.video_visualizer import VideoVisualizer | |
| from utils.common import wavelet_decomposition, wavelet_reconstruction, count_vram_usage | |
| import vidtome | |
| from GMFlow.gmflow.gmflow import GMFlow | |
| from utils.flow_utils import get_warped_and_mask | |
| def save_video(input_folder, out_path, output_name, fps=25): | |
| video_visualizer = VideoVisualizer(path=os.path.join(out_path, output_name), | |
| frame_size=None, | |
| fps=fps) | |
| input_folder = os.path.join(out_path, input_folder) | |
| imgs = sorted([filename for filename in os.listdir(input_folder) if filename.endswith(('.png', '.jpg'))], key=lambda x: int(x.split('.')[0])) | |
| for img in imgs: | |
| img_pth = os.path.join(input_folder, img) | |
| image = cv2.imread(img_pth) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| video_visualizer.add(image) | |
| video_visualizer.save() | |
| def batch_bicubic_resize(img: np.ndarray, scale: float) -> np.ndarray: | |
| if scale != 1: | |
| for i in range(img.shape[0]): | |
| img[i] = bicubic_resize(img[i], scale) | |
| # pil = Image.fromarray(img) | |
| # res = pil.resize(tuple(int(x * scale) for x in pil.size), Image.BICUBIC) | |
| return img | |
| def bicubic_resize(img: np.ndarray, scale: float) -> np.ndarray: | |
| if scale != 1: | |
| pil = Image.fromarray(img) | |
| res = pil.resize(tuple(int(x * scale) for x in pil.size), Image.BICUBIC) | |
| return np.array(res) | |
| def resize_short_edge_to(imgs: torch.Tensor, size: int) -> torch.Tensor: | |
| _, _, h, w = imgs.size() | |
| if h == w: | |
| new_h, new_w = size, size | |
| elif h < w: | |
| new_h, new_w = size, int(w * (size / h)) | |
| else: | |
| new_h, new_w = int(h * (size / w)), size | |
| return F.interpolate(imgs, size=(new_h, new_w), mode="bicubic", antialias=True) | |
| def pad_to_multiples_of(imgs: torch.Tensor, multiple: int) -> torch.Tensor: | |
| _, _, h, w = imgs.size() | |
| if h % multiple == 0 and w % multiple == 0: | |
| return imgs.clone() | |
| # get_pad = lambda x: (x // multiple + 1) * multiple - x | |
| get_pad = lambda x: (x // multiple + int(x % multiple != 0)) * multiple - x | |
| ph, pw = get_pad(h), get_pad(w) | |
| return F.pad(imgs, pad=(0, pw, 0, ph), mode="constant", value=0) | |
| class Pipeline: | |
| def __init__(self, stage1_model: nn.Module, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None: | |
| self.stage1_model = stage1_model | |
| self.cldm = cldm | |
| self.diffusion = diffusion | |
| self.cond_fn = cond_fn | |
| self.device = device | |
| self.final_size: Tuple[int] = None | |
| def set_final_size(self, lq: torch.Tensor) -> None: | |
| h, w = lq.shape[2:] | |
| self.final_size = (h, w) | |
| def run_stage1(self, lq: torch.Tensor) -> torch.Tensor: | |
| ... | |
| def run_stage2( | |
| self, | |
| clean: torch.Tensor, | |
| steps: int, | |
| strength: float, | |
| tiled: bool, | |
| tile_size: int, | |
| tile_stride: int, | |
| pos_prompt: str, | |
| neg_prompt: str, | |
| cfg_scale: float, | |
| better_start: float, | |
| index: int = 0, | |
| input: str = None | |
| ) -> torch.Tensor: | |
| ### preprocess | |
| bs, _, ori_h, ori_w = clean.shape | |
| # pad: ensure that height & width are multiples of 64 | |
| pad_clean = pad_to_multiples_of(clean, multiple=64) | |
| h, w = pad_clean.shape[2:] | |
| if self.cldm.controller is not None: | |
| self.cldm.controller.cldm = self.cldm | |
| self.cldm.controller.non_pad_ratio = (ori_h / h, ori_w / w) | |
| self.cldm.vae.decoder.controller = self.cldm.controller | |
| # prepare conditon | |
| if not tiled: | |
| cond = self.cldm.prepare_condition(pad_clean, [pos_prompt] * bs) | |
| uncond = self.cldm.prepare_condition(pad_clean, [neg_prompt] * bs) | |
| else: | |
| cond = self.cldm.prepare_condition_tiled(pad_clean, [pos_prompt] * bs, tile_size, tile_stride) | |
| uncond = self.cldm.prepare_condition_tiled(pad_clean, [neg_prompt] * bs, tile_size, tile_stride) | |
| if self.cond_fn: | |
| self.cond_fn.load_target(pad_clean * 2 - 1) | |
| old_control_scales = self.cldm.control_scales | |
| self.cldm.control_scales = [strength] * 13 | |
| if better_start: | |
| # using noised low frequency part of condition as a better start point of | |
| # reverse sampling, which can prevent our model from generating noise in | |
| # image background. | |
| _, low_freq = wavelet_decomposition(pad_clean) | |
| # low_freq = pad_clean | |
| if not tiled: | |
| x_0 = self.cldm.vae_encode(low_freq, batch_size=5) | |
| else: | |
| x_0 = self.cldm.vae_encode_tiled(low_freq, tile_size, tile_stride) | |
| x_T = self.diffusion.q_sample( | |
| x_0, | |
| torch.full((bs, ), self.diffusion.num_timesteps - 1, dtype=torch.long, device=self.device), | |
| torch.randn(x_0.shape, dtype=torch.float32, device=self.device) | |
| ) | |
| # print(f"diffusion sqrt_alphas_cumprod: {self.diffusion.sqrt_alphas_cumprod[-1]}") | |
| else: | |
| if self.cldm.latent_control: | |
| print(f"[INFO] random initialize {bs} same latents") | |
| x_T = 1 * torch.randn((1, 4, h // 8, w // 8), dtype=torch.float32, device=self.device) | |
| x_T = x_T.repeat(bs, 1, 1, 1) | |
| else: | |
| print(f"[INFO] random initialize {bs} latents") | |
| x_T = torch.randn((bs, 4, h // 8, w // 8), dtype=torch.float32, device=self.device) | |
| ''' loaded latents ''' | |
| # t = 981 | |
| # latent_fname = f'noisy_latents_{t}.pt' | |
| # # model_key = config.model_key.split('/')[-1] | |
| # model_key = "stable-diffusion-2-1-base" | |
| # inversion_path = os.path.join("latents", os.path.basename(input), "latents") | |
| # # outputs/bear_4_BD/latents/stable-diffusion-v1-5/noisy_latents_981.pt | |
| # lp = os.path.join(inversion_path, model_key, latent_fname) | |
| # latents = torch.load(lp) | |
| # # init_noise = latents.to(dtype).to(args.device) | |
| # x_T = latents[index][None].to(torch.float32).to(self.device) | |
| # print(f"[INFO] loaded latents[{index}]") | |
| ''' loaded latent ended ''' | |
| ### run sampler | |
| sampler = SpacedSampler(self.diffusion.betas) | |
| z = sampler.sample( | |
| model=self.cldm, device=self.device, steps=steps, batch_size=bs, x_size=(4, h // 8, w // 8), | |
| cond=cond, uncond=uncond, cfg_scale=cfg_scale, x_T=x_T, progress=True, | |
| progress_leave=True, cond_fn=self.cond_fn, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, | |
| non_pad_ratio=(ori_h / h, ori_w / w) | |
| ) | |
| if not tiled: | |
| if ori_w > 1500: | |
| x = self.cldm.vae_decode(z, batch_size=2) | |
| else: | |
| x = self.cldm.vae_decode(z, batch_size=5) | |
| else: | |
| x = self.cldm.vae_decode_tiled(z, tile_size // 8, tile_stride // 8) | |
| ### postprocess | |
| self.cldm.control_scales = old_control_scales | |
| sample = x[:, :, :ori_h, :ori_w] | |
| return sample | |
| def run( | |
| self, | |
| lq: np.ndarray, | |
| steps: int, | |
| strength: float, | |
| tiled: bool, | |
| tile_size: int, | |
| tile_stride: int, | |
| pos_prompt: str, | |
| neg_prompt: str, | |
| cfg_scale: float, | |
| better_start: bool, | |
| index: int = 0, | |
| input: str = None, | |
| final_size: Tuple[int] = None, | |
| flow_model: GMFlow = None, | |
| hq: np.ndarray = None | |
| ) -> np.ndarray: | |
| # image to tensor | |
| lq = torch.tensor((lq / 255.).clip(0, 1), dtype=torch.float32, device=self.device) | |
| lq = rearrange(lq, "n h w c -> n c h w").contiguous() | |
| # set pipeline output size | |
| if final_size is None: | |
| self.set_final_size(lq) | |
| else: | |
| self.final_size = final_size | |
| clean = self.run_stage1(lq) | |
| print(f"[INFO] {clean.shape}") | |
| # import ipdb; ipdb.set_trace() | |
| # clean = F.interpolate(lq, size=clean.shape[-2:], mode='bicubic', align_corners=False) | |
| ''' hq flow & occlusion mask ''' | |
| # hq = torch.tensor((hq / 255.).clip(0, 1), dtype=torch.float32, device=self.device) | |
| # hq = rearrange(hq, "n h w c -> n c h w").contiguous() | |
| # hq = resize_short_edge_to(hq, size=512) | |
| # pre_keyframe_lq = None | |
| # if self.cldm.controller is not None and \ | |
| # self.cldm.controller.step_store["pre_keyframe_lq"] is not None: | |
| # pre_keyframe_lq = self.cldm.controller.step_store["pre_keyframe_lq"] | |
| # pre_keyframe_lq = torch.tensor((pre_keyframe_lq / 255.).clip(0, 1), dtype=torch.float32, device=self.device) | |
| # pre_keyframe_lq = rearrange(pre_keyframe_lq, "n h w c -> n c h w").contiguous() | |
| # pre_keyframe_lq = resize_short_edge_to(pre_keyframe_lq, size=512) | |
| # pre_keyframe_clean = pre_keyframe_lq[0] | |
| # # pre_keyframe_clean = self.run_stage1(pre_keyframe_lq)[0] | |
| # flows, masks, confids = [], [], [] | |
| # mid = lq.shape[0] // 2 | |
| # for k in range(lq.shape[0]): | |
| # if k == mid: | |
| # if pre_keyframe_lq is not None: | |
| # tar_img = (torch.clamp(hq[mid], 0 ,1) * 255).float().to(self.device) | |
| # src_img = (torch.clamp(pre_keyframe_clean, 0 ,1) * 255).float().to(self.device) | |
| # else: | |
| # flows.append(None) | |
| # masks.append(None) | |
| # confids.append(None) | |
| # continue | |
| # else: | |
| # tar_img = (torch.clamp(hq[k], 0 ,1) * 255).float().to(self.device) | |
| # src_img = (torch.clamp(hq[mid], 0 ,1) * 255).float().to(self.device) | |
| # # tar_img = stage1_x[0].float().to(args.device) | |
| # _, bwd_occ, bwd_flow, bwd_confid = get_warped_and_mask( | |
| # flow_model, src_img, tar_img, image3=None, pixel_consistency=False, return_confidence=True) | |
| # blend_mask = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))( | |
| # F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4)) | |
| # blend_mask = torch.clamp(blend_mask + bwd_occ, 0, 1) | |
| # blend_mask = 1 - F.max_pool2d(blend_mask, kernel_size=8) | |
| # bwd_confid = F.max_pool2d(bwd_confid, kernel_size=8) | |
| # bwd_flow = F.interpolate(bwd_flow / 8.0, scale_factor=1. / 8, mode='bilinear') | |
| # # _, _, h, w = bwd_flow.shape | |
| # # bwd_flow = pad_to_multiples_of(bwd_flow, 8) | |
| # # padding_ratio = w / bwd_flow.shape[3] | |
| # blend_mask = pad_to_multiples_of(blend_mask[None], 8)[0] | |
| # # bwd_confid = pad_to_multiples_of(bwd_confid[None], 8)[0] | |
| # flows.append(bwd_flow) | |
| # masks.append(blend_mask) | |
| # confids.append(bwd_confid) | |
| # if self.cldm.controller is not None: | |
| # self.cldm.controller.set_warp(flows, masks, flow_confids=confids) | |
| ''' flow & occlusion mask ''' | |
| pre_keyframe_lq = None | |
| if self.cldm.controller is not None and \ | |
| self.cldm.controller.step_store["pre_keyframe_lq"] is not None: | |
| pre_keyframe_lq = self.cldm.controller.step_store["pre_keyframe_lq"] | |
| pre_keyframe_lq = torch.tensor((pre_keyframe_lq / 255.).clip(0, 1), dtype=torch.float32, device=self.device) | |
| pre_keyframe_lq = rearrange(pre_keyframe_lq, "n h w c -> n c h w").contiguous() | |
| pre_keyframe_clean = self.run_stage1(pre_keyframe_lq)[0] | |
| flows, masks, confids = [], [], [] | |
| flows2, confids2 = [], [] | |
| mid = lq.shape[0] // 2 | |
| for k in range(lq.shape[0]): | |
| if k == mid: | |
| if pre_keyframe_lq is not None: | |
| tar_img = (torch.clamp(clean[mid], 0 ,1) * 255).float().to(self.device) | |
| src_img = (torch.clamp(pre_keyframe_clean, 0 ,1) * 255).float().to(self.device) | |
| else: | |
| flows.append(None) | |
| masks.append(None) | |
| confids.append(None) | |
| continue | |
| else: | |
| tar_img = (torch.clamp(clean[k], 0 ,1) * 255).float().to(self.device) | |
| src_img = (torch.clamp(clean[mid], 0 ,1) * 255).float().to(self.device) | |
| # tar_img = stage1_x[0].float().to(args.device) | |
| _, bwd_occ, bwd_flow, bwd_confid = get_warped_and_mask( | |
| flow_model, src_img, tar_img, image3=None, pixel_consistency=False, return_confidence=True) | |
| blend_mask = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))( | |
| F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4)) | |
| blend_mask = torch.clamp(blend_mask + bwd_occ, 0, 1) | |
| blend_mask = 1 - F.max_pool2d(blend_mask, kernel_size=8) | |
| blend_mask = 1 - F.max_pool2d(bwd_occ, kernel_size=8) | |
| bwd_confid2 = F.max_pool2d(bwd_confid, kernel_size=16) | |
| bwd_flow2 = F.interpolate(bwd_flow / 16.0, scale_factor=1. / 16, mode='bilinear') | |
| bwd_confid = F.max_pool2d(bwd_confid, kernel_size=8) | |
| bwd_flow = F.interpolate(bwd_flow / 8.0, scale_factor=1. / 8, mode='bilinear') | |
| # _, _, h, w = bwd_flow.shape | |
| # bwd_flow = pad_to_multiples_of(bwd_flow, 8) | |
| # padding_ratio = w / bwd_flow.shape[3] | |
| blend_mask = pad_to_multiples_of(blend_mask[None], 8)[0] | |
| # bwd_confid = pad_to_multiples_of(bwd_confid[None], 8)[0] | |
| flows.append(bwd_flow) | |
| masks.append(blend_mask) | |
| confids.append(bwd_confid) | |
| flows2.append(bwd_flow2) | |
| confids2.append(bwd_confid2) | |
| if self.cldm.controller is not None: | |
| self.cldm.controller.set_warp(flows, masks, flow_confids=confids) | |
| # import ipdb; ipdb.set_trace() | |
| _, H, W = confids[0].shape | |
| self.cldm.controller.set_flow_correspondence(lq.shape[0], H, W, lq.shape[0] // 2, confids, flows) | |
| _, H, W = confids2[0].shape | |
| self.cldm.controller.set_flow_correspondence(lq.shape[0], H, W, lq.shape[0] // 2, confids2, flows2) | |
| for j, flow in enumerate(self.cldm.controller.step_store["flows"]): | |
| if flow is not None: | |
| self.cldm.controller.step_store["flows"][j] = pad_to_multiples_of(self.cldm.controller.step_store["flows"][j], 8) | |
| # self.cldm.controller.set_warp2(flows2, confids2) | |
| ''' flow & occlusion mask ended ''' | |
| sample = self.run_stage2( | |
| clean, steps, strength, tiled, tile_size, tile_stride, | |
| pos_prompt, neg_prompt, cfg_scale, better_start, | |
| index=index, input=input | |
| ) | |
| if self.cldm.controller is not None: | |
| print(f"[INFO] clearing controller correspondence scores ... ") | |
| self.cldm.controller.step_store["corres_scores"] = None | |
| # colorfix (borrowed from StableSR, thanks for their work) | |
| sample = (sample + 1) / 2 | |
| sample = wavelet_reconstruction(sample, clean) | |
| # resize to desired output size | |
| sample = F.interpolate(sample, size=self.final_size, mode="bicubic", antialias=True) | |
| clean = F.interpolate(clean, size=self.final_size, mode="bilinear", antialias=True) | |
| # tensor to image | |
| sample = rearrange(sample * 255., "n c h w -> n h w c") | |
| sample = sample.contiguous().clamp(0, 255).to(torch.uint8).cpu().numpy() | |
| clean = rearrange(clean * 255., "n c h w -> n h w c") | |
| clean = clean.contiguous().clamp(0, 255).to(torch.uint8).cpu().numpy() | |
| return sample, clean | |
| class BSRNetPipeline(Pipeline): | |
| def __init__(self, bsrnet: RRDBNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str, upscale: float) -> None: | |
| super().__init__(bsrnet, cldm, diffusion, cond_fn, device) | |
| self.upscale = upscale | |
| def set_final_size(self, lq: torch.Tensor) -> None: | |
| h, w = lq.shape[2:] | |
| self.final_size = (int(h * self.upscale), int(w * self.upscale)) | |
| def run_stage1(self, lq: torch.Tensor) -> torch.Tensor: | |
| # NOTE: upscale is always set to 4 in our experiments | |
| if lq.shape[-2] > 1000: | |
| clean = [] | |
| for i in range(lq.shape[0]): | |
| torch.cuda.empty_cache() | |
| clean.append(self.stage1_model(lq[i:i+1])) | |
| clean = torch.cat(clean, dim=0) | |
| else: | |
| clean = self.stage1_model(lq) | |
| # if self.final_size[0] < 512 and self.final_size[1] < 512: | |
| if min(self.final_size) < 512: | |
| clean = resize_short_edge_to(clean, size=512) | |
| else: | |
| clean = F.interpolate(clean, size=self.final_size, mode="bicubic", antialias=True) | |
| return clean | |
| class SwinIRPipeline(Pipeline): | |
| def __init__(self, swinir: SwinIR, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None: | |
| super().__init__(swinir, cldm, diffusion, cond_fn, device) | |
| def run_stage1(self, lq: torch.Tensor) -> torch.Tensor: | |
| # NOTE: lq size is always equal to 512 in our experiments | |
| # resize: ensure the input lq size is as least 512, since SwinIR is trained on 512 resolution | |
| if min(lq.shape[2:]) < 512: | |
| lq = resize_short_edge_to(lq, size=512) | |
| ori_h, ori_w = lq.shape[2:] | |
| # pad: ensure that height & width are multiples of 64 | |
| pad_lq = pad_to_multiples_of(lq, multiple=64) | |
| # run | |
| clean = self.stage1_model(pad_lq) | |
| # remove padding | |
| clean = clean[:, :, :ori_h, :ori_w] | |
| return clean | |
| class SCUNetPipeline(Pipeline): | |
| def __init__(self, scunet: SCUNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None: | |
| super().__init__(scunet, cldm, diffusion, cond_fn, device) | |
| def run_stage1(self, lq: torch.Tensor) -> torch.Tensor: | |
| if lq.shape[-1] > 1500: | |
| clean = [] | |
| batch_lq = lq.split(2, dim=0) | |
| for lq_ in batch_lq: | |
| clean.append(self.stage1_model(lq_)) | |
| torch.cuda.empty_cache() | |
| clean = torch.cat(clean) | |
| else: | |
| clean = self.stage1_model(lq) | |
| if min(clean.shape[2:]) < 512: | |
| clean = resize_short_edge_to(clean, size=512) | |
| # import ipdb; ipdb.set_trace() | |
| return clean |