import torch import torch.nn.functional as F from PIL import Image import numpy as np from tqdm import trange import torchvision.transforms as T import torch.nn.functional as F from typing import Tuple import scipy.ndimage import cv2 from train.src.condition.util import HWC3, common_input_validate def check_image_mask(image, mask, name): if len(image.shape) < 4: # image tensor shape should be [B, H, W, C], but batch somehow is missing image = image[None,:,:,:] if len(mask.shape) > 3: # mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be? # take first mask, red channel mask = (mask[:,:,:,0])[:,:,:] elif len(mask.shape) < 3: # mask tensor shape should be [B, H, W] but batch somehow is missing mask = mask[None,:,:] if image.shape[0] > mask.shape[0]: print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0])) if mask.shape[0] == 1: print(name, "will copy the mask to fill batch") mask = torch.cat([mask] * image.shape[0], dim=0) else: print(name, "will add empty masks to fill batch") empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]]) mask = torch.cat([mask, empty_mask], dim=0) elif image.shape[0] < mask.shape[0]: print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0])) mask = mask[:image.shape[0],:,:] return (image, mask) def cv2_resize_shortest_edge(image, size): h, w = image.shape[:2] if h < w: new_h = size new_w = int(round(w / h * size)) else: new_w = size new_h = int(round(h / w * size)) resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA) return resized_image def apply_color(img, res=512): img = cv2_resize_shortest_edge(img, res) h, w = img.shape[:2] input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC) input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST) return input_img_color #Color T2I like multiples-of-64, upscale methods are fixed. class ColorDetector: def __call__(self, input_image=None, detect_resolution=512, output_type=None, **kwargs): input_image, output_type = common_input_validate(input_image, output_type, **kwargs) input_image = HWC3(input_image) detected_map = HWC3(apply_color(input_image, detect_resolution)) if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map class InpaintPreprocessor: def preprocess(self, image, mask, black_pixel_for_xinsir_cn=False): mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(image.shape[1], image.shape[2]), mode="bilinear") mask = mask.movedim(1,-1).expand((-1,-1,-1,3)) image = image.clone() if black_pixel_for_xinsir_cn: masked_pixel = 0.0 else: masked_pixel = -1.0 image[mask > 0.5] = masked_pixel return (image,) class BlendInpaint: def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:int, origin=None) -> Tuple[torch.Tensor]: original, mask = check_image_mask(original, mask, 'Blend Inpaint') if len(inpaint.shape) < 4: # image tensor shape should be [B, H, W, C], but batch somehow is missing inpaint = inpaint[None,:,:,:] if inpaint.shape[0] < original.shape[0]: print("Blend Inpaint gets batch of original images (%d) but only (%d) inpaint images" % (original.shape[0], inpaint.shape[0])) original= original[:inpaint.shape[0],:,:] mask = mask[:inpaint.shape[0],:,:] if inpaint.shape[0] > original.shape[0]: # batch over inpaint count = 0 original_list = [] mask_list = [] origin_list = [] while (count < inpaint.shape[0]): for i in range(original.shape[0]): original_list.append(original[i][None,:,:,:]) mask_list.append(mask[i][None,:,:]) if origin is not None: origin_list.append(origin[i][None,:]) count += 1 if count >= inpaint.shape[0]: break original = torch.concat(original_list, dim=0) mask = torch.concat(mask_list, dim=0) if origin is not None: origin = torch.concat(origin_list, dim=0) if kernel % 2 == 0: kernel += 1 transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma)) ret = [] blurred = [] for i in range(inpaint.shape[0]): if origin is None: blurred_mask = transform(mask[i][None,None,:,:]).to(original.device).to(original.dtype) blurred.append(blurred_mask[0]) result = torch.nn.functional.interpolate( inpaint[i][None,:,:,:].permute(0, 3, 1, 2), size=( original[i].shape[0], original[i].shape[1], ) ).permute(0, 2, 3, 1).to(original.device).to(original.dtype) else: # got mask from CutForInpaint height, width, _ = original[i].shape x0 = origin[i][0].item() y0 = origin[i][1].item() if mask[i].shape[0] < height or mask[i].shape[1] < width: padded_mask = F.pad(input=mask[i], pad=(x0, width-x0-mask[i].shape[1], y0, height-y0-mask[i].shape[0]), mode='constant', value=0) else: padded_mask = mask[i] blurred_mask = transform(padded_mask[None,None,:,:]).to(original.device).to(original.dtype) blurred.append(blurred_mask[0][0]) result = F.pad(input=inpaint[i], pad=(0, 0, x0, width-x0-inpaint[i].shape[1], y0, height-y0-inpaint[i].shape[0]), mode='constant', value=0) result = result[None,:,:,:].to(original.device).to(original.dtype) ret.append(original[i] * (1.0 - blurred_mask[0][0][:,:,None]) + result[0] * blurred_mask[0][0][:,:,None]) return (torch.stack(ret), torch.stack(blurred), ) def resize_mask(mask, shape): return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1) class JoinImageWithAlpha: def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor): batch_size = min(len(image), len(alpha)) out_images = [] alpha = 1.0 - resize_mask(alpha, image.shape[1:]) for i in range(batch_size): out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) result = (torch.stack(out_images),) return result class GrowMask: def expand_mask(self, mask, expand, tapered_corners): c = 0 if tapered_corners else 1 kernel = np.array([[c, 1, c], [1, 1, 1], [c, 1, c]]) mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) out = [] for m in mask: output = m.numpy() for _ in range(abs(expand)): if expand < 0: output = scipy.ndimage.grey_erosion(output, footprint=kernel) else: output = scipy.ndimage.grey_dilation(output, footprint=kernel) output = torch.from_numpy(output) out.append(output) return (torch.stack(out, dim=0),) class InvertMask: def invert(self, mask): out = 1.0 - mask return (out,)