Tree-based models basically consist of one or more nested if-then statements for the predictors that partition the data. Within these partitions, a specific model is used to predict the outcome. This recursive partitioning technique provides for exploration of the stucture of a set of data (outcome and predictors) and identification of easy to visualize decision rules for predicting a categorical (Classification Tree) or continuous (Regression Tree) outcome. In this tutorial we briefly describe the process of growing, examining, and pruning regression trees.
In this session we cover …
library(MASS) #for the Boston Data
library(psych) #for general functions
library(ggplot2) #for data visualization
# library(devtools)
# devtools::install_github('topepo/caret/pkg/caret') #May need the github version to correct a bug with parallelizing
library(caret) #for training and cross validation (also calls other model libaries)
## Warning: Installed Rcpp (0.12.13) different from Rcpp used to build dplyr (0.12.11).
## Please reinstall dplyr to avoid random crashes or undefined behavior.
library(rpart) #for trees
#library(rattle) # Fancy tree plot This is a difficult library to install (https://gist.github.com/zhiyzuo/a489ffdcc5da87f28f8589a55aa206dd)
library(rpart.plot) # Enhanced tree plots
library(RColorBrewer) # Color selection for fancy tree plot
library(party) # Alternative decision tree algorithm
library(partykit) # Updated party functions
For this example we use data that accompanies the MASS package. No special reason these data were selected, other than they were used in some other examples we were working on. The data can be considered “typical” social science data, with a mix of nominal, count, and continuous variables. Of note, there are no missing data.
#loading the data
data("Boston")
Lets have a quick look at the data file and the descriptives.
#data structure
head(Boston,10)
## crim zn indus chas nox rm age dis rad tax ptratio black
## 1 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90
## 2 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90
## 3 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83
## 4 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63
## 5 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90
## 6 0.02985 0.0 2.18 0 0.458 6.430 58.7 6.0622 3 222 18.7 394.12
## 7 0.08829 12.5 7.87 0 0.524 6.012 66.6 5.5605 5 311 15.2 395.60
## 8 0.14455 12.5 7.87 0 0.524 6.172 96.1 5.9505 5 311 15.2 396.90
## 9 0.21124 12.5 7.87 0 0.524 5.631 100.0 6.0821 5 311 15.2 386.63
## 10 0.17004 12.5 7.87 0 0.524 6.004 85.9 6.5921 5 311 15.2 386.71
## lstat medv
## 1 4.98 24.0
## 2 9.14 21.6
## 3 4.03 34.7
## 4 2.94 33.4
## 5 5.33 36.2
## 6 5.21 28.7
## 7 12.43 22.9
## 8 19.15 27.1
## 9 29.93 16.5
## 10 17.10 18.9
Our outcome of interest is medv
: median value of owner-occupied homes in $1000s.
(Note that there is no id
variable. This is convenient for some tasks.)
Descriptives
#sample descriptives
describe(Boston)
## vars n mean sd median trimmed mad min max range
## crim 1 506 3.61 8.60 0.26 1.68 0.33 0.01 88.98 88.97
## zn 2 506 11.36 23.32 0.00 5.08 0.00 0.00 100.00 100.00
## indus 3 506 11.14 6.86 9.69 10.93 9.37 0.46 27.74 27.28
## chas 4 506 0.07 0.25 0.00 0.00 0.00 0.00 1.00 1.00
## nox 5 506 0.55 0.12 0.54 0.55 0.13 0.38 0.87 0.49
## rm 6 506 6.28 0.70 6.21 6.25 0.51 3.56 8.78 5.22
## age 7 506 68.57 28.15 77.50 71.20 28.98 2.90 100.00 97.10
## dis 8 506 3.80 2.11 3.21 3.54 1.91 1.13 12.13 11.00
## rad 9 506 9.55 8.71 5.00 8.73 2.97 1.00 24.00 23.00
## tax 10 506 408.24 168.54 330.00 400.04 108.23 187.00 711.00 524.00
## ptratio 11 506 18.46 2.16 19.05 18.66 1.70 12.60 22.00 9.40
## black 12 506 356.67 91.29 391.44 383.17 8.09 0.32 396.90 396.58
## lstat 13 506 12.65 7.14 11.36 11.90 7.11 1.73 37.97 36.24
## medv 14 506 22.53 9.20 21.20 21.56 5.93 5.00 50.00 45.00
## skew kurtosis se
## crim 5.19 36.60 0.38
## zn 2.21 3.95 1.04
## indus 0.29 -1.24 0.30
## chas 3.39 9.48 0.01
## nox 0.72 -0.09 0.01
## rm 0.40 1.84 0.03
## age -0.60 -0.98 1.25
## dis 1.01 0.46 0.09
## rad 1.00 -0.88 0.39
## tax 0.67 -1.15 7.49
## ptratio -0.80 -0.30 0.10
## black -2.87 7.10 4.06
## lstat 0.90 0.46 0.32
## medv 1.10 1.45 0.41
#plots
pairs.panels(Boston)
#histogram of outcome
ggplot(data=Boston, aes(x=medv)) +
geom_histogram(binwidth=1, boundary=.5, fill="white", color="black") +
labs(x = "Median Home Value")
#correlation matrix
round(cor(Boston),2)
## crim zn indus chas nox rm age dis rad tax
## crim 1.00 -0.20 0.41 -0.06 0.42 -0.22 0.35 -0.38 0.63 0.58
## zn -0.20 1.00 -0.53 -0.04 -0.52 0.31 -0.57 0.66 -0.31 -0.31
## indus 0.41 -0.53 1.00 0.06 0.76 -0.39 0.64 -0.71 0.60 0.72
## chas -0.06 -0.04 0.06 1.00 0.09 0.09 0.09 -0.10 -0.01 -0.04
## nox 0.42 -0.52 0.76 0.09 1.00 -0.30 0.73 -0.77 0.61 0.67
## rm -0.22 0.31 -0.39 0.09 -0.30 1.00 -0.24 0.21 -0.21 -0.29
## age 0.35 -0.57 0.64 0.09 0.73 -0.24 1.00 -0.75 0.46 0.51
## dis -0.38 0.66 -0.71 -0.10 -0.77 0.21 -0.75 1.00 -0.49 -0.53
## rad 0.63 -0.31 0.60 -0.01 0.61 -0.21 0.46 -0.49 1.00 0.91
## tax 0.58 -0.31 0.72 -0.04 0.67 -0.29 0.51 -0.53 0.91 1.00
## ptratio 0.29 -0.39 0.38 -0.12 0.19 -0.36 0.26 -0.23 0.46 0.46
## black -0.39 0.18 -0.36 0.05 -0.38 0.13 -0.27 0.29 -0.44 -0.44
## lstat 0.46 -0.41 0.60 -0.05 0.59 -0.61 0.60 -0.50 0.49 0.54
## medv -0.39 0.36 -0.48 0.18 -0.43 0.70 -0.38 0.25 -0.38 -0.47
## ptratio black lstat medv
## crim 0.29 -0.39 0.46 -0.39
## zn -0.39 0.18 -0.41 0.36
## indus 0.38 -0.36 0.60 -0.48
## chas -0.12 0.05 -0.05 0.18
## nox 0.19 -0.38 0.59 -0.43
## rm -0.36 0.13 -0.61 0.70
## age 0.26 -0.27 0.60 -0.38
## dis -0.23 0.29 -0.50 0.25
## rad 0.46 -0.44 0.49 -0.38
## tax 0.46 -0.44 0.54 -0.47
## ptratio 1.00 -0.18 0.37 -0.51
## black -0.18 1.00 -0.37 0.33
## lstat 0.37 -0.37 1.00 -0.74
## medv -0.51 0.33 -0.74 1.00
For independent comparison of model predictions, we partition the data into a Training Set and an independent Test Set
#Setting the random seed for replication
set.seed(1234)
#renaming data set
dat <- Boston
#Spliting training set into two parts based on outcome: 75% and 25%
index <- sample(1:nrow(dat), size=0.75*nrow(dat))
trainData <- dat[index,]
testData <- dat[-index,]
# #Using caret package function
index <- createDataPartition(dat$medv, times=1, p=0.75, list=FALSE)
trainData <- dat[index,]
testData <- dat[-index,]
There are some nuanced distinctions between indexes created using the base sample()
function and the caret package’s createDataPartition()
function. From the documentation for caret … For bootstrap samples, simple random sampling is used. For other data splitting, the random sampling is done within the levels of y when y is a factor in an attempt to balance the class distributions within the splits. For numeric y, the sample is split into groups sections based on percentiles and sampling is done within these subgroups. For createDataPartition, the number of percentiles is set via the groups argument. Also, for createDataPartition, very small class sizes (<= 3) the classes may not show up in both the training and test data.
Here, we proceed with the createDataPartition()
version.
For “baseline”, lets run a regression, prediciting medv
from all other variables. (This is also a classification model)
#Running exploratory linear regression
lm.fit <- lm(medv ~., data=trainData)
summary(lm.fit)
##
## Call:
## lm(formula = medv ~ ., data = trainData)
##
## Residuals:
## Min 1Q Median 3Q Max
## -14.9608 -2.7813 -0.5848 1.5981 26.4313
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 39.503683 6.162763 6.410 4.47e-10 ***
## crim -0.108655 0.039259 -2.768 0.00593 **
## zn 0.045468 0.015854 2.868 0.00437 **
## indus 0.048381 0.070074 0.690 0.49036
## chas 3.233716 1.004598 3.219 0.00140 **
## nox -19.644558 4.461747 -4.403 1.40e-05 ***
## rm 3.573069 0.514584 6.944 1.74e-11 ***
## age -0.000203 0.015225 -0.013 0.98937
## dis -1.457415 0.234219 -6.222 1.34e-09 ***
## rad 0.310900 0.077180 4.028 6.83e-05 ***
## tax -0.012513 0.004223 -2.963 0.00324 **
## ptratio -1.006436 0.158528 -6.349 6.42e-10 ***
## black 0.008316 0.003256 2.554 0.01105 *
## lstat -0.490541 0.058686 -8.359 1.33e-15 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 4.825 on 367 degrees of freedom
## Multiple R-squared: 0.7214, Adjusted R-squared: 0.7115
## F-statistic: 73.09 on 13 and 367 DF, p-value: < 2.2e-16
Fit of the regression is pretty good. \(R^2 = 0.72\)
Unfortunately, there do not seem to be any really good ways for visuaizing these models (besides when there are only two predictors we can obtain a prediction plane in 3-d space).
However, we would like to assess on the Test Data. We look at the squared correlation between predicted scores and actual scores in the Test Data.
cor(predict(lm.fit, newdata=testData),testData$medv)^2
## [1] 0.7946235
Also pretty good!
Traditional Classification and Regression Trees (as described by Brieman, Freidman, Olshen, and Stone) can be generated through the rpart
package. In the terminology of tree models, the data are recursively split into terminal nodes or leaves of the tree. To obtain a prediction for a new sample, we would follow the if-then statements defined by the tree using values of the new sample’s predictors until reaching a terminal node. The model formula in the terminal node would then be used to generate the prediction. In simple (traditional) trees, the model is a simple numeric value (yes/no, or a given numeric value). In other cases, the terminal node may be defined by a more complex function of the predictors (terminal nodes have models within them).
Tree-based and rule-based models are popular modeling tools for a number of reasons. (1) They generate a set of conditions that are highly interpretable and are easy to implement. (2) They can effectively handle many types of predictors (sparse, skewed, continuous, categorical, etc.) without the need for pre-processing. (3) These models do not require the user to specify the form of the predictors’ relationship to the response (e.g., linear, quadratic). (4) these models can (in some forms) effectively handle missing data and implicitly conduct feature selection. They have been extremely useful in many scenarios.
Basic implementation is done by Growing, Examining, Pruning - as illustrated below.
To grow a traditional tree, we can use the rpart()
function in the rpart
package.
tree.fit <- rpart(formula, data=, method=,control=) where
+formula
is in the format outcome ~ predictor1+predictor2+predictor3+etc.
+data=
specifies the data frame +method=
“class” for a classification tree; “anova” for a regression tree +control=
optional parameters for controlling tree growth. For example, control=rpart.control(minsplit=30,cp=0.001) requires that the minimum number of observations in a node be 30 before attempting a split and that a split must decrease the overall lack of fit by a factor of 0.001 (cost complexity factor) before being attempted.
rtree.fit <- rpart(medv ~ .,
data=trainData,
method="anova", #for regression tree
control=rpart.control(minsplit=30,cp=0.001))
A collection of functions help us evaluate and examine the model.
+printcp(tree.fit)
displays table of fits across cp (complexity parameter) values +rsq.rpart(tree.fit)
plots approximate R-squared and relative error for different splits (2 plots). Labels are only appropriate for the “anova” method. +plotcp(tree.fit)
plots the cross-validation results across cp values +print(tree.fit)
print results +summary(tree.fit)
detailed results including surrogate splits +plot(tree.fit)
plot decision tree +text(tree.fit)
label the decision tree plot +post(tree.fit, file=)
create postscript plot of decision tree (there may be better ways to get good looking tree plots)
First we look at what the error looks like across the range of complexity parameters (depth of tree)
printcp(rtree.fit) # display the results
##
## Regression tree:
## rpart(formula = medv ~ ., data = trainData, method = "anova",
## control = rpart.control(minsplit = 30, cp = 0.001))
##
## Variables actually used in tree construction:
## [1] crim dis indus lstat nox rad rm
##
## Root node error: 30668/381 = 80.493
##
## n= 381
##
## CP nsplit rel error xerror xstd
## 1 0.4406452 0 1.00000 1.00597 0.098945
## 2 0.1588911 1 0.55935 0.67973 0.069755
## 3 0.0835210 2 0.40046 0.50136 0.060344
## 4 0.0495001 3 0.31694 0.45661 0.055934
## 5 0.0263705 4 0.26744 0.39656 0.050995
## 6 0.0143825 5 0.24107 0.35588 0.050372
## 7 0.0094562 6 0.22669 0.32995 0.050420
## 8 0.0088834 7 0.21723 0.32011 0.050071
## 9 0.0080952 8 0.20835 0.31467 0.049982
## 10 0.0073540 9 0.20025 0.31412 0.049975
## 11 0.0068629 10 0.19290 0.31171 0.049984
## 12 0.0052839 11 0.18604 0.31145 0.049931
## 13 0.0031558 12 0.18075 0.30537 0.048024
## 14 0.0029893 13 0.17760 0.30137 0.047905
## 15 0.0028718 14 0.17461 0.30254 0.047908
## 16 0.0021683 15 0.17174 0.30474 0.048736
## 17 0.0016971 16 0.16957 0.30548 0.048754
## 18 0.0012029 17 0.16787 0.30547 0.047800
## 19 0.0010000 18 0.16667 0.30646 0.047791
rsq.rpart(rtree.fit) #produces 2 plots
##
## Regression tree:
## rpart(formula = medv ~ ., data = trainData, method = "anova",
## control = rpart.control(minsplit = 30, cp = 0.001))
##
## Variables actually used in tree construction:
## [1] crim dis indus lstat nox rad rm
##
## Root node error: 30668/381 = 80.493
##
## n= 381
##
## CP nsplit rel error xerror xstd
## 1 0.4406452 0 1.00000 1.00597 0.098945
## 2 0.1588911 1 0.55935 0.67973 0.069755
## 3 0.0835210 2 0.40046 0.50136 0.060344
## 4 0.0495001 3 0.31694 0.45661 0.055934
## 5 0.0263705 4 0.26744 0.39656 0.050995
## 6 0.0143825 5 0.24107 0.35588 0.050372
## 7 0.0094562 6 0.22669 0.32995 0.050420
## 8 0.0088834 7 0.21723 0.32011 0.050071
## 9 0.0080952 8 0.20835 0.31467 0.049982
## 10 0.0073540 9 0.20025 0.31412 0.049975
## 11 0.0068629 10 0.19290 0.31171 0.049984
## 12 0.0052839 11 0.18604 0.31145 0.049931
## 13 0.0031558 12 0.18075 0.30537 0.048024
## 14 0.0029893 13 0.17760 0.30137 0.047905
## 15 0.0028718 14 0.17461 0.30254 0.047908
## 16 0.0021683 15 0.17174 0.30474 0.048736
## 17 0.0016971 16 0.16957 0.30548 0.048754
## 18 0.0012029 17 0.16787 0.30547 0.047800
## 19 0.0010000 18 0.16667 0.30646 0.047791
plotcp(rtree.fit) # visualize cross-validation results
#A good choice of cp for pruning is often the leftmost value for which the mean lies below the horizontal line
The detailed sumamry of the tree.
summary(rtree.fit) # detailed summary of splits
## Call:
## rpart(formula = medv ~ ., data = trainData, method = "anova",
## control = rpart.control(minsplit = 30, cp = 0.001))
## n= 381
##
## CP nsplit rel error xerror xstd
## 1 0.440645242 0 1.0000000 1.0059747 0.09894494
## 2 0.158891093 1 0.5593548 0.6797253 0.06975529
## 3 0.083520969 2 0.4004637 0.5013565 0.06034401
## 4 0.049500097 3 0.3169427 0.4566145 0.05593426
## 5 0.026370500 4 0.2674426 0.3965637 0.05099456
## 6 0.014382462 5 0.2410721 0.3558771 0.05037202
## 7 0.009456155 6 0.2266896 0.3299511 0.05042039
## 8 0.008883415 7 0.2172335 0.3201078 0.05007067
## 9 0.008095237 8 0.2083501 0.3146668 0.04998226
## 10 0.007353989 9 0.2002548 0.3141157 0.04997481
## 11 0.006862936 10 0.1929008 0.3117138 0.04998378
## 12 0.005283919 11 0.1860379 0.3114460 0.04993099
## 13 0.003155807 12 0.1807540 0.3053746 0.04802427
## 14 0.002989275 13 0.1775982 0.3013679 0.04790516
## 15 0.002871761 14 0.1746089 0.3025378 0.04790840
## 16 0.002168256 15 0.1717371 0.3047402 0.04873559
## 17 0.001697102 16 0.1695689 0.3054756 0.04875382
## 18 0.001202905 17 0.1678718 0.3054687 0.04780020
## 19 0.001000000 18 0.1666689 0.3064591 0.04779130
##
## Variable importance
## lstat rm indus crim age zn nox dis tax
## 29 21 13 10 10 8 4 3 1
## rad ptratio
## 1 1
##
## Node number 1: 381 observations, complexity param=0.4406452
## mean=22.37323, MSE=80.49277
## left son=2 (247 obs) right son=3 (134 obs)
## Primary splits:
## lstat < 8.935 to the right, improve=0.4406452, (0 missing)
## rm < 6.9715 to the left, improve=0.4258611, (0 missing)
## indus < 6.66 to the right, improve=0.2473444, (0 missing)
## ptratio < 19.9 to the right, improve=0.2279278, (0 missing)
## nox < 0.6695 to the right, improve=0.2090672, (0 missing)
## Surrogate splits:
## rm < 6.4775 to the left, agree=0.811, adj=0.463, (0 split)
## indus < 7.625 to the right, agree=0.803, adj=0.440, (0 split)
## age < 41.8 to the right, agree=0.777, adj=0.366, (0 split)
## zn < 16.25 to the left, agree=0.772, adj=0.351, (0 split)
## crim < 0.08276 to the right, agree=0.764, adj=0.328, (0 split)
##
## Node number 2: 247 observations, complexity param=0.08352097
## mean=17.98664, MSE=26.29994
## left son=4 (133 obs) right son=5 (114 obs)
## Primary splits:
## lstat < 14.4 to the right, improve=0.3942990, (0 missing)
## dis < 2.0754 to the left, improve=0.3460477, (0 missing)
## crim < 5.84803 to the right, improve=0.3416487, (0 missing)
## nox < 0.6635 to the right, improve=0.3177246, (0 missing)
## age < 88.7 to the right, improve=0.2379447, (0 missing)
## Surrogate splits:
## age < 88.1 to the right, agree=0.781, adj=0.526, (0 split)
## indus < 16.57 to the right, agree=0.733, adj=0.421, (0 split)
## dis < 2.23935 to the left, agree=0.733, adj=0.421, (0 split)
## crim < 0.166705 to the right, agree=0.729, adj=0.412, (0 split)
## nox < 0.5765 to the right, agree=0.725, adj=0.404, (0 split)
##
## Node number 3: 134 observations, complexity param=0.1588911
## mean=30.45896, MSE=79.53779
## left son=6 (115 obs) right son=7 (19 obs)
## Primary splits:
## rm < 7.4545 to the left, improve=0.4571967, (0 missing)
## lstat < 4.15 to the right, improve=0.3785168, (0 missing)
## age < 89.35 to the left, improve=0.1937996, (0 missing)
## nox < 0.574 to the left, improve=0.1835016, (0 missing)
## ptratio < 15 to the right, improve=0.1794687, (0 missing)
## Surrogate splits:
## lstat < 3.21 to the right, agree=0.896, adj=0.263, (0 split)
##
## Node number 4: 133 observations, complexity param=0.0263705
## mean=15.00526, MSE=18.98035
## left son=8 (76 obs) right son=9 (57 obs)
## Primary splits:
## nox < 0.607 to the right, improve=0.3203645, (0 missing)
## dis < 1.92035 to the left, improve=0.3136884, (0 missing)
## crim < 5.7819 to the right, improve=0.3064597, (0 missing)
## lstat < 19.83 to the right, improve=0.2513090, (0 missing)
## tax < 567.5 to the right, improve=0.2357110, (0 missing)
## Surrogate splits:
## tax < 397 to the right, agree=0.857, adj=0.667, (0 split)
## dis < 2.38405 to the left, agree=0.850, adj=0.649, (0 split)
## indus < 16.01 to the right, agree=0.842, adj=0.632, (0 split)
## crim < 1.40092 to the right, agree=0.797, adj=0.526, (0 split)
## rad < 16 to the right, agree=0.729, adj=0.368, (0 split)
##
## Node number 5: 114 observations, complexity param=0.009456155
## mean=21.46491, MSE=12.37105
## left son=10 (103 obs) right son=11 (11 obs)
## Primary splits:
## rm < 6.616 to the left, improve=0.20562930, (0 missing)
## indus < 4.22 to the right, improve=0.12641370, (0 missing)
## lstat < 9.725 to the right, improve=0.10074740, (0 missing)
## ptratio < 17.85 to the right, improve=0.09394059, (0 missing)
## tax < 278 to the right, improve=0.08978631, (0 missing)
## Surrogate splits:
## ptratio < 13.85 to the right, agree=0.921, adj=0.182, (0 split)
##
## Node number 6: 115 observations, complexity param=0.0495001
## mean=28.00783, MSE=42.94976
## left son=12 (66 obs) right son=13 (49 obs)
## Primary splits:
## rm < 6.659 to the left, improve=0.3073472, (0 missing)
## lstat < 5.06 to the right, improve=0.2640314, (0 missing)
## dis < 1.9704 to the right, improve=0.2601843, (0 missing)
## age < 89.45 to the left, improve=0.2487178, (0 missing)
## nox < 0.589 to the left, improve=0.2173094, (0 missing)
## Surrogate splits:
## lstat < 5.06 to the right, agree=0.722, adj=0.347, (0 split)
## indus < 4.01 to the right, agree=0.704, adj=0.306, (0 split)
## zn < 31.5 to the left, agree=0.687, adj=0.265, (0 split)
## ptratio < 15.55 to the right, agree=0.670, adj=0.224, (0 split)
## nox < 0.4045 to the right, agree=0.635, adj=0.143, (0 split)
##
## Node number 7: 19 observations
## mean=45.29474, MSE=44.52681
##
## Node number 8: 76 observations, complexity param=0.01438246
## mean=12.86974, MSE=13.41764
## left son=16 (41 obs) right son=17 (35 obs)
## Primary splits:
## lstat < 19.645 to the right, improve=0.43253920, (0 missing)
## crim < 9.87002 to the right, improve=0.40868340, (0 missing)
## dis < 1.92035 to the left, improve=0.28536200, (0 missing)
## tax < 551.5 to the right, improve=0.14356430, (0 missing)
## nox < 0.7065 to the left, improve=0.09989442, (0 missing)
## Surrogate splits:
## dis < 1.6727 to the left, agree=0.776, adj=0.514, (0 split)
## crim < 9.55467 to the right, agree=0.750, adj=0.457, (0 split)
## rm < 5.574 to the left, agree=0.671, adj=0.286, (0 split)
## nox < 0.7065 to the left, agree=0.645, adj=0.229, (0 split)
## age < 97.6 to the right, agree=0.632, adj=0.200, (0 split)
##
## Node number 9: 57 observations, complexity param=0.005283919
## mean=17.85263, MSE=12.20916
## left son=18 (26 obs) right son=19 (31 obs)
## Primary splits:
## crim < 0.55381 to the right, improve=0.23285060, (0 missing)
## ptratio < 19.45 to the right, improve=0.19027740, (0 missing)
## black < 378.085 to the left, improve=0.18937010, (0 missing)
## nox < 0.531 to the right, improve=0.13829190, (0 missing)
## tax < 280.5 to the right, improve=0.09768062, (0 missing)
## Surrogate splits:
## ptratio < 19.95 to the right, agree=0.912, adj=0.808, (0 split)
## nox < 0.531 to the right, agree=0.807, adj=0.577, (0 split)
## rad < 16 to the right, agree=0.807, adj=0.577, (0 split)
## tax < 567.5 to the right, agree=0.807, adj=0.577, (0 split)
## black < 377.48 to the left, agree=0.807, adj=0.577, (0 split)
##
## Node number 10: 103 observations, complexity param=0.002989275
## mean=20.94369, MSE=8.357412
## left son=20 (91 obs) right son=21 (12 obs)
## Primary splits:
## indus < 4.22 to the right, improve=0.10649730, (0 missing)
## rm < 6.0775 to the left, improve=0.09588832, (0 missing)
## tax < 278 to the right, improve=0.07208953, (0 missing)
## ptratio < 18.65 to the right, improve=0.07165208, (0 missing)
## dis < 3.734 to the right, improve=0.06275371, (0 missing)
## Surrogate splits:
## zn < 57.5 to the left, agree=0.913, adj=0.250, (0 split)
## nox < 0.4035 to the right, agree=0.913, adj=0.250, (0 split)
## age < 32.75 to the right, agree=0.903, adj=0.167, (0 split)
## dis < 8.57235 to the left, agree=0.903, adj=0.167, (0 split)
## tax < 208 to the right, agree=0.893, adj=0.083, (0 split)
##
## Node number 11: 11 observations
## mean=26.34545, MSE=23.58975
##
## Node number 12: 66 observations, complexity param=0.008883415
## mean=24.87727, MSE=30.79024
## left son=24 (52 obs) right son=25 (14 obs)
## Primary splits:
## rad < 5.5 to the left, improve=0.1340617, (0 missing)
## lstat < 5.41 to the right, improve=0.1334019, (0 missing)
## crim < 0.39646 to the left, improve=0.1307344, (0 missing)
## black < 376.935 to the right, improve=0.1172771, (0 missing)
## dis < 3.58055 to the right, improve=0.1093455, (0 missing)
## Surrogate splits:
## crim < 2.98347 to the left, agree=0.833, adj=0.214, (0 split)
## tax < 548 to the left, agree=0.833, adj=0.214, (0 split)
## nox < 0.618 to the left, agree=0.818, adj=0.143, (0 split)
## dis < 1.5449 to the right, agree=0.818, adj=0.143, (0 split)
## lstat < 4.04 to the right, agree=0.818, adj=0.143, (0 split)
##
## Node number 13: 49 observations, complexity param=0.008095237
## mean=32.22449, MSE=28.34716
## left son=26 (34 obs) right son=27 (15 obs)
## Primary splits:
## lstat < 4.6 to the right, improve=0.17873350, (0 missing)
## dis < 3.14095 to the right, improve=0.17635500, (0 missing)
## rm < 6.941 to the left, improve=0.15531190, (0 missing)
## crim < 0.159085 to the left, improve=0.14585030, (0 missing)
## indus < 6.305 to the left, improve=0.09883321, (0 missing)
## Surrogate splits:
## indus < 6.305 to the left, agree=0.776, adj=0.267, (0 split)
## tax < 400 to the left, agree=0.776, adj=0.267, (0 split)
## crim < 0.943545 to the left, agree=0.755, adj=0.200, (0 split)
## dis < 1.88595 to the right, agree=0.755, adj=0.200, (0 split)
## zn < 75 to the left, agree=0.735, adj=0.133, (0 split)
##
## Node number 16: 41 observations, complexity param=0.003155807
## mean=10.6439, MSE=8.77856
## left son=32 (26 obs) right son=33 (15 obs)
## Primary splits:
## crim < 9.87002 to the right, improve=0.2688965, (0 missing)
## rad < 14.5 to the right, improve=0.2109716, (0 missing)
## indus < 18.84 to the left, improve=0.2109716, (0 missing)
## nox < 0.729 to the left, improve=0.1837369, (0 missing)
## dis < 1.464 to the right, improve=0.1803603, (0 missing)
## Surrogate splits:
## indus < 18.84 to the left, agree=0.878, adj=0.667, (0 split)
## rad < 14.5 to the right, agree=0.878, adj=0.667, (0 split)
## tax < 551.5 to the right, agree=0.829, adj=0.533, (0 split)
## ptratio < 20.15 to the right, agree=0.805, adj=0.467, (0 split)
## nox < 0.646 to the right, agree=0.780, adj=0.400, (0 split)
##
## Node number 17: 35 observations, complexity param=0.002168256
## mean=15.47714, MSE=6.249763
## left son=34 (17 obs) right son=35 (18 obs)
## Primary splits:
## crim < 5.76921 to the right, improve=0.30399110, (0 missing)
## dis < 1.9467 to the left, improve=0.11615010, (0 missing)
## rm < 6.1405 to the right, improve=0.09596736, (0 missing)
## black < 318.38 to the left, improve=0.09543795, (0 missing)
## nox < 0.675 to the right, improve=0.02644998, (0 missing)
## Surrogate splits:
## indus < 18.84 to the left, agree=0.800, adj=0.588, (0 split)
## rad < 14.5 to the right, agree=0.800, adj=0.588, (0 split)
## rm < 6.1685 to the right, agree=0.771, adj=0.529, (0 split)
## tax < 551.5 to the right, agree=0.771, adj=0.529, (0 split)
## nox < 0.663 to the right, agree=0.743, adj=0.471, (0 split)
##
## Node number 18: 26 observations
## mean=16.01154, MSE=12.15102
##
## Node number 19: 31 observations
## mean=19.39677, MSE=7.030635
##
## Node number 20: 91 observations, complexity param=0.001697102
## mean=20.6011, MSE=5.389779
## left son=40 (52 obs) right son=41 (39 obs)
## Primary splits:
## rm < 6.0775 to the left, improve=0.10611520, (0 missing)
## indus < 10.3 to the left, improve=0.06608626, (0 missing)
## dis < 5.58775 to the right, improve=0.06204415, (0 missing)
## rad < 6.5 to the left, improve=0.05724333, (0 missing)
## ptratio < 20.95 to the right, improve=0.04823231, (0 missing)
## Surrogate splits:
## age < 71.55 to the left, agree=0.659, adj=0.205, (0 split)
## tax < 394.5 to the left, agree=0.626, adj=0.128, (0 split)
## ptratio < 20.55 to the left, agree=0.626, adj=0.128, (0 split)
## crim < 3.36614 to the left, agree=0.615, adj=0.103, (0 split)
## indus < 9.795 to the left, agree=0.615, adj=0.103, (0 split)
##
## Node number 21: 12 observations
## mean=23.54167, MSE=23.22243
##
## Node number 24: 52 observations, complexity param=0.006862936
## mean=23.82308, MSE=11.57331
## left son=48 (13 obs) right son=49 (39 obs)
## Primary splits:
## lstat < 7.62 to the right, improve=0.3497283, (0 missing)
## rm < 6.543 to the left, improve=0.2758680, (0 missing)
## nox < 0.5125 to the right, improve=0.1822342, (0 missing)
## tax < 267.5 to the right, improve=0.1690680, (0 missing)
## ptratio < 19.4 to the right, improve=0.1271483, (0 missing)
## Surrogate splits:
## rm < 6.053 to the left, agree=0.865, adj=0.462, (0 split)
## rad < 1.5 to the left, agree=0.769, adj=0.077, (0 split)
## ptratio < 20.6 to the right, agree=0.769, adj=0.077, (0 split)
##
## Node number 25: 14 observations
## mean=28.79286, MSE=82.70781
##
## Node number 26: 34 observations, complexity param=0.007353989
## mean=30.72941, MSE=15.35031
## left son=52 (22 obs) right son=53 (12 obs)
## Primary splits:
## rm < 7.127 to the left, improve=0.43212440, (0 missing)
## tax < 264.5 to the right, improve=0.24548080, (0 missing)
## indus < 3.19 to the right, improve=0.09197057, (0 missing)
## crim < 0.0572 to the left, improve=0.05274750, (0 missing)
## lstat < 5.495 to the right, improve=0.04800032, (0 missing)
## Surrogate splits:
## indus < 2.21 to the right, agree=0.735, adj=0.250, (0 split)
## tax < 264.5 to the right, agree=0.735, adj=0.250, (0 split)
## crim < 0.301555 to the left, agree=0.706, adj=0.167, (0 split)
## dis < 1.9704 to the right, agree=0.706, adj=0.167, (0 split)
## nox < 0.61 to the left, agree=0.676, adj=0.083, (0 split)
##
## Node number 27: 15 observations
## mean=35.61333, MSE=41.25582
##
## Node number 32: 26 observations
## mean=9.476923, MSE=5.596391
##
## Node number 33: 15 observations
## mean=12.66667, MSE=7.842222
##
## Node number 34: 17 observations
## mean=14.05882, MSE=4.025952
##
## Node number 35: 18 observations
## mean=16.81667, MSE=4.655833
##
## Node number 40: 52 observations, complexity param=0.001202905
## mean=19.94615, MSE=6.364024
## left son=80 (10 obs) right son=81 (42 obs)
## Primary splits:
## dis < 5.58775 to the right, improve=0.11147510, (0 missing)
## indus < 10.3 to the left, improve=0.09947490, (0 missing)
## rad < 7.5 to the left, improve=0.06656262, (0 missing)
## nox < 0.5485 to the left, improve=0.05213346, (0 missing)
## black < 376.835 to the left, improve=0.04388885, (0 missing)
## Surrogate splits:
## zn < 6.25 to the right, agree=0.904, adj=0.5, (0 split)
## indus < 5.16 to the left, agree=0.865, adj=0.3, (0 split)
## nox < 0.445 to the left, agree=0.865, adj=0.3, (0 split)
## crim < 0.0327 to the left, agree=0.846, adj=0.2, (0 split)
##
## Node number 41: 39 observations
## mean=21.47436, MSE=2.756266
##
## Node number 48: 13 observations
## mean=20.33846, MSE=9.045444
##
## Node number 49: 39 observations, complexity param=0.002871761
## mean=24.98462, MSE=7.01925
## left son=98 (22 obs) right son=99 (17 obs)
## Primary splits:
## rm < 6.428 to the left, improve=0.3217176, (0 missing)
## tax < 278 to the right, improve=0.2797394, (0 missing)
## rad < 3.5 to the right, improve=0.2432884, (0 missing)
## lstat < 5.745 to the right, improve=0.2058624, (0 missing)
## indus < 4.1 to the right, improve=0.1917861, (0 missing)
## Surrogate splits:
## lstat < 5.495 to the right, agree=0.795, adj=0.529, (0 split)
## indus < 3.095 to the right, agree=0.718, adj=0.353, (0 split)
## zn < 34 to the left, agree=0.692, adj=0.294, (0 split)
## tax < 280.5 to the right, agree=0.692, adj=0.294, (0 split)
## crim < 0.02819 to the right, agree=0.667, adj=0.235, (0 split)
##
## Node number 52: 22 observations
## mean=28.82727, MSE=11.85198
##
## Node number 53: 12 observations
## mean=34.21667, MSE=2.969722
##
## Node number 80: 10 observations
## mean=18.22, MSE=0.8736
##
## Node number 81: 42 observations
## mean=20.35714, MSE=6.792925
##
## Node number 98: 22 observations
## mean=23.66364, MSE=1.891405
##
## Node number 99: 17 observations
## mean=26.69412, MSE=8.474671
That is a lot of output, but here we can also look at the predictors used in the tree and their relative importance in the prediction. We see specifically that rm
(average number of rooms per dwelling) and lstat
(lower status of the population, percent) are driving much of the prediction.
This particular tree methodology can also handle missing data. When building the tree, missing data are ignored. For each split, a variety of alternatives (called surrogate splits) are evaluated. A surrogate split is one whose results are similar to the original split actually used in the tree. If a surrogate split approximates the original split well, it can be used when the predictor data associated with the original split are not available. In practice, several surrogate splits may be saved for any particular split in the tree.
Plotting the tree.
# plot tree (old schol way)
plot(rtree.fit, uniform=TRUE,
main="Regression Tree for Median Home Value")
text(rtree.fit, use.n=TRUE, all=TRUE, cex=.8)
# create more atrractive plot of tree
#using prp() in the rpart.plot package
prp(rtree.fit)
#using Rattle package
#fancyRpartPlot(rtree.fit)
We see the intuitive value of the tree method in the plot.
Prune back the tree to avoid overfitting the data. Hastie et al. (2008) suggest selecting the tree size associated with the numerically smallest error. That is, the size of the tree is selected by examining the error using cross-validation, specifically the minimum of the xerror column (cross-validation error) printed by printcp( )
.
Pruning is easily done using the function prune(fit, cp= )
by examining the cross-validated error results from printcp()
, selecting the complexity parameter associated with minimum error, and placing it into the prune( )
function. Alternatively, this can be automated using tree.fit$cptable[which.min(tree.fit$cptable[,"xerror"]),"CP"]
.
# prune the tree based on minimim xerror
pruned.rtree.fit<- prune(rtree.fit, cp= rtree.fit$cptable[which.min(rtree.fit$cptable[,"xerror"]),"CP"])
# plot the pruned tree using prp() in the rpart.plot package
prp(pruned.rtree.fit, main="Pruned Regression Tree for Median Home Value")
In this case the pruned tree is not that much smaller than the original tree.
There are, of course other approaches for pruning. Breiman et al. (1984) suggest using the cross-validation approach and applying a one-standard-error rule on the optimization criteria for identifying the simplest tree. That is, find the smallest tree that is within one standard error of the tree with smallest absolute error, which is the leftmost cp value for which the mean lies below the horizontal line placed 1 SE above the minmum of the curve by the minline
in the plotcp()
function.
# prune the tree based on 1 SE error
pruned2.rtree.fit<- prune(rtree.fit, cp=.01)
# plot the pruned tree using prp() in the rpart.plot package
prp(pruned2.rtree.fit, main="Pruned Regression Tree for Median Home Value")
Finally, for comparison with the regression model, we examine the \(R^2\) of the original and pruned trees. (Note: The predictive value of the model would typically be established through cross-validation and test samples. We do the below only for didactic illustration.)
#original tree
cor(predict(rtree.fit, newdata=testData),testData$medv)^2
## [1] 0.8184857
#pruned tree #1
cor(predict(pruned.rtree.fit, newdata=testData),testData$medv)^2
## [1] 0.8110404
#pruned tree #2
cor(predict(pruned2.rtree.fit, newdata=testData),testData$medv)^2
## [1] 0.7846115
We see here the tradeoff between “overfit” to training data and potential generalizability to new data. More formal evlauations would be done using cross-validation. But the smaller pruned tree is still doing pretty well (almost as well as the multiple regression).
Traditional CART-based trees recursively perform univariate splits of the dependent variable based on values on a set of covariates. An information measures (such as the Gini coefficient) is used to select the current covariate. There is, however, a variable selection bias in the algorithms used in the traditional (rpart and related methods) algorithms. These approaches tend to select variables that have many possible splits or many missing values.
To overcome that bias, conditional inference trees were introduced. Unlike the other approaches, Conditional Inference Trees use a significance test procedure to select variables at each split. The significance test, or better: the multiple significance tests computed at each start of the algorithm (select covariate - choose split - recurse) are permutation tests that are used to obtain the the distribution of the test statistic under the null hypothesis (by calculating all possible values of the test statistic under rearrangements of the labels on the observed data points (see wikipedia).
More details can be found here https://stats.stackexchange.com/questions/12140/conditional-inference-trees-vs-traditional-decision-trees, and in the original paper here http://statmath.wu-wien.ac.at/~zeileis/papers/Hothorn+Hornik+Zeileis-2006.pdf.
The steps for implementation are largely the same: Grow, examine (and maybe prune).
To grow a tree using the conditional inference method, we can use the party
(party: A Laboratory for Recursive Partitioning) package or the updated package partykit
(partykit: A Toolkit for Recursive Partytioning). This package provides nonparametric regression trees for nominal, ordinal, numeric, censored, and multivariate responses.
Specifically, regression or classification trees are obtained using the function +ctree(formula, data=, control=)
where +formula
is in the format outcome ~ predictor1+predictor2+predictor3+etc.
+data=
specifies the data frame +control=
optional parameters for controlling tree growth. For example, control=ctree_control(maxdepth=3) requires that the maximum depth of the tree is 3. The default maxdepth = Inf means that no restrictions are applied to tree size.
ctree.fit <- ctree(medv ~ .,
data=trainData,
control=ctree_control(maxdepth=Inf))
A collection of functions help us evaluate and examine the model.
+print(tree.fit)
displays the details of the tree
+plot(tree.fit)
plot decision tree
For our example, this is
print(ctree.fit) # display the results
##
## Model formula:
## medv ~ crim + zn + indus + chas + nox + rm + age + dis + rad +
## tax + ptratio + black + lstat
##
## Fitted party:
## [1] root
## | [2] lstat <= 8.93
## | | [3] rm <= 7.42
## | | | [4] crim <= 1.05393
## | | | | [5] rm <= 6.54
## | | | | | [6] nox <= 0.51
## | | | | | | [7] rm <= 6.121: 21.229 (n = 7, err = 8.1)
## | | | | | | [8] rm > 6.121: 24.166 (n = 29, err = 79.4)
## | | | | | [9] nox > 0.51: 19.956 (n = 9, err = 101.1)
## | | | | [10] rm > 6.54
## | | | | | [11] rm <= 6.957: 28.397 (n = 37, err = 340.7)
## | | | | | [12] rm > 6.957: 33.346 (n = 24, err = 229.8)
## | | | [13] crim > 1.05393: 37.878 (n = 9, err = 1280.5)
## | | [14] rm > 7.42: 45.295 (n = 19, err = 846.0)
## | [15] lstat > 8.93
## | | [16] lstat <= 14.37
## | | | [17] rm <= 6.59: 20.944 (n = 103, err = 860.8)
## | | | [18] rm > 6.59: 26.345 (n = 11, err = 259.5)
## | | [19] lstat > 14.37
## | | | [20] tax <= 469
## | | | | [21] crim <= 0.43571: 18.881 (n = 37, err = 291.6)
## | | | | [22] crim > 0.43571: 14.857 (n = 23, err = 60.7)
## | | | [23] tax > 469
## | | | | [24] dis <= 1.9976: 11.320 (n = 46, err = 732.7)
## | | | | [25] dis > 1.9976: 16.100 (n = 27, err = 225.9)
##
## Number of inner nodes: 12
## Number of terminal nodes: 13
plot(ctree.fit,
main="Regression CTree for Median Home Value")
For comparison with the regression model, we examine the \(R^2\) of the conditional inference tree. (Note: The predictive value of the model would typically be established through cross-validation across many test samples.)
#R-square conditional inference tree
cor(predict(ctree.fit, newdata=testData),testData$medv)^2
## [1] 0.7720055
Although the statistical approach ensures that the right-sized tree is grown without additional (post-)pruning or cross-validation, the depth of the tree here is rather large (6 levels and 13 terminal nodes), which of course makes interpretation more difficult than with less deep trees.
Prune back the tree to avoid overfitting the data. This time we migth simply prune for simplicity of plotting and interpretation. Pruning is done by regrowing with a different control parameter.
# regrow the tree with small depth, maxdepth = 3
pruned.ctree.fit<- ctree(medv ~ .,
data=trainData,
control=ctree_control(maxdepth=3))
#examine pruned tree
print(pruned.ctree.fit) # display the results
##
## Model formula:
## medv ~ crim + zn + indus + chas + nox + rm + age + dis + rad +
## tax + ptratio + black + lstat
##
## Fitted party:
## [1] root
## | [2] lstat <= 8.93
## | | [3] rm <= 7.42
## | | | [4] crim <= 1.05393: 27.170 (n = 106, err = 2707.5)
## | | | [5] crim > 1.05393: 37.878 (n = 9, err = 1280.5)
## | | [6] rm > 7.42: 45.295 (n = 19, err = 846.0)
## | [7] lstat > 8.93
## | | [8] lstat <= 14.37
## | | | [9] rm <= 6.59: 20.944 (n = 103, err = 860.8)
## | | | [10] rm > 6.59: 26.345 (n = 11, err = 259.5)
## | | [11] lstat > 14.37
## | | | [12] tax <= 469: 17.338 (n = 60, err = 582.0)
## | | | [13] tax > 469: 13.088 (n = 73, err = 1347.4)
##
## Number of inner nodes: 6
## Number of terminal nodes: 7
plot(pruned.ctree.fit,
main="(Pruned) Regression CTree for Median Home Value")
#R-square conditional inference tree
cor(predict(pruned.ctree.fit, newdata=testData),testData$medv)^2
## [1] 0.7030792
In this case the pruned tree provides an easier set of rules, but gives up prediction accuracy (in the hope for better generalization to other data).
In this session we walked through some very basics of implmenting regression tree models. Classification trees operate in much the same way, just that the outcome is a nominal variable. While individual trees are not often used in practice much anymore, they provide a foundation for the forthcoming ensemble methods - where many trees are combined together. So, next we take a walk into the forest.
As awlays, thank you for playing!