|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import plotly.express as px |
|
|
import plotly.graph_objects as go |
|
|
from datetime import datetime, timedelta |
|
|
from io import StringIO |
|
|
import os |
|
|
import json |
|
|
from scipy.stats import linregress |
|
|
|
|
|
|
|
|
st.write("Debug: Checking file paths...") |
|
|
files_to_check = ["new_best_improved_model.pth", "scaler.pkl", "feature_names.json", "model_config.json"] |
|
|
for file in files_to_check: |
|
|
st.write(f"{file}: {'Found' if os.path.exists(file) else 'Missing'}") |
|
|
|
|
|
try: |
|
|
from inference import load_model_and_artifacts, predict, derive_features |
|
|
except Exception as e: |
|
|
st.error(f"Error importing inference: {str(e)}") |
|
|
st.stop() |
|
|
|
|
|
st.title("Store Sales Time Series Forecasting") |
|
|
st.markdown("Forecast 13-week store sales using an LSTM model trained on Kaggle Store Sales data.") |
|
|
|
|
|
|
|
|
try: |
|
|
st.write("Debug: Loading model and artifacts...") |
|
|
model, scaler, feature_names, config = load_model_and_artifacts() |
|
|
st.success("Model and artifacts loaded successfully") |
|
|
except Exception as e: |
|
|
st.error(f"Error loading model or artifacts: {str(e)}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
st.header("Model Performance Metrics") |
|
|
metrics = { |
|
|
"MAE": 710.75, |
|
|
"RMSE": 1108.36, |
|
|
"MAPE": 7.16, |
|
|
"R2": 0.8633 |
|
|
} |
|
|
st.markdown(f""" |
|
|
- **MAE**: ${metrics['MAE']:.2f} |
|
|
- **RMSE**: ${metrics['RMSE']:.2f} |
|
|
- **MAPE**: {metrics['MAPE']:.2f}% |
|
|
- **R² Score**: {metrics['R2']:.4f} |
|
|
""") |
|
|
|
|
|
|
|
|
st.header("Model Architecture") |
|
|
st.markdown(f""" |
|
|
- **Input Size**: {config['input_size']} features |
|
|
- **Hidden Size**: {config['hidden_size']} |
|
|
- **Number of Layers**: {config['num_layers']} |
|
|
- **Forecast Horizon**: {config['forecast_horizon']} weeks |
|
|
- **Dropout**: {config['dropout']} |
|
|
- **Attention**: {config['has_attention']} |
|
|
- **Input Projection**: {config['has_input_projection']} |
|
|
- **Parameters**: 227,441 |
|
|
""") |
|
|
|
|
|
|
|
|
def analyze_input_data(df, date_col, sales_col): |
|
|
"""Compute statistical metrics and generate infographics for input sales data.""" |
|
|
|
|
|
df[date_col] = pd.to_datetime(df[date_col]) |
|
|
|
|
|
|
|
|
sales = df[sales_col] |
|
|
metrics = { |
|
|
"Mean Sales ($)": sales.mean(), |
|
|
"Std Sales ($)": sales.std(), |
|
|
"Min Sales ($)": sales.min(), |
|
|
"Max Sales ($)": sales.max(), |
|
|
"Median Sales ($)": sales.median(), |
|
|
"Trend Slope": linregress(range(len(sales)), sales).slope |
|
|
} |
|
|
|
|
|
|
|
|
metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=["Value"]) |
|
|
metrics_df["Value"] = metrics_df["Value"].round(2) |
|
|
|
|
|
|
|
|
|
|
|
fig1 = px.line(df, x=date_col, y=sales_col, title="Historical Sales Over Time") |
|
|
fig1.update_traces(line=dict(color="blue")) |
|
|
|
|
|
|
|
|
df["MA_7"] = df[sales_col].rolling(window=7, min_periods=1).mean() |
|
|
fig2 = px.line(df, x=date_col, y=["sales", "MA_7"], title="Sales with 7-Day Moving Average") |
|
|
fig2.update_traces(line=dict(color="blue"), selector=dict(name="sales")) |
|
|
fig2.update_traces(line=dict(color="orange", dash="dash"), selector=dict(name="MA_7")) |
|
|
|
|
|
|
|
|
fig3 = px.histogram(df, x=sales_col, nbins=20, title="Sales Distribution") |
|
|
fig3.update_traces(marker=dict(color="blue")) |
|
|
|
|
|
return metrics_df, [fig1, fig2, fig3] |
|
|
|
|
|
|
|
|
st.header("Generate Synthetic Test Data") |
|
|
st.markdown("Create a sample dataset with 21 timesteps matching the training data distribution (sales ~$3,000–19,000).") |
|
|
if st.button("Generate Synthetic Data"): |
|
|
np.random.seed(42) |
|
|
sequence_length = 21 |
|
|
n_features = len(feature_names) |
|
|
synthetic_data = np.zeros((sequence_length, n_features)) |
|
|
|
|
|
|
|
|
for i, feature in enumerate(feature_names): |
|
|
if feature == "sales": |
|
|
synthetic_data[:, i] = np.random.normal(8954.97, 3307.49, sequence_length) |
|
|
elif feature == "onpromotion": |
|
|
synthetic_data[:, i] = np.random.choice([0, 1], sequence_length, p=[0.8, 0.2]) |
|
|
elif feature in ["dayofweek_sin", "dayofweek_cos"]: |
|
|
synthetic_data[:, i] = np.sin(np.linspace(0, 2 * np.pi, sequence_length)) if "sin" in feature else np.cos(np.linspace(0, 2 * np.pi, sequence_length)) |
|
|
elif feature in ["month_sin", "month_cos"]: |
|
|
synthetic_data[:, i] = np.sin(np.linspace(0, 2 * np.pi * 12 / sequence_length, sequence_length)) if "sin" in feature else np.cos(np.linspace(0, 2 * np.pi * 12 / sequence_length, sequence_length)) |
|
|
elif feature == "trend": |
|
|
synthetic_data[:, i] = np.linspace(0, sequence_length, sequence_length) |
|
|
elif feature == "is_weekend": |
|
|
synthetic_data[:, i] = np.random.choice([0, 1], sequence_length, p=[0.7, 0.3]) |
|
|
elif feature == "quarter": |
|
|
synthetic_data[:, i] = np.random.choice([1, 2, 3, 4], sequence_length) |
|
|
elif "lag" in feature: |
|
|
lag = int(feature.split('_')[-1]) |
|
|
synthetic_data[:, i] = np.roll(synthetic_data[:, 0], lag) |
|
|
if lag > 0: |
|
|
synthetic_data[:lag, i] = synthetic_data[:lag, 0] |
|
|
elif "ma" in feature: |
|
|
window = int(feature.split('_')[-1]) |
|
|
synthetic_data[:, i] = pd.Series(synthetic_data[:, 0]).rolling(window=window, min_periods=1).mean().values |
|
|
elif "ratio" in feature: |
|
|
window = int(feature.split('_')[-1]) |
|
|
ma = pd.Series(synthetic_data[:, 0]).rolling(window=window, min_periods=1).mean().values |
|
|
synthetic_data[:, i] = synthetic_data[:, 0] / (ma + 1e-8) |
|
|
elif "promo" in feature: |
|
|
synthetic_data[:, i] = np.random.choice([0, 1], sequence_length, p=[0.8, 0.2]) |
|
|
elif feature == "dcoilwtico": |
|
|
synthetic_data[:, i] = np.random.normal(80, 10, sequence_length) |
|
|
elif feature == "is_holiday": |
|
|
synthetic_data[:, i] = np.random.choice([0, 1], sequence_length, p=[0.9, 0.1]) |
|
|
|
|
|
|
|
|
synthetic_df = pd.DataFrame(synthetic_data, columns=feature_names) |
|
|
end_date = datetime.now().date() |
|
|
dates = [end_date - timedelta(days=x) for x in range(sequence_length-1, -1, -1)] |
|
|
synthetic_df['Date'] = dates |
|
|
|
|
|
|
|
|
st.session_state["synthetic_df"] = synthetic_df |
|
|
|
|
|
|
|
|
st.subheader("Input Data Analysis") |
|
|
metrics_df, infographics = analyze_input_data(synthetic_df, "Date", "sales") |
|
|
st.write("**Statistical Metrics**") |
|
|
st.dataframe(metrics_df) |
|
|
st.write("**Infographics**") |
|
|
for fig in infographics: |
|
|
st.plotly_chart(fig) |
|
|
|
|
|
st.subheader("Synthetic Data Preview") |
|
|
st.dataframe(synthetic_df[["Date", "sales", "onpromotion", "dcoilwtico", "is_holiday"]].head()) |
|
|
|
|
|
|
|
|
csv_buffer = StringIO() |
|
|
synthetic_df[["Date", "sales", "onpromotion", "dcoilwtico", "is_holiday"]].rename(columns={"Date": "date"}).to_csv(csv_buffer, index=False) |
|
|
st.download_button( |
|
|
label="Download Synthetic Data CSV", |
|
|
data=csv_buffer.getvalue(), |
|
|
file_name="synthetic_sales_data.csv", |
|
|
mime="text/csv" |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
sequences = synthetic_df[feature_names].values.reshape(1, sequence_length, n_features) |
|
|
sequences_scaled = scaler.transform(sequences.reshape(-1, n_features)).reshape(1, sequence_length, n_features) |
|
|
predictions, uncertainties = predict(model, scaler, sequences_scaled) |
|
|
|
|
|
|
|
|
if predictions.shape != (1, 13) or uncertainties.shape != (1, 13): |
|
|
raise ValueError(f"Expected predictions and uncertainties of shape (1, 13), got {predictions.shape} and {uncertainties.shape}") |
|
|
|
|
|
|
|
|
forecast_dates = [end_date + timedelta(days=x*7) for x in range(1, 14)] |
|
|
forecast_df = pd.DataFrame({ |
|
|
'Date': forecast_dates, |
|
|
'Predicted Sales ($)': predictions[0], |
|
|
'Uncertainty ($)': uncertainties[0] |
|
|
}) |
|
|
|
|
|
st.subheader("13-Week Forecast") |
|
|
st.dataframe(forecast_df) |
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=synthetic_df["Date"], |
|
|
y=synthetic_df["sales"], |
|
|
mode='lines+markers', |
|
|
name='Historical Sales', |
|
|
line=dict(color='blue') |
|
|
)) |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=forecast_df['Date'], |
|
|
y=forecast_df['Predicted Sales ($)'], |
|
|
mode='lines+markers', |
|
|
name='Predicted Sales', |
|
|
line=dict(color='red', dash='dash') |
|
|
)) |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=forecast_df['Date'], |
|
|
y=forecast_df['Predicted Sales ($)'] + forecast_df['Uncertainty ($)'], |
|
|
mode='lines', |
|
|
name='Upper Bound', |
|
|
line=dict(color='green', dash='dash'), |
|
|
showlegend=True |
|
|
)) |
|
|
fig.add_trace(go.Scatter( |
|
|
x=forecast_df['Date'], |
|
|
y=forecast_df['Predicted Sales ($)'] - forecast_df['Uncertainty ($)'], |
|
|
mode='lines', |
|
|
name='Lower Bound', |
|
|
line=dict(color='green', dash='dash'), |
|
|
fill='tonexty', |
|
|
fillcolor='rgba(0, 255, 0, 0.1)' |
|
|
)) |
|
|
fig.update_layout( |
|
|
title="Historical and 13-Week Forecasted Sales", |
|
|
xaxis_title="Date", |
|
|
yaxis_title="Sales ($)", |
|
|
template="plotly_white" |
|
|
) |
|
|
st.plotly_chart(fig) |
|
|
|
|
|
|
|
|
csv_buffer = StringIO() |
|
|
forecast_df.to_csv(csv_buffer, index=False) |
|
|
st.download_button( |
|
|
label="Download Forecast CSV", |
|
|
data=csv_buffer.getvalue(), |
|
|
file_name="forecast_results.csv", |
|
|
mime="text/csv" |
|
|
) |
|
|
except Exception as e: |
|
|
st.error(f"Error generating forecast: {str(e)}") |
|
|
|
|
|
|
|
|
st.header("Upload Custom Data") |
|
|
st.markdown(""" |
|
|
Upload a CSV with 21 timesteps containing the following columns: |
|
|
- **date**: Date in YYYY-MM-DD format (e.g., 2025-06-22) |
|
|
- **sales**: Weekly sales in USD (e.g., 3000 to 19372) |
|
|
- **onpromotion**: 0 or 1 indicating if items are on promotion |
|
|
- **dcoilwtico**: Oil price (e.g., 70 to 90) |
|
|
- **is_holiday**: 0 or 1 indicating if the day is a holiday |
|
|
|
|
|
The remaining features will be derived automatically. Download a sample CSV below to see the expected format. |
|
|
""") |
|
|
|
|
|
|
|
|
sample_data = pd.DataFrame({ |
|
|
"date": ["2025-06-22", "2025-06-15", "2025-06-08"], |
|
|
"sales": [8954.97, 9500.00, 8000.00], |
|
|
"onpromotion": [0, 1, 0], |
|
|
"dcoilwtico": [80.0, 82.5, 78.0], |
|
|
"is_holiday": [0, 0, 1] |
|
|
}) |
|
|
csv_buffer = StringIO() |
|
|
sample_data.to_csv(csv_buffer, index=False) |
|
|
st.download_button( |
|
|
label="Download Sample CSV", |
|
|
data=csv_buffer.getvalue(), |
|
|
file_name="sample_input.csv", |
|
|
mime="text/csv" |
|
|
) |
|
|
|
|
|
|
|
|
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") |
|
|
|
|
|
if uploaded_file is not None: |
|
|
try: |
|
|
data = pd.read_csv(uploaded_file) |
|
|
required_columns = ["date", "sales", "onpromotion", "dcoilwtico", "is_holiday"] |
|
|
if set(required_columns).issubset(data.columns) and len(data) == 21: |
|
|
|
|
|
st.subheader("Input Data Analysis") |
|
|
metrics_df, infographics = analyze_input_data(data, "date", "sales") |
|
|
st.write("**Statistical Metrics**") |
|
|
st.dataframe(metrics_df) |
|
|
st.write("**Infographics**") |
|
|
for fig in infographics: |
|
|
st.plotly_chart(fig) |
|
|
|
|
|
|
|
|
sequences = derive_features(data, feature_names, sequence_length=21) |
|
|
sequences_scaled = scaler.transform(sequences.reshape(-1, len(feature_names))).reshape(1, 21, len(feature_names)) |
|
|
predictions, uncertainties = predict(model, scaler, sequences_scaled) |
|
|
|
|
|
|
|
|
if predictions.shape != (1, 13) or uncertainties.shape != (1, 13): |
|
|
raise ValueError(f"Expected predictions and uncertainties of shape (1, 13), got {predictions.shape} and {uncertainties.shape}") |
|
|
|
|
|
|
|
|
end_date = pd.to_datetime(data["date"].iloc[0]).date() |
|
|
forecast_dates = [end_date + timedelta(days=x*7) for x in range(1, 14)] |
|
|
forecast_df = pd.DataFrame({ |
|
|
'Date': forecast_dates, |
|
|
'Predicted Sales ($)': predictions[0], |
|
|
'Uncertainty ($)': uncertainties[0] |
|
|
}) |
|
|
|
|
|
st.subheader("13-Week Forecast") |
|
|
st.dataframe(forecast_df) |
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=data["date"], |
|
|
y=data["sales"], |
|
|
mode='lines+markers', |
|
|
name='Historical Sales', |
|
|
line=dict(color='blue') |
|
|
)) |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=forecast_df['Date'], |
|
|
y=forecast_df['Predicted Sales ($)'], |
|
|
mode='lines+markers', |
|
|
name='Predicted Sales', |
|
|
line=dict(color='red', dash='dash') |
|
|
)) |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=forecast_df['Date'], |
|
|
y=forecast_df['Predicted Sales ($)'] + forecast_df['Uncertainty ($)'], |
|
|
mode='lines', |
|
|
name='Upper Bound', |
|
|
line=dict(color='green', dash='dash'), |
|
|
showlegend=True |
|
|
)) |
|
|
fig.add_trace(go.Scatter( |
|
|
x=forecast_df['Date'], |
|
|
y=forecast_df['Predicted Sales ($)'] - forecast_df['Uncertainty ($)'], |
|
|
mode='lines', |
|
|
name='Lower Bound', |
|
|
line=dict(color='green', dash='dash'), |
|
|
fill='tonexty', |
|
|
fillcolor='rgba(0, 255, 0, 0.1)' |
|
|
)) |
|
|
fig.update_layout( |
|
|
title="Historical and 13-Week Forecasted Sales", |
|
|
xaxis_title="Date", |
|
|
yaxis_title="Sales ($)", |
|
|
template="plotly_white" |
|
|
) |
|
|
st.plotly_chart(fig) |
|
|
|
|
|
|
|
|
csv_buffer = StringIO() |
|
|
forecast_df.to_csv(csv_buffer, index=False) |
|
|
st.download_button( |
|
|
label="Download Forecast CSV", |
|
|
data=csv_buffer.getvalue(), |
|
|
file_name="custom_forecast_results.csv", |
|
|
mime="text/csv" |
|
|
) |
|
|
else: |
|
|
st.error(f"Invalid CSV. Expected 21 rows and columns: {', '.join(required_columns)}") |
|
|
except Exception as e: |
|
|
st.error(f"Error processing CSV or generating forecast: {str(e)}") |