File size: 3,520 Bytes
079c51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2abfc42
 
 
 
 
 
 
 
 
 
 
 
079c51e
 
 
2abfc42
079c51e
 
 
2abfc42
 
 
 
 
 
 
 
079c51e
 
2abfc42
079c51e
2abfc42
079c51e
 
 
 
2abfc42
 
079c51e
2abfc42
079c51e
 
 
 
2abfc42
079c51e
2abfc42
 
 
 
079c51e
2abfc42
 
 
079c51e
2abfc42
 
079c51e
 
2abfc42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
import numpy as np
import pickle
import json
import os

class ImprovedCashFlowLSTM(nn.Module):
    def __init__(self, input_size, hidden_size=128, num_layers=2, forecast_horizon=13, dropout=0.2):
        super(ImprovedCashFlowLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.forecast_horizon = forecast_horizon
        self.lstm = nn.LSTM(
            input_size, 
            hidden_size, 
            num_layers, 
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        self.output_layers = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, forecast_horizon)
        )
    
    def forward(self, x):
        lstm_out, (hidden, cell) = self.lstm(x)
        last_hidden = lstm_out[:, -1, :]
        forecast = self.output_layers(last_hidden)
        return forecast

def load_model_and_artifacts(
    model_path="new_best_improved_model.pth",
    scaler_path="scaler.pkl",
    feature_names_path="feature_names.json",
    config_path="model_config.json"
):
    if not all(os.path.exists(path) for path in [model_path, scaler_path, feature_names_path, config_path]):
        missing = [path for path in [model_path, scaler_path, feature_names_path, config_path] if not os.path.exists(path)]
        raise FileNotFoundError(f"Missing files: {missing}")
    
    with open(config_path, "r") as f:
        config = json.load(f)
    
    with open(scaler_path, "rb") as f:
        scaler = pickle.load(f)
    
    with open(feature_names_path, "r") as f:
        feature_names = json.load(f)
    
    input_size = config["input_size"]
    model = ImprovedCashFlowLSTM(
        input_size=input_size,
        hidden_size=config["hidden_size"],
        num_layers=config["num_layers"],
        forecast_horizon=config["forecast_horizon"],
        dropout=config["dropout"]
    )
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
    model.eval()
    return model, scaler, feature_names, config

def predict(model, scaler, sequences):
    device = torch.device("cpu")
    model.to(device)
    
    # Validate input shape: (batch_size, sequence_length=21, n_features=20)
    if len(sequences.shape) != 3 or sequences.shape[1] != 21 or sequences.shape[2] != model.lstm.input_size:
        raise ValueError(f"Expected input shape (batch_size, 21, {model.lstm.input_size}), got {sequences.shape}")
    
    # Convert to tensor
    sequences = torch.tensor(sequences, dtype=torch.float32).to(device)
    
    # Generate predictions
    with torch.no_grad():
        predictions = model(sequences).cpu().numpy()  # Shape: (batch_size, 13)
    
    # Inverse transform predictions (sales is first feature)
    dummy = np.zeros((predictions.shape[0] * predictions.shape[1], scaler.n_features_in_))
    dummy[:, 0] = predictions.flatten()
    rescaled = scaler.inverse_transform(dummy)[:, 0].reshape(predictions.shape)
    
    # Ensure non-negative predictions and clip to training range
    rescaled = np.maximum(rescaled, 0)
    rescaled = np.clip(rescaled, 3000, 19372)  # Training sales range: $3069–19372
    
    # Estimate uncertainty (simplified: std of predictions + base uncertainty)
    uncertainties = np.std(rescaled, axis=1, keepdims=True) + 100
    uncertainties = np.clip(uncertainties, 100, 500)
    
    return rescaled, uncertainties