{ "cells": [ { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aifgSPecKkfY", "outputId": "9db7f3b7-2a36-42b6-8eb3-6ca07425437d" }, "id": "aifgSPecKkfY", "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Mounted at /content/drive\n" ] } ] }, { "cell_type": "markdown", "id": "aacf5211", "metadata": { "id": "aacf5211" }, "source": [ "###Importing Liberaries" ] }, { "cell_type": "code", "execution_count": null, "id": "24577b88", "metadata": { "id": "24577b88" }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.model_selection import GridSearchCV\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.neural_network import MLPClassifier\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from xgboost import XGBClassifier\n", "from sklearn.svm import SVC\n", "from sklearn.metrics import accuracy_score, classification_report\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "markdown", "id": "d70990dc", "metadata": { "id": "d70990dc" }, "source": [ "### Data Load" ] }, { "cell_type": "code", "execution_count": null, "id": "3de86ddb", "metadata": { "id": "3de86ddb", "colab": { "base_uri": "https://localhost:8080/", "height": 321 }, "outputId": "05c87a38-574b-4a6d-bb07-1edd7a9afd42" }, "outputs": [ { "output_type": "error", "ename": "FileNotFoundError", "evalue": "[Errno 2] No such file or directory: '/content/drive/MyDrive/heart_failure_clinical_records_dataset.csv'", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipython-input-4048807198.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mr'/content/drive/MyDrive/heart_failure_clinical_records_dataset.csv'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36mread_csv\u001b[0;34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, date_format, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options, dtype_backend)\u001b[0m\n\u001b[1;32m 1024\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwds_defaults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1025\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1026\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_read\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1027\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1028\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36m_read\u001b[0;34m(filepath_or_buffer, kwds)\u001b[0m\n\u001b[1;32m 618\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 619\u001b[0m \u001b[0;31m# Create the parser.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 620\u001b[0;31m \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTextFileReader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 622\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mchunksize\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, f, engine, **kwds)\u001b[0m\n\u001b[1;32m 1618\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1619\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandles\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mIOHandles\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1620\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_engine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1622\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36m_make_engine\u001b[0;34m(self, f, engine)\u001b[0m\n\u001b[1;32m 1878\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m\"b\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1879\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m\"b\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1880\u001b[0;31m self.handles = get_handle(\n\u001b[0m\u001b[1;32m 1881\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1882\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/pandas/io/common.py\u001b[0m in \u001b[0;36mget_handle\u001b[0;34m(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)\u001b[0m\n\u001b[1;32m 871\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoding\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m\"b\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 872\u001b[0m \u001b[0;31m# Encoding\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 873\u001b[0;31m handle = open(\n\u001b[0m\u001b[1;32m 874\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 875\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/content/drive/MyDrive/heart_failure_clinical_records_dataset.csv'" ] } ], "source": [ "data = pd.read_csv(r'/content/drive/MyDrive/heart_failure_clinical_records_dataset.csv')" ] }, { "cell_type": "markdown", "source": [ "### Data Exploratory" ], "metadata": { "id": "P20f_aZ0nanU" }, "id": "P20f_aZ0nanU" }, { "cell_type": "code", "source": [ "data" ], "metadata": { "id": "R0JxTMpInaUs" }, "id": "R0JxTMpInaUs", "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "id": "c7f83776", "metadata": { "id": "c7f83776" }, "outputs": [], "source": [ "data.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "ac3d6a1e", "metadata": { "id": "ac3d6a1e" }, "outputs": [], "source": [ "data.info()" ] }, { "cell_type": "code", "execution_count": null, "id": "e754b5e8", "metadata": { "id": "e754b5e8" }, "outputs": [], "source": [ "data.isnull().sum()" ] }, { "cell_type": "code", "execution_count": null, "id": "e95bcd68", "metadata": { "id": "e95bcd68" }, "outputs": [], "source": [ "data.duplicated().sum()" ] }, { "cell_type": "code", "execution_count": null, "id": "2ce23598", "metadata": { "id": "2ce23598" }, "outputs": [], "source": [ "labels = [\"40-45\", \"46-50\", \"51-55\", \"56-60\", \"61-65\", \"66-70\", \"71-75\", \"76-80\", \"81-95\"]\n", "data['age_group'] = pd.cut(data['age'], bins=[40, 45, 50, 55, 60, 65, 70, 75, 80, 95], labels=labels)" ] }, { "cell_type": "markdown", "id": "852a3203", "metadata": { "id": "852a3203" }, "source": [ "### Data Visualization" ] }, { "cell_type": "code", "execution_count": null, "id": "fc5f6131", "metadata": { "id": "fc5f6131" }, "outputs": [], "source": [ "plt.figure(figsize=(10,6))\n", "sns.countplot(data=data, x='age_group', hue='DEATH_EVENT', palette=[\"lightblue\", \"red\"])\n", "plt.title(\"Death Count by Age Group\")\n", "plt.xlabel(\"Age Group\")\n", "plt.ylabel(\"Patient Count\")\n", "plt.legend([\"Survived\", \"Died\"])\n", "plt.show()" ] }, { "cell_type": "code", "source": [ "corr_matrix = data.drop(columns=['age_group']).corr()\n", "plt.figure(figsize=(12, 10))\n", "sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=\".2f\")\n", "plt.title('Correlation Matrix of Heart Failure Clinical Records')\n", "plt.show()" ], "metadata": { "id": "687Lx-xInvLN" }, "id": "687Lx-xInvLN", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "death_counts = data['DEATH_EVENT'].value_counts()\n", "plt.figure(figsize=(6, 6))\n", "plt.pie(death_counts, labels=['Not Died', 'Died'], autopct='%1.1f%%', startangle=90, colors=['skyblue', 'lightcoral'])\n", "plt.title('Distribution of DEATH_EVENT')\n", "plt.show()" ], "metadata": { "id": "CFGNvM9un7CB" }, "id": "CFGNvM9un7CB", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Select a subset of numerical features that showed some correlation with DEATH_EVENT\n", "selected_features = ['time', 'serum_creatinine', 'ejection_fraction', 'age', 'serum_sodium', 'DEATH_EVENT']\n", "\n", "sns.pairplot(data[selected_features], hue='DEATH_EVENT', diag_kind='kde')\n", "plt.suptitle('Pairplot of Selected Numerical Features by DEATH_EVENT', y=1.02)\n", "plt.show()" ], "metadata": { "id": "akxmasIGn_Ps" }, "id": "akxmasIGn_Ps", "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Data Preprocessing" ], "metadata": { "id": "lAmTgq0AoJbP" }, "id": "lAmTgq0AoJbP" }, { "cell_type": "markdown", "id": "6318b50d", "metadata": { "id": "6318b50d" }, "source": [ "### Data Split\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f9bbf4a6", "metadata": { "id": "f9bbf4a6" }, "outputs": [], "source": [ "data.drop(columns=['age_group'], inplace=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "67245c6b", "metadata": { "id": "67245c6b" }, "outputs": [], "source": [ "X = data.drop('DEATH_EVENT', axis=1)\n", "y = data['DEATH_EVENT']\n", "from sklearn.model_selection import train_test_split\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)" ] }, { "cell_type": "markdown", "source": [ "### Feature Scaling" ], "metadata": { "id": "9RC0CaRQoPSL" }, "id": "9RC0CaRQoPSL" }, { "cell_type": "code", "execution_count": null, "id": "eff46e4d", "metadata": { "id": "eff46e4d" }, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler\n", "scaler = StandardScaler()\n", "continuous_features = ['age', 'creatinine_phosphokinase', 'ejection_fraction', 'platelets', 'serum_creatinine', 'serum_sodium', 'time']\n", "X_train[continuous_features] = scaler.fit_transform(X_train[continuous_features])\n", "X_test[continuous_features] = scaler.transform(X_test[continuous_features])" ] }, { "cell_type": "markdown", "source": [ "#Modeling" ], "metadata": { "id": "RgfpGCrFoYYo" }, "id": "RgfpGCrFoYYo" }, { "cell_type": "markdown", "id": "c6c49e64", "metadata": { "id": "c6c49e64" }, "source": [ "### Logistic Regression" ] }, { "cell_type": "code", "execution_count": null, "id": "c65331a8", "metadata": { "id": "c65331a8" }, "outputs": [], "source": [ "log_params = {\n", " 'penalty': ['l1', 'l2', 'elasticnet', 'none'],\n", " 'C': [0.01, 0.1, 1, 10, 100],\n", " 'solver': ['lbfgs', 'saga'],\n", " 'max_iter': [1000]\n", "}\n", "\n", "log_grid = GridSearchCV(LogisticRegression(random_state=42), log_params, cv=5)\n", "log_grid.fit(X_train, y_train)\n", "\n", "print(\" Logistic Regression Best Params:\", log_grid.best_params_)" ] }, { "cell_type": "markdown", "source": [ "####Evaluation" ], "metadata": { "id": "A7F1ne-9okC3" }, "id": "A7F1ne-9okC3" }, { "cell_type": "code", "execution_count": null, "id": "bb425d64", "metadata": { "id": "bb425d64" }, "outputs": [], "source": [ "log_model = LogisticRegression(\n", " penalty='l2',\n", " C=0.1,\n", " solver='lbfgs',\n", " max_iter=1000,\n", " random_state=42\n", ")\n", "log_model.fit(X_train, y_train)\n", "y_pred_log = log_model.predict(X_test)\n", "print(\" Logistic Regression\")\n", "print(f\"Accuracy: {accuracy_score(y_test, y_pred_log):.4f}\")\n", "print(classification_report(y_test, y_pred_log))" ] }, { "cell_type": "markdown", "id": "9ec5c7bd", "metadata": { "id": "9ec5c7bd" }, "source": [ "### Random Forest" ] }, { "cell_type": "code", "execution_count": null, "id": "355a5349", "metadata": { "id": "355a5349" }, "outputs": [], "source": [ "rf_params = {\n", " 'n_estimators': [50, 100, 200],\n", " 'max_depth': [None, 5, 10],\n", " 'min_samples_split': [2, 5],\n", " 'min_samples_leaf': [1, 2]\n", "}\n", "\n", "rf_grid = GridSearchCV(RandomForestClassifier(random_state=42), rf_params, cv=5)\n", "rf_grid.fit(X_train, y_train)\n", "\n", "print(\" Random Forest Best Params:\", rf_grid.best_params_)\n" ] }, { "cell_type": "markdown", "source": [ "####Evaluation" ], "metadata": { "id": "ZgnqGv2_onMp" }, "id": "ZgnqGv2_onMp" }, { "cell_type": "code", "execution_count": null, "id": "7a814143", "metadata": { "id": "7a814143" }, "outputs": [], "source": [ "rf_model = RandomForestClassifier(\n", " n_estimators=50, max_depth=5,\n", " min_samples_leaf=2, min_samples_split=5,\n", " random_state=42\n", ")\n", "rf_model.fit(X_train, y_train)\n", "y_pred_rf = rf_model.predict(X_test)\n", "print(\" Random Forest\")\n", "print(f\"Accuracy: {accuracy_score(y_test, y_pred_rf):.4f}\")\n", "print(classification_report(y_test, y_pred_rf))\n" ] }, { "cell_type": "markdown", "id": "8ae23a4c", "metadata": { "id": "8ae23a4c" }, "source": [ "### SVM" ] }, { "cell_type": "code", "execution_count": null, "id": "98d79b19", "metadata": { "id": "98d79b19" }, "outputs": [], "source": [ "svm_params = {\n", " 'kernel': ['linear', 'rbf'],\n", " 'C': [0.1, 1, 10],\n", " 'gamma': ['scale', 'auto']\n", "}\n", "\n", "svm_grid = GridSearchCV(SVC(probability=True, random_state=42), svm_params, cv=5)\n", "svm_grid.fit(X_train, y_train)\n", "\n", "print(\" SVM Best Params:\", svm_grid.best_params_)" ] }, { "cell_type": "markdown", "source": [ "#### Evaluation" ], "metadata": { "id": "lGcRpN66oqox" }, "id": "lGcRpN66oqox" }, { "cell_type": "code", "execution_count": null, "id": "5b3f845f", "metadata": { "id": "5b3f845f" }, "outputs": [], "source": [ "svm_model = SVC(\n", " C=0.1, gamma='scale', kernel='linear',\n", " probability=True, random_state=42\n", ")\n", "svm_model.fit(X_train, y_train)\n", "y_pred_svm = svm_model.predict(X_test)\n", "print(\"\\n SVM\")\n", "print(f\"Accuracy: {accuracy_score(y_test, y_pred_svm):.4f}\")\n", "print(classification_report(y_test, y_pred_svm))" ] }, { "cell_type": "markdown", "id": "397c4db9", "metadata": { "id": "397c4db9" }, "source": [ "### MLP" ] }, { "cell_type": "code", "execution_count": null, "id": "161c3769", "metadata": { "id": "161c3769" }, "outputs": [], "source": [ "mlp_params = {\n", " 'hidden_layer_sizes': [(64,), (64, 32), (128, 64)],\n", " 'activation': ['relu', 'tanh'],\n", " 'alpha': [0.0001, 0.001],\n", " 'learning_rate': ['constant', 'adaptive']\n", "}\n", "\n", "mlp_grid = GridSearchCV(MLPClassifier(max_iter=1000, random_state=42), mlp_params, cv=5)\n", "mlp_grid.fit(X_train, y_train)\n", "\n", "print(\" MLP Best Params:\", mlp_grid.best_params_)\n" ] }, { "cell_type": "markdown", "source": [ "#### Evaluation" ], "metadata": { "id": "xP9abpojovRZ" }, "id": "xP9abpojovRZ" }, { "cell_type": "code", "execution_count": null, "id": "c3f80cb8", "metadata": { "id": "c3f80cb8" }, "outputs": [], "source": [ "mlp_model = MLPClassifier(\n", " hidden_layer_sizes=(64, 32),\n", " activation='tanh',\n", " alpha=0.0001,\n", " learning_rate='constant',\n", " max_iter=1000,\n", " random_state=42\n", ")\n", "mlp_model.fit(X_train, y_train)\n", "y_pred_mlp = mlp_model.predict(X_test)\n", "print(\"\\n MLP Neural Network\")\n", "print(f\"Accuracy: {accuracy_score(y_test, y_pred_mlp):.4f}\")\n", "print(classification_report(y_test, y_pred_mlp))" ] }, { "cell_type": "markdown", "id": "26b1f47b", "metadata": { "id": "26b1f47b" }, "source": [ "### XGBoost" ] }, { "cell_type": "code", "execution_count": null, "id": "c2cccaf0", "metadata": { "id": "c2cccaf0" }, "outputs": [], "source": [ "xgb_params = {\n", " 'n_estimators': [50, 100, 200],\n", " 'max_depth': [3, 4, 5],\n", " 'learning_rate': [0.01, 0.1, 0.2]\n", "}\n", "\n", "xgb_grid = GridSearchCV(\n", " XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42),\n", " xgb_params, cv=5\n", ")\n", "xgb_grid.fit(X_train, y_train)\n", "\n", "print(\" XGBoost Best Params:\", xgb_grid.best_params_)\n" ] }, { "cell_type": "markdown", "source": [ "#### Evaluation" ], "metadata": { "id": "gzj365Wkoyni" }, "id": "gzj365Wkoyni" }, { "cell_type": "code", "execution_count": null, "id": "01cefcfa", "metadata": { "id": "01cefcfa" }, "outputs": [], "source": [ "xgb_model = XGBClassifier(\n", " n_estimators=50,\n", " max_depth=4,\n", " learning_rate=0.2,\n", " use_label_encoder=False,\n", " eval_metric='logloss',\n", " random_state=42\n", ")\n", "xgb_model.fit(X_train, y_train)\n", "y_pred_xgb = xgb_model.predict(X_test)\n", "print(\"\\n XGBoost\")\n", "print(f\"Accuracy: {accuracy_score(y_test, y_pred_xgb):.4f}\")\n", "print(classification_report(y_test, y_pred_xgb))" ] }, { "cell_type": "markdown", "id": "eecde701", "metadata": { "id": "eecde701" }, "source": [ "### KNN" ] }, { "cell_type": "code", "execution_count": null, "id": "985c647f", "metadata": { "id": "985c647f" }, "outputs": [], "source": [ "knn_params = {\n", " 'n_neighbors': [3, 5, 7, 9],\n", " 'weights': ['uniform', 'distance'],\n", " 'metric': ['euclidean', 'manhattan']\n", "}\n", "\n", "knn_grid = GridSearchCV(KNeighborsClassifier(), knn_params, cv=5)\n", "knn_grid.fit(X_train, y_train)\n", "\n", "print(\" KNN Best Params:\", knn_grid.best_params_)" ] }, { "cell_type": "markdown", "source": [ "#### Evaluation" ], "metadata": { "id": "20E5x9Rmo3Le" }, "id": "20E5x9Rmo3Le" }, { "cell_type": "code", "execution_count": null, "id": "a5f50c88", "metadata": { "id": "a5f50c88" }, "outputs": [], "source": [ "knn_model = KNeighborsClassifier(\n", " n_neighbors=5,\n", " weights='uniform',\n", " metric='euclidean'\n", ")\n", "knn_model.fit(X_train, y_train)\n", "y_pred_knn = knn_model.predict(X_test)\n", "print(\"\\n KNN\")\n", "print(f\"Accuracy: {accuracy_score(y_test, y_pred_knn):.4f}\")\n", "print(classification_report(y_test, y_pred_knn))" ] }, { "cell_type": "markdown", "id": "658b2f4c", "metadata": { "id": "658b2f4c" }, "source": [ "### Models Accuracies" ] }, { "cell_type": "code", "execution_count": null, "id": "8eb234da", "metadata": { "id": "8eb234da" }, "outputs": [], "source": [ "models = [\n", " 'Random Forest', 'SVM', 'MLP',\n", " 'XGBoost', 'KNN', 'Logistic Regression'\n", "]\n", "accuracies = [\n", " 0.85, 0.8333, 0.6833,\n", " 0.8333, 0.7167, 0.8333\n", "]\n", "\n", "plt.figure(figsize=(10, 6))\n", "plt.bar(models, accuracies, color=['blue', 'green', 'purple', 'orange', 'red', 'cyan'])\n", "plt.ylim(0, 1)\n", "plt.ylabel('Accuracy')\n", "plt.title('Model Accuracy Comparison')\n", "\n", "plt.xticks(rotation=30)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "0ef22c6c", "metadata": { "id": "0ef22c6c" }, "outputs": [], "source": [ "import gradio as gr\n", "from sklearn.preprocessing import StandardScaler\n", "import joblib" ] }, { "cell_type": "code", "execution_count": null, "id": "28aa35d9", "metadata": { "id": "28aa35d9" }, "outputs": [], "source": [ "joblib.dump(rf_model, \"heart_model.pkl\")\n", "joblib.dump(scaler, \"scaler.pkl\")\n", "print(\"Model and scaler saved successfully\")" ] }, { "cell_type": "code", "execution_count": null, "id": "165b4cab", "metadata": { "id": "165b4cab" }, "outputs": [], "source": [ "model = joblib.load(\"heart_model.pkl\")\n", "scaler = joblib.load(\"scaler.pkl\")" ] }, { "cell_type": "code", "execution_count": null, "id": "c41a4646", "metadata": { "id": "c41a4646" }, "outputs": [], "source": [ "def predict_heart_risk(age, cpk, ef, platelets, sc, ss, time, anaemia, diabetes, high_bp, sex, smoking):\n", " data = pd.DataFrame([[\n", " age, anaemia, cpk, diabetes, ef, high_bp,\n", " platelets, sc, ss, sex, smoking, time\n", " ]], columns=[\n", " 'age', 'anaemia', 'creatinine_phosphokinase', 'diabetes',\n", " 'ejection_fraction', 'high_blood_pressure', 'platelets',\n", " 'serum_creatinine', 'serum_sodium', 'sex', 'smoking', 'time'\n", " ])\n", "\n", "\n", " continuous_features = ['age', 'creatinine_phosphokinase', 'ejection_fraction','platelets', 'serum_creatinine', 'serum_sodium', 'time']\n", " data[continuous_features] = scaler.transform(data[continuous_features])\n", "\n", " prediction = model.predict(data)[0]\n", " return \" At Risk\" if prediction == 1 else \" Not At Risk\"" ] }, { "cell_type": "code", "execution_count": null, "id": "5ca7be47", "metadata": { "id": "5ca7be47" }, "outputs": [], "source": [ "inputs = [\n", " gr.Number(label=\"Age\"),\n", " gr.Number(label=\"Creatinine Phosphokinase, Range [0,100000]\"),\n", " gr.Number(label=\"Ejection Fraction, Range [5,85] \"),\n", " gr.Number(label=\"Platelets, Range [5000,2000000]\"),\n", " gr.Number(label=\"Serum Creatinine, Range [0.1,60]\"),\n", " gr.Number(label=\"Serum Sodium, Range [95,255]\"),\n", " gr.Number(label=\"Follow-up Time (days)\"),\n", " gr.Radio([0, 1], label=\"Anaemia (0=No, 1=Yes)\"),\n", " gr.Radio([0, 1], label=\"Diabetes (0=No, 1=Yes)\"),\n", " gr.Radio([0, 1], label=\"High Blood Pressure (0=No, 1=Yes)\"),\n", " gr.Radio([0, 1], label=\"Sex (0=Female, 1=Male)\"),\n", " gr.Radio([0, 1], label=\"Smoking (0=No, 1=Yes)\")\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "563bc8b2", "metadata": { "id": "563bc8b2" }, "outputs": [], "source": [ "gr.Interface(\n", " fn=predict_heart_risk,\n", " inputs=inputs,\n", " outputs=\"text\",\n", " title=\" Heart Failure Risk Predictor\",\n", " description=\"Enter patient data to predict if they are at risk of heart failure.\",\n", " allow_flagging=\"never\"\n", ").launch()" ] }, { "cell_type": "code", "source": [], "metadata": { "id": "OlW7PfhJLXlE" }, "id": "OlW7PfhJLXlE", "execution_count": null, "outputs": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 5 }