|
|
import gradio as gr |
|
|
import base64 |
|
|
import json |
|
|
import os |
|
|
from PIL import Image |
|
|
import io |
|
|
from handler import EndpointHandler |
|
|
|
|
|
|
|
|
print("Initializing MobileCLIP handler...") |
|
|
try: |
|
|
handler = EndpointHandler() |
|
|
print(f"Handler initialized successfully! Device: {handler.device}") |
|
|
except Exception as e: |
|
|
print(f"Error initializing handler: {e}") |
|
|
handler = None |
|
|
|
|
|
def classify_image(image, top_k=10): |
|
|
""" |
|
|
Main classification function for public interface. |
|
|
""" |
|
|
if handler is None: |
|
|
return "Error: Handler not initialized", None |
|
|
|
|
|
if image is None: |
|
|
return "Please upload an image", None |
|
|
|
|
|
try: |
|
|
|
|
|
buffered = io.BytesIO() |
|
|
image.save(buffered, format="PNG") |
|
|
img_b64 = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
|
|
|
|
|
result = handler({ |
|
|
"inputs": { |
|
|
"image": img_b64, |
|
|
"top_k": int(top_k) |
|
|
} |
|
|
}) |
|
|
|
|
|
|
|
|
if isinstance(result, list): |
|
|
|
|
|
output_text = "**Top {} Classifications:**\n\n".format(len(result)) |
|
|
|
|
|
|
|
|
chart_data = [] |
|
|
|
|
|
for i, item in enumerate(result, 1): |
|
|
score_pct = item['score'] * 100 |
|
|
output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n" |
|
|
chart_data.append((item['label'], item['score'])) |
|
|
|
|
|
return output_text, chart_data |
|
|
else: |
|
|
return f"Error: {result.get('error', 'Unknown error')}", None |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}", None |
|
|
|
|
|
def upsert_labels_admin(admin_token, new_items_json): |
|
|
""" |
|
|
Admin function to add new labels. |
|
|
""" |
|
|
if handler is None: |
|
|
return "Error: Handler not initialized" |
|
|
|
|
|
if not admin_token: |
|
|
return "Error: Admin token required" |
|
|
|
|
|
try: |
|
|
|
|
|
items = json.loads(new_items_json) if new_items_json else [] |
|
|
|
|
|
result = handler({ |
|
|
"inputs": { |
|
|
"op": "upsert_labels", |
|
|
"token": admin_token, |
|
|
"items": items |
|
|
} |
|
|
}) |
|
|
|
|
|
if result.get("status") == "ok": |
|
|
return f"β
Success! Added {result.get('added', 0)} new labels. Current version: {result.get('labels_version', 'unknown')}" |
|
|
elif result.get("error") == "unauthorized": |
|
|
return "β Error: Invalid admin token" |
|
|
else: |
|
|
return f"β Error: {result.get('detail', result.get('error', 'Unknown error'))}" |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
return "β Error: Invalid JSON format" |
|
|
except Exception as e: |
|
|
return f"β Error: {str(e)}" |
|
|
|
|
|
def reload_labels_admin(admin_token, version): |
|
|
""" |
|
|
Admin function to reload a specific label version. |
|
|
""" |
|
|
if handler is None: |
|
|
return "Error: Handler not initialized" |
|
|
|
|
|
if not admin_token: |
|
|
return "Error: Admin token required" |
|
|
|
|
|
try: |
|
|
result = handler({ |
|
|
"inputs": { |
|
|
"op": "reload_labels", |
|
|
"token": admin_token, |
|
|
"version": int(version) if version else 1 |
|
|
} |
|
|
}) |
|
|
|
|
|
if result.get("status") == "ok": |
|
|
return f"β
Labels reloaded successfully! Current version: {result.get('labels_version', 'unknown')}" |
|
|
elif result.get("status") == "nochange": |
|
|
return f"βΉοΈ No change needed. Current version: {result.get('labels_version', 'unknown')}" |
|
|
elif result.get("error") == "unauthorized": |
|
|
return "β Error: Invalid admin token" |
|
|
elif result.get("error") == "invalid_version": |
|
|
return "β Error: Invalid version number" |
|
|
else: |
|
|
return f"β Error: {result.get('error', 'Unknown error')}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"β Error: {str(e)}" |
|
|
|
|
|
def get_current_stats(): |
|
|
""" |
|
|
Get current label statistics. |
|
|
""" |
|
|
if handler is None: |
|
|
return "Handler not initialized" |
|
|
|
|
|
try: |
|
|
num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0 |
|
|
version = getattr(handler, 'labels_version', 1) |
|
|
device = handler.device if hasattr(handler, 'device') else "unknown" |
|
|
|
|
|
stats = f""" |
|
|
**Current Statistics:** |
|
|
- Number of labels: {num_labels} |
|
|
- Labels version: {version} |
|
|
- Device: {device} |
|
|
- Model: MobileCLIP-B |
|
|
""" |
|
|
|
|
|
if hasattr(handler, 'class_names') and len(handler.class_names) > 0: |
|
|
stats += f"\n- Sample labels: {', '.join(handler.class_names[:5])}" |
|
|
if len(handler.class_names) > 5: |
|
|
stats += "..." |
|
|
|
|
|
return stats |
|
|
except Exception as e: |
|
|
return f"Error getting stats: {str(e)}" |
|
|
|
|
|
def get_labels_table(): |
|
|
""" |
|
|
Get all current labels as a formatted table for display. |
|
|
""" |
|
|
if handler is None: |
|
|
return "Handler not initialized" |
|
|
|
|
|
if not hasattr(handler, 'class_ids') or len(handler.class_ids) == 0: |
|
|
return "No labels currently loaded" |
|
|
|
|
|
try: |
|
|
|
|
|
table_data = [] |
|
|
for id, name in zip(handler.class_ids, handler.class_names): |
|
|
table_data.append([int(id), name]) |
|
|
|
|
|
return table_data |
|
|
except Exception as e: |
|
|
return f"Error getting labels: {str(e)}" |
|
|
|
|
|
def remove_labels_admin(admin_token, ids_to_remove_str): |
|
|
""" |
|
|
Admin function to remove labels by ID. |
|
|
""" |
|
|
if handler is None: |
|
|
return "Error: Handler not initialized" |
|
|
|
|
|
if not admin_token: |
|
|
return "Error: Admin token required" |
|
|
|
|
|
try: |
|
|
|
|
|
if not ids_to_remove_str or ids_to_remove_str.strip() == "": |
|
|
return "β Error: Please provide IDs to remove (comma-separated)" |
|
|
|
|
|
ids_to_remove = [] |
|
|
for id_str in ids_to_remove_str.split(','): |
|
|
id_str = id_str.strip() |
|
|
if id_str: |
|
|
ids_to_remove.append(int(id_str)) |
|
|
|
|
|
if not ids_to_remove: |
|
|
return "β Error: No valid IDs provided" |
|
|
|
|
|
|
|
|
removed_names = [] |
|
|
if hasattr(handler, 'class_ids'): |
|
|
for id in ids_to_remove: |
|
|
if id in handler.class_ids: |
|
|
idx = handler.class_ids.index(id) |
|
|
removed_names.append(f"{id}: {handler.class_names[idx]}") |
|
|
|
|
|
result = handler({ |
|
|
"inputs": { |
|
|
"op": "remove_labels", |
|
|
"token": admin_token, |
|
|
"ids": ids_to_remove |
|
|
} |
|
|
}) |
|
|
|
|
|
if result.get("status") == "ok": |
|
|
removed_list = "\n".join(removed_names) if removed_names else "None found" |
|
|
return f"β
Success! Removed {result.get('removed', 0)} labels. Current version: {result.get('labels_version', 'unknown')}\n\nRemoved items:\n{removed_list}" |
|
|
elif result.get("error") == "unauthorized": |
|
|
return "β Error: Invalid admin token" |
|
|
elif result.get("error") == "no_ids_provided": |
|
|
return "β Error: No IDs provided" |
|
|
else: |
|
|
return f"β Error: {result.get('detail', result.get('error', 'Unknown error'))}" |
|
|
|
|
|
except ValueError: |
|
|
return "β Error: Invalid ID format. Please provide comma-separated numbers (e.g., 1001,1002,1003)" |
|
|
except Exception as e: |
|
|
return f"β Error: {str(e)}" |
|
|
|
|
|
|
|
|
print("Creating Gradio interface...") |
|
|
with gr.Blocks(title="MobileCLIP Image Classifier") as demo: |
|
|
gr.Markdown(""" |
|
|
# πΌοΈ MobileCLIP-B Zero-Shot Image Classifier |
|
|
|
|
|
Upload an image to classify it using MobileCLIP-B model with dynamic label management. |
|
|
""") |
|
|
|
|
|
with gr.Tab("π Image Classification"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image( |
|
|
type="pil", |
|
|
label="Upload Image" |
|
|
) |
|
|
top_k_slider = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=50, |
|
|
value=10, |
|
|
step=1, |
|
|
label="Number of top results to show" |
|
|
) |
|
|
classify_btn = gr.Button("π Classify Image", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Markdown(label="Classification Results") |
|
|
|
|
|
output_chart = gr.Dataframe( |
|
|
headers=["Label", "Confidence"], |
|
|
label="Classification Scores", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
classify_btn.click( |
|
|
fn=classify_image, |
|
|
inputs=[input_image, top_k_slider], |
|
|
outputs=[output_text, output_chart] |
|
|
) |
|
|
|
|
|
|
|
|
input_image.change( |
|
|
fn=classify_image, |
|
|
inputs=[input_image, top_k_slider], |
|
|
outputs=[output_text, output_chart] |
|
|
) |
|
|
|
|
|
with gr.Tab("π§ Admin Panel"): |
|
|
gr.Markdown(""" |
|
|
### Admin Functions |
|
|
**Note:** Requires admin token (set via environment variable `ADMIN_TOKEN`) |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
admin_token_input = gr.Textbox( |
|
|
label="Admin Token", |
|
|
type="password", |
|
|
placeholder="Enter admin token" |
|
|
) |
|
|
|
|
|
with gr.Accordion("π Current Statistics", open=True): |
|
|
stats_display = gr.Markdown(value=get_current_stats()) |
|
|
refresh_stats_btn = gr.Button("π Refresh Stats") |
|
|
refresh_stats_btn.click( |
|
|
fn=get_current_stats, |
|
|
inputs=[], |
|
|
outputs=stats_display |
|
|
) |
|
|
|
|
|
with gr.Accordion("β Add New Labels", open=False): |
|
|
gr.Markdown(""" |
|
|
Add new labels by providing JSON array: |
|
|
```json |
|
|
[ |
|
|
{"id": 100, "name": "new_object", "prompt": "a photo of a new_object"}, |
|
|
{"id": 101, "name": "another_object", "prompt": "a photo of another_object"} |
|
|
] |
|
|
``` |
|
|
""") |
|
|
new_items_input = gr.Code( |
|
|
label="New Items JSON", |
|
|
language="json", |
|
|
lines=5, |
|
|
value='[\n {"id": 100, "name": "example", "prompt": "a photo of example"}\n]' |
|
|
) |
|
|
upsert_btn = gr.Button("β Add Labels", variant="primary") |
|
|
upsert_output = gr.Markdown() |
|
|
|
|
|
upsert_btn.click( |
|
|
fn=upsert_labels_admin, |
|
|
inputs=[admin_token_input, new_items_input], |
|
|
outputs=upsert_output |
|
|
) |
|
|
|
|
|
with gr.Accordion("π Reload Label Version", open=False): |
|
|
gr.Markdown("Reload labels from a specific version stored in the Hub") |
|
|
version_input = gr.Number( |
|
|
label="Version Number", |
|
|
value=1, |
|
|
precision=0 |
|
|
) |
|
|
reload_btn = gr.Button("π Reload Version", variant="primary") |
|
|
reload_output = gr.Markdown() |
|
|
|
|
|
reload_btn.click( |
|
|
fn=reload_labels_admin, |
|
|
inputs=[admin_token_input, version_input], |
|
|
outputs=reload_output |
|
|
) |
|
|
|
|
|
with gr.Accordion("ποΈ Remove Labels", open=False): |
|
|
gr.Markdown("Remove specific labels by their IDs") |
|
|
|
|
|
|
|
|
labels_table = gr.Dataframe( |
|
|
value=get_labels_table(), |
|
|
headers=["ID", "Name"], |
|
|
label="Current Labels", |
|
|
interactive=False, |
|
|
height=300 |
|
|
) |
|
|
|
|
|
refresh_labels_btn = gr.Button("π Refresh Label List", size="sm") |
|
|
refresh_labels_btn.click( |
|
|
fn=get_labels_table, |
|
|
inputs=[], |
|
|
outputs=labels_table |
|
|
) |
|
|
|
|
|
gr.Markdown("Enter IDs to remove (comma-separated):") |
|
|
ids_to_remove_input = gr.Textbox( |
|
|
label="IDs to Remove", |
|
|
placeholder="e.g., 1001, 1002, 1003", |
|
|
lines=1 |
|
|
) |
|
|
|
|
|
remove_btn = gr.Button("ποΈ Remove Selected Labels", variant="stop") |
|
|
remove_output = gr.Markdown() |
|
|
|
|
|
def remove_and_refresh(token, ids): |
|
|
result = remove_labels_admin(token, ids) |
|
|
updated_table = get_labels_table() |
|
|
return result, updated_table |
|
|
|
|
|
remove_btn.click( |
|
|
fn=remove_and_refresh, |
|
|
inputs=[admin_token_input, ids_to_remove_input], |
|
|
outputs=[remove_output, labels_table] |
|
|
) |
|
|
|
|
|
with gr.Tab("βΉοΈ About"): |
|
|
gr.Markdown(""" |
|
|
## About MobileCLIP-B Classifier |
|
|
|
|
|
This Space provides a web interface for Apple's MobileCLIP-B model, optimized for fast zero-shot image classification. |
|
|
|
|
|
### Features: |
|
|
- π **Fast inference**: < 30ms on GPU |
|
|
- π·οΈ **Dynamic labels**: Add/update labels without redeployment |
|
|
- π **Version control**: Track and reload label versions |
|
|
- π **Visual results**: Classification scores and confidence |
|
|
|
|
|
### Environment Variables (set in Space Settings): |
|
|
- `ADMIN_TOKEN`: Secret token for admin operations |
|
|
- `HF_LABEL_REPO`: Hub repository for label storage |
|
|
- `HF_WRITE_TOKEN`: Token with write permissions to label repo |
|
|
- `HF_READ_TOKEN`: Token with read permissions (optional) |
|
|
|
|
|
### Model Details: |
|
|
- **Architecture**: MobileCLIP-B with MobileOne blocks |
|
|
- **Text Encoder**: Transformer-based, 77 token context |
|
|
- **Image Size**: 224x224 |
|
|
- **Embedding Dim**: 512 |
|
|
|
|
|
### License: |
|
|
Model weights are licensed under Apple Sample Code License (ASCL). |
|
|
""") |
|
|
|
|
|
print("Gradio interface created successfully!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Launching Gradio app...") |
|
|
demo.launch() |