import gradio as gr import base64 import json import os from PIL import Image import io from handler import EndpointHandler handler = EndpointHandler() def classify_image(image, top_k=10): """ Main classification function for public interface. """ if image is None: return None, "Please upload an image" try: # Convert PIL image to base64 buffered = io.BytesIO() image.save(buffered, format="PNG") img_b64 = base64.b64encode(buffered.getvalue()).decode() # Call handler result = handler({ "inputs": { "image": img_b64, "top_k": int(top_k) } }) # Format results for display if isinstance(result, list): # Create formatted output output_text = "**Top {} Classifications:**\n\n".format(len(result)) # Create a dictionary for the bar chart chart_data = {} for i, item in enumerate(result, 1): score_pct = item['score'] * 100 output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n" chart_data[item['label']] = item['score'] return chart_data, output_text else: return None, f"Error: {result.get('error', 'Unknown error')}" except Exception as e: return None, f"Error: {str(e)}" def upsert_labels_admin(admin_token, new_items_json): """ Admin function to add new labels. """ if not admin_token: return "Error: Admin token required" try: # Parse the JSON input items = json.loads(new_items_json) if new_items_json else [] result = handler({ "inputs": { "op": "upsert_labels", "token": admin_token, "items": items } }) if result.get("status") == "ok": return f"✅ Success! Added {result.get('added', 0)} new labels. Current version: {result.get('labels_version', 'unknown')}" elif result.get("error") == "unauthorized": return "❌ Error: Invalid admin token" else: return f"❌ Error: {result.get('detail', result.get('error', 'Unknown error'))}" except json.JSONDecodeError: return "❌ Error: Invalid JSON format" except Exception as e: return f"❌ Error: {str(e)}" def reload_labels_admin(admin_token, version): """ Admin function to reload a specific label version. """ if not admin_token: return "Error: Admin token required" try: result = handler({ "inputs": { "op": "reload_labels", "token": admin_token, "version": int(version) if version else 1 } }) if result.get("status") == "ok": return f"✅ Labels reloaded successfully! Current version: {result.get('labels_version', 'unknown')}" elif result.get("status") == "nochange": return f"ℹ️ No change needed. Current version: {result.get('labels_version', 'unknown')}" elif result.get("error") == "unauthorized": return "❌ Error: Invalid admin token" elif result.get("error") == "invalid_version": return "❌ Error: Invalid version number" else: return f"❌ Error: {result.get('error', 'Unknown error')}" except Exception as e: return f"❌ Error: {str(e)}" def get_current_stats(): """ Get current label statistics. """ try: num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0 version = getattr(handler, 'labels_version', 1) device = handler.device if hasattr(handler, 'device') else "unknown" stats = f""" **Current Statistics:** - Number of labels: {num_labels} - Labels version: {version} - Device: {device} - Model: MobileCLIP-B """ if hasattr(handler, 'class_names') and len(handler.class_names) > 0: stats += f"\n- Sample labels: {', '.join(handler.class_names[:5])}" if len(handler.class_names) > 5: stats += "..." return stats except Exception as e: return f"Error getting stats: {str(e)}" # Create Gradio interface with gr.Blocks(title="MobileCLIP Image Classifier") as demo: gr.Markdown(""" # 🖼️ MobileCLIP-B Zero-Shot Image Classifier Upload an image to classify it using MobileCLIP-B model with dynamic label management. """) with gr.Tab("🔍 Image Classification"): with gr.Row(): with gr.Column(): input_image = gr.Image( type="pil", label="Upload Image" ) top_k_slider = gr.Slider( minimum=1, maximum=50, value=10, step=1, label="Number of top results to show" ) classify_btn = gr.Button("🚀 Classify Image", variant="primary") with gr.Column(): output_chart = gr.BarPlot( label="Classification Confidence", x_label="Label", y_label="Confidence", vertical=False, height=400 ) output_text = gr.Markdown(label="Classification Results") gr.Examples( examples=[ ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/image_classifier/examples/cheetah.jpg"], ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/image_classifier/examples/elephant.jpg"], ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/image_classifier/examples/giraffe.jpg"] ], inputs=input_image, label="Example Images" ) classify_btn.click( classify_image, inputs=[input_image, top_k_slider], outputs=[output_chart, output_text] ) with gr.Tab("🔧 Admin Panel"): gr.Markdown(""" ### Admin Functions **Note:** Requires admin token (set via environment variable `ADMIN_TOKEN`) """) with gr.Row(): admin_token_input = gr.Textbox( label="Admin Token", type="password", placeholder="Enter admin token" ) with gr.Accordion("📊 Current Statistics", open=True): stats_display = gr.Markdown(value=get_current_stats()) refresh_stats_btn = gr.Button("🔄 Refresh Stats") refresh_stats_btn.click( get_current_stats, outputs=stats_display ) with gr.Accordion("➕ Add New Labels", open=False): gr.Markdown(""" Add new labels by providing JSON array: ```json [ {"id": 100, "name": "new_object", "prompt": "a photo of a new_object"}, {"id": 101, "name": "another_object", "prompt": "a photo of another_object"} ] ``` """) new_items_input = gr.Code( label="New Items JSON", language="json", lines=5, value='[\n {"id": 100, "name": "example", "prompt": "a photo of example"}\n]' ) upsert_btn = gr.Button("➕ Add Labels", variant="primary") upsert_output = gr.Markdown() upsert_btn.click( upsert_labels_admin, inputs=[admin_token_input, new_items_input], outputs=upsert_output ) with gr.Accordion("🔄 Reload Label Version", open=False): gr.Markdown("Reload labels from a specific version stored in the Hub") version_input = gr.Number( label="Version Number", value=1, precision=0 ) reload_btn = gr.Button("🔄 Reload Version", variant="primary") reload_output = gr.Markdown() reload_btn.click( reload_labels_admin, inputs=[admin_token_input, version_input], outputs=reload_output ) with gr.Tab("ℹ️ About"): gr.Markdown(""" ## About MobileCLIP-B Classifier This Space provides a web interface for Apple's MobileCLIP-B model, optimized for fast zero-shot image classification. ### Features: - 🚀 **Fast inference**: < 30ms on GPU - 🏷️ **Dynamic labels**: Add/update labels without redeployment - 🔄 **Version control**: Track and reload label versions - 📊 **Visual results**: Bar charts and confidence scores ### Environment Variables (set in Space Settings): - `ADMIN_TOKEN`: Secret token for admin operations - `HF_LABEL_REPO`: Hub repository for label storage (e.g., "username/labels") - `HF_WRITE_TOKEN`: Token with write permissions to label repo - `HF_READ_TOKEN`: Token with read permissions (optional, defaults to write token) ### Model Details: - **Architecture**: MobileCLIP-B with MobileOne blocks - **Text Encoder**: Transformer-based, 77 token context - **Image Size**: 224x224 - **Embedding Dim**: 512 ### License: Model weights are licensed under Apple Sample Code License (ASCL). """) if __name__ == "__main__": demo.launch()