|
|
from collections.abc import Mapping |
|
|
from functools import lru_cache |
|
|
from typing import Unpack, cast |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL.Image import Image |
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
from transformers.image_processing_utils import BaseImageProcessor |
|
|
from transformers.image_transforms import to_pil_image |
|
|
from transformers.image_utils import ImageInput, make_flat_list_of_images |
|
|
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs |
|
|
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TextInput |
|
|
|
|
|
|
|
|
class HeronImagesKwargs(ImagesKwargs): |
|
|
min_tiles: int | None |
|
|
max_tiles: int | None |
|
|
|
|
|
|
|
|
class HeronProcessorKwargs(ProcessingKwargs, total=False): |
|
|
images_kwargs: HeronImagesKwargs |
|
|
_defaults = { |
|
|
"text_kwargs": { |
|
|
"return_mm_token_type_ids": False, |
|
|
}, |
|
|
"images_kwargs": { |
|
|
"min_tiles": 1, |
|
|
"max_tiles": 12, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
class HeronProcessor(ProcessorMixin): |
|
|
attributes = ["image_processor", "tokenizer"] |
|
|
image_processor_class = "AutoImageProcessor" |
|
|
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") |
|
|
image_processor: BaseImageProcessor |
|
|
tokenizer: PreTrainedTokenizerBase |
|
|
|
|
|
def __init__(self, image_processor, tokenizer, chat_template=None, num_image_features: int = 256, **kwargs): |
|
|
image_token = kwargs.pop("image_token", None) |
|
|
if image_token is None: |
|
|
image_token = "<image>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token |
|
|
assert isinstance(image_token, str) |
|
|
|
|
|
image_token_id = tokenizer.convert_tokens_to_ids(image_token) |
|
|
if image_token_id is None: |
|
|
raise ValueError(f"tokenizer does not contain {image_token!r} token") |
|
|
|
|
|
self.num_image_features = num_image_features |
|
|
self.image_token = tokenizer.image_token = image_token |
|
|
self.image_token_id = tokenizer.image_token_id = image_token_id |
|
|
|
|
|
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
text: TextInput | list[TextInput], |
|
|
images: ImageInput | None = None, |
|
|
**kwargs: Unpack[HeronProcessorKwargs], |
|
|
) -> BatchFeature: |
|
|
output_kwargs = self._merge_kwargs( |
|
|
HeronProcessorKwargs, |
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if not isinstance(text, list): |
|
|
text = [text] |
|
|
assert isinstance(text, list) |
|
|
if images is not None: |
|
|
images = cast(list, make_flat_list_of_images(images)) |
|
|
images = [to_pil_image(image) for image in images] |
|
|
|
|
|
image_inputs: Mapping = {} |
|
|
num_image_tiles = None |
|
|
if images is not None: |
|
|
if sum(s.count(self.image_token) for s in text) != len(images): |
|
|
raise ValueError("the number of images does not match the number of image tokens in the text") |
|
|
image_inputs = self.process_images(images, **output_kwargs["images_kwargs"]) |
|
|
num_image_tiles = image_inputs["pixel_values"].shape[1] |
|
|
|
|
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) |
|
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) |
|
|
text_inputs = self.process_text(text, num_image_tiles, **output_kwargs["text_kwargs"]) |
|
|
|
|
|
if return_mm_token_type_ids: |
|
|
array_ids = np.array(text_inputs["input_ids"]) |
|
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) |
|
|
mm_token_type_ids[array_ids == self.image_token_id] = 1 |
|
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() |
|
|
|
|
|
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) |
|
|
|
|
|
def process_text( |
|
|
self, text: list[str], num_image_tiles: int | None = None, **kwargs: Unpack[TextKwargs] |
|
|
) -> BatchEncoding: |
|
|
if all(self.image_token not in prompt for prompt in text): |
|
|
return self.tokenizer(text, **kwargs) |
|
|
|
|
|
if num_image_tiles is None: |
|
|
raise ValueError("num_image_tiles must be specified when processing image tokens") |
|
|
|
|
|
image_feature_placeholder = self.image_token * self.num_image_features |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert self.tokenizer.eos_token is not None |
|
|
image_feature_placeholder += self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
processed_text = [] |
|
|
for prompt in text: |
|
|
new_prompt = prompt |
|
|
assert "<placeholder>" not in new_prompt |
|
|
replace_strings = [] |
|
|
while self.image_token in new_prompt: |
|
|
replace_strings.append("\n".join([image_feature_placeholder] * num_image_tiles)) |
|
|
new_prompt = new_prompt.replace(self.image_token, "<placeholder>", 1) |
|
|
for s in replace_strings: |
|
|
new_prompt = new_prompt.replace("<placeholder>", s, 1) |
|
|
processed_text.append(new_prompt) |
|
|
|
|
|
encoding = self.tokenizer(processed_text, **kwargs) |
|
|
|
|
|
|
|
|
token_ids = self.tokenizer.encode("\n") |
|
|
assert len(token_ids) == 1 |
|
|
newline_token_id = token_ids[0] |
|
|
for input_ids in encoding.input_ids: |
|
|
i, n = 0, len(input_ids) |
|
|
while i < n: |
|
|
if input_ids[i] != self.image_token_id: |
|
|
i += 1 |
|
|
continue |
|
|
i += self.num_image_features |
|
|
input_ids[i] = newline_token_id |
|
|
i += 1 |
|
|
|
|
|
self._check_special_mm_tokens(processed_text, encoding, modalities=["image"]) |
|
|
return encoding |
|
|
|
|
|
def process_images( |
|
|
self, |
|
|
images: list[Image], |
|
|
min_tiles: int = 1, |
|
|
max_tiles: int = 12, |
|
|
**kwargs: Unpack[ImagesKwargs], |
|
|
) -> BatchFeature: |
|
|
assert isinstance(min_tiles, int) and isinstance(max_tiles, int) |
|
|
|
|
|
crop_size = self.image_processor.size |
|
|
assert crop_size["height"] == crop_size["width"] |
|
|
|
|
|
return_tensors = kwargs.pop("return_tensors", None) |
|
|
pixel_values = [] |
|
|
for image in images: |
|
|
image_tiles = _dynamic_preprocess( |
|
|
image, min_num=min_tiles, max_num=max_tiles, image_size=crop_size["height"] |
|
|
) |
|
|
pixel_values.append(self.image_processor(image_tiles, **kwargs, return_tensors="pt")["pixel_values"]) |
|
|
|
|
|
return BatchFeature({"pixel_values": torch.stack(pixel_values)}, tensor_type=return_tensors) |
|
|
|
|
|
|
|
|
|
|
|
def _dynamic_preprocess( |
|
|
image: Image, min_num: int, max_num: int, image_size: int, use_thumbnail: bool = True |
|
|
) -> list[Image]: |
|
|
if image.mode != "RGB": |
|
|
image = image.convert("RGB") |
|
|
|
|
|
orig_width, orig_height = image.size |
|
|
(target_width, target_height), crop_boxes = _calculate_crops(orig_width, orig_height, image_size, min_num, max_num) |
|
|
|
|
|
resized_img = image.resize((target_width, target_height)) |
|
|
processed_images = [resized_img.crop(box) for box in crop_boxes] |
|
|
|
|
|
if use_thumbnail and len(processed_images) != 1: |
|
|
thumbnail_img = image.resize((image_size, image_size)) |
|
|
processed_images.append(thumbnail_img) |
|
|
|
|
|
return processed_images |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=32) |
|
|
def _calculate_crops( |
|
|
width: int, height: int, crop_size: int, min_num: int, max_num: int |
|
|
) -> tuple[tuple[int, int], list[tuple[int, int, int, int]]]: |
|
|
aspect_ratio = width / height |
|
|
|
|
|
|
|
|
target_ratio_set = { |
|
|
(i, j) |
|
|
for n in range(min_num, max_num + 1) |
|
|
for i in range(1, n + 1) |
|
|
for j in range(1, n + 1) |
|
|
if min_num <= i * j <= max_num |
|
|
} |
|
|
target_ratios = sorted(target_ratio_set, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
|
|
|
target_aspect_ratio = _find_closest_aspect_ratio( |
|
|
aspect_ratio, |
|
|
target_ratios, |
|
|
width, |
|
|
height, |
|
|
crop_size, |
|
|
) |
|
|
|
|
|
|
|
|
target_width = crop_size * target_aspect_ratio[0] |
|
|
target_height = crop_size * target_aspect_ratio[1] |
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
crop_boxes = [] |
|
|
for i in range(blocks): |
|
|
box = ( |
|
|
(i % (target_width // crop_size)) * crop_size, |
|
|
(i // (target_width // crop_size)) * crop_size, |
|
|
((i % (target_width // crop_size)) + 1) * crop_size, |
|
|
((i // (target_width // crop_size)) + 1) * crop_size, |
|
|
) |
|
|
crop_boxes.append(box) |
|
|
|
|
|
return (target_width, target_height), crop_boxes |
|
|
|
|
|
|
|
|
|
|
|
def _find_closest_aspect_ratio( |
|
|
aspect_ratio: float, target_ratios: list[tuple[int, int]], width: int, height: int, image_size: int |
|
|
) -> tuple[int, int]: |
|
|
best_ratio_diff = float("inf") |
|
|
best_ratio = (1, 1) |
|
|
area = width * height |
|
|
for ratio in target_ratios: |
|
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
|
if ratio_diff < best_ratio_diff: |
|
|
best_ratio_diff = ratio_diff |
|
|
best_ratio = ratio |
|
|
elif ratio_diff == best_ratio_diff: |
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
|
best_ratio = ratio |
|
|
return best_ratio |
|
|
|