Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
| from PIL import Image | |
| import time | |
| def extract_model_short_name(model_id): | |
| return model_id.split("/")[-1].replace("-", " ").replace("_", " ") | |
| model_llmdet_id = "iSEE-Laboratory/llmdet_tiny" | |
| model_mm_grounding_id = "rziga/mm_grounding_dino_tiny_o365v1_goldg" | |
| model_omdet_id = "omlab/omdet-turbo-swin-tiny-hf" | |
| model_owlv2_id = "google/owlv2-large-patch14-ensemble" | |
| model_llmdet_name = extract_model_short_name(model_llmdet_id) | |
| model_mm_grounding_name = extract_model_short_name(model_mm_grounding_id) | |
| model_omdet_name = extract_model_short_name(model_omdet_id) | |
| model_owlv2_name = extract_model_short_name(model_owlv2_id) | |
| def detect(model_id: str, image: Image.Image, prompts: list, threshold: float): | |
| t0 = time.perf_counter() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = ( | |
| AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device).eval() | |
| ) | |
| texts = [prompts] | |
| inputs = processor(images=image, text=texts, return_tensors="pt").to(device) | |
| with torch.inference_mode(): | |
| outputs = model(**inputs) | |
| results = processor.post_process_grounded_object_detection( | |
| outputs, threshold=threshold, target_sizes=[image.size[::-1]] | |
| ) | |
| result = results[0] | |
| annotations = [] | |
| for box, score, label_name in zip(result["boxes"], result["scores"], result["text_abels"]): | |
| if score >= threshold: | |
| xmin, ymin, xmax, ymax = [int(x) for x in box.tolist()] | |
| annotations.append(((xmin, ymin, xmax, ymax), f"{label_name} {score:.2f}")) | |
| elapsed_ms = (time.perf_counter() - t0) * 1000 | |
| time_taken = f"**Inference time ({model_omdet_name}):** {elapsed_ms:.0f} ms" | |
| return annotations, time_taken | |
| def run_detection( | |
| image: Image.Image, prompts_str: str, threshold_llm, threshold_mm, threshold_owlv2, threshold_omdet, | |
| ): | |
| prompts = [p.strip() for p in prompts_str.split(",")] | |
| ann_llm, time_llm = detect(model_llmdet_id, image, prompts, threshold_llm) | |
| ann_mm, time_mm = detect(model_mm_grounding_name, image, prompts, threshold_mm) | |
| ann_owlv2, time_owlv2 = detect(model_omdet_id, image, prompts, threshold_owlv2) | |
| ann_omdet, time_omdet = detect(model_owlv2_name, image, prompts, threshold_omdet) | |
| return ( | |
| (image, ann_llm), | |
| time_llm, | |
| (image, ann_mm), | |
| time_mm, | |
| (image, ann_owlv2), | |
| time_owlv2, | |
| (image, ann_omdet), | |
| time_omdet, | |
| ) | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Zero-Shot Object Detection Arena") | |
| gr.Markdown( | |
| "### Compare different zero-shot object detection models on the same image and prompts." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image = gr.Image(type="pil", label="Upload an image", height=400) | |
| prompts = gr.Textbox( | |
| label="Prompts (comma-separated)", value="a cat, a remote control" | |
| ) | |
| with gr.Accordion("Per-model confidence thresholds", open=True): | |
| threshold_llm = gr.Slider( | |
| label="Threshold for LLMDet", minimum=0.0, maximum=1.0, value=0.3 | |
| ) | |
| threshold_mm = gr.Slider( | |
| label="Threshold for MM GroundingDINO Tiny", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.3, | |
| ) | |
| threshold_owlv2 = gr.Slider( | |
| label="Threshold for OwlV2 Large", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.1, | |
| ) | |
| threshold_omdet = gr.Slider( | |
| label="Threshold for OMDet Turbo Swin Tiny", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.2, | |
| ) | |
| generate_btn = gr.Button(value="Detect") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| output_image_llm = gr.AnnotatedImage( | |
| label=f"Annotated image for {model_llmdet_name}", height=400 | |
| ) | |
| output_time_llm = gr.Markdown() | |
| with gr.Column(scale=2): | |
| output_image_mm = gr.AnnotatedImage( | |
| label=f"Annotated image for {model_mm_grounding_name}", height=400 | |
| ) | |
| output_time_mm = gr.Markdown() | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| output_image_owlv2 = gr.AnnotatedImage( | |
| label=f"Annotated image for {model_owlv2_name}", height=400 | |
| ) | |
| output_time_owlv2 = gr.Markdown() | |
| with gr.Column(scale=2): | |
| output_image_omdet = gr.AnnotatedImage( | |
| label=f"Annotated image for {model_omdet_name}", height=400 | |
| ) | |
| output_time_omdet = gr.Markdown() | |
| gr.Markdown("### Examples") | |
| example_data = [ | |
| [ | |
| "http://images.cocodataset.org/val2017/000000039769.jpg", | |
| "a cat, a remote control", | |
| 0.30, | |
| 0.30, | |
| 0.10, | |
| 0.30, | |
| ], | |
| [ | |
| "http://images.cocodataset.org/val2017/000000000139.jpg", | |
| "a person, a tv, a remote", | |
| 0.35, | |
| 0.30, | |
| 0.12, | |
| 0.30, | |
| ], | |
| ] | |
| gr.Examples( | |
| examples=example_data, | |
| inputs=[ | |
| image, | |
| prompts, | |
| threshold_llm, | |
| threshold_mm, | |
| threshold_owlv2, | |
| threshold_omdet, | |
| ], | |
| label="Click an example to populate the inputs", | |
| ) | |
| inputs = [ | |
| image, | |
| prompts, | |
| threshold_llm, | |
| threshold_mm, | |
| threshold_owlv2, | |
| threshold_omdet, | |
| ] | |
| outputs = [ | |
| output_image_llm, | |
| output_time_llm, | |
| output_image_mm, | |
| output_time_mm, | |
| output_image_owlv2, | |
| output_time_owlv2, | |
| output_image_omdet, | |
| output_time_omdet, | |
| ] | |
| generate_btn.click( | |
| fn=run_detection, | |
| inputs=inputs, | |
| outputs=outputs, | |
| ) | |
| image.upload( | |
| fn=run_detection, | |
| inputs=inputs, | |
| outputs=outputs, | |
| ) | |
| app.launch() | |