Spaces:
Build error
Build error
| # python3.7 | |
| """Contains utility functions for image processing. | |
| The module is primarily built on `cv2`. But, differently, we assume all colorful | |
| images are with `RGB` channel order by default. Also, we assume all gray-scale | |
| images to be with shape [height, width, 1]. | |
| """ | |
| import os | |
| import cv2 | |
| import numpy as np | |
| # File extensions regarding images (not including GIFs). | |
| IMAGE_EXTENSIONS = ( | |
| '.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp', | |
| '.tiff', '.tif' | |
| ) | |
| def check_file_ext(filename, *ext_list): | |
| """Checks whether the given filename is with target extension(s). | |
| NOTE: If `ext_list` is empty, this function will always return `False`. | |
| Args: | |
| filename: Filename to check. | |
| *ext_list: A list of extensions. | |
| Returns: | |
| `True` if the filename is with one of extensions in `ext_list`, | |
| otherwise `False`. | |
| """ | |
| if len(ext_list) == 0: | |
| return False | |
| ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list] | |
| ext_list = [ext.lower() for ext in ext_list] | |
| basename = os.path.basename(filename) | |
| ext = os.path.splitext(basename)[1].lower() | |
| return ext in ext_list | |
| def _check_2d_image(image): | |
| """Checks whether a given image is valid. | |
| A valid image is expected to be with dtype `uint8`. Also, it should have | |
| shape like: | |
| (1) (height, width, 1) # gray-scale image. | |
| (2) (height, width, 3) # colorful image. | |
| (3) (height, width, 4) # colorful image with transparency (RGBA) | |
| """ | |
| assert isinstance(image, np.ndarray) | |
| assert image.dtype == np.uint8 | |
| assert image.ndim == 3 and image.shape[2] in [1, 3, 4] | |
| def get_blank_image(height, width, channels=3, use_black=True): | |
| """Gets a blank image, either white of black. | |
| NOTE: This function will always return an image with `RGB` channel order for | |
| color image and pixel range [0, 255]. | |
| Args: | |
| height: Height of the returned image. | |
| width: Width of the returned image. | |
| channels: Number of channels. (default: 3) | |
| use_black: Whether to return a black image. (default: True) | |
| """ | |
| shape = (height, width, channels) | |
| if use_black: | |
| return np.zeros(shape, dtype=np.uint8) | |
| return np.ones(shape, dtype=np.uint8) * 255 | |
| def load_image(path): | |
| """Loads an image from disk. | |
| NOTE: This function will always return an image with `RGB` channel order for | |
| color image and pixel range [0, 255]. | |
| Args: | |
| path: Path to load the image from. | |
| Returns: | |
| An image with dtype `np.ndarray`, or `None` if `path` does not exist. | |
| """ | |
| image = cv2.imread(path, cv2.IMREAD_UNCHANGED) | |
| if image is None: | |
| return None | |
| if image.ndim == 2: | |
| image = image[:, :, np.newaxis] | |
| _check_2d_image(image) | |
| if image.shape[2] == 3: | |
| return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| if image.shape[2] == 4: | |
| return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) | |
| return image | |
| def save_image(path, image): | |
| """Saves an image to disk. | |
| NOTE: The input image (if colorful) is assumed to be with `RGB` channel | |
| order and pixel range [0, 255]. | |
| Args: | |
| path: Path to save the image to. | |
| image: Image to save. | |
| """ | |
| if image is None: | |
| return | |
| _check_2d_image(image) | |
| if image.shape[2] == 1: | |
| cv2.imwrite(path, image) | |
| elif image.shape[2] == 3: | |
| cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) | |
| elif image.shape[2] == 4: | |
| cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)) | |
| def resize_image(image, *args, **kwargs): | |
| """Resizes image. | |
| This is a wrap of `cv2.resize()`. | |
| NOTE: The channel order of the input image will not be changed. | |
| Args: | |
| image: Image to resize. | |
| *args: Additional positional arguments. | |
| **kwargs: Additional keyword arguments. | |
| Returns: | |
| An image with dtype `np.ndarray`, or `None` if `image` is empty. | |
| """ | |
| if image is None: | |
| return None | |
| _check_2d_image(image) | |
| if image.shape[2] == 1: # Re-expand the squeezed dim of gray-scale image. | |
| return cv2.resize(image, *args, **kwargs)[:, :, np.newaxis] | |
| return cv2.resize(image, *args, **kwargs) | |
| def add_text_to_image(image, | |
| text='', | |
| position=None, | |
| font=cv2.FONT_HERSHEY_TRIPLEX, | |
| font_size=1.0, | |
| line_type=cv2.LINE_8, | |
| line_width=1, | |
| color=(255, 255, 255)): | |
| """Overlays text on given image. | |
| NOTE: The input image is assumed to be with `RGB` channel order. | |
| Args: | |
| image: The image to overlay text on. | |
| text: Text content to overlay on the image. (default: empty) | |
| position: Target position (bottom-left corner) to add text. If not set, | |
| center of the image will be used by default. (default: None) | |
| font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX) | |
| font_size: Font size of the text added. (default: 1.0) | |
| line_type: Line type used to depict the text. (default: cv2.LINE_8) | |
| line_width: Line width used to depict the text. (default: 1) | |
| color: Color of the text added in `RGB` channel order. (default: | |
| (255, 255, 255)) | |
| Returns: | |
| An image with target text overlaid on. | |
| """ | |
| if image is None or not text: | |
| return image | |
| _check_2d_image(image) | |
| cv2.putText(img=image, | |
| text=text, | |
| org=position, | |
| fontFace=font, | |
| fontScale=font_size, | |
| color=color, | |
| thickness=line_width, | |
| lineType=line_type, | |
| bottomLeftOrigin=False) | |
| return image | |
| def preprocess_image(image, min_val=-1.0, max_val=1.0): | |
| """Pre-processes image by adjusting the pixel range and to dtype `float32`. | |
| This function is particularly used to convert an image or a batch of images | |
| to `NCHW` format, which matches the data type commonly used in deep models. | |
| NOTE: The input image is assumed to be with pixel range [0, 255] and with | |
| format `HWC` or `NHWC`. The returned image will be always be with format | |
| `NCHW`. | |
| Args: | |
| image: The input image for pre-processing. | |
| min_val: Minimum value of the output image. | |
| max_val: Maximum value of the output image. | |
| Returns: | |
| The pre-processed image. | |
| """ | |
| assert isinstance(image, np.ndarray) | |
| image = image.astype(np.float64) | |
| image = image / 255.0 * (max_val - min_val) + min_val | |
| if image.ndim == 3: | |
| image = image[np.newaxis] | |
| assert image.ndim == 4 and image.shape[3] in [1, 3, 4] | |
| return image.transpose(0, 3, 1, 2) | |
| def postprocess_image(image, min_val=-1.0, max_val=1.0): | |
| """Post-processes image to pixel range [0, 255] with dtype `uint8`. | |
| This function is particularly used to handle the results produced by deep | |
| models. | |
| NOTE: The input image is assumed to be with format `NCHW`, and the returned | |
| image will always be with format `NHWC`. | |
| Args: | |
| image: The input image for post-processing. | |
| min_val: Expected minimum value of the input image. | |
| max_val: Expected maximum value of the input image. | |
| Returns: | |
| The post-processed image. | |
| """ | |
| assert isinstance(image, np.ndarray) | |
| image = image.astype(np.float64) | |
| image = (image - min_val) / (max_val - min_val) * 255 | |
| image = np.clip(image + 0.5, 0, 255).astype(np.uint8) | |
| assert image.ndim == 4 and image.shape[1] in [1, 3, 4] | |
| return image.transpose(0, 2, 3, 1) | |
| def parse_image_size(obj): | |
| """Parses an object to a pair of image size, i.e., (height, width). | |
| Args: | |
| obj: The input object to parse image size from. | |
| Returns: | |
| A two-element tuple, indicating image height and width respectively. | |
| Raises: | |
| If the input is invalid, i.e., neither a list or tuple, nor a string. | |
| """ | |
| if obj is None or obj == '': | |
| height = 0 | |
| width = 0 | |
| elif isinstance(obj, int): | |
| height = obj | |
| width = obj | |
| elif isinstance(obj, (list, tuple, str, np.ndarray)): | |
| if isinstance(obj, str): | |
| splits = obj.replace(' ', '').split(',') | |
| numbers = tuple(map(int, splits)) | |
| else: | |
| numbers = tuple(obj) | |
| if len(numbers) == 0: | |
| height = 0 | |
| width = 0 | |
| elif len(numbers) == 1: | |
| height = int(numbers[0]) | |
| width = int(numbers[0]) | |
| elif len(numbers) == 2: | |
| height = int(numbers[0]) | |
| width = int(numbers[1]) | |
| else: | |
| raise ValueError('At most two elements for image size.') | |
| else: | |
| raise ValueError(f'Invalid type of input: `{type(obj)}`!') | |
| return (max(0, height), max(0, width)) | |
| def get_grid_shape(size, height=0, width=0, is_portrait=False): | |
| """Gets the shape of a grid based on the size. | |
| This function makes greatest effort on making the output grid square if | |
| neither `height` nor `width` is set. If `is_portrait` is set as `False`, the | |
| height will always be equal to or smaller than the width. For example, if | |
| input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, | |
| output shape will be (3, 5). Otherwise, the height will always be equal to | |
| or larger than the width. | |
| Args: | |
| size: Size (height * width) of the target grid. | |
| height: Expected height. If `size % height != 0`, this field will be | |
| ignored. (default: 0) | |
| width: Expected width. If `size % width != 0`, this field will be | |
| ignored. (default: 0) | |
| is_portrait: Whether to return a portrait size of a landscape size. | |
| (default: False) | |
| Returns: | |
| A two-element tuple, representing height and width respectively. | |
| """ | |
| assert isinstance(size, int) | |
| assert isinstance(height, int) | |
| assert isinstance(width, int) | |
| if size <= 0: | |
| return (0, 0) | |
| if height > 0 and width > 0 and height * width != size: | |
| height = 0 | |
| width = 0 | |
| if height > 0 and width > 0 and height * width == size: | |
| return (height, width) | |
| if height > 0 and size % height == 0: | |
| return (height, size // height) | |
| if width > 0 and size % width == 0: | |
| return (size // width, width) | |
| height = int(np.sqrt(size)) | |
| while height > 0: | |
| if size % height == 0: | |
| width = size // height | |
| break | |
| height = height - 1 | |
| return (width, height) if is_portrait else (height, width) | |
| def list_images_from_dir(directory): | |
| """Lists all images from the given directory. | |
| NOTE: Do NOT support finding images recursively. | |
| Args: | |
| directory: The directory to find images from. | |
| Returns: | |
| A list of sorted filenames, with the directory as prefix. | |
| """ | |
| image_list = [] | |
| for filename in os.listdir(directory): | |
| if check_file_ext(filename, *IMAGE_EXTENSIONS): | |
| image_list.append(os.path.join(directory, filename)) | |
| return sorted(image_list) | |