from collections.abc import Mapping from functools import lru_cache from typing import Unpack, cast import numpy as np import torch from PIL.Image import Image from transformers.feature_extraction_utils import BatchFeature from transformers.image_processing_utils import BaseImageProcessor from transformers.image_transforms import to_pil_image from transformers.image_utils import ImageInput, make_flat_list_of_images from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TextInput class HeronImagesKwargs(ImagesKwargs): min_tiles: int | None max_tiles: int | None class HeronProcessorKwargs(ProcessingKwargs, total=False): images_kwargs: HeronImagesKwargs # type: ignore[misc] _defaults = { # type: ignore "text_kwargs": { "return_mm_token_type_ids": False, }, "images_kwargs": { "min_tiles": 1, "max_tiles": 12, }, } class HeronProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") # type: ignore[assignment] image_processor: BaseImageProcessor tokenizer: PreTrainedTokenizerBase def __init__(self, image_processor, tokenizer, chat_template=None, num_image_features: int = 256, **kwargs): image_token = kwargs.pop("image_token", None) if image_token is None: image_token = "" if not hasattr(tokenizer, "image_token") else tokenizer.image_token assert isinstance(image_token, str) image_token_id = tokenizer.convert_tokens_to_ids(image_token) if image_token_id is None: raise ValueError(f"tokenizer does not contain {image_token!r} token") self.num_image_features = num_image_features self.image_token = tokenizer.image_token = image_token self.image_token_id = tokenizer.image_token_id = image_token_id super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs) def __call__( # type: ignore[override] self, text: TextInput | list[TextInput], images: ImageInput | None = None, **kwargs: Unpack[HeronProcessorKwargs], ) -> BatchFeature: output_kwargs = self._merge_kwargs( HeronProcessorKwargs, # type: ignore[arg-type] tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) if not isinstance(text, list): text = [text] assert isinstance(text, list) if images is not None: images = cast(list, make_flat_list_of_images(images)) images = [to_pil_image(image) for image in images] image_inputs: Mapping = {} num_image_tiles = None if images is not None: if sum(s.count(self.image_token) for s in text) != len(images): raise ValueError("the number of images does not match the number of image tokens in the text") image_inputs = self.process_images(images, **output_kwargs["images_kwargs"]) num_image_tiles = image_inputs["pixel_values"].shape[1] return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) text_inputs = self.process_text(text, num_image_tiles, **output_kwargs["text_kwargs"]) if return_mm_token_type_ids: array_ids = np.array(text_inputs["input_ids"]) mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) mm_token_type_ids[array_ids == self.image_token_id] = 1 text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) def process_text( self, text: list[str], num_image_tiles: int | None = None, **kwargs: Unpack[TextKwargs] ) -> BatchEncoding: if all(self.image_token not in prompt for prompt in text): return self.tokenizer(text, **kwargs) if num_image_tiles is None: raise ValueError("num_image_tiles must be specified when processing image tokens") image_feature_placeholder = self.image_token * self.num_image_features # NOTE: Original implementation appends an extra newline after image features. # https://github.com/NVlabs/VILA/blob/36f6adcd11a10be1580caeb7e647e1b6f8517f89/llava/model/encoders/image/basic.py#L38 # Instead of appending a newline character, this implementation reserves 1 extra token to be replaced later. # This treatment is needed because "\n\n" is tokenized as [271] instead of [198, 198]. assert self.tokenizer.eos_token is not None image_feature_placeholder += self.tokenizer.eos_token # Expand image tokens according to the number of image features and tiles processed_text = [] for prompt in text: new_prompt = prompt assert "" not in new_prompt replace_strings = [] while self.image_token in new_prompt: replace_strings.append("\n".join([image_feature_placeholder] * num_image_tiles)) new_prompt = new_prompt.replace(self.image_token, "", 1) for s in replace_strings: new_prompt = new_prompt.replace("", s, 1) processed_text.append(new_prompt) encoding = self.tokenizer(processed_text, **kwargs) # Replace the last token of every image tile with the newline token token_ids = self.tokenizer.encode("\n") assert len(token_ids) == 1 newline_token_id = token_ids[0] for input_ids in encoding.input_ids: i, n = 0, len(input_ids) while i < n: if input_ids[i] != self.image_token_id: i += 1 continue i += self.num_image_features input_ids[i] = newline_token_id i += 1 self._check_special_mm_tokens(processed_text, encoding, modalities=["image"]) # type: ignore[arg-type] return encoding def process_images( self, images: list[Image], min_tiles: int = 1, max_tiles: int = 12, **kwargs: Unpack[ImagesKwargs], ) -> BatchFeature: assert isinstance(min_tiles, int) and isinstance(max_tiles, int) crop_size = self.image_processor.size # type: ignore[attr-defined] assert crop_size["height"] == crop_size["width"] return_tensors = kwargs.pop("return_tensors", None) # type: ignore[typeddict-item] pixel_values = [] for image in images: image_tiles = _dynamic_preprocess( image, min_num=min_tiles, max_num=max_tiles, image_size=crop_size["height"] ) pixel_values.append(self.image_processor(image_tiles, **kwargs, return_tensors="pt")["pixel_values"]) return BatchFeature({"pixel_values": torch.stack(pixel_values)}, tensor_type=return_tensors) # Adapted from https://github.com/NVlabs/VILA/blob/36f6adcd11a10be1580caeb7e647e1b6f8517f89/llava/mm_utils.py#L296 def _dynamic_preprocess( image: Image, min_num: int, max_num: int, image_size: int, use_thumbnail: bool = True ) -> list[Image]: if image.mode != "RGB": image = image.convert("RGB") orig_width, orig_height = image.size (target_width, target_height), crop_boxes = _calculate_crops(orig_width, orig_height, image_size, min_num, max_num) resized_img = image.resize((target_width, target_height)) processed_images = [resized_img.crop(box) for box in crop_boxes] if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images @lru_cache(maxsize=32) def _calculate_crops( width: int, height: int, crop_size: int, min_num: int, max_num: int ) -> tuple[tuple[int, int], list[tuple[int, int, int, int]]]: aspect_ratio = width / height # calculate the existing image aspect ratio target_ratio_set = { (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if min_num <= i * j <= max_num } target_ratios = sorted(target_ratio_set, key=lambda x: x[0] * x[1]) # Find the closest aspect ratio to the target target_aspect_ratio = _find_closest_aspect_ratio( aspect_ratio, target_ratios, width, height, crop_size, ) # calculate the target width and height target_width = crop_size * target_aspect_ratio[0] target_height = crop_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] crop_boxes = [] for i in range(blocks): box = ( (i % (target_width // crop_size)) * crop_size, (i // (target_width // crop_size)) * crop_size, ((i % (target_width // crop_size)) + 1) * crop_size, ((i // (target_width // crop_size)) + 1) * crop_size, ) crop_boxes.append(box) return (target_width, target_height), crop_boxes # Copied from https://github.com/NVlabs/VILA/blob/36f6adcd11a10be1580caeb7e647e1b6f8517f89/llava/mm_utils.py#L280 def _find_closest_aspect_ratio( aspect_ratio: float, target_ratios: list[tuple[int, int]], width: int, height: int, image_size: int ) -> tuple[int, int]: best_ratio_diff = float("inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio