import os import sys import json # pip install ott-jax==0.2.0 import jax import numpy as np import tensorflow as tf from scipy.special import expit as sigmoid import skimage from skimage import io as skimage_io from skimage import transform as skimage_transform import matplotlib as mpl from matplotlib import pyplot as plt sys.path.append('/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/big_vision') tf.config.experimental.set_visible_devices([], 'GPU') from scenic.projects.owl_vit import configs from scenic.projects.owl_vit import models # from owlv2_helper_functions import prepare_images from owlv2_helper_functions import read_images, preprocess_images from owlv2_helper_functions import plot_bbox_on_image, image_based_plot_boxes_on_image, plot_boxes_on_image from owlv2_helper_functions import top_object_index from owlv2_helper_functions import rescale_detection_box """ Prepare OWLv2 pretrained model """ config = configs.owl_v2_clip_l14.get_config(init_mode='canonical_checkpoint') module = models.TextZeroShotDetectionModule( body_configs=config.model.body, objectness_head_configs=config.model.objectness_head, normalize=config.model.normalize, box_bias=config.model.box_bias) variables = module.load_variables(config.init_from.checkpoint_path) """ Wrapped model components """ import functools image_embedder = jax.jit( functools.partial(module.apply, variables, train=False, method=module.image_embedder)) objectness_predictor = jax.jit( functools.partial(module.apply, variables, method=module.objectness_predictor)) box_predictor = jax.jit( functools.partial(module.apply, variables, method=module.box_predictor)) class_predictor = jax.jit( functools.partial(module.apply, variables, method=module.class_predictor)) """ Detect the main object on instances' images """ INSTANCE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_0' INSTANCE_DETECTION = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_detections_0' model_input_size = config.dataset_configs.input_size images, source_images_names = read_images(INSTANCE_DIR) source_images = preprocess_images(images, model_input_size) feature_map = image_embedder(source_images) b, h, w, d = feature_map.shape image_features = feature_map.reshape(b, h * w, d) objectnesses = objectness_predictor(image_features)['objectness_logits'] bboxes = box_predictor(image_features=image_features, feature_map=feature_map)['pred_boxes'] source_class_embeddings = class_predictor(image_features=image_features)['class_embeddings'] # print(f"Debug: source instance detection") # print(f" Source images features shape: {image_features.shape}") # print(f" objectnesses shape: {objectnesses.shape}") # print(f" bboxes shape: {bboxes.shape}") # print(f" source_class_embeddings shape: {source_class_embeddings.shape}") objectnesses = sigmoid(objectnesses) top_objectnesses = np.max(objectnesses, axis=1) instances, query_embeddings, indexes = [], [], [] for i in range(len(source_images_names)): index = top_object_index(objectnesses[i], top_objectnesses[i]) query_embedding = source_class_embeddings[index] indexes.append(index) instances.append(source_images_names[i].split('_')[0]) query_embeddings.append(query_embedding) output_file = os.path.join(INSTANCE_DETECTION, source_images_names[i]) plot_bbox_on_image(source_images[i], bboxes[i], objectnesses[i], top_objectnesses[i], output_file) IMAGE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/data_sample' OUTPUT_DIR = '/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/bliu75_output/test_output/batch_results' images, target_images_names = read_images(IMAGE_DIR) target_images = preprocess_images(images, model_input_size) for target_image, target_image_name, image in zip(target_images, target_images_names, images): feature_map = image_embedder(target_image[None, ...]) b, h, w, d = feature_map.shape target_boxes = box_predictor(image_features=feature_map.reshape(b, h * w, d), feature_map=feature_map)['pred_boxes'] target_class_predictions = class_predictor( image_features=feature_map.reshape(b, h * w, d), query_embeddings=query_embedding[None, ...], # [batch, queries, d] ) logits = np.array(target_class_predictions['pred_logits'][0]) raw_boxes = np.array(target_boxes[0]) top_ind = np.argmax(logits[:, 0], axis=0) score = sigmoid(logits[top_ind, 0]) # labels = np.argmax(target_class_predictions['pred_logits'][0], axis=-1) # scores = sigmoid(np.max(logits, axis=-1)) boxes = rescale_detection_box(raw_boxes, image) boxes = boxes[top_ind] score = np.array([score]) boxes = np.array([boxes]) image_based_plot_boxes_on_image(image, instances, score, boxes, target_image_name, OUTPUT_DIR) print(f"Debug: traget instance detection") # print(f" target_class_predictions' keys: {target_class_predictions.keys()}") print(f" target_logits: {logits.shape}") print(logits) # print(f" target_scores: {scores.shape}") # print(f" target_labels: {labels.shape}") # print(f" target_boxes shape: {raw_boxes.shape}") # plot_boxes_on_image(image, instances, scores, boxes, labels, target_image_name, 0.5, OUTPUT_DIR)