| import os |
| import sys |
| import json |
|
|
| |
| import jax |
| import numpy as np |
| import tensorflow as tf |
| from scipy.special import expit as sigmoid |
|
|
| import skimage |
| from skimage import io as skimage_io |
| from skimage import transform as skimage_transform |
| import matplotlib as mpl |
| from matplotlib import pyplot as plt |
|
|
| sys.path.append('/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/big_vision') |
| tf.config.experimental.set_visible_devices([], 'GPU') |
|
|
| from scenic.projects.owl_vit import configs |
| from scenic.projects.owl_vit import models |
|
|
| |
| from owlv2_helper_functions import read_images, preprocess_images |
| from owlv2_helper_functions import plot_bbox_on_image, image_based_plot_boxes_on_image, plot_boxes_on_image |
| from owlv2_helper_functions import top_object_index |
| from owlv2_helper_functions import rescale_detection_box |
|
|
|
|
|
|
|
|
| """ |
| Prepare OWLv2 pretrained model |
| """ |
| config = configs.owl_v2_clip_l14.get_config(init_mode='canonical_checkpoint') |
| module = models.TextZeroShotDetectionModule( |
| body_configs=config.model.body, |
| objectness_head_configs=config.model.objectness_head, |
| normalize=config.model.normalize, |
| box_bias=config.model.box_bias) |
| variables = module.load_variables(config.init_from.checkpoint_path) |
|
|
|
|
|
|
|
|
| """ |
| Wrapped model components |
| """ |
| import functools |
|
|
| image_embedder = jax.jit( |
| functools.partial(module.apply, variables, train=False, method=module.image_embedder)) |
| objectness_predictor = jax.jit( |
| functools.partial(module.apply, variables, method=module.objectness_predictor)) |
| box_predictor = jax.jit( |
| functools.partial(module.apply, variables, method=module.box_predictor)) |
| class_predictor = jax.jit( |
| functools.partial(module.apply, variables, method=module.class_predictor)) |
|
|
|
|
|
|
|
|
| """ |
| Detect the main object on instances' images |
| """ |
| INSTANCE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_0' |
| INSTANCE_DETECTION = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_detections_0' |
|
|
| model_input_size = config.dataset_configs.input_size |
| images, source_images_names = read_images(INSTANCE_DIR) |
| source_images = preprocess_images(images, model_input_size) |
|
|
| feature_map = image_embedder(source_images) |
| b, h, w, d = feature_map.shape |
| image_features = feature_map.reshape(b, h * w, d) |
|
|
| objectnesses = objectness_predictor(image_features)['objectness_logits'] |
| bboxes = box_predictor(image_features=image_features, feature_map=feature_map)['pred_boxes'] |
| source_class_embeddings = class_predictor(image_features=image_features)['class_embeddings'] |
|
|
| |
| |
| |
| |
| |
|
|
| objectnesses = sigmoid(objectnesses) |
| top_objectnesses = np.max(objectnesses, axis=1) |
|
|
| instances, query_embeddings, indexes = [], [], [] |
| for i in range(len(source_images_names)): |
| index = top_object_index(objectnesses[i], top_objectnesses[i]) |
| query_embedding = source_class_embeddings[index] |
|
|
| indexes.append(index) |
| instances.append(source_images_names[i].split('_')[0]) |
| query_embeddings.append(query_embedding) |
|
|
| output_file = os.path.join(INSTANCE_DETECTION, source_images_names[i]) |
| plot_bbox_on_image(source_images[i], bboxes[i], objectnesses[i], top_objectnesses[i], output_file) |
|
|
|
|
|
|
|
|
|
|
| IMAGE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/data_sample' |
| OUTPUT_DIR = '/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/bliu75_output/test_output/batch_results' |
|
|
| images, target_images_names = read_images(IMAGE_DIR) |
| target_images = preprocess_images(images, model_input_size) |
|
|
| for target_image, target_image_name, image in zip(target_images, target_images_names, images): |
|
|
| feature_map = image_embedder(target_image[None, ...]) |
| b, h, w, d = feature_map.shape |
| target_boxes = box_predictor(image_features=feature_map.reshape(b, h * w, d), feature_map=feature_map)['pred_boxes'] |
|
|
| target_class_predictions = class_predictor( |
| image_features=feature_map.reshape(b, h * w, d), |
| query_embeddings=query_embedding[None, ...], |
| ) |
|
|
| logits = np.array(target_class_predictions['pred_logits'][0]) |
| raw_boxes = np.array(target_boxes[0]) |
|
|
| top_ind = np.argmax(logits[:, 0], axis=0) |
| score = sigmoid(logits[top_ind, 0]) |
|
|
| |
| |
|
|
| boxes = rescale_detection_box(raw_boxes, image) |
| boxes = boxes[top_ind] |
|
|
| score = np.array([score]) |
| boxes = np.array([boxes]) |
|
|
| image_based_plot_boxes_on_image(image, instances, score, boxes, target_image_name, OUTPUT_DIR) |
| |
| print(f"Debug: traget instance detection") |
| |
| print(f" target_logits: {logits.shape}") |
| print(logits) |
| |
| |
| |
|
|
| |
|
|