File size: 4,891 Bytes
92e3a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4d358d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92e3a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868e16b
a4d358d
92e3a81
 
 
 
a4d358d
92e3a81
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os  # For filesystem operations
import shutil  # For directory cleanup
import zipfile  # For extracting model archives
import pathlib  # For path manipulations
import pandas  # For tabular data handling
import gradio  # For interactive UI
import huggingface_hub  # For downloading model assets
import autogluon.tabular  # For loading and running AutoGluon predictors

# Settings
MODEL_REPO_ID = "its-zion-18/flowers-tabular-autolguon-predictor"
ZIP_FILENAME  = "autogluon_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"

FEATURE_COLS = [
    "flower_diameter_cm",
    "petal_length_cm",
    "petal_width_cm",
    "petal_count",
    "stem_height_cm",
]
TARGET_COL = "color"

# If your repo id has a typo in "autolguon", fix it here if download fails.
# Download & load the native predictor
def _prepare_predictor_dir() -> str:
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    local_zip = huggingface_hub.hf_hub_download(
        repo_id=MODEL_REPO_ID,
        filename=ZIP_FILENAME,
        repo_type="model",
        local_dir=str(CACHE_DIR),
        local_dir_use_symlinks=False,
    )
    if EXTRACT_DIR.exists():
        shutil.rmtree(EXTRACT_DIR)
    EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(local_zip, "r") as zf:
        zf.extractall(str(EXTRACT_DIR))
    contents = list(EXTRACT_DIR.iterdir())
    predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
    return str(predictor_root)

PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = autogluon.tabular.TabularPredictor.load(PREDICTOR_DIR, require_py_version_match=False)

def do_predict(flower_diameter_cm, petal_length_cm, petal_width_cm, petal_count, stem_height_cm, top_k):
    try:
        row = {
            FEATURE_COLS[0]: float(flower_diameter_cm),
            FEATURE_COLS[1]: float(petal_length_cm),
            FEATURE_COLS[2]: float(petal_width_cm),
            FEATURE_COLS[3]: int(petal_count),
            FEATURE_COLS[4]: float(stem_height_cm),
        }
        X = pandas.DataFrame([row], columns=FEATURE_COLS)

        pred_series = PREDICTOR.predict(X)
        pred_label = str(pred_series.iloc[0])

        try:
            proba = PREDICTOR.predict_proba(X)
            if isinstance(proba, pandas.Series):
                proba = proba.to_frame().T
            row0 = proba.iloc[0].sort_values(ascending=False)
            if isinstance(top_k, (int, float)) and top_k > 0:
                row0 = row0.head(int(top_k))
            proba_dict = {str(cls): float(val) for cls, val in row0.items()}
        except Exception:
            proba_dict = {pred_label: 1.0}

        return proba_dict, ""  # second output is debug text
    except Exception as e:
        import traceback
        return {}, f"{e}\n\n{traceback.format_exc()}"

# ----------------
# Example records
# ----------------
EXAMPLES = [
    [4.5, 5.2, 1.8, 5, 35.0],  # diam, petal_len, petal_wid, count, stem_h
    [2.1, 3.3, 0.9, 8, 22.0],
    [6.8, 7.1, 2.5, 6, 55.0],
    [9.0, 4.0, 1.2, 12, 80.0],
    [1.8, 2.6, 0.5, 4, 15.0],
]

with gradio.Blocks() as demo:
    gradio.Markdown("# 🌼 Flower Color Classifier\nPredict the flower **color** from five measurements.")
    gradio.Markdown(
        "Enter a single flower’s measurements below. "
        "Use **Top-K** to see the most likely colors with their probabilities."
    )

    with gradio.Row():
        flower_diameter_cm = gradio.Slider(0.0, 20.0, step=0.1, value=4.5, label="flower_diameter_cm")
        petal_length_cm    = gradio.Slider(0.0, 15.0, step=0.1, value=5.2, label="petal_length_cm")
        petal_width_cm     = gradio.Slider(0.0, 10.0, step=0.1, value=1.8, label="petal_width_cm")

    with gradio.Row():
        petal_count        = gradio.Slider(1, 100, step=1, value=5, label="petal_count")
        stem_height_cm     = gradio.Slider(0.0, 200.0, step=0.5, value=35.0, label="stem_height_cm")
        top_k              = gradio.Slider(1, 10, step=1, value=3, label="Top-K classes shown")

    # Separate outputs: Textbox for label, Label for probs (dict must be numeric)
    proba_pretty = gradio.Label(num_top_classes=10, label="Class probabilities")
    debug_box    = gradio.Textbox(label="debug", visible=False)
    
    inputs = [flower_diameter_cm, petal_length_cm, petal_width_cm, petal_count, stem_height_cm, top_k]

    # Trigger on any change
    for comp in inputs:
        comp.change(fn=do_predict, inputs=inputs, outputs=[proba_pretty, debug_box])

    # Examples: only pass the first 5 inputs (excluding top_k) to match example rows
    gradio.Examples(
        examples=EXAMPLES,
        inputs=inputs[:-1],  # exclude top_k so example length matches
        label="Representative examples",
        examples_per_page=5,
        cache_examples=False,
    )

if __name__ == "__main__":
    demo.launch()