File size: 6,831 Bytes
58bdb4f
d5d002c
 
58bdb4f
 
 
 
d5d002c
 
 
58bdb4f
d5d002c
58bdb4f
 
 
 
 
d5d002c
 
58bdb4f
 
d5d002c
 
 
 
 
 
 
 
775582e
d5d002c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58bdb4f
d5d002c
58bdb4f
 
 
 
 
 
d5d002c
 
58bdb4f
 
d5d002c
 
 
 
58bdb4f
d5d002c
58bdb4f
 
 
d5d002c
 
 
 
 
 
 
58bdb4f
d5d002c
58bdb4f
 
d5d002c
58bdb4f
d5d002c
58bdb4f
 
d5d002c
 
 
 
 
 
 
 
 
 
775582e
 
 
 
 
 
 
 
 
 
d5d002c
 
58bdb4f
 
 
 
 
775582e
d5d002c
 
 
 
 
 
 
 
775582e
 
 
 
 
 
 
 
 
 
 
58bdb4f
775582e
d5d002c
 
 
 
 
 
 
 
58bdb4f
775582e
d5d002c
58bdb4f
 
d5d002c
58bdb4f
d5d002c
 
 
58bdb4f
d5d002c
 
 
 
 
 
 
58bdb4f
d5d002c
 
 
 
58bdb4f
d5d002c
 
58bdb4f
d5d002c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import time
import wandb
from dotenv import load_dotenv
from sqlalchemy import create_engine
from sqlalchemy.exc import SQLAlchemyError
from huggingface_hub import HfApi
from preprocess_and_prepare_dataset import preprocess_and_push_dataset
from prepare_pd_df import fetch_misclassified_dataframe


#  LOAD ENVIRONMENT 
load_dotenv()


def prepare_datasets():
    """
    Fetch misclassified data, preprocess, and push datasets to Hugging Face.
    Tracks all steps and metrics in Weights & Biases (W&B).
    """

    #  CONFIGURATION 
    hf_token = os.getenv("HF_TOKEN")
    dept_dataset_dir = os.getenv("DEPARTMENT_DATASET")
    urgency_dataset_dir = os.getenv("URGENCY_DATASET")
    DB_URL = os.getenv("POSTGRES_URL")
    PREPARE_DATASET_SPACE_ID = os.getenv("PREPARE_DATASET_SPACE_ID")
    WANDB_API_KEY = os.getenv('WANDB_API_KEY') 
    WANDB_PROJECT_NAME = os.getenv('WANDB_PROJECT_NAME', "sambodhan-dataset-pipeline")
    MIN_DATASET_LEN= os.getenv('MIN_DATASET_LEN', 1000)

    # Validate environment variables
    required_env = {
        "HF_TOKEN": hf_token,
        "DEPARTMENT_DATASET": dept_dataset_dir,
        "URGENCY_DATASET": urgency_dataset_dir,
        "POSTGRES_URL": DB_URL,
        "WANDB_API_KEY": WANDB_API_KEY,
    }
    missing_vars = [k for k, v in required_env.items() if not v]
    if missing_vars:
        raise EnvironmentError(f"Missing required environment variables: {missing_vars}")

    #  INIT W&B 
    wandb.login(key=WANDB_API_KEY)
    run = wandb.init(
        project=WANDB_PROJECT_NAME,
        job_type="prepare_dataset",
        config={
            "database_url": DB_URL,
            "department_dataset": dept_dataset_dir,
            "urgency_dataset": urgency_dataset_dir,
            "hf_space_id": PREPARE_DATASET_SPACE_ID,
            "timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
        },
        tags=["dataset-prep", "hf-space", "auto-sync"],
        settings=wandb.Settings(start_method="thread"),
    )

    wandb.log({"status": "starting_pipeline"})
    wandb.termlog("Starting dataset preparation pipeline...")

    #  DATABASE CONNECTION 
    try:
        engine = create_engine(DB_URL, pool_pre_ping=True)
        wandb.termlog("Created SQLAlchemy engine. Validating connection...")

        max_attempts = 3
        for attempt in range(1, max_attempts + 1):
            try:
                with engine.connect() as conn:
                    conn.exec_driver_sql("SELECT 1")
                wandb.termlog("Database connection successful.")
                wandb.log({"db_connection_status": "success"})
                break
            except SQLAlchemyError as e:
                if attempt == max_attempts:
                    wandb.termlog("Database connection failed after multiple attempts.")
                    wandb.log({"db_connection_status": "failed"})
                    raise
                wait = 2 ** attempt
                wandb.termlog(f"Attempt {attempt} failed: {e}. Retrying in {wait}s...")
                time.sleep(wait)

    except Exception as e:
        wandb.alert(
            title="Database Connection Failed",
            text=str(e),
            level=wandb.AlertLevel.ERROR,
        )
        wandb.finish(exit_code=1)
        raise

    #  DATASET PROCESSING 
    dataset_mapping = {
        "department": dept_dataset_dir,
        "urgency": urgency_dataset_dir,
    }

    for label, dataset_dir in dataset_mapping.items():
        try:
            wandb.termlog(f"Fetching misclassified data for '{label}'...")
            df = fetch_misclassified_dataframe(
                label_column=label,
                engine=engine,
                correct_ratio=0.5,
            )
            record_count = len(df)
            wandb.log({f"{label}_records_fetched": record_count})
            wandb.termlog(f"Retrieved {record_count} records for '{label}'.")

            # Check dataset length before pushing
            if record_count < int(MIN_DATASET_LEN):
                msg = f"Skipped pushing '{label}' dataset — insufficient data ({record_count} < {MIN_DATASET_LEN})."
                wandb.termlog(msg)
                wandb.log({f"{label}_push_status": "skipped_insufficient_data"})

                # Optional: raise controlled exception (won’t stop outer loop)
                raise ValueError(msg)

            #  If sufficient data, proceed
            wandb.termlog(f"Preprocessing and pushing '{label}' dataset to HF Hub...")
            dataset = preprocess_and_push_dataset(
                df=df,
                hf_token=hf_token,
                hf_dataset_dir=dataset_dir,
                label_column=label,
            )

            wandb.termlog(f"Successfully pushed '{label}' dataset.")
            wandb.log({f"{label}_push_status": "success"})
            wandb.alert(
                title=f"{label.capitalize()} Dataset Updated",
                text=f"Successfully pushed dataset to {dataset_dir}",
                level=wandb.AlertLevel.INFO,
            )

  
        except ValueError as ve:
            # Controlled skip — no crash, just log warning
            wandb.alert(
                title=f"{label.capitalize()} Dataset Skipped",
                text=str(ve),
                level=wandb.AlertLevel.WARN,
            )
            wandb.termlog(f"[SKIPPED] {ve}")
            continue  # skip to next label safely

        except Exception as e:
            # Real errors
            wandb.alert(
                title=f"{label.capitalize()} Dataset Preparation Failed",
                text=str(e),
                level=wandb.AlertLevel.ERROR,
            )
            wandb.log({f"{label}_push_status": "failed"})
            wandb.termlog(f"Error processing '{label}' dataset: {e}")
            raise


    #  PAUSE HUGGING FACE SPACE 
    if PREPARE_DATASET_SPACE_ID:
        try:
            wandb.termlog("⏸ Attempting to pause Hugging Face Space...")
            api = HfApi()
            api.pause_space(repo_id=PREPARE_DATASET_SPACE_ID, token=hf_token)
            wandb.log({"hf_space_pause": "success"})
            wandb.termlog("Hugging Face Space paused successfully.")
        except Exception as e:
            wandb.termlog(f"Failed to pause HF Space: {e}")
            wandb.log({"hf_space_pause": "failed"})
            wandb.alert(
                title="HF Space Pause Failed",
                text=str(e),
                level=wandb.AlertLevel.WARN,
            )

    #  COMPLETE 
    wandb.log({"status": "completed"})
    wandb.termlog("Dataset preparation completed successfully!")
    run.finish(exit_code=0)


#  ENTRY POINT 
if __name__ == "__main__":
    try:
        prepare_datasets()
    except Exception as e:
        wandb.termlog(f" Pipeline failed due to an error: {e}")
        wandb.finish(exit_code=1)
        raise