This blog describes how we can use machine learning, and the XGBoost (eXtreme Gradient Boosting) library in particular, in association with a set of clinical data points to predict liver disease risk in patients.
The data that we will be using has been sourced from https://www.kaggle.com/uciml/indian-liver-patient-records. This data set contains 416 liver patient records and 167 non liver patient records collected from the state of Andhra Pradesh in India.
To start with let us load the R libraries that we’ll be using.
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
library(xgboost)
library(methods)
Now it’s time to look at what features (attributes or data points) are present in the available data
data <- read.csv('indian_liver_patient.csv')
names(data)
## [1] "Age" "Gender"
## [3] "Total_Bilirubin" "Direct_Bilirubin"
## [5] "Alkaline_Phosphotase" "Alamine_Aminotransferase"
## [7] "Aspartate_Aminotransferase" "Total_Protiens"
## [9] "Albumin" "Albumin_and_Globulin_Ratio"
## [11] "Dataset"
The final feature, “Dataset”, is what contains the label. A value of 1 indicating a liver patient and a value of 2 indicating a non-liver patient.
Now let’s take a closer look at the data itself.
head(data)
## Age Gender Total_Bilirubin Direct_Bilirubin Alkaline_Phosphotase
## 1 65 Female 0.7 0.1 187
## 2 62 Male 10.9 5.5 699
## 3 62 Male 7.3 4.1 490
## 4 58 Male 1.0 0.4 182
## 5 72 Male 3.9 2.0 195
## 6 46 Male 1.8 0.7 208
## Alamine_Aminotransferase Aspartate_Aminotransferase Total_Protiens
## 1 16 18 6.8
## 2 64 100 7.5
## 3 60 68 7.0
## 4 14 20 6.8
## 5 27 59 7.3
## 6 19 14 7.6
## Albumin Albumin_and_Globulin_Ratio Dataset
## 1 3.3 0.90 1
## 2 3.2 0.74 1
## 3 3.3 0.89 1
## 4 3.4 1.00 1
## 5 2.4 0.40 1
## 6 4.4 1.30 1
We can also run the helpful “summary” function to get a quick view into the entire data set.
summary(data)
## Age Gender Total_Bilirubin Direct_Bilirubin
## Min. : 4.00 Female:142 Min. : 0.400 Min. : 0.100
## 1st Qu.:33.00 Male :441 1st Qu.: 0.800 1st Qu.: 0.200
## Median :45.00 Median : 1.000 Median : 0.300
## Mean :44.75 Mean : 3.299 Mean : 1.486
## 3rd Qu.:58.00 3rd Qu.: 2.600 3rd Qu.: 1.300
## Max. :90.00 Max. :75.000 Max. :19.700
##
## Alkaline_Phosphotase Alamine_Aminotransferase Aspartate_Aminotransferase
## Min. : 63.0 Min. : 10.00 Min. : 10.0
## 1st Qu.: 175.5 1st Qu.: 23.00 1st Qu.: 25.0
## Median : 208.0 Median : 35.00 Median : 42.0
## Mean : 290.6 Mean : 80.71 Mean : 109.9
## 3rd Qu.: 298.0 3rd Qu.: 60.50 3rd Qu.: 87.0
## Max. :2110.0 Max. :2000.00 Max. :4929.0
##
## Total_Protiens Albumin Albumin_and_Globulin_Ratio
## Min. :2.700 Min. :0.900 Min. :0.3000
## 1st Qu.:5.800 1st Qu.:2.600 1st Qu.:0.7000
## Median :6.600 Median :3.100 Median :0.9300
## Mean :6.483 Mean :3.142 Mean :0.9471
## 3rd Qu.:7.200 3rd Qu.:3.800 3rd Qu.:1.1000
## Max. :9.600 Max. :5.500 Max. :2.8000
## NA's :4
## Dataset
## Min. :1.000
## 1st Qu.:1.000
## Median :1.000
## Mean :1.286
## 3rd Qu.:2.000
## Max. :2.000
##
The data seems to be generally clean and the only feature with missing values is Albumin_and_Globulin_Ratio.
There are a few things we need to do to make this data acceptable to the XGBoost algorithm. Let’s start by imputing the missing values in Albumin_and_Globulin_Ratio in a simplistic way by replacing them with the mean of the column.
data$Albumin_and_Globulin_Ratio[is.na(data$Albumin_and_Globulin_Ratio)] <- mean(data$Albumin_and_Globulin_Ratio, na.rm = TRUE)
The next step is to replace non-numeric values with numeric values - there’s only one non-numeric field here and that’s gender. Let’s replace the gender strings with numbers, using a str2int function I like to keep handy.
str2int <- function(df) {
strings=sort(unique(df))
numbers=1:length(strings)
names(numbers)=strings
return(numbers[df])
}
data$Gender <- str2int(data$Gender)
We also check if there are any features which are highly co-related to each other and then remove them.
tmp <- cor(data)
tmp[!lower.tri(tmp)] <- 0
data.new <- data[,!apply(tmp,2,function(x) any(x > 0.8))]
data = data.new
names(data)
## [1] "Age" "Gender"
## [3] "Direct_Bilirubin" "Alkaline_Phosphotase"
## [5] "Alamine_Aminotransferase" "Aspartate_Aminotransferase"
## [7] "Total_Protiens" "Albumin"
## [9] "Albumin_and_Globulin_Ratio" "Dataset"
You will notice that Total_Bilirubin has been eliminated from the list of features since it was very highly correlated with an existing feature.
As a final step, we convert the label variable into the 0 or 1 values that XGBoost expects for binary classification
data$Dataset <- data$Dataset - 1
Let us split the available data into train and test sets with a 75:25 ratio.
sample_size <- floor(0.75 * nrow(data))
set.seed(123)
train_ind <- sample(seq_len(nrow(data)), size = sample_size)
train <- data[train_ind, ]
test <- data[-train_ind, ]
train_label <- as.numeric(train$Dataset)
test_label <- as.numeric(test$Dataset)
To use XGBoost the data needs to be converted to a different format called DMatrix. DMatrix is an internal data structure used by XGBoost which is optimized for both memory efficiency and training speed.
train <- as(as.matrix(train[ , -which(names(train) %in% c("Dataset"))]), "dgCMatrix")
test <- as(as.matrix(test[ , -which(names(test) %in% c("Dataset"))]), "dgCMatrix")
dtrain <- xgb.DMatrix(data = train, label=train_label)
dtest <- xgb.DMatrix(data = test, label=test_label)
watchlist <- list(train=dtrain, test=dtest)
Finally, it’s time to train the model. “early_stopping_rounds” is a useful parameter we can use to prevent getting stuck in local minima. By providing a watchlist we can monitor and prevent overfitting.
xgbModel <- xgb.train(data = dtrain, max.depth = 100, eta = 0.001,
nthread = 2, nround = 10000,
watchlist=watchlist, objective = "binary:logistic", early_stopping_rounds = 300, verbose = FALSE)
It’s time to make some predictions!
We remove labels from the full data set and use the model we trained to predict the labels, or in other words, predict if the patients have liver disease or not.
fulldata <- as(as.matrix(data[ , -which(names(data) %in% c("Dataset"))]), "dgCMatrix")
test_pred <- predict(xgbModel, newdata = fulldata)
A confusion matrix is a good way to observe and qualify the results of our model’s prediction.
confusionMatrix(round(test_pred), data$Dataset)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 395 46
## 1 21 121
##
## Accuracy : 0.8851
## 95% CI : (0.8564, 0.9098)
## No Information Rate : 0.7136
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.7057
## Mcnemar's Test P-Value : 0.003367
##
## Sensitivity : 0.9495
## Specificity : 0.7246
## Pos Pred Value : 0.8957
## Neg Pred Value : 0.8521
## Prevalence : 0.7136
## Detection Rate : 0.6775
## Detection Prevalence : 0.7564
## Balanced Accuracy : 0.8370
##
## 'Positive' Class : 0
##
There’s a lot of data here to digest, and you can read up on what each individual field indicates. A high level summary would be that we get an accuracy of 88.51% and were able to succesfully predict liver disease 395 out 416 cases, of concern however is the fact that we had 46 false positives (where we incorrectly predicted liver disease in a healthy patient). The rest of the values indicate that the model is perforiming well overall, and that if we had more training data (especially for non-liver patients) we could increase the accuracy and reduce the false positives.
Another interesting datapoint would be what feautres were the most important to our model when it made the prediction (or should we say diagnosis) of whether a patient has liver disease or not.
xgb.importance(colnames(fulldata), model = xgbModel)
## Feature Gain Cover Frequency
## 1: Age 0.18293615 0.19141002 0.20333048
## 2: Direct_Bilirubin 0.17775292 0.16542985 0.03819063
## 3: Alkaline_Phosphotase 0.17697704 0.17246281 0.16729775
## 4: Alamine_Aminotransferase 0.13449196 0.16824644 0.12348846
## 5: Total_Protiens 0.09506672 0.06812345 0.12906641
## 6: Albumin 0.08763175 0.11666587 0.11900981
## 7: Aspartate_Aminotransferase 0.07817340 0.06409471 0.11538618
## 8: Albumin_and_Globulin_Ratio 0.04802362 0.03874675 0.07947559
## 9: Gender 0.01894645 0.01482009 0.02475469
And there you have it, we were able to train a model to diagnose whether a patient has liver disease or not based on a set of available data points. Our model achieved an accuracy of 88.51% which is pretty good.
The dataset we used was indeed limited, and to truly have a model which generalizes well we would need to collect much more data but the results we achieved are very promising indeed. And hospitals and health authorities would clearly have more of the data we require to make our model achieve (or even surpass) human-level diagnosis accuracy.
Comments and suggestions are welcome - and if anyone has interesting data that they’d like to collaborate on, please hit me up on twitter @shyamvalsan or email shyam.valsan@gmail.com
Acknowledgements
This dataset was downloaded from the UCI ML Repository:
Lichman, M. (2013). UCI Machine Learning Repository. Irvine, CA: University of California, School of Information and Computer Science.
The XGBoost library is used for training the prediction model
This analysis was published as a Kaggle kernel