Spaces:
Running
Running
| import os # For filesystem operations | |
| import shutil # For directory cleanup | |
| import zipfile # For extracting model archives | |
| import pathlib # For path manipulations | |
| import pandas # For tabular data handling | |
| import gradio # For interactive UI | |
| import huggingface_hub # For downloading model assets | |
| import autogluon.tabular # For loading and running AutoGluon predictors | |
| # Settings | |
| MODEL_REPO_ID = "its-zion-18/flowers-tabular-autolguon-predictor" | |
| ZIP_FILENAME = "autogluon_predictor_dir.zip" | |
| CACHE_DIR = pathlib.Path("hf_assets") | |
| EXTRACT_DIR = CACHE_DIR / "predictor_native" | |
| FEATURE_COLS = [ | |
| "flower_diameter_cm", | |
| "petal_length_cm", | |
| "petal_width_cm", | |
| "petal_count", | |
| "stem_height_cm", | |
| ] | |
| TARGET_COL = "color" | |
| # If your repo id has a typo in "autolguon", fix it here if download fails. | |
| # Download & load the native predictor | |
| def _prepare_predictor_dir() -> str: | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| local_zip = huggingface_hub.hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=ZIP_FILENAME, | |
| repo_type="model", | |
| local_dir=str(CACHE_DIR), | |
| local_dir_use_symlinks=False, | |
| ) | |
| if EXTRACT_DIR.exists(): | |
| shutil.rmtree(EXTRACT_DIR) | |
| EXTRACT_DIR.mkdir(parents=True, exist_ok=True) | |
| with zipfile.ZipFile(local_zip, "r") as zf: | |
| zf.extractall(str(EXTRACT_DIR)) | |
| contents = list(EXTRACT_DIR.iterdir()) | |
| predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR | |
| return str(predictor_root) | |
| PREDICTOR_DIR = _prepare_predictor_dir() | |
| PREDICTOR = autogluon.tabular.TabularPredictor.load(PREDICTOR_DIR, require_py_version_match=False) | |
| def do_predict(flower_diameter_cm, petal_length_cm, petal_width_cm, petal_count, stem_height_cm, top_k): | |
| try: | |
| row = { | |
| FEATURE_COLS[0]: float(flower_diameter_cm), | |
| FEATURE_COLS[1]: float(petal_length_cm), | |
| FEATURE_COLS[2]: float(petal_width_cm), | |
| FEATURE_COLS[3]: int(petal_count), | |
| FEATURE_COLS[4]: float(stem_height_cm), | |
| } | |
| X = pandas.DataFrame([row], columns=FEATURE_COLS) | |
| pred_series = PREDICTOR.predict(X) | |
| pred_label = str(pred_series.iloc[0]) | |
| try: | |
| proba = PREDICTOR.predict_proba(X) | |
| if isinstance(proba, pandas.Series): | |
| proba = proba.to_frame().T | |
| row0 = proba.iloc[0].sort_values(ascending=False) | |
| if isinstance(top_k, (int, float)) and top_k > 0: | |
| row0 = row0.head(int(top_k)) | |
| proba_dict = {str(cls): float(val) for cls, val in row0.items()} | |
| except Exception: | |
| proba_dict = {pred_label: 1.0} | |
| return proba_dict, "" # second output is debug text | |
| except Exception as e: | |
| import traceback | |
| return {}, f"{e}\n\n{traceback.format_exc()}" | |
| # ---------------- | |
| # Example records | |
| # ---------------- | |
| EXAMPLES = [ | |
| [4.5, 5.2, 1.8, 5, 35.0], # diam, petal_len, petal_wid, count, stem_h | |
| [2.1, 3.3, 0.9, 8, 22.0], | |
| [6.8, 7.1, 2.5, 6, 55.0], | |
| [9.0, 4.0, 1.2, 12, 80.0], | |
| [1.8, 2.6, 0.5, 4, 15.0], | |
| ] | |
| with gradio.Blocks() as demo: | |
| gradio.Markdown("# 🌼 Flower Color Classifier\nPredict the flower **color** from five measurements.") | |
| gradio.Markdown( | |
| "Enter a single flower’s measurements below. " | |
| "Use **Top-K** to see the most likely colors with their probabilities." | |
| ) | |
| with gradio.Row(): | |
| flower_diameter_cm = gradio.Slider(0.0, 20.0, step=0.1, value=4.5, label="flower_diameter_cm") | |
| petal_length_cm = gradio.Slider(0.0, 15.0, step=0.1, value=5.2, label="petal_length_cm") | |
| petal_width_cm = gradio.Slider(0.0, 10.0, step=0.1, value=1.8, label="petal_width_cm") | |
| with gradio.Row(): | |
| petal_count = gradio.Slider(1, 100, step=1, value=5, label="petal_count") | |
| stem_height_cm = gradio.Slider(0.0, 200.0, step=0.5, value=35.0, label="stem_height_cm") | |
| top_k = gradio.Slider(1, 10, step=1, value=3, label="Top-K classes shown") | |
| # Separate outputs: Textbox for label, Label for probs (dict must be numeric) | |
| proba_pretty = gradio.Label(num_top_classes=10, label="Class probabilities") | |
| debug_box = gradio.Textbox(label="debug", visible=False) | |
| inputs = [flower_diameter_cm, petal_length_cm, petal_width_cm, petal_count, stem_height_cm, top_k] | |
| # Trigger on any change | |
| for comp in inputs: | |
| comp.change(fn=do_predict, inputs=inputs, outputs=[proba_pretty, debug_box]) | |
| # Examples: only pass the first 5 inputs (excluding top_k) to match example rows | |
| gradio.Examples( | |
| examples=EXAMPLES, | |
| inputs=inputs[:-1], # exclude top_k so example length matches | |
| label="Representative examples", | |
| examples_per_page=5, | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |