prepare_dataset / prepare_pd_df.py
mr-kush's picture
refactor fetch_misclassified_dataframe to improve SQL queries and remove unused imports
a546051
# prepare_pd_dataframe.py
import pandas as pd
from sqlalchemy import text
def fetch_misclassified_dataframe(label_column: str,
engine,
correct_ratio: float = 0.5,
random_state: int = 42
) -> pd.DataFrame:
"""
Fetches a DataFrame with grievance text + labels from the tables:
- misclassified_complaints (schema as provided)
- complaints (schema as provided)
Will include:
- all reviewed misclassified records (model_predicted_x != correct_x)
- + sampled correct records (model_predicted_x == correct_x) at `correct_ratio` of misclassified count.
Args:
label_column (str): either 'department' or 'urgency'
correct_ratio (float): fraction of misclassified count to sample from correct set
random_state (int): random seed for sampling
Returns:
pd.DataFrame with columns ['grievance', 'department', 'urgency']
"""
if label_column not in {"department", "urgency"}:
raise ValueError("label_column must be either 'department' or 'urgency'")
# define conditions based on column
miscond = f"mc.correct_{label_column} IS NOT NULL AND mc.model_predicted_{label_column} IS DISTINCT FROM mc.correct_{label_column}"
# SQL to fetch misclassified records
sql_mis = text(f"""
SELECT c.message AS grievance,
mc.correct_department AS department,
mc.correct_urgency AS urgency
FROM misclassified_complaints mc
JOIN complaints c ON c.id = mc.complaint_id
WHERE mc.reviewed = TRUE
AND {miscond}
""")
with engine.connect() as conn:
df_mis = pd.read_sql(sql_mis, conn)
# basic check
if df_mis.empty:
return pd.DataFrame(columns=["grievance","department","urgency"])
n_mis = len(df_mis)
n_correct = int(n_mis * correct_ratio)
# SQL to fetch correct records from complaints table NOT in misclassified_complaints
sql_corr = text(f"""
SELECT c.id AS complaint_id,
c.message AS grievance,
c.department AS department,
c.urgency AS urgency
FROM complaints c
WHERE c.id NOT IN (SELECT complaint_id FROM misclassified_complaints)
AND c.{label_column} IS NOT NULL
""")
with engine.connect() as conn:
df_corr_all = pd.read_sql(sql_corr, conn)
if n_correct > 0 and not df_corr_all.empty:
# reproducible random sample
df_corr = df_corr_all.sample(n=min(n_correct, len(df_corr_all)), random_state=random_state).reset_index(drop=True)
else:
df_corr = pd.DataFrame(columns=["grievance","department","urgency"])
# Combine
df_combined = pd.concat([df_mis.reset_index(drop=True), df_corr], ignore_index=True)
# final check: ensure columns present
assert set(df_combined.columns) == {"grievance","department","urgency"}, "Unexpected columns in combined DataFrame"
return df_combined
# # If this file is run directly, simple test:
# if __name__ == "__main__":
# # Quick sanity test for department label
# df_test = fetch_misclassified_dataframe(label_column="department",
# correct_ratio=0.5)
# print("Rows fetched:", len(df_test))
# print(df_test.head())
# # Basic assertion: if rows>0 then none of grievances should be null
# if len(df_test) > 0:
# assert df_test['grievance'].isna().sum() == 0, "Some grievances are null"