Project: Healthcare cost prediction

external-content.duckduckgo.com.jpeg

The objectives of this project are manifold:

  • Select the best features that can help our prediction based our their information.
  • Identify the patients in the top 10% of the population with the highest expenditure on medical care.
  • Explain the factors that may be at the origin of these high expenses.

In order to respond to these question, we will use the supervised learning machine to classify whether or not a patient is in the top 10% of health costs.

The data used in this project come from Centers for Medicare and Medicaid Limited Data Set Files, which records medicare claims (What's are medical claims ?) happened in all medical settings in US. Medical claims are created when a patient visits a doctor. They include diagnosis codes, procedure codes, as well as costs. These claims are used to adjust Medicare payments to health care plans for the health expenditure risk of their enrollees. It’s intended use is to pay insurance plans appropriately for their expected relative costs. For example, health plans that care for overwhelmingly healthy populations are paid less than those that care for much sicker populations.

Since these claims are not necessarely clear for a non-medical person will develop its various components:

  • Patient informations (e.g: Age, sex, gender, race, e.g).
  • Hierarchical Chronic Condictions (HCC): Cluster diagnosis codes into meaningful categories of chronic diseases (e.g: Vascular Diseases,etc...). Go check:The top 10 of HCC
  • Clinical Classification Software (CCS): Provides a way to classify diagnoses into a limited number of categories.
  • Procedure: Medical services used by patients (e.g: Anesthesia, Home visit, Consultations, etc..)
  • Spending: Bill received by the patient.

In order to classify whether or not a patient is in the top 10% of health costs we will use the medicare claims of 2012 and use the spendings of 2013 to create a binary target ( 0: not in top 10% high cost patients | 1: top 10% high cost patients)

Data can be found here.

Import packages:

In [2]:
import os, sys
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import xgboost as xgb 
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV, RandomizedSearchCV, cross_validate
from sklearn.ensemble import RandomForestRegressor, AdaBoostRegressor, GradientBoostingRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.preprocessing import StandardScaler, RobustScaler, OneHotEncoder, OrdinalEncoder
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, f1_score
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.feature_selection import RFE, RFECV
from imblearn.under_sampling import RandomUnderSampler
import shap
import pickle

shap.initjs()
pd.options.mode.chained_assignment = None  # default='warn

First step is to creat a DataFrame gathering of the necessary informations ( hcc, patients info, ccs, procedure and spendings).

To do that we join all the csv using their common feature which is the patient ID.

In [3]:
dxccs = pd.read_csv('DxCCS.csv',na_values=0).set_index('BENE_ID')

hcc = pd.read_csv('HCC.csv',na_values=0).set_index('bene_id')
hcc.index.rename('BENE_ID',inplace=True)

patient = pd.read_csv('Patient.csv',na_values=0).set_index('BENE_ID')
patient.drop(['STATE_CD','State_County_CD','State','County','BENE_ZIP','gender'], axis=1, inplace=True)

procedure = pd.read_csv('Procedure.csv',na_values=0).set_index('BENE_ID')

spending12 = pd.read_csv('Spending2012.csv',na_values=0).set_index('BENE_ID')
spending12['Spending2012'] = spending12.sum(axis=1)
spending13 = pd.read_csv('Spending2013.csv',na_values=0).set_index('BENE_ID')
spending2012 = spending12['Spending2012']
spending2013 = spending13['Spending2013']
spending = pd.concat([spending2012,spending2013],join='inner',axis=1)

df = patient.join([hcc,dxccs,procedure,spending],how='inner').reset_index()

df.fillna(value=0, inplace=True)
#df['total_spending'] = df['Spending2012'] + df['Spending2013']
df.drop('BENE_ID', axis=1, inplace=True)

Now we can take a look at our DataFrame:

In [4]:
df.head()
Out[4]:
AGE SEX race HCC1 HCC2 HCC5 HCC7 HCC8 HCC9 HCC10 ... I3E O1F P8A P2A I4A P2C M5A _029 Spending2012 Spending2013
0 86 1 White 0.0 0.0 0.0 0.0 0.0 0.0 1.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 12671.70 4845.05
1 87 2 White 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1281.69 3947.39
2 79 2 White 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 6053.86 6063.92
3 86 2 White 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2520.09 2392.67
4 73 1 White 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2115.43 4204.09

5 rows × 463 columns

In [5]:
print(f'The DataFrame has {df.shape[0]} rows and {df.shape[1]} columns')
The DataFrame has 131273 rows and 463 columns

Just by looking the shape of our DataFrame we can see that we will have to work in high dimensionality. Because we want to explain our predictions using a small number of features (and optionally reduce the computation time) we will have to perform a features selection.

Exploration Data Analysis

1/ Patients

The patient DataFrame is composed of 3 features:

In [6]:
patient.columns
Out[6]:
Index(['AGE', 'SEX', 'race'], dtype='object')
  • Age distribution:
In [7]:
age_bins = pd.cut(patient['AGE'],bins=[0,10,20,30,40,50,60,70,80,90,100,120]).value_counts()

fig, ax = plt.subplots(1,2,figsize=(15,7))

sns.barplot(age_bins.index, age_bins, ax=ax[0])
plt.setp(ax[0].get_xticklabels(), rotation=45)

sns.distplot(patient['AGE'],ax=ax[1])

fig.suptitle('Age distribution', fontsize=14)

plt.show()

As we can see the majority of the patients in our dataset have between 60 and 90 years old with a peak above approximatively 75 years old.

  • Sex repartition:
In [8]:
sex_info = patient.groupby('SEX')['AGE'].aggregate(['count','mean'])

labels = ['Men', 'Women']
explode = (0.1,0)
color_palette_list = ['#009ACD', '#ADD8E6', '#63D1F4', '#0EBFE9',   
                      '#C1F0F6', '#0099CC']

fig, axes = plt.subplots(figsize=(10,7))

axes.pie(sex_info['count'], explode, labels=labels,  autopct='%1.0f%%', 
       shadow=False, startangle=0,   
       pctdistance=1.2,labeldistance=1.05, 
       colors=color_palette_list[0:2])
axes.axis('equal')

axes.legend(frameon=False, bbox_to_anchor=(1.5,0.8))
axes.set_title('Gender',loc='left')

plt.show()

Based on our data, it's seems that women (56%) are more represented that men (44%).

  • Race repartition
In [9]:
race_info = patient['race'].value_counts()

labels = ['White','Hispanic','Black','Other']

color_palette_list = ['#009ACD', '#ADD8E6', '#63D1F4', '#0EBFE9',   
                      '#C1F0F6', '#0099CC']

fig, ax = plt.subplots(figsize=(10,7))

ax.pie(race_info,  labels=labels,  autopct='%1.0f%%', 
       shadow=False, startangle=0,   
       pctdistance=1.2,labeldistance=0.6, 
       colors=color_palette_list[0:4])
ax.axis('equal')

ax.legend(frameon=False, bbox_to_anchor=(1.5,0.8))
ax.set_title('Race',loc='center')

plt.show()

White people are mostly present (87%) followed by hispanics (5%) and blacks/Others (4%).

2/HCC

HCC referred to hierarchical Chronic conditions and are classified into 70 categories (here is an example):

In [11]:
#Example of HCC categories and their meanings:

hcc_variables = pd.read_csv('HCC_variable_list.csv')
hcc_variables['hcc'] = hcc_variables['HCC No.'].apply(lambda x:x.split(" ")[0])
hcc_list = hcc_variables[['hcc','HCC Name']].set_index('hcc')

for i,hcc_index in enumerate(hcc_list.index[:10]):
    print(f'{i+1}/ {hcc_index}:'+hcc_list.loc[hcc_index,'HCC Name'])
1/ HCC1:HIV/AIDS
2/ HCC2:Septicemia/Shock  
3/ HCC5:Opportunistic Infections        
4/ HCC7:Metastatic Cancer and Acute Leukemia
5/ HCC8:Lung, Upper Digestive Tract, and Other Severe Cancers     
6/ HCC9:Lymphatic, Head and Neck, Brain, and Other Major Cancers       
7/ HCC10:Breast, Prostate, Colorectal and Other Cancers and Tumors     
8/ HCC15:Diabetes with Renal or Peripheral Circulatory Manifestation    
9/ HCC16:Diabetes with Neurologic or Other Specified Manifestation      
10/ HCC17:Diabetes with Acute Complications           
  • Most frequent chronic conditions:
In [12]:
# Most common chronic conditions:
df_hcc = df.loc[:,hcc.columns]

diseases = df_hcc.sum().sort_values(ascending=True)

common_diseases = diseases.tail(n=10).sort_values(ascending=False).index

#hcc_list[hcc_list.index.isin(common_diseases)]

for i,hcc in enumerate(common_diseases):
    print(f'{i+1}.'+hcc_list.loc[hcc,'HCC Name'])
1.Specified Heart Arrhythmias      
2.Vascular Disease           
3.Diabetes without Complication  
4.Chronic Obstructive Pulmonary Disease                          
5.Renal Failure               
6.Congestive Heart Failure   
7.Breast, Prostate, Colorectal and Other Cancers and Tumors     
8.Major Depressive, Bipolar, and Paranoid Disorders              
9.Polyneuropathy                                                 
10.Diabetes with Renal or Peripheral Circulatory Manifestation    
In [13]:
#Chronic conditions plot:

plt.figure(figsize=(15,7))

sns.barplot(x=diseases.index, y=diseases)
plt.title('Number of cases per chronic conditions in total population')
plt.xticks(rotation=90)

plt.show()
  • Most frequent chronic conditions grouped by race (White, hispanic, black and other)
In [14]:
# Chronic conditions group by race:
diseases_by_race = df.groupby('race')[df_hcc.columns].sum()
palette_list = ['mako','magma','viridis']

for race, palette in zip(race_info.index,palette_list):
    
    fig, axes = plt.subplots(figsize=(15,7))

    sns.barplot(x=diseases_by_race.loc[race,:].sort_values().index, y=diseases_by_race.loc[race,:].sort_values(), palette=palette)
    plt.xticks(rotation=90)
    plt.title(f'Number of cases per chronic conditions in {race} population')

    plt.show()

It is quiet interesting to notice that depending on the race the most common chronic conditions are note the same:

  • White:
    • HCC92: Specified Heart Arrhythmias
    • HCC105: Vascular Disease
    • HCC108: Chronic Obstructive Pulmonary Disease
  • Hispanic:
    • HCC19: Inflammatory Bowel Disease
    • HCC55: Major Depressive, Bipolar, and Paranoid Disorders
    • HCC108: Chronic Obstructive Pulmonary Disease
  • Black:
    • HCC19: Inflammatory Bowel Disease
    • HCC131: Renal Failure
    • HCC55: Major Depressive, Bipolar, and Paranoid Disorders

3/CCS

CCS referred to clinical classification software and provides 283 diagnosis categories (Check all CSS categories: CCSCatgoryNames).

In [15]:
ccs_variables = pd.read_csv('dxCCS_variable_list.csv')
ccs_variables['CCS Category'] = ccs_variables['CCS Category'].apply(lambda x: 'CCS'+x.replace("'","").replace(" ",""))
ccs_variables.set_index('CCS Category',inplace=True)
df_css = df.loc[:,dxccs.columns]
  • Most common diagnosis categories:
In [16]:
diagnosis_category = df_css[df_css !=0].count().sort_values(ascending=False)

common_diagnostics = diagnosis_category.head(n=10).index

for i, ccs in enumerate(common_diagnostics):
    print(f'{i+1}/{ccs}:'+ ccs_variables.loc[ccs,'CCS Category Description'][0])
    
1/CCS98:'HTN'
2/CCS53:'Hyperlipidem'
3/CCS10:'Immuniz/scrn'
4/CCS259:'Unclassified'
5/CCS256:'Exam/eval'
6/CCS258:'Other screen'
7/CCS211:'Ot conn tiss'
8/CCS200:'Oth skin dx'
9/CCS257:'Ot aftercare'
10/CCS204:'Ot joint dx'
In [17]:
fig, ax = plt.subplots(figsize=(10,7))

sns.barplot(x=diagnosis_category.head(n=10).index, y= diagnosis_category.head(n=10),palette='mako')

plt.title('Number of cases per Diagnostis Categories')

plt.show()

4/Procedure

Procedure referred to medical services used by patients and provided 104 Procedure indicators (Betos)

  • Most common procedures:
In [20]:
procedure_variables = pd.read_csv('Procedure_variable_list.csv').set_index('betos')
most_common_procedure = procedure.count().sort_values(ascending=False)[:10].index

for i, procedure_index in enumerate(most_common_procedure):
    print(f'{i+1}/ {procedure_index}:'+procedure_variables.loc[procedure_index,'description'])
1/ M1B:Office visits - established
2/ T1H:Lab tests - other (non-Medicare fee schedule)
3/ T1A:Lab tests - routine venipuncture (non Medicare fee schedule)
4/ T1B:Lab tests - automated general profiles
5/ T1D:Lab tests - blood counts
6/ M5D:Specialist - other
7/ O1G:Immunizations/Vaccinations
8/ M5C:Specialist - opthamology
9/ M1A:Office visits - new
10/ T2A:Other tests - electrocardiograms
In [21]:
plt.figure(figsize=(20,7))

sns.barplot(x=procedure.count().sort_values().index, y=procedure.count().sort_values())
plt.title('Number of cases per procedure in total population')
plt.xticks(rotation=90)

plt.show()

5/Spending

  • Spending relationship between 2012 and 2013:
In [22]:
plt.figure(figsize=(15,10))

sns.scatterplot(x='Spending2012',y='Spending2013',data=df,hue=df['Spending2012'].tolist(),palette='mako')

plt.title('Spending relationship between 2012 & 2013')

plt.show()

A linear relationship between the spendings of 2012 and 2013 seems to exist which seems logical. Thus, we could probably predict the patients spendings of 2013 with the spending of 2012 as a single predictor using a linear regression.

  • Spending2013 distribution:
In [23]:
fig, ax  = plt.subplots(figsize=(7,5))

sns.distplot(df['Spending2013'])

plt.title('Spending2013 distribution')
plt.show()

Healthcare costs in 2013 vary mainly between 0 and 20 000 $.

However, the costs can be much higher for some patients. Despite the fact that patients above a max threshold (Q3 + IQR*1.5 ) could be considered as outliers, they were not removed from the dataset because it is precisely the patients we want to classify (those who pay the most) and explain why.

In [24]:
race_encoded = pd.get_dummies(df['race'],prefix='race')
df = pd.concat([df,race_encoded], axis=1)

High costs patient classification

We are almost ready to use ML in order to classify our data. Still, we have to create the target feature based on 'spending2013'.

In [26]:
#function to define the top 10 patient with high medical costs:

def high_cost(x):
    
    '''function that binarize the spendings of 2013 into two categories:  
    those who belong to the top 10 patient with high medical costs (1) or not (0)'''
    
    if x >= np.percentile(df['Spending2013'],q=90):
        return 1
    else:
        return 0

df['target'] = df['Spending2013'].apply(lambda x: high_cost(x))
    
In [28]:
df.drop(['race','Spending2012','Spending2013'], axis=1, inplace=True)
y = df['target']
X = df.drop('target', axis=1)

We can have a look at a sample of our final DataFrame:

In [29]:
df.head()
Out[29]:
AGE SEX HCC1 HCC2 HCC5 HCC7 HCC8 HCC9 HCC10 HCC15 ... P2A I4A P2C M5A _029 race_Black race_Hispanic race_Other race_White target
0 86 1 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0 0 0 1 0
1 87 2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0 0 0 1 0
2 79 2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0 0 0 1 0
3 86 2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0 0 0 1 0
4 73 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0 0 0 1 0

5 rows × 465 columns

Another important thing to do is to check how our dataset is balanced.

In [32]:
# As we can see the dataset is imbalanced which will lead to poor metrics results (such as accuracy):

(df['target'].value_counts()/ df.shape[0])*100
Out[32]:
0    89.999467
1    10.000533
Name: target, dtype: float64

As we can observe almost 90 % of our data belong to the class 0 (Not top 10 patient with high medical costs) whereas 10% belong to the class 1 (top 10 patient with high medical costs).ipynb_checkpoints/

Because our dataset is umbalanced we have to find a way to solve this issue. Regarding our dataset, the best way is to perform an undersampling of the majority class. Using an algorithm like SMOTE (oversampling of the minority class) won't produce good results since we are working with a high dimensionality dataset which affect significantly algorithms based on distances.

Undersampling of the majority class:

In [37]:
rus = RandomUnderSampler(sampling_strategy='majority',random_state=0)
X_resampled, y_resampled = rus.fit_resample(X,y)

Now, our dataset is balanced (50/50).

In [38]:
X_train_resampled, X_test_resampled, y_train_resampled, y_test_resampled = train_test_split(X_resampled, y_resampled)

In order to classify our dataset we used a gradient boosting algorithm using XGBoost library. XGBoosting gave the best results compare to other models based on trees.

Metrics used were accuracy_score and f1_score (harmonic mean of the precision and recall)

In [44]:
#Baseline XGBoost classification using all features:

xgb_clf = xgb.XGBClassifier()

xgb_clf.fit(X_train_resampled,y_train_resampled)
y_pred_train_clf = xgb_clf.predict(X_train_resampled)
acc_score_xgb = accuracy_score(y_train_resampled, y_pred_train_clf)
f1_score_xgb = f1_score(y_train_resampled, y_pred_train_clf)

# 5 folds Cross validation XGBoost:
xgb_clf_cv = cross_validate(xgb_clf,X_train_resampled, y_train_resampled,scoring=['accuracy','f1'],cv=5,n_jobs=-1)
xgb_cv_acc = xgb_clf_cv['test_accuracy'].mean()
xgb_cv_f1 = xgb_clf_cv['test_f1'].mean()

print(f'Accuracy score on the train set: {acc_score_xgb}')
print(f'F1 Score on the train set: {f1_score_xgb}')
      
print(f'Accuracy score on validation set: {xgb_cv_acc}')
print(f'F1 score on validation set:{xgb_cv_f1}')
Accuracy score on the train set: 0.8636502132845826
F1 Score on the train set: 0.8549197600907764
Accuracy score on validation set: 0.7136903677475612
F1 score on validation set:0.7057479876512697

Using the 464 features we achieved an accuracy of 0.714 and a f1 score of 0.705.

As we said in the beginning of this project we don't want to use all the features, the aim is to find a small numbers ( ~ 20) of features that could bring significant information for the classification.

Feature selection

In order to find the best 20 features that explain the classification, we will compare two methods:

  • Feature importance: Based on the total information gain ( total reduction of impurity contributed by all splits for a given feature)

  • Shap values: We can compute the average absolute of shapley values for each features which explain how much a feature contribute to the prediction (in our case how it increases or decreases the probability to be part of the top 10 of who pay the most.)

1/Using feature importance

In [51]:
#We sort the top 20 features based on their importances (information gain):

sorted_idx = xgb_clf.feature_importances_.argsort()
xgb_clf.feature_importances_[sorted_idx][::-1][:20]

X_train_resampled.columns[sorted_idx][::-1][:20]
X_train_resampled_20_features = X_train_resampled.loc[:,X_train_resampled.columns[sorted_idx][::-1][:20]]
In [52]:
plt.figure(figsize=(10,7))

sns.barplot(y=X_train_resampled.columns[sorted_idx][::-1][:20],x=xgb_clf.feature_importances_[sorted_idx][::-1][:20])

plt.xticks(rotation=45)
plt.title('Best 20 features based on their importance')

plt.show()

2/Using SHAP

In [54]:
# defining the model explainer
explainer = shap.TreeExplainer(xgb_clf)

# getting the SHAP values
shap_values = explainer.shap_values(X_train_resampled)

# plotting the SHAP values
shap.summary_plot(shap_values, X_train_resampled, feature_names=X_train_resampled.columns, plot_type="bar")
Setting feature_perturbation = "tree_path_dependent" because no background data was given.

Comparison between feature importance & SHAP

In [55]:
features_shap = pd.DataFrame(shap_values,columns=X_train_resampled.columns).abs().mean().sort_values(ascending=False)[::20].index

fig, ax = plt.subplots(1,2,figsize=(15,7))

sns.barplot(y=X_train_resampled.columns[sorted_idx][::-1][:20],x=xgb_clf.feature_importances_[sorted_idx][::-1][:20], ax=ax[0], palette='magma')
ax[0].set_title('Best 20 features with feature selection')

sns.barplot(y=features_shap, x=pd.DataFrame(shap_values,columns=X_train_resampled.columns).abs().mean().sort_values(ascending=False)[::20],ax=ax[1])
ax[1].set_title('Best 20 features with SHAP')

plt.show()

Comparing the two methods of feature selection we can observe that each model determines the importance of features in a different way (e.g: M2B (Hospital visit -subsequent) is the most important feature using feature importance wheareas it is the age using SHAP)

Now, we will compare classification results using the 20 best features according to their importances and shap values.

Classification based on feature importances

In [60]:
#Baseline XGBoost using the top 20 features with feature importances:

xgb_clf = xgb.XGBClassifier()

xgb_clf.fit(X_train_resampled_20_features,y_train_resampled)
y_pred_train_clf = xgb_clf.predict(X_train_resampled_20_features)
acc_score_xgb = accuracy_score(y_train_resampled, y_pred_train_clf)
f1_score_xgb = f1_score(y_train_resampled, y_pred_train_clf)

xgb_clf_cv = cross_validate(xgb_clf,X_train_resampled_20_features, y_train_resampled,scoring=['accuracy','f1'],cv=10,n_jobs=-1)
xgb_cv_acc = xgb_clf_cv['test_accuracy'].mean()
xgb_cv_f1 = xgb_clf_cv['test_f1'].mean()

print(f'Accuracy score on the train set: {acc_score_xgb}')
print(f'F1 Score on the train set: {f1_score_xgb}')
      
print(f'Accuracy score on validation set: {xgb_cv_acc}')
print(f'F1 score on validation set:{xgb_cv_f1}')
Accuracy score on the train set: 0.7691448303879748
F1 Score on the train set: 0.7555388255538826
Accuracy score on validation set: 0.7105937204332123
F1 score on validation set:0.6979527759013684
In [61]:
#Baseline XGBoost using the top 20 features with SHAP:
X_train_resampled_shap = X_train_resampled.loc[:,features_shap]

xgb_clf_shap = xgb.XGBClassifier()

xgb_clf_shap.fit(X_train_resampled_shap,y_train_resampled)
y_pred_train_clf_shap = xgb_clf_shap.predict(X_train_resampled_shap)
acc_score_xgb_shap = accuracy_score(y_train_resampled, y_pred_train_clf_shap)
f1_score_xgb_shap = f1_score(y_train_resampled, y_pred_train_clf_shap)

xgb_clf_cv_shap = cross_validate(xgb_clf_shap,X_train_resampled_shap, y_train_resampled,scoring=['accuracy','f1'],cv=10,n_jobs=-1)
xgb_acc_cv_shap = xgb_clf_cv_shap['test_accuracy'].mean()
xgb_f1_cv_shap = xgb_clf_cv_shap['test_f1'].mean()

print(f'Accuracy on train set: {acc_score_xgb_shap}')
print(f'F1 score on train set: {f1_score_xgb_shap}')

print(f'Accuracy Score on validation set: {xgb_acc_cv_shap}')
print(f'F1 score on validation set: {xgb_f1_cv_shap}')
Accuracy on train set: 0.7115072110501727
F1 score on train set: 0.6875996700577399
Accuracy Score on validation set: 0.6608258720832807
F1 score on validation set: 0.6390259433024374

Using feature importances gives better results on a validation set. We will tweak the hyperparameters of the XGBoost model based using feature importance.

Hyperparametrization

Successive random search CV were performed by reducing progessively the space search in order find the best model.

In [65]:
#Final Random Search CV:
params = {'max_depth':range(4,10),
         'n_estimators':range(600,1000,20),
         'learning_rate':np.linspace(0.001,0.02),
         'reg_alpha':np.linspace(0.01,0.2),
         'col_sample_bytree':[0.6,0.7,0.8,0.9,1],
         'gamma':np.linspace(0.001,0.04),
         'min_samples_leaf':range(1,15)}

xgb_rand_cv = RandomizedSearchCV(xgb.XGBClassifier(),params,n_iter=40,n_jobs=-1, cv=5)
xgb_rand_cv.fit(X_train_resampled_20_features,y_train_resampled)

xgb_best_esti = xgb_rand_cv.best_estimator_
xgb_best_score = xgb_rand_cv.best_score_
In [89]:
print(f'Best accuracy score on a validation set: {xgb_best_score}')
Best accuracy score on a validation set: 0.7189212303267285

Model explainability with SHAP

Now that we have our final model to predict whether or not a patient belong to the top ten high cost patient, we would like to explain how the classifaction work. In other words, we want to understand how much each feature is implicated in the classification.

The library SHAP allows to do that using shapley values again. Here we use the Tree SHAP implementation integrated into XGBoost to explain the entire dataset.

In [66]:
#Computing shap values:

xgb_best_esti.fit(X_train_resampled_20_features,y_train_resampled)

explainer_xgb = shap.TreeExplainer(xgb_best_esti)
shap_values_xgb = explainer_xgb.shap_values(X_train_resampled_20_features)
Setting feature_perturbation = "tree_path_dependent" because no background data was given.

SHAP summary plot

The summary plot is density scatter plot of SHAP values for each feature to identify how much impact each feature has on the model output for individuals in the validation dataset. For instance we can see that P9A(Medicare Fee Schedule) has a high impact on the model output.

Features are sorted by the sum of the SHAP value magnitudes across all samples. The color of each point represents the feature value of that individual. We can see that higher the features values are, the higher are the shap values which leads to increase the probability to belong to the top 10 highest cost patient

In [67]:
#Shap summary plot:

shap.summary_plot(shap_values_xgb, X_train_resampled_20_features, feature_names=X_train_resampled_20_features.columns, plot_size=(20, 15))
In [70]:
X_train_resampled_20_features.reset_index(drop=True,inplace=True)

Visualize a single prediction

We can vizualize for each patient how each feature depending on their value contribute (increase or decrease the probability from the base value) to the model output.

For instance, for this patient we can observe that all the features tend to reduce the probability to belong to the top 10% patient that pay the most.

In [71]:
def force_plot(ind_num):
    return shap.force_plot(explainer_xgb.expected_value, shap_values_xgb[ind_num,:], X_train_resampled_20_features.loc[ind_num,:], feature_names= X_train_resampled_20_features.columns, link='logit')
In [72]:
#Vizualise a single prediction:
force_plot(1)
Out[72]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Another example,here we can clearly see this patient had to undergo numerous procedures (P9A: Medical Fee Schedule,O1E: Durable Medical Equipement, O1A: Ambulance,...) during the year 2012 and also the fact that he or she suffers from a kidney diseases (CCS158) tends to increase the probability to pay more. According to this force plot, this patient belongs to the top 10% high cost patients.

In [78]:
force_plot(10000)
Out[78]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Visualize multiple predictions

The force plot can also be use to vizualize multiple predictions but also allows to observe each feature effect. For instance, in the next plot we use the force plot to understand how the age contribute to the prediction. We see that before approximatively 75 years old, age tends to reduce the probability to belongs to the top 10% high cost patients and pass this age it increases the probability.

In [85]:
shap.force_plot(explainer_xgb.expected_value, shap_values_xgb[:1000,:], X_train_resampled_20_features.loc[:1000,:])
Out[85]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Throughout this project we have tried to understand and give meaning to our data. Because we are working in high dimensionality, we had to reduce the number of features and keep a small number of features that really carry information while being careful not to make our model too simple so that our classification could be relevant. With 20 features selected on the basis of their "importance" we obtained an accuracy of 0.72 on the validation set.

Finally, using our best model, we computed the shapley values for each feature using SHAP in order to give explainability to our model. With this approach, we calculated the overall importance of the variables but especially the effects of the variables for each example of the dataset.

In [ ]: