Heron-NVILA-Lite-1B-hf / processing_heron.py
chantera's picture
Upload processing_heron.py with huggingface_hub
0b0c3e1 verified
from collections.abc import Mapping
from functools import lru_cache
from typing import Unpack, cast
import numpy as np
import torch
from PIL.Image import Image
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_transforms import to_pil_image
from transformers.image_utils import ImageInput, make_flat_list_of_images
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TextInput
class HeronImagesKwargs(ImagesKwargs):
min_tiles: int | None
max_tiles: int | None
class HeronProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: HeronImagesKwargs # type: ignore[misc]
_defaults = { # type: ignore
"text_kwargs": {
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"min_tiles": 1,
"max_tiles": 12,
},
}
class HeronProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") # type: ignore[assignment]
image_processor: BaseImageProcessor
tokenizer: PreTrainedTokenizerBase
def __init__(self, image_processor, tokenizer, chat_template=None, num_image_features: int = 256, **kwargs):
image_token = kwargs.pop("image_token", None)
if image_token is None:
image_token = "<image>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
assert isinstance(image_token, str)
image_token_id = tokenizer.convert_tokens_to_ids(image_token)
if image_token_id is None:
raise ValueError(f"tokenizer does not contain {image_token!r} token")
self.num_image_features = num_image_features
self.image_token = tokenizer.image_token = image_token
self.image_token_id = tokenizer.image_token_id = image_token_id
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
def __call__( # type: ignore[override]
self,
text: TextInput | list[TextInput],
images: ImageInput | None = None,
**kwargs: Unpack[HeronProcessorKwargs],
) -> BatchFeature:
output_kwargs = self._merge_kwargs(
HeronProcessorKwargs, # type: ignore[arg-type]
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if not isinstance(text, list):
text = [text]
assert isinstance(text, list)
if images is not None:
images = cast(list, make_flat_list_of_images(images))
images = [to_pil_image(image) for image in images]
image_inputs: Mapping = {}
num_image_tiles = None
if images is not None:
if sum(s.count(self.image_token) for s in text) != len(images):
raise ValueError("the number of images does not match the number of image tokens in the text")
image_inputs = self.process_images(images, **output_kwargs["images_kwargs"])
num_image_tiles = image_inputs["pixel_values"].shape[1]
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.process_text(text, num_image_tiles, **output_kwargs["text_kwargs"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def process_text(
self, text: list[str], num_image_tiles: int | None = None, **kwargs: Unpack[TextKwargs]
) -> BatchEncoding:
if all(self.image_token not in prompt for prompt in text):
return self.tokenizer(text, **kwargs)
if num_image_tiles is None:
raise ValueError("num_image_tiles must be specified when processing image tokens")
image_feature_placeholder = self.image_token * self.num_image_features
# NOTE: Original implementation appends an extra newline after image features.
# https://github.com/NVlabs/VILA/blob/36f6adcd11a10be1580caeb7e647e1b6f8517f89/llava/model/encoders/image/basic.py#L38
# Instead of appending a newline character, this implementation reserves 1 extra token to be replaced later.
# This treatment is needed because "\n\n" is tokenized as [271] instead of [198, 198].
assert self.tokenizer.eos_token is not None
image_feature_placeholder += self.tokenizer.eos_token
# Expand image tokens according to the number of image features and tiles
processed_text = []
for prompt in text:
new_prompt = prompt
assert "<placeholder>" not in new_prompt
replace_strings = []
while self.image_token in new_prompt:
replace_strings.append("\n".join([image_feature_placeholder] * num_image_tiles))
new_prompt = new_prompt.replace(self.image_token, "<placeholder>", 1)
for s in replace_strings:
new_prompt = new_prompt.replace("<placeholder>", s, 1)
processed_text.append(new_prompt)
encoding = self.tokenizer(processed_text, **kwargs)
# Replace the last token of every image tile with the newline token
token_ids = self.tokenizer.encode("\n")
assert len(token_ids) == 1
newline_token_id = token_ids[0]
for input_ids in encoding.input_ids:
i, n = 0, len(input_ids)
while i < n:
if input_ids[i] != self.image_token_id:
i += 1
continue
i += self.num_image_features
input_ids[i] = newline_token_id
i += 1
self._check_special_mm_tokens(processed_text, encoding, modalities=["image"]) # type: ignore[arg-type]
return encoding
def process_images(
self,
images: list[Image],
min_tiles: int = 1,
max_tiles: int = 12,
**kwargs: Unpack[ImagesKwargs],
) -> BatchFeature:
assert isinstance(min_tiles, int) and isinstance(max_tiles, int)
crop_size = self.image_processor.size # type: ignore[attr-defined]
assert crop_size["height"] == crop_size["width"]
return_tensors = kwargs.pop("return_tensors", None) # type: ignore[typeddict-item]
pixel_values = []
for image in images:
image_tiles = _dynamic_preprocess(
image, min_num=min_tiles, max_num=max_tiles, image_size=crop_size["height"]
)
pixel_values.append(self.image_processor(image_tiles, **kwargs, return_tensors="pt")["pixel_values"])
return BatchFeature({"pixel_values": torch.stack(pixel_values)}, tensor_type=return_tensors)
# Adapted from https://github.com/NVlabs/VILA/blob/36f6adcd11a10be1580caeb7e647e1b6f8517f89/llava/mm_utils.py#L296
def _dynamic_preprocess(
image: Image, min_num: int, max_num: int, image_size: int, use_thumbnail: bool = True
) -> list[Image]:
if image.mode != "RGB":
image = image.convert("RGB")
orig_width, orig_height = image.size
(target_width, target_height), crop_boxes = _calculate_crops(orig_width, orig_height, image_size, min_num, max_num)
resized_img = image.resize((target_width, target_height))
processed_images = [resized_img.crop(box) for box in crop_boxes]
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
@lru_cache(maxsize=32)
def _calculate_crops(
width: int, height: int, crop_size: int, min_num: int, max_num: int
) -> tuple[tuple[int, int], list[tuple[int, int, int, int]]]:
aspect_ratio = width / height
# calculate the existing image aspect ratio
target_ratio_set = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if min_num <= i * j <= max_num
}
target_ratios = sorted(target_ratio_set, key=lambda x: x[0] * x[1])
# Find the closest aspect ratio to the target
target_aspect_ratio = _find_closest_aspect_ratio(
aspect_ratio,
target_ratios,
width,
height,
crop_size,
)
# calculate the target width and height
target_width = crop_size * target_aspect_ratio[0]
target_height = crop_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
crop_boxes = []
for i in range(blocks):
box = (
(i % (target_width // crop_size)) * crop_size,
(i // (target_width // crop_size)) * crop_size,
((i % (target_width // crop_size)) + 1) * crop_size,
((i // (target_width // crop_size)) + 1) * crop_size,
)
crop_boxes.append(box)
return (target_width, target_height), crop_boxes
# Copied from https://github.com/NVlabs/VILA/blob/36f6adcd11a10be1580caeb7e647e1b6f8517f89/llava/mm_utils.py#L280
def _find_closest_aspect_ratio(
aspect_ratio: float, target_ratios: list[tuple[int, int]], width: int, height: int, image_size: int
) -> tuple[int, int]:
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio