aslan-ng's picture
Update app.py
868e16b verified
import os # For filesystem operations
import shutil # For directory cleanup
import zipfile # For extracting model archives
import pathlib # For path manipulations
import pandas # For tabular data handling
import gradio # For interactive UI
import huggingface_hub # For downloading model assets
import autogluon.tabular # For loading and running AutoGluon predictors
# Settings
MODEL_REPO_ID = "its-zion-18/flowers-tabular-autolguon-predictor"
ZIP_FILENAME = "autogluon_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
FEATURE_COLS = [
"flower_diameter_cm",
"petal_length_cm",
"petal_width_cm",
"petal_count",
"stem_height_cm",
]
TARGET_COL = "color"
# If your repo id has a typo in "autolguon", fix it here if download fails.
# Download & load the native predictor
def _prepare_predictor_dir() -> str:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_zip = huggingface_hub.hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=ZIP_FILENAME,
repo_type="model",
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
)
if EXTRACT_DIR.exists():
shutil.rmtree(EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(EXTRACT_DIR))
contents = list(EXTRACT_DIR.iterdir())
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
return str(predictor_root)
PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = autogluon.tabular.TabularPredictor.load(PREDICTOR_DIR, require_py_version_match=False)
def do_predict(flower_diameter_cm, petal_length_cm, petal_width_cm, petal_count, stem_height_cm, top_k):
try:
row = {
FEATURE_COLS[0]: float(flower_diameter_cm),
FEATURE_COLS[1]: float(petal_length_cm),
FEATURE_COLS[2]: float(petal_width_cm),
FEATURE_COLS[3]: int(petal_count),
FEATURE_COLS[4]: float(stem_height_cm),
}
X = pandas.DataFrame([row], columns=FEATURE_COLS)
pred_series = PREDICTOR.predict(X)
pred_label = str(pred_series.iloc[0])
try:
proba = PREDICTOR.predict_proba(X)
if isinstance(proba, pandas.Series):
proba = proba.to_frame().T
row0 = proba.iloc[0].sort_values(ascending=False)
if isinstance(top_k, (int, float)) and top_k > 0:
row0 = row0.head(int(top_k))
proba_dict = {str(cls): float(val) for cls, val in row0.items()}
except Exception:
proba_dict = {pred_label: 1.0}
return proba_dict, "" # second output is debug text
except Exception as e:
import traceback
return {}, f"{e}\n\n{traceback.format_exc()}"
# ----------------
# Example records
# ----------------
EXAMPLES = [
[4.5, 5.2, 1.8, 5, 35.0], # diam, petal_len, petal_wid, count, stem_h
[2.1, 3.3, 0.9, 8, 22.0],
[6.8, 7.1, 2.5, 6, 55.0],
[9.0, 4.0, 1.2, 12, 80.0],
[1.8, 2.6, 0.5, 4, 15.0],
]
with gradio.Blocks() as demo:
gradio.Markdown("# 🌼 Flower Color Classifier\nPredict the flower **color** from five measurements.")
gradio.Markdown(
"Enter a single flower’s measurements below. "
"Use **Top-K** to see the most likely colors with their probabilities."
)
with gradio.Row():
flower_diameter_cm = gradio.Slider(0.0, 20.0, step=0.1, value=4.5, label="flower_diameter_cm")
petal_length_cm = gradio.Slider(0.0, 15.0, step=0.1, value=5.2, label="petal_length_cm")
petal_width_cm = gradio.Slider(0.0, 10.0, step=0.1, value=1.8, label="petal_width_cm")
with gradio.Row():
petal_count = gradio.Slider(1, 100, step=1, value=5, label="petal_count")
stem_height_cm = gradio.Slider(0.0, 200.0, step=0.5, value=35.0, label="stem_height_cm")
top_k = gradio.Slider(1, 10, step=1, value=3, label="Top-K classes shown")
# Separate outputs: Textbox for label, Label for probs (dict must be numeric)
proba_pretty = gradio.Label(num_top_classes=10, label="Class probabilities")
debug_box = gradio.Textbox(label="debug", visible=False)
inputs = [flower_diameter_cm, petal_length_cm, petal_width_cm, petal_count, stem_height_cm, top_k]
# Trigger on any change
for comp in inputs:
comp.change(fn=do_predict, inputs=inputs, outputs=[proba_pretty, debug_box])
# Examples: only pass the first 5 inputs (excluding top_k) to match example rows
gradio.Examples(
examples=EXAMPLES,
inputs=inputs[:-1], # exclude top_k so example length matches
label="Representative examples",
examples_per_page=5,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()