import os # For filesystem operations import shutil # For directory cleanup import zipfile # For extracting model archives import pathlib # For path manipulations import pandas # For tabular data handling import gradio # For interactive UI import huggingface_hub # For downloading model assets import autogluon.tabular # For loading and running AutoGluon predictors # Settings MODEL_REPO_ID = "its-zion-18/flowers-tabular-autolguon-predictor" ZIP_FILENAME = "autogluon_predictor_dir.zip" CACHE_DIR = pathlib.Path("hf_assets") EXTRACT_DIR = CACHE_DIR / "predictor_native" FEATURE_COLS = [ "flower_diameter_cm", "petal_length_cm", "petal_width_cm", "petal_count", "stem_height_cm", ] TARGET_COL = "color" # If your repo id has a typo in "autolguon", fix it here if download fails. # Download & load the native predictor def _prepare_predictor_dir() -> str: CACHE_DIR.mkdir(parents=True, exist_ok=True) local_zip = huggingface_hub.hf_hub_download( repo_id=MODEL_REPO_ID, filename=ZIP_FILENAME, repo_type="model", local_dir=str(CACHE_DIR), local_dir_use_symlinks=False, ) if EXTRACT_DIR.exists(): shutil.rmtree(EXTRACT_DIR) EXTRACT_DIR.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(local_zip, "r") as zf: zf.extractall(str(EXTRACT_DIR)) contents = list(EXTRACT_DIR.iterdir()) predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR return str(predictor_root) PREDICTOR_DIR = _prepare_predictor_dir() PREDICTOR = autogluon.tabular.TabularPredictor.load(PREDICTOR_DIR, require_py_version_match=False) def do_predict(flower_diameter_cm, petal_length_cm, petal_width_cm, petal_count, stem_height_cm, top_k): try: row = { FEATURE_COLS[0]: float(flower_diameter_cm), FEATURE_COLS[1]: float(petal_length_cm), FEATURE_COLS[2]: float(petal_width_cm), FEATURE_COLS[3]: int(petal_count), FEATURE_COLS[4]: float(stem_height_cm), } X = pandas.DataFrame([row], columns=FEATURE_COLS) pred_series = PREDICTOR.predict(X) pred_label = str(pred_series.iloc[0]) try: proba = PREDICTOR.predict_proba(X) if isinstance(proba, pandas.Series): proba = proba.to_frame().T row0 = proba.iloc[0].sort_values(ascending=False) if isinstance(top_k, (int, float)) and top_k > 0: row0 = row0.head(int(top_k)) proba_dict = {str(cls): float(val) for cls, val in row0.items()} except Exception: proba_dict = {pred_label: 1.0} return proba_dict, "" # second output is debug text except Exception as e: import traceback return {}, f"{e}\n\n{traceback.format_exc()}" # ---------------- # Example records # ---------------- EXAMPLES = [ [4.5, 5.2, 1.8, 5, 35.0], # diam, petal_len, petal_wid, count, stem_h [2.1, 3.3, 0.9, 8, 22.0], [6.8, 7.1, 2.5, 6, 55.0], [9.0, 4.0, 1.2, 12, 80.0], [1.8, 2.6, 0.5, 4, 15.0], ] with gradio.Blocks() as demo: gradio.Markdown("# 🌼 Flower Color Classifier\nPredict the flower **color** from five measurements.") gradio.Markdown( "Enter a single flower’s measurements below. " "Use **Top-K** to see the most likely colors with their probabilities." ) with gradio.Row(): flower_diameter_cm = gradio.Slider(0.0, 20.0, step=0.1, value=4.5, label="flower_diameter_cm") petal_length_cm = gradio.Slider(0.0, 15.0, step=0.1, value=5.2, label="petal_length_cm") petal_width_cm = gradio.Slider(0.0, 10.0, step=0.1, value=1.8, label="petal_width_cm") with gradio.Row(): petal_count = gradio.Slider(1, 100, step=1, value=5, label="petal_count") stem_height_cm = gradio.Slider(0.0, 200.0, step=0.5, value=35.0, label="stem_height_cm") top_k = gradio.Slider(1, 10, step=1, value=3, label="Top-K classes shown") # Separate outputs: Textbox for label, Label for probs (dict must be numeric) proba_pretty = gradio.Label(num_top_classes=10, label="Class probabilities") debug_box = gradio.Textbox(label="debug", visible=False) inputs = [flower_diameter_cm, petal_length_cm, petal_width_cm, petal_count, stem_height_cm, top_k] # Trigger on any change for comp in inputs: comp.change(fn=do_predict, inputs=inputs, outputs=[proba_pretty, debug_box]) # Examples: only pass the first 5 inputs (excluding top_k) to match example rows gradio.Examples( examples=EXAMPLES, inputs=inputs[:-1], # exclude top_k so example length matches label="Representative examples", examples_per_page=5, cache_examples=False, ) if __name__ == "__main__": demo.launch()