Tree-Based Methods/Ensemble Schemes


Recorded Stream



Classification Trees


We will work with HANES data set. Information on HANES can be found here. We will first import the curated HANES data set in RStudio:

  # Load the package RCurl
  library(RCurl)
  # Import the HANES data set from GitHub; break the string into two for readability
  # (Please note this readability aspect very carefully)
  URL_text_1 <- "https://raw.githubusercontent.com/kannan-kasthuri/kannan-kasthuri.github.io"
  URL_text_2 <- "/master/Datasets/HANES/NYC_HANES_DIAB.csv"
  # Paste it to constitute a single URL 
  URL <- paste(URL_text_1,URL_text_2, sep="")
  HANES <- read.csv(text=getURL(URL))
  # Rename the GENDER factor for identification
  HANES$GENDER <- factor(HANES$GENDER, labels=c("M","F"))
  # Rename the AGEGROUP factor for identification
  HANES$AGEGROUP <- factor(HANES$AGEGROUP, labels=c("20-39","40-59","60+"))
  # Rename the HSQ_1 factor for identification
  HANES$HSQ_1 <- factor(HANES$HSQ_1, labels=c("Excellent","Very Good","Good", "Fair", "Poor"))
  # Rename the DX_DBTS as a factor
  HANES$DX_DBTS <- factor(HANES$DX_DBTS, labels=c("DIAB","DIAB NO_DX","NO DIAB"))
  # Omit all NA from the data frame
  HANES <- na.omit(HANES)
  # Observe the structure
  str(HANES)
## 'data.frame':    1112 obs. of  23 variables:
##  $ KEY              : Factor w/ 1527 levels "133370A","133370B",..: 28 43 44 53 55 70 84 90 100 107 ...
##  $ GENDER           : Factor w/ 2 levels "M","F": 1 1 1 1 1 1 1 1 1 1 ...
##  $ SPAGE            : int  29 28 27 24 30 26 31 32 34 32 ...
##  $ AGEGROUP         : Factor w/ 3 levels "20-39","40-59",..: 1 1 1 1 1 1 1 1 1 1 ...
##  $ HSQ_1            : Factor w/ 5 levels "Excellent","Very Good",..: 2 2 2 1 1 3 1 2 1 3 ...
##  $ UCREATININE      : int  105 53 314 105 163 150 46 36 177 156 ...
##  $ UALBUMIN         : num  0.707 1 8 4 3 2 2 0.707 4 3 ...
##  $ UACR             : num  0.00673 2 3 4 2 ...
##  $ MERCURYU         : num  0.37 0.106 0.487 2.205 0.979 ...
##  $ DX_DBTS          : Factor w/ 3 levels "DIAB","DIAB NO_DX",..: 3 3 3 3 3 3 3 3 3 3 ...
##  $ A1C              : num  5 5.2 4.8 5.1 4.3 5.2 4.8 5.2 4.8 5.2 ...
##  $ CADMIUM          : num  0.2412 0.1732 0.0644 0.0929 0.1202 ...
##  $ LEAD             : num  1.454 1.019 0.863 1.243 0.612 ...
##  $ MERCURYTOTALBLOOD: num  2.34 2.57 1.32 14.66 2.13 ...
##  $ HDL              : int  42 51 42 61 52 50 57 56 42 44 ...
##  $ CHOLESTEROLTOTAL : int  184 157 145 206 120 155 156 235 156 120 ...
##  $ GLUCOSESI        : num  4.61 4.77 5.16 5 5.11 ...
##  $ CREATININESI     : num  74.3 73 80 84.9 66 ...
##  $ CREATININE       : num  0.84 0.83 0.91 0.96 0.75 0.99 0.9 0.84 0.93 1.09 ...
##  $ TRIGLYCERIDE     : int  156 43 108 65 51 29 31 220 82 35 ...
##  $ GLUCOSE          : int  83 86 93 90 92 85 72 87 96 92 ...
##  $ COTININE         : num  31.5918 0.0635 0.035 0.0514 0.035 ...
##  $ LDLESTIMATE      : int  111 97 81 132 58 99 93 135 98 69 ...
##  - attr(*, "na.action")=Class 'omit'  Named int [1:415] 2 15 16 24 26 28 33 34 35 39 ...
##   .. ..- attr(*, "names")= chr [1:415] "2" "15" "16" "24" ...

A tibble, or tbl_df, is a modern reimagining of the data frame, keeping what has been proven to be effective, and throwing out what is not. To convert a data frame to a tibble, we can use the function as.tibble(df) where df is a data frame. We will convert the HANES data frame into tibble and use the tibble from now on.

  # Load the tidyverse library
  library(tidyverse)
  # Convert HANES data frame into a tibble and observe it
  HANES_TIB <- as.tibble(HANES)
  HANES_TIB
## # A tibble: 1,112 x 23
##    KEY   GENDER SPAGE AGEGROUP HSQ_1 UCREATININE UALBUMIN    UACR MERCURYU
##  * <fct> <fct>  <int> <fct>    <fct>       <int>    <dbl>   <dbl>    <dbl>
##  1 1340… M         29 20-39    Very…         105    0.707 0.00673    0.370
##  2 1344… M         28 20-39    Very…          53    1.00  2.00       0.106
##  3 1344… M         27 20-39    Very…         314    8.00  3.00       0.487
##  4 1346… M         24 20-39    Exce…         105    4.00  4.00       2.21 
##  5 1346… M         30 20-39    Exce…         163    3.00  2.00       0.979
##  6 1352… M         26 20-39    Good          150    2.00  1.00       1.48 
##  7 1354… M         31 20-39    Exce…          46    2.00  4.00       0.106
##  8 1357… M         32 20-39    Very…          36    0.707 0.0196     0.238
##  9 1360… M         34 20-39    Exce…         177    4.00  2.00       2.30 
## 10 1362… M         32 20-39    Good          156    3.00  2.00       1.51 
## # ... with 1,102 more rows, and 14 more variables: DX_DBTS <fct>,
## #   A1C <dbl>, CADMIUM <dbl>, LEAD <dbl>, MERCURYTOTALBLOOD <dbl>,
## #   HDL <int>, CHOLESTEROLTOTAL <int>, GLUCOSESI <dbl>,
## #   CREATININESI <dbl>, CREATININE <dbl>, TRIGLYCERIDE <int>,
## #   GLUCOSE <int>, COTININE <dbl>, LDLESTIMATE <int>

When we take the logrithm of the variables A1C and UACR, we notice there are two clusters -

  # Make a ggplot for the log(A1C) and log(UACR) variables with asthetic color for the variable DX_DBTS
  ggplot(data = HANES_TIB) + 
    geom_point(mapping = aes(x = log(A1C), y = log(UACR), color=DX_DBTS))

These two clusters are primarily composed of non-diabetic people. In the set of all non-diabetic people, if we call the lower cluster (i.e, log(UACR) <= -2) as LOW-UACR and the upper cluster (i.e, log(UACR) >= -2) as HIGH-UACR, we would like to construct classification trees to predict these clusters using all variables except UACR. Therefore, we will classify them into these two classes -

  # We can use the mutate function to make this class label and remove non-necessary variables
  mydata <- HANES_TIB %>% filter(DX_DBTS == "NO DIAB") %>% 
    mutate(Cluster = ifelse(log(UACR) <= -2, 'LU', 'HU')) %>%
    select(everything(), -KEY, -GENDER, -AGEGROUP, -HSQ_1, -DX_DBTS)
  # Change names to smaller ones for tree labling purposes
  library(data.table)
  setnames(mydata, old = c('MERCURYTOTALBLOOD','CREATININESI', 'CHOLESTEROLTOTAL', 
                           'LDLESTIMATE', 'CADMIUM', 'MERCURYU', 'TRIGLYCERIDE'), 
           new=c('MERTOTBLD','CREATSI', 'CHOLTOT', 'LDLE', 'CAD', 'MERU', 'TRIG'))

The tree library can be used to construct classification and regression trees. We can use the tree() function to fit a classification tree in order to predict the Cluster variable:

  # Load the tree library
  library(tree)
  # We need to convert the Cluster class into factor to use the tree function
  mydata$Cluster <- as.factor(mydata$Cluster)
  # Use the tree function to construct the classification tree removing the UACR and 
  # variables UALBUMIN, UCREATININE that make up UACR
  tree.HANES <-  tree(Cluster~.-UACR-UALBUMIN-UCREATININE, data = mydata)
  # The summary() function lists the variables that are used as internal nodes 
  # in the tree, the number of terminal nodes, and the training error rate
  summary(tree.HANES)
## 
## Classification tree:
## tree(formula = Cluster ~ . - UACR - UALBUMIN - UCREATININE, data = mydata)
## Variables actually used in tree construction:
## [1] "MERU"      "LDLE"      "MERTOTBLD" "CREATSI"   "CAD"      
## Number of terminal nodes:  10 
## Residual mean deviance:  0.6167 = 596.3 / 967 
## Misclassification error rate: 0.1218 = 119 / 977

We see that the training error is 12%. For classification trees, the reported deviance is given by,

\[ -2 \sum_{m}\sum_{k}n_{mk}\log \hat{p}_{mk}\]

where \(n_{mk}\) is the number of observations in the \(m\)th terminal node that belong to the \(k\)th class. A small deviance indicates a good fit. The residual mean deviance is simply the deviance divided by \(n-|T_{0}|\), which in this case is \(977-10\). The plot() and text() functions can be handy in plotting the tree and viewing it -

  # Plot the tree and label it
  plot(tree.HANES)
  text(tree.HANES, use.n=TRUE, all=TRUE, cex=.6)

We should technically use the test error instead of the training error. We will split the observations into a training set and a test set, build the tree using the training set, and evaluate its performance on the test data. The predict() function can be used for this purpose. In the case of a classification tree, the argument type="class" instructs R to return the actual class prediction.

  # We will set seed and do a training on 700 data points
  set.seed (2)
  train <- sample(1:nrow(mydata), 700)
  # Form the test data by subtracting the training data indices
  test <- setdiff(seq(1,nrow(mydata)), train)
  # Make HANES test data
  HANES.test <- mydata[test, ]
  # Form the cluster data for tabulating the accuracy
  Cluster.test <- select(mydata, Cluster) 
  Cluster.test <- Cluster.test[test, ]
  Cluster.test <- as.character(t(Cluster.test))
  # Make the tree on the training data 
  tree.mydata <- tree(Cluster~.-UACR-UALBUMIN-UCREATININE, mydata, subset=train)
  # Use predict() function to make a prediction
  tree.pred <- predict(tree.mydata, HANES.test, type="class")
  # Form the accuracy table/confusion matrix
  table(tree.pred,Cluster.test)
##          Cluster.test
## tree.pred  HU  LU
##        HU 227  25
##        LU  20   5

We see the accuracy is (227+5)/277 = 0.83 or 83%.

Next we will verify pruning the tree improves the accuracy. We can use the function cv.tree() that will perform cross-validation to determine the optimal tree complexity. The function reports the number of terminal nodes of each tree considered (size) as well as the corresponding error rate and the value of the cost-complexity parameter.

  # Set the seed
  set.seed(3)
  # Use cv.tree() function to do cross-validation to prune and find the optimal tree
  cv.mydata <- cv.tree(tree.mydata, FUN=prune.misclass)
  # List and show the cross-validated data frame returned by the cv.tree() function
  names(cv.mydata)
## [1] "size"   "dev"    "k"      "method"
  cv.mydata
## $size
## [1] 27 22 19 15  9  6  5  1
## 
## $dev
## [1] 115 116 115 114 100 100 102  95
## 
## $k
## [1]      -Inf 0.0000000 0.3333333 0.7500000 1.5000000 1.6666667 2.0000000
## [8] 3.2500000
## 
## $method
## [1] "misclass"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"
  # Plot the deviance vs. the size of the tree and the pruning parameter k/alpha 
  par(mfrow=c(1,2))
  plot(cv.mydata$size, cv.mydata$dev, type="b")
  plot(cv.mydata$k, cv.mydata$dev, type="b")

  # We can use this data to find the misclassification rate
  prune.mydata=prune.misclass(tree.mydata,best=6)
  # Plot the pruned tree
  plot(prune.mydata)
  text(prune.mydata,cex=.6)
  # Verify if the accuracy increases on the pruned tree by applying the test data
  tree.pred=predict(prune.mydata, HANES.test, type="class")
  table(tree.pred,Cluster.test)
##          Cluster.test
## tree.pred  HU  LU
##        HU 239  26
##        LU   8   4


We see the test accuracy with the pruned tree is (239+4)/277 = 0.87 which is 87%. Not only the optimization through cross-validation has increased the accuracy but also made a more interpretatble tree.


Selected materials and references

R for Data Science - Data Transformation Part