Heart disease continues to be a paramount challenge of modern medicine. Since 2016, an estimated 1 million heart surgeries are performed globally each year. In 2020, the Center for Disease Control identified heart disease as the leading cause of death in the United States. Our goal is to use the given patient attributes to develop a model that can predict whether a patient has heart disease or not. If the model performs well, it could be used by the medical industry on new patients for which this data is available to predict the presence of heart disease, allowing them to assign priority among patients for diagnostic testing, surgeries, etc.
In this project, we will:
Again, we will be using a dataset containing patient information related to heart disease. UCI Heart Disease Data source
# Import packages
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
#from tabulate import tabulate
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
#from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, classification_report,f1_score
from sklearn.model_selection import cross_val_score, RandomizedSearchCV
from sklearn import metrics
from sklearn.metrics import mean_squared_error
#from sklearn.tree import DecisionTreeClassifier
#from sklearn.tree import plot_tree
#from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier
#from xgboost import XGBClassifier
from sklearn.metrics import make_scorer
from sklearn.metrics import classification_report
# Import data
heart = pd.read_csv("heart_disease_uci.csv")
Our target for prediction is the column 'num', which indicates the presence of heart disease in any record with a non-zero value.
heart.head()
id | age | sex | dataset | cp | trestbps | chol | fbs | restecg | thalch | exang | oldpeak | slope | ca | thal | num | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 63 | Male | Cleveland | typical angina | 145.0 | 233.0 | True | lv hypertrophy | 150.0 | False | 2.3 | downsloping | 0.0 | fixed defect | 0 |
1 | 2 | 67 | Male | Cleveland | asymptomatic | 160.0 | 286.0 | False | lv hypertrophy | 108.0 | True | 1.5 | flat | 3.0 | normal | 2 |
2 | 3 | 67 | Male | Cleveland | asymptomatic | 120.0 | 229.0 | False | lv hypertrophy | 129.0 | True | 2.6 | flat | 2.0 | reversable defect | 1 |
3 | 4 | 37 | Male | Cleveland | non-anginal | 130.0 | 250.0 | False | normal | 187.0 | False | 3.5 | downsloping | 0.0 | normal | 0 |
4 | 5 | 41 | Female | Cleveland | atypical angina | 130.0 | 204.0 | False | lv hypertrophy | 172.0 | False | 1.4 | upsloping | 0.0 | normal | 0 |
# Examine suspected categorical variables
print(f"Unique values for 'cp':",heart.cp.unique(),"\n")
print(f"Unique values for 'restecg':",heart.restecg.unique(),"\n")
print(f"Unique values for 'slope':",heart.slope.unique(),"\n")
print(f"Unique values for 'ca':",heart.ca.unique(),"\n")
print(f"Unique values for 'thal':",heart.thal.unique(),"\n")
print(f"Unique values for 'num':",heart.num.unique(),"\n")
Unique values for 'cp': ['typical angina' 'asymptomatic' 'non-anginal' 'atypical angina'] Unique values for 'restecg': ['lv hypertrophy' 'normal' 'st-t abnormality' nan] Unique values for 'slope': ['downsloping' 'flat' 'upsloping' nan] Unique values for 'ca': [ 0. 3. 2. 1. nan] Unique values for 'thal': ['fixed defect' 'normal' 'reversable defect' nan] Unique values for 'num': [0 2 1 3 4]
# Examine numerical variables
heart.describe()
id | age | trestbps | chol | thalch | oldpeak | ca | num | |
---|---|---|---|---|---|---|---|---|
count | 920.000000 | 920.000000 | 861.000000 | 890.000000 | 865.000000 | 858.000000 | 309.000000 | 920.000000 |
mean | 460.500000 | 53.510870 | 132.132404 | 199.130337 | 137.545665 | 0.878788 | 0.676375 | 0.995652 |
std | 265.725422 | 9.424685 | 19.066070 | 110.780810 | 25.926276 | 1.091226 | 0.935653 | 1.142693 |
min | 1.000000 | 28.000000 | 0.000000 | 0.000000 | 60.000000 | -2.600000 | 0.000000 | 0.000000 |
25% | 230.750000 | 47.000000 | 120.000000 | 175.000000 | 120.000000 | 0.000000 | 0.000000 | 0.000000 |
50% | 460.500000 | 54.000000 | 130.000000 | 223.000000 | 140.000000 | 0.500000 | 0.000000 | 1.000000 |
75% | 690.250000 | 60.000000 | 140.000000 | 268.000000 | 157.000000 | 1.500000 | 1.000000 | 2.000000 |
max | 920.000000 | 77.000000 | 200.000000 | 603.000000 | 202.000000 | 6.200000 | 3.000000 | 4.000000 |
heart.isnull().sum()
id 0 age 0 sex 0 dataset 0 cp 0 trestbps 59 chol 30 fbs 90 restecg 2 thalch 55 exang 55 oldpeak 62 slope 309 ca 611 thal 486 num 0 dtype: int64
heart.dtypes
id int64 age int64 sex object dataset object cp object trestbps float64 chol float64 fbs object restecg object thalch float64 exang object oldpeak float64 slope object ca float64 thal object num int64 dtype: object
With so many null values in a few specific features, we will select only the features/columns of interest for the models.
The fields slope, ca and thal have many null values (>300). Types of data:
- slope : categorical
- ca: discrete (0,1,2,3) - categorical
- thal: categorical
Since imputing these categorical variables will add false data, we will drop the three columns.
Columns being dropped:
# List of columns to retain for analysis
columns_keep = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs',
'restecg', 'thalch', 'exang', 'oldpeak', 'num']
heartpred = heart[columns_keep].copy()
# Note: The target column 'num' is still included here for preprocessing
# and will be separated when the data is split.
heartpred.head()
age | sex | cp | trestbps | chol | fbs | restecg | thalch | exang | oldpeak | num | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | Male | typical angina | 145.0 | 233.0 | True | lv hypertrophy | 150.0 | False | 2.3 | 0 |
1 | 67 | Male | asymptomatic | 160.0 | 286.0 | False | lv hypertrophy | 108.0 | True | 1.5 | 2 |
2 | 67 | Male | asymptomatic | 120.0 | 229.0 | False | lv hypertrophy | 129.0 | True | 2.6 | 1 |
3 | 37 | Male | non-anginal | 130.0 | 250.0 | False | normal | 187.0 | False | 3.5 | 0 |
4 | 41 | Female | atypical angina | 130.0 | 204.0 | False | lv hypertrophy | 172.0 | False | 1.4 | 0 |
# Populate missing values in numerical columns with the median
heartpred.trestbps = heartpred.trestbps.fillna(value=heartpred['trestbps'].median())
heartpred.chol = heartpred.chol.fillna(value=heartpred['chol'].median())
heartpred.thalch = heartpred.thalch.fillna(value=heartpred['thalch'].median())
heartpred.oldpeak = heartpred.oldpeak.fillna(value=heartpred['oldpeak'].median())
heartpred.isnull().sum()
age 0 sex 0 cp 0 trestbps 0 chol 0 fbs 90 restecg 2 thalch 0 exang 55 oldpeak 0 num 0 dtype: int64
While the median values are assumed to be appropriate replacements for missing values in the numerical columns ('trestbps', 'chol', 'thalch', 'oldpeak'), it is difficult to estimate a similar 'median' replacement for missing values in the categorical columns. For this reason, the remaining records with missing values will be dropped instead of potentially misclassified.
# Populate missing values in categorical columns with NA to easily drop rows
heartpred.replace('', np.nan, inplace=True)
# Drop rows where missing values are still present (only categorical columns)
heartpred.dropna(subset=['fbs'], inplace=True)
heartpred.dropna(subset=['restecg'], inplace=True)
heartpred.dropna(subset=['exang'], inplace=True)
heartpred.isnull().sum()
age 0 sex 0 cp 0 trestbps 0 chol 0 fbs 0 restecg 0 thalch 0 exang 0 oldpeak 0 num 0 dtype: int64
# Data set reduced by 146 records
heartpred.describe()
age | trestbps | chol | thalch | oldpeak | num | |
---|---|---|---|---|---|---|
count | 774.000000 | 774.000000 | 774.000000 | 774.000000 | 774.000000 | 774.000000 |
mean | 53.071059 | 132.775194 | 219.301034 | 138.677003 | 0.885401 | 0.919897 |
std | 9.430970 | 18.577723 | 92.594114 | 25.808812 | 1.081890 | 1.133424 |
min | 28.000000 | 0.000000 | 0.000000 | 60.000000 | -1.000000 | 0.000000 |
25% | 46.000000 | 120.000000 | 198.000000 | 120.000000 | 0.000000 | 0.000000 |
50% | 54.000000 | 130.000000 | 228.000000 | 140.000000 | 0.500000 | 1.000000 |
75% | 60.000000 | 140.000000 | 269.000000 | 159.000000 | 1.500000 | 1.000000 |
max | 77.000000 | 200.000000 | 603.000000 | 202.000000 | 6.200000 | 4.000000 |
As a distance based algorithm, k-NN is affected by the scale of numerical variables. The k-NN performance is affected by giving more weight to values with higher magnitudes, so we will choose to standardize these features. We will not normalize these features between a scale of 0 and 1 as it is preferable to retain information about outliers that could be lost during MinMax scaling. The decision tree model(s) will not be affected by the scaling.
# Create scaler for standardization
scaler = StandardScaler()
# Apply to numerical columns: age, trestbps, chol, thalch, oldpeak
heartpred.iloc[:,[0,3,4,7,9]] = scaler.fit_transform(heartpred.iloc[:,[0,3,4,7,9]])
# Confirm scaling applied
heartpred.describe()
age | trestbps | chol | thalch | oldpeak | num | |
---|---|---|---|---|---|---|
count | 7.740000e+02 | 7.740000e+02 | 7.740000e+02 | 7.740000e+02 | 7.740000e+02 | 774.000000 |
mean | -5.909714e-17 | 2.003852e-16 | 2.626380e-16 | 5.531034e-16 | -9.848568e-16 | 0.919897 |
std | 1.000647e+00 | 1.000647e+00 | 1.000647e+00 | 1.000647e+00 | 1.000647e+00 | 1.133424 |
min | -2.660094e+00 | -7.151632e+00 | -2.369944e+00 | -3.050426e+00 | -1.743819e+00 | 0.000000 |
25% | -7.502549e-01 | -6.881066e-01 | -2.301961e-01 | -7.241356e-01 | -8.189126e-01 | 0.000000 |
50% | 9.856263e-02 | -1.494795e-01 | 9.400804e-02 | 5.129461e-02 | -3.564594e-01 | 1.000000 |
75% | 7.351758e-01 | 3.891477e-01 | 5.370871e-01 | 7.879533e-01 | 5.684470e-01 | 1.000000 |
max | 2.538913e+00 | 3.620911e+00 | 4.146560e+00 | 2.455128e+00 | 4.915507e+00 | 4.000000 |
heartpred.head()
age | sex | cp | trestbps | chol | fbs | restecg | thalch | exang | oldpeak | num | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1.053482 | Male | typical angina | 0.658461 | 0.148042 | True | lv hypertrophy | 0.439010 | False | 1.308372 | 0 |
1 | 1.477891 | Male | asymptomatic | 1.466402 | 0.720803 | False | lv hypertrophy | -1.189394 | True | 0.568447 | 2 |
2 | 1.477891 | Male | asymptomatic | -0.688107 | 0.104815 | False | lv hypertrophy | -0.375192 | True | 1.585844 | 1 |
3 | -1.705175 | Male | non-anginal | -0.149479 | 0.331758 | False | normal | 1.873556 | False | 2.418260 | 0 |
4 | -1.280766 | Female | atypical angina | -0.149479 | -0.165355 | False | lv hypertrophy | 1.291983 | False | 0.475956 | 0 |
Encoding the categorical variables will be necessary for either model and can be accomplished as ordinal or dummy/one-hot encoding. Ordinal encoding would provide an extra dimension of information (the natural rank of the classes), but unfortunately cannot be used here as a result of limited domain knowledge.
The 2 remaining features with multiple classes below cannot be interpreted as having a natural ranking order without deeper understanding of these patient attributes. For this reason, only dummy/one-hot encoding will be used.
print(f"Unique values for 'cp':",heart.cp.unique(),"\n")
print(f"Unique values for 'restecg':",heart.restecg.unique(),"\n")
Unique values for 'cp': ['typical angina' 'asymptomatic' 'non-anginal' 'atypical angina'] Unique values for 'restecg': ['lv hypertrophy' 'normal' 'st-t abnormality' nan]
# View data types
heartpred.dtypes
age float64 sex object cp object trestbps float64 chol float64 fbs object restecg object thalch float64 exang object oldpeak float64 num int64 dtype: object
# Convert object data type to category (for encoding)
heartpred.sex = heartpred.sex.astype('category')
heartpred.cp = heartpred.cp.astype('category')
heartpred.fbs = heartpred.fbs.astype('category')
heartpred.restecg = heartpred.restecg.astype('category')
heartpred.exang = heartpred.exang.astype('category')
heartpred.dtypes
age float64 sex category cp category trestbps float64 chol float64 fbs category restecg category thalch float64 exang category oldpeak float64 num int64 dtype: object
# Split categorical variable into dummy variables
heartpred = pd.get_dummies(heartpred,prefix_sep='_')
heartpred.head()
age | trestbps | chol | thalch | oldpeak | num | sex_Female | sex_Male | cp_asymptomatic | cp_atypical angina | cp_non-anginal | cp_typical angina | fbs_False | fbs_True | restecg_lv hypertrophy | restecg_normal | restecg_st-t abnormality | exang_False | exang_True | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1.053482 | 0.658461 | 0.148042 | 0.439010 | 1.308372 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 1 | 0 | 0 | 1 | 0 |
1 | 1.477891 | 1.466402 | 0.720803 | -1.189394 | 0.568447 | 2 | 0 | 1 | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 1 |
2 | 1.477891 | -0.688107 | 0.104815 | -0.375192 | 1.585844 | 1 | 0 | 1 | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 1 |
3 | -1.705175 | -0.149479 | 0.331758 | 1.873556 | 2.418260 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 1 | 0 | 1 | 0 |
4 | -1.280766 | -0.149479 | -0.165355 | 1.291983 | 0.475956 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 1 | 0 |
The decision to standardize the target variable 'num' into a binary classification implies that our models will predict the presence of any heart disease in the subject, rather than the severity of the heart disease (1, 2, 3, 4).
It is noted this may impact the efficacy of the models, as this oversimplification may not consider any confounding relationships between the levels of heart disease (1, 2, 3, 4).
# Simplify target variable to a binary classification
heartpred.num = heartpred.num.replace(2, 1, regex=True)
heartpred.num = heartpred.num.replace(3, 1, regex=True)
heartpred.num = heartpred.num.replace(4, 1, regex=True)
heartpred.num.value_counts()
1 397 0 377 Name: num, dtype: int64
The target column is no longer evenly split, so there is now a data imbalance that will negatively impact the k-NN model. However, the decision trees will not be impacted by the data imbalance. Because of this, the data imbalance issue will remain unaddressed at this stage.
heartpred.columns
Index(['age', 'trestbps', 'chol', 'thalch', 'oldpeak', 'num', 'sex_Female', 'sex_Male', 'cp_asymptomatic', 'cp_atypical angina', 'cp_non-anginal', 'cp_typical angina', 'fbs_False', 'fbs_True', 'restecg_lv hypertrophy', 'restecg_normal', 'restecg_st-t abnormality', 'exang_False', 'exang_True'], dtype='object')
target = 'num'
predictors = ['age', 'trestbps', 'chol', 'thalch', 'oldpeak', 'sex_Female',
'sex_Male', 'cp_asymptomatic', 'cp_atypical angina', 'cp_non-anginal',
'cp_typical angina', 'fbs_False', 'fbs_True', 'restecg_lv hypertrophy',
'restecg_normal', 'restecg_st-t abnormality', 'exang_False',
'exang_True']
X = heartpred[predictors]
y = heartpred[target]
# 70/30 split
train_X, valid_X, train_y, valid_y = train_test_split(X,y, test_size=0.3, random_state=1)
# Number of predictors in the input layer
len(X.columns)
18
# Create MLP Classifiers with parameters, keeping in mind the number of input variables is 18
np.random.seed(6136)
# Default parameters
ann_default = MLPClassifier(max_iter=2500)
# hidden_layer_sizes=(100,)
# activation='relu'
# solver='adam'
# learning_rate='constant'
# Specifying number of neurons in each layer
# Solver 'sgd' previously outperformed the default 'adam' solver with
# previous experiments using the same data and will be used again
ann_layers1 = MLPClassifier(hidden_layer_sizes=[16,],solver='sgd',max_iter=2000)
ann_layers2 = MLPClassifier(hidden_layer_sizes=[14,8,4],solver='sgd',max_iter=2000)
ann_layers3 = MLPClassifier(hidden_layer_sizes=[14,10,8,4],solver='sgd',max_iter=2000)
%%time
ann_default.fit(train_X, train_y)
CPU times: user 16.1 s, sys: 296 ms, total: 16.4 s Wall time: 4.22 s
MLPClassifier(max_iter=2500)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
MLPClassifier(max_iter=2500)
%%time
ann_layers1.fit(train_X, train_y)
CPU times: user 2.01 s, sys: 24.6 ms, total: 2.03 s Wall time: 516 ms
MLPClassifier(hidden_layer_sizes=[16], max_iter=2000, solver='sgd')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
MLPClassifier(hidden_layer_sizes=[16], max_iter=2000, solver='sgd')
%%time
ann_layers2.fit(train_X, train_y)
CPU times: user 5.19 s, sys: 73.9 ms, total: 5.26 s Wall time: 1.35 s
MLPClassifier(hidden_layer_sizes=[14, 8, 4], max_iter=2000, solver='sgd')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
MLPClassifier(hidden_layer_sizes=[14, 8, 4], max_iter=2000, solver='sgd')
%%time
ann_layers3.fit(train_X, train_y)
CPU times: user 5.07 s, sys: 48.2 ms, total: 5.12 s Wall time: 1.29 s
MLPClassifier(hidden_layer_sizes=[14, 10, 8, 4], max_iter=2000, solver='sgd')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
MLPClassifier(hidden_layer_sizes=[14, 10, 8, 4], max_iter=2000, solver='sgd')
# Generate predictions on validation data
%%time
y_pred_def = ann_default.predict(valid_X)
CPU times: user 11.3 ms, sys: 2.31 ms, total: 13.6 ms Wall time: 4.81 ms
%%time
y_pred_lr1 = ann_layers1.predict(valid_X)
CPU times: user 13.6 ms, sys: 2.34 ms, total: 16 ms Wall time: 6.67 ms
%%time
y_pred_lr2 = ann_layers2.predict(valid_X)
CPU times: user 13.1 ms, sys: 2.27 ms, total: 15.3 ms Wall time: 3.9 ms
%%time
y_pred_lr3 = ann_layers3.predict(valid_X)
CPU times: user 9.95 ms, sys: 2.71 ms, total: 12.7 ms Wall time: 5.97 ms
print("DEFAULT\n", classification_report(valid_y, y_pred_def))
print("layers=[16,]\n", classification_report(valid_y, y_pred_lr1))
print("layers=[14,8,4]\n", classification_report(valid_y, y_pred_lr2))
print("layers=[14,10,8,4]\n", classification_report(valid_y, y_pred_lr3))
DEFAULT precision recall f1-score support 0 0.72 0.78 0.75 112 1 0.78 0.73 0.75 121 accuracy 0.75 233 macro avg 0.75 0.75 0.75 233 weighted avg 0.75 0.75 0.75 233 layers=[16,] precision recall f1-score support 0 0.76 0.81 0.79 112 1 0.82 0.77 0.79 121 accuracy 0.79 233 macro avg 0.79 0.79 0.79 233 weighted avg 0.79 0.79 0.79 233 layers=[14,8,4] precision recall f1-score support 0 0.74 0.81 0.77 112 1 0.81 0.74 0.77 121 accuracy 0.77 233 macro avg 0.77 0.77 0.77 233 weighted avg 0.78 0.77 0.77 233 layers=[14,10,8,4] precision recall f1-score support 0 0.76 0.80 0.78 112 1 0.81 0.76 0.78 121 accuracy 0.78 233 macro avg 0.78 0.78 0.78 233 weighted avg 0.78 0.78 0.78 233
print("**************************** RMSE Values ****************************")
print("DEFAULT\n", mean_squared_error(valid_y, y_pred_def)**(1/2))
print("layers=[16,]\n", mean_squared_error(valid_y, y_pred_lr1)**(1/2))
print("layers=[14,8,4]\n", mean_squared_error(valid_y, y_pred_lr2)**(1/2))
print("layers=[14,12,10,8,6,4]\n", mean_squared_error(valid_y, y_pred_lr3)**(1/2))
**************************** RMSE Values **************************** DEFAULT 0.49892588490336864 layers=[16,] 0.45858524745629287 layers=[14,8,4] 0.47693585644067304 layers=[14,12,10,8,6,4] 0.4678505318706754
# Define parameter grid
param_grid = {
'n_neighbors': list(range(1,round(np.sqrt(900)),2)), # k-values 1 through sqrt of our total # of records
'weights': ['uniform','distance'],
'metric': ['euclidean', 'cosine']
}
# Use GridSearchCV to search through the parameter grid and find the best parameters (focus on recall)
score_measure = 'recall'
k_fold = 10
gridSearch = GridSearchCV(KNeighborsClassifier(),
param_grid,
cv=k_fold,
scoring=make_scorer(recall_score), #, average='micro'
n_jobs=-1, # n_jobs=-1 will utilize all available CPUs
error_score='raise')
%%time
grid_result = gridSearch.fit(train_X,train_y)
CPU times: user 586 ms, sys: 36.1 ms, total: 623 ms Wall time: 1.59 s
grid_result.best_params_
{'metric': 'euclidean', 'n_neighbors': 3, 'weights': 'uniform'}
# Use these best parameters to fit a model
knn = KNeighborsClassifier()
final_model = knn.set_params(**grid_result.best_params_)
%%time
final_model.fit(train_X,train_y)
CPU times: user 2.69 ms, sys: 655 µs, total: 3.35 ms Wall time: 2.82 ms
KNeighborsClassifier(metric='euclidean', n_neighbors=3)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
KNeighborsClassifier(metric='euclidean', n_neighbors=3)
%%time
# Generate predictions on validation data
Y_predict = final_model.predict(valid_X)
CPU times: user 65.8 ms, sys: 13.5 ms, total: 79.3 ms Wall time: 49 ms
bestRecallKnn = gridSearch.best_estimator_
print("Recall Score : ",recall_score(valid_y,Y_predict)) #average='weighted'
Recall Score : 0.7272727272727273
print(classification_report(valid_y, Y_predict))
precision recall f1-score support 0 0.72 0.78 0.75 112 1 0.78 0.73 0.75 121 accuracy 0.75 233 macro avg 0.75 0.75 0.75 233 weighted avg 0.75 0.75 0.75 233
Model | precision | recall | accuracy | f1-score |
---|---|---|---|---|
k-NN | 0.78 | 0.73 | 0.75 | 0.75 |
MLP | 0.82 | 0.77 | 0.79 | 0.79 |
Note: Metrics above are for positive class (1 = yes heart disease) |
---|
Model | train | predict |
---|---|---|
k-NN | 627 ms | 79 ms |
MLP | 2030 ms | 16 ms |
| Note: The values for k-NN above include the GridSearchCV | | --- |
Overall, we would select the MLP classifier as the chosen model, over k-NN, for this particular business objective of predicting heart disease. The multilayer perceptron performance was better than k-NN for every standard performance metric we evaluated, and made those predictions far quicker. While MLP took significantly more time to train, we would prefer the model with faster predictions as this would assumedly be preferred by the medical professionals waiting for results on their patients and whether they need to escalate their care.