Overview

Tree-based models basically consist of one or more nested if-then statements for the predictors that partition the data. Within these partitions, a specific model is used to predict the outcome. This recursive partitioning technique provides for exploration of the stucture of a set of data (outcome and predictors) and identification of easy to visualize decision rules for predicting a categorical (Classification Tree) or continuous (Regression Tree) outcome. In this tutorial we briefly describe the process of growing, examining, and pruning regression trees.

Outline

In this session we cover …

  1. Introduction to Data (Boston Data)
  2. Multivariate Regression Baseline
  3. Regression Tree (CART method): rpart (rpart package)
  4. Regression Tree (Conditional Inference method): ctree (partykit package)
  5. Conclusion

Prelim - Loading libraries used in this script.

library(MASS)  #for the Boston Data

library(psych)  #for general functions
library(ggplot2)  #for data visualization

# library(devtools)
# devtools::install_github('topepo/caret/pkg/caret') #May need the github version to correct a bug with parallelizing
library(caret)  #for training and cross validation (also calls other model libaries)
## Warning: Installed Rcpp (0.12.13) different from Rcpp used to build dplyr (0.12.11).
## Please reinstall dplyr to avoid random crashes or undefined behavior.
library(rpart)  #for trees
#library(rattle)    # Fancy tree plot This is a difficult library to install (https://gist.github.com/zhiyzuo/a489ffdcc5da87f28f8589a55aa206dd) 
library(rpart.plot)             # Enhanced tree plots
library(RColorBrewer)       # Color selection for fancy tree plot
library(party)                  # Alternative decision tree algorithm
library(partykit)               # Updated party functions

1. Introduction to Data

For this example we use data that accompanies the MASS package. No special reason these data were selected, other than they were used in some other examples we were working on. The data can be considered “typical” social science data, with a mix of nominal, count, and continuous variables. Of note, there are no missing data.

Reading in the Boston Data exploration data set.

#loading the data
data("Boston")

Prelim - Descriptives

Lets have a quick look at the data file and the descriptives.

#data structure
head(Boston,10)
##       crim   zn indus chas   nox    rm   age    dis rad tax ptratio  black
## 1  0.00632 18.0  2.31    0 0.538 6.575  65.2 4.0900   1 296    15.3 396.90
## 2  0.02731  0.0  7.07    0 0.469 6.421  78.9 4.9671   2 242    17.8 396.90
## 3  0.02729  0.0  7.07    0 0.469 7.185  61.1 4.9671   2 242    17.8 392.83
## 4  0.03237  0.0  2.18    0 0.458 6.998  45.8 6.0622   3 222    18.7 394.63
## 5  0.06905  0.0  2.18    0 0.458 7.147  54.2 6.0622   3 222    18.7 396.90
## 6  0.02985  0.0  2.18    0 0.458 6.430  58.7 6.0622   3 222    18.7 394.12
## 7  0.08829 12.5  7.87    0 0.524 6.012  66.6 5.5605   5 311    15.2 395.60
## 8  0.14455 12.5  7.87    0 0.524 6.172  96.1 5.9505   5 311    15.2 396.90
## 9  0.21124 12.5  7.87    0 0.524 5.631 100.0 6.0821   5 311    15.2 386.63
## 10 0.17004 12.5  7.87    0 0.524 6.004  85.9 6.5921   5 311    15.2 386.71
##    lstat medv
## 1   4.98 24.0
## 2   9.14 21.6
## 3   4.03 34.7
## 4   2.94 33.4
## 5   5.33 36.2
## 6   5.21 28.7
## 7  12.43 22.9
## 8  19.15 27.1
## 9  29.93 16.5
## 10 17.10 18.9

Our outcome of interest is medv: median value of owner-occupied homes in $1000s.

(Note that there is no id variable. This is convenient for some tasks.)

Descriptives

#sample descriptives
describe(Boston)
##         vars   n   mean     sd median trimmed    mad    min    max  range
## crim       1 506   3.61   8.60   0.26    1.68   0.33   0.01  88.98  88.97
## zn         2 506  11.36  23.32   0.00    5.08   0.00   0.00 100.00 100.00
## indus      3 506  11.14   6.86   9.69   10.93   9.37   0.46  27.74  27.28
## chas       4 506   0.07   0.25   0.00    0.00   0.00   0.00   1.00   1.00
## nox        5 506   0.55   0.12   0.54    0.55   0.13   0.38   0.87   0.49
## rm         6 506   6.28   0.70   6.21    6.25   0.51   3.56   8.78   5.22
## age        7 506  68.57  28.15  77.50   71.20  28.98   2.90 100.00  97.10
## dis        8 506   3.80   2.11   3.21    3.54   1.91   1.13  12.13  11.00
## rad        9 506   9.55   8.71   5.00    8.73   2.97   1.00  24.00  23.00
## tax       10 506 408.24 168.54 330.00  400.04 108.23 187.00 711.00 524.00
## ptratio   11 506  18.46   2.16  19.05   18.66   1.70  12.60  22.00   9.40
## black     12 506 356.67  91.29 391.44  383.17   8.09   0.32 396.90 396.58
## lstat     13 506  12.65   7.14  11.36   11.90   7.11   1.73  37.97  36.24
## medv      14 506  22.53   9.20  21.20   21.56   5.93   5.00  50.00  45.00
##          skew kurtosis   se
## crim     5.19    36.60 0.38
## zn       2.21     3.95 1.04
## indus    0.29    -1.24 0.30
## chas     3.39     9.48 0.01
## nox      0.72    -0.09 0.01
## rm       0.40     1.84 0.03
## age     -0.60    -0.98 1.25
## dis      1.01     0.46 0.09
## rad      1.00    -0.88 0.39
## tax      0.67    -1.15 7.49
## ptratio -0.80    -0.30 0.10
## black   -2.87     7.10 4.06
## lstat    0.90     0.46 0.32
## medv     1.10     1.45 0.41
#plots
pairs.panels(Boston)

#histogram of outcome
ggplot(data=Boston, aes(x=medv)) +
  geom_histogram(binwidth=1, boundary=.5, fill="white", color="black") + 
  labs(x = "Median Home Value")

#correlation matrix
round(cor(Boston),2)
##          crim    zn indus  chas   nox    rm   age   dis   rad   tax
## crim     1.00 -0.20  0.41 -0.06  0.42 -0.22  0.35 -0.38  0.63  0.58
## zn      -0.20  1.00 -0.53 -0.04 -0.52  0.31 -0.57  0.66 -0.31 -0.31
## indus    0.41 -0.53  1.00  0.06  0.76 -0.39  0.64 -0.71  0.60  0.72
## chas    -0.06 -0.04  0.06  1.00  0.09  0.09  0.09 -0.10 -0.01 -0.04
## nox      0.42 -0.52  0.76  0.09  1.00 -0.30  0.73 -0.77  0.61  0.67
## rm      -0.22  0.31 -0.39  0.09 -0.30  1.00 -0.24  0.21 -0.21 -0.29
## age      0.35 -0.57  0.64  0.09  0.73 -0.24  1.00 -0.75  0.46  0.51
## dis     -0.38  0.66 -0.71 -0.10 -0.77  0.21 -0.75  1.00 -0.49 -0.53
## rad      0.63 -0.31  0.60 -0.01  0.61 -0.21  0.46 -0.49  1.00  0.91
## tax      0.58 -0.31  0.72 -0.04  0.67 -0.29  0.51 -0.53  0.91  1.00
## ptratio  0.29 -0.39  0.38 -0.12  0.19 -0.36  0.26 -0.23  0.46  0.46
## black   -0.39  0.18 -0.36  0.05 -0.38  0.13 -0.27  0.29 -0.44 -0.44
## lstat    0.46 -0.41  0.60 -0.05  0.59 -0.61  0.60 -0.50  0.49  0.54
## medv    -0.39  0.36 -0.48  0.18 -0.43  0.70 -0.38  0.25 -0.38 -0.47
##         ptratio black lstat  medv
## crim       0.29 -0.39  0.46 -0.39
## zn        -0.39  0.18 -0.41  0.36
## indus      0.38 -0.36  0.60 -0.48
## chas      -0.12  0.05 -0.05  0.18
## nox        0.19 -0.38  0.59 -0.43
## rm        -0.36  0.13 -0.61  0.70
## age        0.26 -0.27  0.60 -0.38
## dis       -0.23  0.29 -0.50  0.25
## rad        0.46 -0.44  0.49 -0.38
## tax        0.46 -0.44  0.54 -0.47
## ptratio    1.00 -0.18  0.37 -0.51
## black     -0.18  1.00 -0.37  0.33
## lstat      0.37 -0.37  1.00 -0.74
## medv      -0.51  0.33 -0.74  1.00

Prelim - Split Training and Test Data

For independent comparison of model predictions, we partition the data into a Training Set and an independent Test Set

#Setting the random seed for replication
set.seed(1234)

#renaming data set 
dat <- Boston

#Spliting training set into two parts based on outcome: 75% and 25%
index <- sample(1:nrow(dat), size=0.75*nrow(dat))
trainData <- dat[index,]
testData <- dat[-index,]

# #Using caret package function  
index <- createDataPartition(dat$medv, times=1, p=0.75, list=FALSE)
trainData <- dat[index,]
testData <- dat[-index,]

There are some nuanced distinctions between indexes created using the base sample() function and the caret package’s createDataPartition() function. From the documentation for caret … For bootstrap samples, simple random sampling is used. For other data splitting, the random sampling is done within the levels of y when y is a factor in an attempt to balance the class distributions within the splits. For numeric y, the sample is split into groups sections based on percentiles and sampling is done within these subgroups. For createDataPartition, the number of percentiles is set via the groups argument. Also, for createDataPartition, very small class sizes (<= 3) the classes may not show up in both the training and test data.
Here, we proceed with the createDataPartition() version.

1. Regression - as a preliminary prediction model

For “baseline”, lets run a regression, prediciting medv from all other variables. (This is also a classification model)

#Running exploratory linear regression
lm.fit <- lm(medv ~., data=trainData)
summary(lm.fit) 
## 
## Call:
## lm(formula = medv ~ ., data = trainData)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -14.9608  -2.7813  -0.5848   1.5981  26.4313 
## 
## Coefficients:
##               Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  39.503683   6.162763   6.410 4.47e-10 ***
## crim         -0.108655   0.039259  -2.768  0.00593 ** 
## zn            0.045468   0.015854   2.868  0.00437 ** 
## indus         0.048381   0.070074   0.690  0.49036    
## chas          3.233716   1.004598   3.219  0.00140 ** 
## nox         -19.644558   4.461747  -4.403 1.40e-05 ***
## rm            3.573069   0.514584   6.944 1.74e-11 ***
## age          -0.000203   0.015225  -0.013  0.98937    
## dis          -1.457415   0.234219  -6.222 1.34e-09 ***
## rad           0.310900   0.077180   4.028 6.83e-05 ***
## tax          -0.012513   0.004223  -2.963  0.00324 ** 
## ptratio      -1.006436   0.158528  -6.349 6.42e-10 ***
## black         0.008316   0.003256   2.554  0.01105 *  
## lstat        -0.490541   0.058686  -8.359 1.33e-15 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 4.825 on 367 degrees of freedom
## Multiple R-squared:  0.7214, Adjusted R-squared:  0.7115 
## F-statistic: 73.09 on 13 and 367 DF,  p-value: < 2.2e-16

Fit of the regression is pretty good. \(R^2 = 0.72\)

Unfortunately, there do not seem to be any really good ways for visuaizing these models (besides when there are only two predictors we can obtain a prediction plane in 3-d space).

Test of Prediction

However, we would like to assess on the Test Data. We look at the squared correlation between predicted scores and actual scores in the Test Data.

cor(predict(lm.fit, newdata=testData),testData$medv)^2
## [1] 0.7946235

Also pretty good!

2. Regression Tree (CART method) - as an alternative prediction Model

Traditional Classification and Regression Trees (as described by Brieman, Freidman, Olshen, and Stone) can be generated through the rpart package. In the terminology of tree models, the data are recursively split into terminal nodes or leaves of the tree. To obtain a prediction for a new sample, we would follow the if-then statements defined by the tree using values of the new sample’s predictors until reaching a terminal node. The model formula in the terminal node would then be used to generate the prediction. In simple (traditional) trees, the model is a simple numeric value (yes/no, or a given numeric value). In other cases, the terminal node may be defined by a more complex function of the predictors (terminal nodes have models within them).

Tree-based and rule-based models are popular modeling tools for a number of reasons. (1) They generate a set of conditions that are highly interpretable and are easy to implement. (2) They can effectively handle many types of predictors (sparse, skewed, continuous, categorical, etc.) without the need for pre-processing. (3) These models do not require the user to specify the form of the predictors’ relationship to the response (e.g., linear, quadratic). (4) these models can (in some forms) effectively handle missing data and implicitly conduct feature selection. They have been extremely useful in many scenarios.

Basic implementation is done by Growing, Examining, Pruning - as illustrated below.

1. Grow a Tree

To grow a traditional tree, we can use the rpart() function in the rpart package.

tree.fit <- rpart(formula, data=, method=,control=) where
+formula is in the format outcome ~ predictor1+predictor2+predictor3+etc.
+data= specifies the data frame +method= “class” for a classification tree; “anova” for a regression tree +control= optional parameters for controlling tree growth. For example, control=rpart.control(minsplit=30,cp=0.001) requires that the minimum number of observations in a node be 30 before attempting a split and that a split must decrease the overall lack of fit by a factor of 0.001 (cost complexity factor) before being attempted.

rtree.fit <- rpart(medv ~ ., 
                  data=trainData,
                  method="anova", #for regression tree
                  control=rpart.control(minsplit=30,cp=0.001))

2. Examine the Tree

A collection of functions help us evaluate and examine the model.

+printcp(tree.fit) displays table of fits across cp (complexity parameter) values +rsq.rpart(tree.fit) plots approximate R-squared and relative error for different splits (2 plots). Labels are only appropriate for the “anova” method. +plotcp(tree.fit) plots the cross-validation results across cp values +print(tree.fit) print results +summary(tree.fit) detailed results including surrogate splits +plot(tree.fit) plot decision tree +text(tree.fit) label the decision tree plot +post(tree.fit, file=) create postscript plot of decision tree (there may be better ways to get good looking tree plots)

First we look at what the error looks like across the range of complexity parameters (depth of tree)

printcp(rtree.fit) # display the results 
## 
## Regression tree:
## rpart(formula = medv ~ ., data = trainData, method = "anova", 
##     control = rpart.control(minsplit = 30, cp = 0.001))
## 
## Variables actually used in tree construction:
## [1] crim  dis   indus lstat nox   rad   rm   
## 
## Root node error: 30668/381 = 80.493
## 
## n= 381 
## 
##           CP nsplit rel error  xerror     xstd
## 1  0.4406452      0   1.00000 1.00597 0.098945
## 2  0.1588911      1   0.55935 0.67973 0.069755
## 3  0.0835210      2   0.40046 0.50136 0.060344
## 4  0.0495001      3   0.31694 0.45661 0.055934
## 5  0.0263705      4   0.26744 0.39656 0.050995
## 6  0.0143825      5   0.24107 0.35588 0.050372
## 7  0.0094562      6   0.22669 0.32995 0.050420
## 8  0.0088834      7   0.21723 0.32011 0.050071
## 9  0.0080952      8   0.20835 0.31467 0.049982
## 10 0.0073540      9   0.20025 0.31412 0.049975
## 11 0.0068629     10   0.19290 0.31171 0.049984
## 12 0.0052839     11   0.18604 0.31145 0.049931
## 13 0.0031558     12   0.18075 0.30537 0.048024
## 14 0.0029893     13   0.17760 0.30137 0.047905
## 15 0.0028718     14   0.17461 0.30254 0.047908
## 16 0.0021683     15   0.17174 0.30474 0.048736
## 17 0.0016971     16   0.16957 0.30548 0.048754
## 18 0.0012029     17   0.16787 0.30547 0.047800
## 19 0.0010000     18   0.16667 0.30646 0.047791
rsq.rpart(rtree.fit) #produces 2 plots
## 
## Regression tree:
## rpart(formula = medv ~ ., data = trainData, method = "anova", 
##     control = rpart.control(minsplit = 30, cp = 0.001))
## 
## Variables actually used in tree construction:
## [1] crim  dis   indus lstat nox   rad   rm   
## 
## Root node error: 30668/381 = 80.493
## 
## n= 381 
## 
##           CP nsplit rel error  xerror     xstd
## 1  0.4406452      0   1.00000 1.00597 0.098945
## 2  0.1588911      1   0.55935 0.67973 0.069755
## 3  0.0835210      2   0.40046 0.50136 0.060344
## 4  0.0495001      3   0.31694 0.45661 0.055934
## 5  0.0263705      4   0.26744 0.39656 0.050995
## 6  0.0143825      5   0.24107 0.35588 0.050372
## 7  0.0094562      6   0.22669 0.32995 0.050420
## 8  0.0088834      7   0.21723 0.32011 0.050071
## 9  0.0080952      8   0.20835 0.31467 0.049982
## 10 0.0073540      9   0.20025 0.31412 0.049975
## 11 0.0068629     10   0.19290 0.31171 0.049984
## 12 0.0052839     11   0.18604 0.31145 0.049931
## 13 0.0031558     12   0.18075 0.30537 0.048024
## 14 0.0029893     13   0.17760 0.30137 0.047905
## 15 0.0028718     14   0.17461 0.30254 0.047908
## 16 0.0021683     15   0.17174 0.30474 0.048736
## 17 0.0016971     16   0.16957 0.30548 0.048754
## 18 0.0012029     17   0.16787 0.30547 0.047800
## 19 0.0010000     18   0.16667 0.30646 0.047791

plotcp(rtree.fit) # visualize cross-validation results 

#A good choice of cp for pruning is often the leftmost value for which the mean lies below the horizontal line

The detailed sumamry of the tree.

summary(rtree.fit) # detailed summary of splits
## Call:
## rpart(formula = medv ~ ., data = trainData, method = "anova", 
##     control = rpart.control(minsplit = 30, cp = 0.001))
##   n= 381 
## 
##             CP nsplit rel error    xerror       xstd
## 1  0.440645242      0 1.0000000 1.0059747 0.09894494
## 2  0.158891093      1 0.5593548 0.6797253 0.06975529
## 3  0.083520969      2 0.4004637 0.5013565 0.06034401
## 4  0.049500097      3 0.3169427 0.4566145 0.05593426
## 5  0.026370500      4 0.2674426 0.3965637 0.05099456
## 6  0.014382462      5 0.2410721 0.3558771 0.05037202
## 7  0.009456155      6 0.2266896 0.3299511 0.05042039
## 8  0.008883415      7 0.2172335 0.3201078 0.05007067
## 9  0.008095237      8 0.2083501 0.3146668 0.04998226
## 10 0.007353989      9 0.2002548 0.3141157 0.04997481
## 11 0.006862936     10 0.1929008 0.3117138 0.04998378
## 12 0.005283919     11 0.1860379 0.3114460 0.04993099
## 13 0.003155807     12 0.1807540 0.3053746 0.04802427
## 14 0.002989275     13 0.1775982 0.3013679 0.04790516
## 15 0.002871761     14 0.1746089 0.3025378 0.04790840
## 16 0.002168256     15 0.1717371 0.3047402 0.04873559
## 17 0.001697102     16 0.1695689 0.3054756 0.04875382
## 18 0.001202905     17 0.1678718 0.3054687 0.04780020
## 19 0.001000000     18 0.1666689 0.3064591 0.04779130
## 
## Variable importance
##   lstat      rm   indus    crim     age      zn     nox     dis     tax 
##      29      21      13      10      10       8       4       3       1 
##     rad ptratio 
##       1       1 
## 
## Node number 1: 381 observations,    complexity param=0.4406452
##   mean=22.37323, MSE=80.49277 
##   left son=2 (247 obs) right son=3 (134 obs)
##   Primary splits:
##       lstat   < 8.935    to the right, improve=0.4406452, (0 missing)
##       rm      < 6.9715   to the left,  improve=0.4258611, (0 missing)
##       indus   < 6.66     to the right, improve=0.2473444, (0 missing)
##       ptratio < 19.9     to the right, improve=0.2279278, (0 missing)
##       nox     < 0.6695   to the right, improve=0.2090672, (0 missing)
##   Surrogate splits:
##       rm    < 6.4775   to the left,  agree=0.811, adj=0.463, (0 split)
##       indus < 7.625    to the right, agree=0.803, adj=0.440, (0 split)
##       age   < 41.8     to the right, agree=0.777, adj=0.366, (0 split)
##       zn    < 16.25    to the left,  agree=0.772, adj=0.351, (0 split)
##       crim  < 0.08276  to the right, agree=0.764, adj=0.328, (0 split)
## 
## Node number 2: 247 observations,    complexity param=0.08352097
##   mean=17.98664, MSE=26.29994 
##   left son=4 (133 obs) right son=5 (114 obs)
##   Primary splits:
##       lstat < 14.4     to the right, improve=0.3942990, (0 missing)
##       dis   < 2.0754   to the left,  improve=0.3460477, (0 missing)
##       crim  < 5.84803  to the right, improve=0.3416487, (0 missing)
##       nox   < 0.6635   to the right, improve=0.3177246, (0 missing)
##       age   < 88.7     to the right, improve=0.2379447, (0 missing)
##   Surrogate splits:
##       age   < 88.1     to the right, agree=0.781, adj=0.526, (0 split)
##       indus < 16.57    to the right, agree=0.733, adj=0.421, (0 split)
##       dis   < 2.23935  to the left,  agree=0.733, adj=0.421, (0 split)
##       crim  < 0.166705 to the right, agree=0.729, adj=0.412, (0 split)
##       nox   < 0.5765   to the right, agree=0.725, adj=0.404, (0 split)
## 
## Node number 3: 134 observations,    complexity param=0.1588911
##   mean=30.45896, MSE=79.53779 
##   left son=6 (115 obs) right son=7 (19 obs)
##   Primary splits:
##       rm      < 7.4545   to the left,  improve=0.4571967, (0 missing)
##       lstat   < 4.15     to the right, improve=0.3785168, (0 missing)
##       age     < 89.35    to the left,  improve=0.1937996, (0 missing)
##       nox     < 0.574    to the left,  improve=0.1835016, (0 missing)
##       ptratio < 15       to the right, improve=0.1794687, (0 missing)
##   Surrogate splits:
##       lstat < 3.21     to the right, agree=0.896, adj=0.263, (0 split)
## 
## Node number 4: 133 observations,    complexity param=0.0263705
##   mean=15.00526, MSE=18.98035 
##   left son=8 (76 obs) right son=9 (57 obs)
##   Primary splits:
##       nox   < 0.607    to the right, improve=0.3203645, (0 missing)
##       dis   < 1.92035  to the left,  improve=0.3136884, (0 missing)
##       crim  < 5.7819   to the right, improve=0.3064597, (0 missing)
##       lstat < 19.83    to the right, improve=0.2513090, (0 missing)
##       tax   < 567.5    to the right, improve=0.2357110, (0 missing)
##   Surrogate splits:
##       tax   < 397      to the right, agree=0.857, adj=0.667, (0 split)
##       dis   < 2.38405  to the left,  agree=0.850, adj=0.649, (0 split)
##       indus < 16.01    to the right, agree=0.842, adj=0.632, (0 split)
##       crim  < 1.40092  to the right, agree=0.797, adj=0.526, (0 split)
##       rad   < 16       to the right, agree=0.729, adj=0.368, (0 split)
## 
## Node number 5: 114 observations,    complexity param=0.009456155
##   mean=21.46491, MSE=12.37105 
##   left son=10 (103 obs) right son=11 (11 obs)
##   Primary splits:
##       rm      < 6.616    to the left,  improve=0.20562930, (0 missing)
##       indus   < 4.22     to the right, improve=0.12641370, (0 missing)
##       lstat   < 9.725    to the right, improve=0.10074740, (0 missing)
##       ptratio < 17.85    to the right, improve=0.09394059, (0 missing)
##       tax     < 278      to the right, improve=0.08978631, (0 missing)
##   Surrogate splits:
##       ptratio < 13.85    to the right, agree=0.921, adj=0.182, (0 split)
## 
## Node number 6: 115 observations,    complexity param=0.0495001
##   mean=28.00783, MSE=42.94976 
##   left son=12 (66 obs) right son=13 (49 obs)
##   Primary splits:
##       rm    < 6.659    to the left,  improve=0.3073472, (0 missing)
##       lstat < 5.06     to the right, improve=0.2640314, (0 missing)
##       dis   < 1.9704   to the right, improve=0.2601843, (0 missing)
##       age   < 89.45    to the left,  improve=0.2487178, (0 missing)
##       nox   < 0.589    to the left,  improve=0.2173094, (0 missing)
##   Surrogate splits:
##       lstat   < 5.06     to the right, agree=0.722, adj=0.347, (0 split)
##       indus   < 4.01     to the right, agree=0.704, adj=0.306, (0 split)
##       zn      < 31.5     to the left,  agree=0.687, adj=0.265, (0 split)
##       ptratio < 15.55    to the right, agree=0.670, adj=0.224, (0 split)
##       nox     < 0.4045   to the right, agree=0.635, adj=0.143, (0 split)
## 
## Node number 7: 19 observations
##   mean=45.29474, MSE=44.52681 
## 
## Node number 8: 76 observations,    complexity param=0.01438246
##   mean=12.86974, MSE=13.41764 
##   left son=16 (41 obs) right son=17 (35 obs)
##   Primary splits:
##       lstat < 19.645   to the right, improve=0.43253920, (0 missing)
##       crim  < 9.87002  to the right, improve=0.40868340, (0 missing)
##       dis   < 1.92035  to the left,  improve=0.28536200, (0 missing)
##       tax   < 551.5    to the right, improve=0.14356430, (0 missing)
##       nox   < 0.7065   to the left,  improve=0.09989442, (0 missing)
##   Surrogate splits:
##       dis  < 1.6727   to the left,  agree=0.776, adj=0.514, (0 split)
##       crim < 9.55467  to the right, agree=0.750, adj=0.457, (0 split)
##       rm   < 5.574    to the left,  agree=0.671, adj=0.286, (0 split)
##       nox  < 0.7065   to the left,  agree=0.645, adj=0.229, (0 split)
##       age  < 97.6     to the right, agree=0.632, adj=0.200, (0 split)
## 
## Node number 9: 57 observations,    complexity param=0.005283919
##   mean=17.85263, MSE=12.20916 
##   left son=18 (26 obs) right son=19 (31 obs)
##   Primary splits:
##       crim    < 0.55381  to the right, improve=0.23285060, (0 missing)
##       ptratio < 19.45    to the right, improve=0.19027740, (0 missing)
##       black   < 378.085  to the left,  improve=0.18937010, (0 missing)
##       nox     < 0.531    to the right, improve=0.13829190, (0 missing)
##       tax     < 280.5    to the right, improve=0.09768062, (0 missing)
##   Surrogate splits:
##       ptratio < 19.95    to the right, agree=0.912, adj=0.808, (0 split)
##       nox     < 0.531    to the right, agree=0.807, adj=0.577, (0 split)
##       rad     < 16       to the right, agree=0.807, adj=0.577, (0 split)
##       tax     < 567.5    to the right, agree=0.807, adj=0.577, (0 split)
##       black   < 377.48   to the left,  agree=0.807, adj=0.577, (0 split)
## 
## Node number 10: 103 observations,    complexity param=0.002989275
##   mean=20.94369, MSE=8.357412 
##   left son=20 (91 obs) right son=21 (12 obs)
##   Primary splits:
##       indus   < 4.22     to the right, improve=0.10649730, (0 missing)
##       rm      < 6.0775   to the left,  improve=0.09588832, (0 missing)
##       tax     < 278      to the right, improve=0.07208953, (0 missing)
##       ptratio < 18.65    to the right, improve=0.07165208, (0 missing)
##       dis     < 3.734    to the right, improve=0.06275371, (0 missing)
##   Surrogate splits:
##       zn  < 57.5     to the left,  agree=0.913, adj=0.250, (0 split)
##       nox < 0.4035   to the right, agree=0.913, adj=0.250, (0 split)
##       age < 32.75    to the right, agree=0.903, adj=0.167, (0 split)
##       dis < 8.57235  to the left,  agree=0.903, adj=0.167, (0 split)
##       tax < 208      to the right, agree=0.893, adj=0.083, (0 split)
## 
## Node number 11: 11 observations
##   mean=26.34545, MSE=23.58975 
## 
## Node number 12: 66 observations,    complexity param=0.008883415
##   mean=24.87727, MSE=30.79024 
##   left son=24 (52 obs) right son=25 (14 obs)
##   Primary splits:
##       rad   < 5.5      to the left,  improve=0.1340617, (0 missing)
##       lstat < 5.41     to the right, improve=0.1334019, (0 missing)
##       crim  < 0.39646  to the left,  improve=0.1307344, (0 missing)
##       black < 376.935  to the right, improve=0.1172771, (0 missing)
##       dis   < 3.58055  to the right, improve=0.1093455, (0 missing)
##   Surrogate splits:
##       crim  < 2.98347  to the left,  agree=0.833, adj=0.214, (0 split)
##       tax   < 548      to the left,  agree=0.833, adj=0.214, (0 split)
##       nox   < 0.618    to the left,  agree=0.818, adj=0.143, (0 split)
##       dis   < 1.5449   to the right, agree=0.818, adj=0.143, (0 split)
##       lstat < 4.04     to the right, agree=0.818, adj=0.143, (0 split)
## 
## Node number 13: 49 observations,    complexity param=0.008095237
##   mean=32.22449, MSE=28.34716 
##   left son=26 (34 obs) right son=27 (15 obs)
##   Primary splits:
##       lstat < 4.6      to the right, improve=0.17873350, (0 missing)
##       dis   < 3.14095  to the right, improve=0.17635500, (0 missing)
##       rm    < 6.941    to the left,  improve=0.15531190, (0 missing)
##       crim  < 0.159085 to the left,  improve=0.14585030, (0 missing)
##       indus < 6.305    to the left,  improve=0.09883321, (0 missing)
##   Surrogate splits:
##       indus < 6.305    to the left,  agree=0.776, adj=0.267, (0 split)
##       tax   < 400      to the left,  agree=0.776, adj=0.267, (0 split)
##       crim  < 0.943545 to the left,  agree=0.755, adj=0.200, (0 split)
##       dis   < 1.88595  to the right, agree=0.755, adj=0.200, (0 split)
##       zn    < 75       to the left,  agree=0.735, adj=0.133, (0 split)
## 
## Node number 16: 41 observations,    complexity param=0.003155807
##   mean=10.6439, MSE=8.77856 
##   left son=32 (26 obs) right son=33 (15 obs)
##   Primary splits:
##       crim  < 9.87002  to the right, improve=0.2688965, (0 missing)
##       rad   < 14.5     to the right, improve=0.2109716, (0 missing)
##       indus < 18.84    to the left,  improve=0.2109716, (0 missing)
##       nox   < 0.729    to the left,  improve=0.1837369, (0 missing)
##       dis   < 1.464    to the right, improve=0.1803603, (0 missing)
##   Surrogate splits:
##       indus   < 18.84    to the left,  agree=0.878, adj=0.667, (0 split)
##       rad     < 14.5     to the right, agree=0.878, adj=0.667, (0 split)
##       tax     < 551.5    to the right, agree=0.829, adj=0.533, (0 split)
##       ptratio < 20.15    to the right, agree=0.805, adj=0.467, (0 split)
##       nox     < 0.646    to the right, agree=0.780, adj=0.400, (0 split)
## 
## Node number 17: 35 observations,    complexity param=0.002168256
##   mean=15.47714, MSE=6.249763 
##   left son=34 (17 obs) right son=35 (18 obs)
##   Primary splits:
##       crim  < 5.76921  to the right, improve=0.30399110, (0 missing)
##       dis   < 1.9467   to the left,  improve=0.11615010, (0 missing)
##       rm    < 6.1405   to the right, improve=0.09596736, (0 missing)
##       black < 318.38   to the left,  improve=0.09543795, (0 missing)
##       nox   < 0.675    to the right, improve=0.02644998, (0 missing)
##   Surrogate splits:
##       indus < 18.84    to the left,  agree=0.800, adj=0.588, (0 split)
##       rad   < 14.5     to the right, agree=0.800, adj=0.588, (0 split)
##       rm    < 6.1685   to the right, agree=0.771, adj=0.529, (0 split)
##       tax   < 551.5    to the right, agree=0.771, adj=0.529, (0 split)
##       nox   < 0.663    to the right, agree=0.743, adj=0.471, (0 split)
## 
## Node number 18: 26 observations
##   mean=16.01154, MSE=12.15102 
## 
## Node number 19: 31 observations
##   mean=19.39677, MSE=7.030635 
## 
## Node number 20: 91 observations,    complexity param=0.001697102
##   mean=20.6011, MSE=5.389779 
##   left son=40 (52 obs) right son=41 (39 obs)
##   Primary splits:
##       rm      < 6.0775   to the left,  improve=0.10611520, (0 missing)
##       indus   < 10.3     to the left,  improve=0.06608626, (0 missing)
##       dis     < 5.58775  to the right, improve=0.06204415, (0 missing)
##       rad     < 6.5      to the left,  improve=0.05724333, (0 missing)
##       ptratio < 20.95    to the right, improve=0.04823231, (0 missing)
##   Surrogate splits:
##       age     < 71.55    to the left,  agree=0.659, adj=0.205, (0 split)
##       tax     < 394.5    to the left,  agree=0.626, adj=0.128, (0 split)
##       ptratio < 20.55    to the left,  agree=0.626, adj=0.128, (0 split)
##       crim    < 3.36614  to the left,  agree=0.615, adj=0.103, (0 split)
##       indus   < 9.795    to the left,  agree=0.615, adj=0.103, (0 split)
## 
## Node number 21: 12 observations
##   mean=23.54167, MSE=23.22243 
## 
## Node number 24: 52 observations,    complexity param=0.006862936
##   mean=23.82308, MSE=11.57331 
##   left son=48 (13 obs) right son=49 (39 obs)
##   Primary splits:
##       lstat   < 7.62     to the right, improve=0.3497283, (0 missing)
##       rm      < 6.543    to the left,  improve=0.2758680, (0 missing)
##       nox     < 0.5125   to the right, improve=0.1822342, (0 missing)
##       tax     < 267.5    to the right, improve=0.1690680, (0 missing)
##       ptratio < 19.4     to the right, improve=0.1271483, (0 missing)
##   Surrogate splits:
##       rm      < 6.053    to the left,  agree=0.865, adj=0.462, (0 split)
##       rad     < 1.5      to the left,  agree=0.769, adj=0.077, (0 split)
##       ptratio < 20.6     to the right, agree=0.769, adj=0.077, (0 split)
## 
## Node number 25: 14 observations
##   mean=28.79286, MSE=82.70781 
## 
## Node number 26: 34 observations,    complexity param=0.007353989
##   mean=30.72941, MSE=15.35031 
##   left son=52 (22 obs) right son=53 (12 obs)
##   Primary splits:
##       rm    < 7.127    to the left,  improve=0.43212440, (0 missing)
##       tax   < 264.5    to the right, improve=0.24548080, (0 missing)
##       indus < 3.19     to the right, improve=0.09197057, (0 missing)
##       crim  < 0.0572   to the left,  improve=0.05274750, (0 missing)
##       lstat < 5.495    to the right, improve=0.04800032, (0 missing)
##   Surrogate splits:
##       indus < 2.21     to the right, agree=0.735, adj=0.250, (0 split)
##       tax   < 264.5    to the right, agree=0.735, adj=0.250, (0 split)
##       crim  < 0.301555 to the left,  agree=0.706, adj=0.167, (0 split)
##       dis   < 1.9704   to the right, agree=0.706, adj=0.167, (0 split)
##       nox   < 0.61     to the left,  agree=0.676, adj=0.083, (0 split)
## 
## Node number 27: 15 observations
##   mean=35.61333, MSE=41.25582 
## 
## Node number 32: 26 observations
##   mean=9.476923, MSE=5.596391 
## 
## Node number 33: 15 observations
##   mean=12.66667, MSE=7.842222 
## 
## Node number 34: 17 observations
##   mean=14.05882, MSE=4.025952 
## 
## Node number 35: 18 observations
##   mean=16.81667, MSE=4.655833 
## 
## Node number 40: 52 observations,    complexity param=0.001202905
##   mean=19.94615, MSE=6.364024 
##   left son=80 (10 obs) right son=81 (42 obs)
##   Primary splits:
##       dis   < 5.58775  to the right, improve=0.11147510, (0 missing)
##       indus < 10.3     to the left,  improve=0.09947490, (0 missing)
##       rad   < 7.5      to the left,  improve=0.06656262, (0 missing)
##       nox   < 0.5485   to the left,  improve=0.05213346, (0 missing)
##       black < 376.835  to the left,  improve=0.04388885, (0 missing)
##   Surrogate splits:
##       zn    < 6.25     to the right, agree=0.904, adj=0.5, (0 split)
##       indus < 5.16     to the left,  agree=0.865, adj=0.3, (0 split)
##       nox   < 0.445    to the left,  agree=0.865, adj=0.3, (0 split)
##       crim  < 0.0327   to the left,  agree=0.846, adj=0.2, (0 split)
## 
## Node number 41: 39 observations
##   mean=21.47436, MSE=2.756266 
## 
## Node number 48: 13 observations
##   mean=20.33846, MSE=9.045444 
## 
## Node number 49: 39 observations,    complexity param=0.002871761
##   mean=24.98462, MSE=7.01925 
##   left son=98 (22 obs) right son=99 (17 obs)
##   Primary splits:
##       rm    < 6.428    to the left,  improve=0.3217176, (0 missing)
##       tax   < 278      to the right, improve=0.2797394, (0 missing)
##       rad   < 3.5      to the right, improve=0.2432884, (0 missing)
##       lstat < 5.745    to the right, improve=0.2058624, (0 missing)
##       indus < 4.1      to the right, improve=0.1917861, (0 missing)
##   Surrogate splits:
##       lstat < 5.495    to the right, agree=0.795, adj=0.529, (0 split)
##       indus < 3.095    to the right, agree=0.718, adj=0.353, (0 split)
##       zn    < 34       to the left,  agree=0.692, adj=0.294, (0 split)
##       tax   < 280.5    to the right, agree=0.692, adj=0.294, (0 split)
##       crim  < 0.02819  to the right, agree=0.667, adj=0.235, (0 split)
## 
## Node number 52: 22 observations
##   mean=28.82727, MSE=11.85198 
## 
## Node number 53: 12 observations
##   mean=34.21667, MSE=2.969722 
## 
## Node number 80: 10 observations
##   mean=18.22, MSE=0.8736 
## 
## Node number 81: 42 observations
##   mean=20.35714, MSE=6.792925 
## 
## Node number 98: 22 observations
##   mean=23.66364, MSE=1.891405 
## 
## Node number 99: 17 observations
##   mean=26.69412, MSE=8.474671

That is a lot of output, but here we can also look at the predictors used in the tree and their relative importance in the prediction. We see specifically that rm (average number of rooms per dwelling) and lstat (lower status of the population, percent) are driving much of the prediction.

This particular tree methodology can also handle missing data. When building the tree, missing data are ignored. For each split, a variety of alternatives (called surrogate splits) are evaluated. A surrogate split is one whose results are similar to the original split actually used in the tree. If a surrogate split approximates the original split well, it can be used when the predictor data associated with the original split are not available. In practice, several surrogate splits may be saved for any particular split in the tree.

Plotting the tree.

# plot tree (old schol way)
plot(rtree.fit, uniform=TRUE, 
    main="Regression Tree for Median Home Value")
text(rtree.fit, use.n=TRUE, all=TRUE, cex=.8)

# create more atrractive plot of tree 
#using prp() in the rpart.plot package
prp(rtree.fit)

#using Rattle package
#fancyRpartPlot(rtree.fit)

We see the intuitive value of the tree method in the plot.

3. Prune the Tree

Prune back the tree to avoid overfitting the data. Hastie et al. (2008) suggest selecting the tree size associated with the numerically smallest error. That is, the size of the tree is selected by examining the error using cross-validation, specifically the minimum of the xerror column (cross-validation error) printed by printcp( ).

Pruning is easily done using the function prune(fit, cp= ) by examining the cross-validated error results from printcp(), selecting the complexity parameter associated with minimum error, and placing it into the prune( ) function. Alternatively, this can be automated using tree.fit$cptable[which.min(tree.fit$cptable[,"xerror"]),"CP"].

# prune the tree based on minimim xerror
pruned.rtree.fit<- prune(rtree.fit, cp= rtree.fit$cptable[which.min(rtree.fit$cptable[,"xerror"]),"CP"])

# plot the pruned tree using prp() in the rpart.plot package 
prp(pruned.rtree.fit, main="Pruned Regression Tree for Median Home Value")

In this case the pruned tree is not that much smaller than the original tree.

There are, of course other approaches for pruning. Breiman et al. (1984) suggest using the cross-validation approach and applying a one-standard-error rule on the optimization criteria for identifying the simplest tree. That is, find the smallest tree that is within one standard error of the tree with smallest absolute error, which is the leftmost cp value for which the mean lies below the horizontal line placed 1 SE above the minmum of the curve by the minline in the plotcp() function.

# prune the tree based on 1 SE error 
pruned2.rtree.fit<- prune(rtree.fit, cp=.01)

# plot the pruned tree using prp() in the rpart.plot package
prp(pruned2.rtree.fit, main="Pruned Regression Tree for Median Home Value")

Test of Prediction

Finally, for comparison with the regression model, we examine the \(R^2\) of the original and pruned trees. (Note: The predictive value of the model would typically be established through cross-validation and test samples. We do the below only for didactic illustration.)

#original tree
cor(predict(rtree.fit, newdata=testData),testData$medv)^2
## [1] 0.8184857
#pruned tree #1
cor(predict(pruned.rtree.fit, newdata=testData),testData$medv)^2
## [1] 0.8110404
#pruned tree #2
cor(predict(pruned2.rtree.fit, newdata=testData),testData$medv)^2
## [1] 0.7846115

We see here the tradeoff between “overfit” to training data and potential generalizability to new data. More formal evlauations would be done using cross-validation. But the smaller pruned tree is still doing pretty well (almost as well as the multiple regression).

2.Regression Tree (Conditional Inference method) - as an alternative prediction Model

Traditional CART-based trees recursively perform univariate splits of the dependent variable based on values on a set of covariates. An information measures (such as the Gini coefficient) is used to select the current covariate. There is, however, a variable selection bias in the algorithms used in the traditional (rpart and related methods) algorithms. These approaches tend to select variables that have many possible splits or many missing values.

To overcome that bias, conditional inference trees were introduced. Unlike the other approaches, Conditional Inference Trees use a significance test procedure to select variables at each split. The significance test, or better: the multiple significance tests computed at each start of the algorithm (select covariate - choose split - recurse) are permutation tests that are used to obtain the the distribution of the test statistic under the null hypothesis (by calculating all possible values of the test statistic under rearrangements of the labels on the observed data points (see wikipedia).

More details can be found here https://stats.stackexchange.com/questions/12140/conditional-inference-trees-vs-traditional-decision-trees, and in the original paper here http://statmath.wu-wien.ac.at/~zeileis/papers/Hothorn+Hornik+Zeileis-2006.pdf.

The steps for implementation are largely the same: Grow, examine (and maybe prune).

1. Grow a Tree

To grow a tree using the conditional inference method, we can use the party (party: A Laboratory for Recursive Partitioning) package or the updated package partykit (partykit: A Toolkit for Recursive Partytioning). This package provides nonparametric regression trees for nominal, ordinal, numeric, censored, and multivariate responses.

Specifically, regression or classification trees are obtained using the function +ctree(formula, data=, control=) where +formula is in the format outcome ~ predictor1+predictor2+predictor3+etc.
+data= specifies the data frame +control= optional parameters for controlling tree growth. For example, control=ctree_control(maxdepth=3) requires that the maximum depth of the tree is 3. The default maxdepth = Inf means that no restrictions are applied to tree size.

ctree.fit <- ctree(medv ~ ., 
                   data=trainData,
                   control=ctree_control(maxdepth=Inf))

2. Examine the Tree

A collection of functions help us evaluate and examine the model.

+print(tree.fit) displays the details of the tree
+plot(tree.fit) plot decision tree

For our example, this is

print(ctree.fit) # display the results 
## 
## Model formula:
## medv ~ crim + zn + indus + chas + nox + rm + age + dis + rad + 
##     tax + ptratio + black + lstat
## 
## Fitted party:
## [1] root
## |   [2] lstat <= 8.93
## |   |   [3] rm <= 7.42
## |   |   |   [4] crim <= 1.05393
## |   |   |   |   [5] rm <= 6.54
## |   |   |   |   |   [6] nox <= 0.51
## |   |   |   |   |   |   [7] rm <= 6.121: 21.229 (n = 7, err = 8.1)
## |   |   |   |   |   |   [8] rm > 6.121: 24.166 (n = 29, err = 79.4)
## |   |   |   |   |   [9] nox > 0.51: 19.956 (n = 9, err = 101.1)
## |   |   |   |   [10] rm > 6.54
## |   |   |   |   |   [11] rm <= 6.957: 28.397 (n = 37, err = 340.7)
## |   |   |   |   |   [12] rm > 6.957: 33.346 (n = 24, err = 229.8)
## |   |   |   [13] crim > 1.05393: 37.878 (n = 9, err = 1280.5)
## |   |   [14] rm > 7.42: 45.295 (n = 19, err = 846.0)
## |   [15] lstat > 8.93
## |   |   [16] lstat <= 14.37
## |   |   |   [17] rm <= 6.59: 20.944 (n = 103, err = 860.8)
## |   |   |   [18] rm > 6.59: 26.345 (n = 11, err = 259.5)
## |   |   [19] lstat > 14.37
## |   |   |   [20] tax <= 469
## |   |   |   |   [21] crim <= 0.43571: 18.881 (n = 37, err = 291.6)
## |   |   |   |   [22] crim > 0.43571: 14.857 (n = 23, err = 60.7)
## |   |   |   [23] tax > 469
## |   |   |   |   [24] dis <= 1.9976: 11.320 (n = 46, err = 732.7)
## |   |   |   |   [25] dis > 1.9976: 16.100 (n = 27, err = 225.9)
## 
## Number of inner nodes:    12
## Number of terminal nodes: 13
plot(ctree.fit,
     main="Regression CTree for Median Home Value")

For comparison with the regression model, we examine the \(R^2\) of the conditional inference tree. (Note: The predictive value of the model would typically be established through cross-validation across many test samples.)

#R-square conditional inference tree
cor(predict(ctree.fit, newdata=testData),testData$medv)^2
## [1] 0.7720055

Although the statistical approach ensures that the right-sized tree is grown without additional (post-)pruning or cross-validation, the depth of the tree here is rather large (6 levels and 13 terminal nodes), which of course makes interpretation more difficult than with less deep trees.

3. Prune the Tree

Prune back the tree to avoid overfitting the data. This time we migth simply prune for simplicity of plotting and interpretation. Pruning is done by regrowing with a different control parameter.

# regrow the tree with small depth, maxdepth = 3
pruned.ctree.fit<- ctree(medv ~ ., 
                         data=trainData,
                         control=ctree_control(maxdepth=3))

#examine pruned tree
print(pruned.ctree.fit) # display the results 
## 
## Model formula:
## medv ~ crim + zn + indus + chas + nox + rm + age + dis + rad + 
##     tax + ptratio + black + lstat
## 
## Fitted party:
## [1] root
## |   [2] lstat <= 8.93
## |   |   [3] rm <= 7.42
## |   |   |   [4] crim <= 1.05393: 27.170 (n = 106, err = 2707.5)
## |   |   |   [5] crim > 1.05393: 37.878 (n = 9, err = 1280.5)
## |   |   [6] rm > 7.42: 45.295 (n = 19, err = 846.0)
## |   [7] lstat > 8.93
## |   |   [8] lstat <= 14.37
## |   |   |   [9] rm <= 6.59: 20.944 (n = 103, err = 860.8)
## |   |   |   [10] rm > 6.59: 26.345 (n = 11, err = 259.5)
## |   |   [11] lstat > 14.37
## |   |   |   [12] tax <= 469: 17.338 (n = 60, err = 582.0)
## |   |   |   [13] tax > 469: 13.088 (n = 73, err = 1347.4)
## 
## Number of inner nodes:    6
## Number of terminal nodes: 7
plot(pruned.ctree.fit,
     main="(Pruned) Regression CTree for Median Home Value")

#R-square conditional inference tree
cor(predict(pruned.ctree.fit, newdata=testData),testData$medv)^2
## [1] 0.7030792

In this case the pruned tree provides an easier set of rules, but gives up prediction accuracy (in the hope for better generalization to other data).

5. Conclusion

In this session we walked through some very basics of implmenting regression tree models. Classification trees operate in much the same way, just that the outcome is a nominal variable. While individual trees are not often used in practice much anymore, they provide a foundation for the forthcoming ensemble methods - where many trees are combined together. So, next we take a walk into the forest.

As awlays, thank you for playing!