Classification Trees
Regression Trees
The response variable \(Y\) is categorical (any number of categories)
Could have one or more predictors \(x_j\) that could be either continuous or categorical
Idea: split the predictors’ values (one predictor at a time) into disjoint sets to make a series of binary decisions about the predicted value of the response
Pro: binary decisions are easy to interpret and explain
Con: the built tree’s predictions may dramatically change with small changes in the input (data)
Earliest usage is in 1972 by Robert Messenger and Lewis Mandell
In R, we may use rpart
function within the rpart
package or the tree
function within the tree
package.
We begin by importing the Chicago Food Inspections Data.
rm(list=ls())
library(tidyverse)
## -- Attaching packages --------------------------------------------- tidyverse 1.2.1 --
## v ggplot2 3.2.1 v purrr 0.3.2
## v tibble 2.1.3 v dplyr 0.8.3
## v tidyr 1.0.0 v stringr 1.4.0
## v readr 1.3.1 v forcats 0.4.0
## -- Conflicts ------------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
food <- read_csv("https://uofi.box.com/shared/static/5637axblfhajotail80yw7j2s4r27hxd.csv",
col_types = cols(Address = col_skip(),
`Census Tracts` = col_skip(), City = col_skip(),
`Community Areas` = col_skip(), `Historical Wards 2003-2015` = col_skip(),
`Inspection Date` = col_date(format = "%m/%d/%Y"),
Location = col_skip(), State = col_skip(),
Wards = col_skip(), `Zip Codes` = col_skip()))
dim(food)
## [1] 187787 13
colnames(food) <- tolower(colnames(food))
#pairs(foodi) #if we had numeric variables, but there are none.
food <- arrange(food, desc(`inspection date`)) #sorts data by newest observations first
head(food$`inspection date`);tail(food$`inspection date`)
## [1] "2019-05-31" "2019-05-31" "2019-05-31" "2019-05-31" "2019-05-31"
## [6] "2019-05-31"
## [1] "2010-01-04" "2010-01-04" "2010-01-04" "2010-01-04" "2010-01-04"
## [6] "2010-01-04"
food2 <- distinct(food, `license #`, .keep_all=TRUE) # keeps the distinct (unique) businesses (where only first entries are kept)
dim(food2)
## [1] 36444 13
food2$totalviolations <- str_count(food2$violations, "\\|") +1 # we could create a new variable that counts the number of violations
food3 <- filter(food2, results == "Pass" | results == "Fail" | results == "Pass w/ Conditions", `license #` >= 1)
table(food3$results)
##
## Fail Pass Pass w/ Conditions
## 688 13328 6868
summary(food3$totalviolations)
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## 1.00 2.00 3.00 4.26 6.00 30.00 6541
plot(density(food3$totalviolations, na.rm=TRUE))
hist(food3$totalviolations)
length(food3$totalviolations)-sum(is.na(food3$totalviolations))
## [1] 14343
#3 sigma rule
mn<-mean(food3$totalviolations, na.rm=TRUE)
sg<-sd(food3$totalviolations, na.rm=TRUE)
tsr <- which(abs(food3$totalviolations-mn) > 3*sg)
length(tsr)
## [1] 272
food3$totalviolations[tsr]
## [1] 21 15 16 16 16 20 16 16 18 17 15 16 14 15 15 15 18 18 15 17 15 17 15
## [24] 14 19 18 17 14 15 20 15 14 20 17 15 16 15 19 14 17 21 23 17 20 14 15
## [47] 14 15 15 14 15 14 14 14 21 16 15 15 18 18 14 16 17 16 16 18 14 19 20
## [70] 14 14 14 18 19 14 20 15 14 15 20 14 17 20 17 14 18 17 19 16 15 14 16
## [93] 14 23 17 17 14 16 14 20 14 14 14 14 15 14 19 19 18 15 14 18 15 18 30
## [116] 15 15 17 17 20 17 20 14 16 16 19 14 14 14 17 17 14 18 14 25 14 18 16
## [139] 16 14 14 16 15 16 14 18 15 15 18 15 14 17 17 14 26 14 17 15 15 16 15
## [162] 15 18 16 17 14 14 14 14 19 18 18 19 18 15 14 17 18 14 15 15 15 14 16
## [185] 16 16 15 19 17 16 14 17 14 14 15 19 15 18 15 19 18 15 18 15 15 14 20
## [208] 17 15 15 16 18 15 18 14 15 15 20 16 14 17 15 14 14 15 20 15 14 14 14
## [231] 17 15 14 15 15 25 21 17 20 14 17 15 17 15 16 14 15 15 17 16 14 17 21
## [254] 15 15 19 20 18 14 20 16 24 14 20 16 14 24 14 20 16 16 14
# box plot rule
q1<-as.vector(quantile(food3$totalviolations,1/4, na.rm=TRUE))
q3<-as.vector(quantile(food3$totalviolations,3/4, na.rm=TRUE))
iqr <- as.vector(q3-q1)
lwr <- which(food3$totalviolations < q1-1.5*iqr)
upr <- which(food3$totalviolations > q3+1.5*iqr)
length(lwr);length(upr)
## [1] 0
## [1] 373
food3$totalviolations[upr]
## [1] 21 13 15 16 16 16 20 13 16 13 13 13 16 18 17 15 16 14 15 15 15 18 18
## [24] 13 15 17 13 15 17 15 13 14 19 13 13 18 13 17 14 13 15 20 15 14 20 17
## [47] 13 15 16 13 15 19 14 13 13 13 13 17 21 23 17 20 14 15 14 13 13 15 15
## [70] 14 15 14 14 14 13 21 16 13 15 15 18 18 13 14 16 17 16 16 18 13 13 14
## [93] 19 20 14 14 14 18 19 13 13 13 14 13 20 13 15 14 15 13 20 13 14 17 20
## [116] 17 13 13 14 18 17 19 16 13 13 15 14 16 14 23 17 17 14 13 13 16 14 20
## [139] 13 14 14 14 14 15 14 19 13 19 13 13 18 15 14 13 18 15 18 30 15 15 17
## [162] 17 20 17 20 14 16 16 13 19 14 14 14 17 17 14 18 14 25 13 14 18 13 16
## [185] 16 14 14 16 15 13 16 14 18 15 13 15 18 15 14 13 17 17 14 26 13 14 17
## [208] 13 13 13 15 15 16 15 15 18 13 16 17 13 13 14 14 13 14 13 14 19 18 18
## [231] 19 18 15 13 14 13 13 17 18 14 15 15 15 14 13 13 13 13 16 16 13 16 15
## [254] 13 13 19 13 13 17 16 14 13 17 14 14 15 19 13 15 18 15 13 13 13 19 18
## [277] 15 18 15 15 14 13 13 20 13 17 15 15 16 13 18 13 15 18 13 14 15 13 13
## [300] 15 20 13 13 16 14 13 13 17 13 15 14 14 15 20 15 13 14 14 14 17 13 15
## [323] 14 13 13 13 15 15 25 13 21 13 13 17 20 14 17 15 17 15 16 14 15 13 15
## [346] 17 16 14 17 21 15 15 19 20 18 14 20 16 24 14 20 16 14 24 14 20 16 13
## [369] 16 14 13 13 13
# hampel identifier
md<-median(food3$totalviolations, na.rm=TRUE)
sg2<-1.4826*(median(abs(food3$totalviolations-md), na.rm=TRUE))
hi <- which(abs(food3$totalviolations-md) > 3*sg2)
length(hi)
## [1] 547
food3$totalviolations[hi]
## [1] 21 12 13 15 16 16 16 20 13 16 12 13 13 13 16 18 17 12 15 12 16 14 12
## [24] 15 15 15 18 12 12 12 18 13 15 17 13 15 17 12 15 13 14 19 13 13 18 13
## [47] 12 17 14 13 15 20 15 14 20 17 13 15 16 12 12 13 15 19 14 13 13 13 13
## [70] 12 17 12 21 23 17 20 14 15 14 13 13 15 15 14 15 14 12 14 12 14 13 21
## [93] 12 16 13 15 12 15 12 18 12 18 13 12 12 14 12 12 12 16 12 17 16 16 18
## [116] 13 13 14 19 20 14 14 14 18 12 19 13 13 13 14 13 20 12 12 13 15 14 12
## [139] 12 15 12 12 13 20 13 14 17 20 17 13 13 12 14 18 12 17 12 19 16 12 13
## [162] 13 15 14 12 16 14 12 23 17 17 12 14 13 13 16 12 14 12 20 12 13 14 12
## [185] 14 14 12 12 14 15 14 19 12 13 12 19 13 13 18 12 15 14 13 12 18 15 18
## [208] 30 12 15 12 15 12 17 12 17 20 17 20 12 14 16 16 13 12 12 12 12 19 14
## [231] 12 14 12 14 17 12 17 14 18 14 25 12 12 13 12 14 18 13 16 16 12 14 14
## [254] 16 12 15 13 16 14 18 15 13 12 15 18 12 12 15 12 14 13 17 17 14 12 12
## [277] 12 26 12 13 12 14 17 13 12 13 13 15 15 16 15 15 18 12 13 16 12 12 12
## [300] 12 17 12 13 13 12 12 12 14 14 13 14 13 12 12 12 12 14 12 19 18 18 19
## [323] 18 15 13 14 13 12 13 12 17 18 14 15 15 12 15 14 12 13 13 13 13 16 12
## [346] 16 13 12 16 15 13 13 12 12 19 12 13 13 17 12 12 16 14 13 17 14 14 15
## [369] 12 19 12 13 12 12 15 18 12 15 13 12 12 12 13 12 12 13 19 18 15 12 18
## [392] 15 15 12 12 12 14 13 13 20 13 12 12 17 15 15 12 16 12 13 12 18 13 15
## [415] 18 12 13 14 15 13 13 12 12 12 12 15 20 12 13 13 16 14 12 13 12 12 13
## [438] 12 17 13 12 15 14 14 15 20 12 15 12 13 14 14 14 12 17 13 12 12 15 12
## [461] 14 13 13 13 12 15 12 12 15 25 12 12 13 12 12 12 12 21 13 13 17 20 12
## [484] 12 12 14 12 17 15 17 15 12 16 12 14 12 15 13 12 15 12 17 16 14 17 12
## [507] 21 15 12 15 19 20 12 12 18 14 20 16 24 12 14 20 16 12 12 12 14 24 14
## [530] 12 20 16 12 12 12 13 12 12 12 16 14 12 13 12 12 13 13
#removing missing values for simplification and clarity
sum(is.na(food3$results)==TRUE)
## [1] 0
sum(is.na(food3$risk)==TRUE)
## [1] 8
sum(is.na(food3$totalviolations)==TRUE)
## [1] 6541
miss <- c(which((is.na(food3$risk)==TRUE)),which(is.na(food3$totalviolations)==TRUE))
length(miss)
## [1] 6549
food4 <- select(food3[-miss,], results, risk, totalviolations)
sum(is.na(food4))
## [1] 0
set.seed(448)
rs <- sample(nrow(food4),800)
food5 <- food4[rs,]
rm(food,food2,food3,food4)
# R code (borrowed from Dalpiaz's textbook)
# partitioning the data - 75% training, 25% testing
set.seed(448)
ids<-sample(nrow(food5),floor(0.75*nrow(food5)))
trainingData <- food5[ids,]
testingData <- food5[-ids,]
colnames(food5)
## [1] "results" "risk" "totalviolations"
#assuming the response is results of the inspection
trainingData_response <- trainingData$results
testingData_response <- testingData$results
trainingData_predictors <- trainingData[,-1]
testingData_predictors <- testingData[,-1]
Classification Algorithm (described by Wei-Yin Loh http://pages.stat.wisc.edu/~loh/treeprogs/guide/wires11.pdf)
Start at the root node. The tree typically begins with all observations in one node, then split into new generations as either leaves or internal nodes.
For each \(X\), find the set \(S\) that minimizes the sum of the node impurities in the two child nodes and choose the split \(\{X^* \in S^*\}\) that gives the minimum overall \(X\) and \(S\).
If a stopping criterion is reached, exit. Otherwise, apply step 2 to each child node in turn.
node impurity: when observations are classified, each resulting class should have observations that make sense for that class; when this doesn’t happen, there is node impurity. We want to have low node impurity; high node purity.
stopping criterion: to simplify the tree structures, we often specify a maximum tree size (controlling the number of terminal nodes i.e. leaves) or setting a value for the cost-complexity of the tree
# R code (borrowed from James Le at DataCamp https://www.datacamp.com/community/tutorials/decision-trees-R)
library(tree)
## Registered S3 method overwritten by 'tree':
## method from
## print.tree cli
response <- factor(trainingData_response)
fittree <- tree(response ~ risk + totalviolations , data = trainingData)
## Warning in tree(response ~ risk + totalviolations, data = trainingData):
## NAs introduced by coercion
summary(fittree)
##
## Classification tree:
## tree(formula = response ~ risk + totalviolations, data = trainingData)
## Variables actually used in tree construction:
## [1] "totalviolations"
## Number of terminal nodes: 4
## Residual mean deviance: 1.262 = 752.1 / 596
## Misclassification error rate: 0.2867 = 172 / 600
plot(fittree)
text(fittree, pretty=0)
#cvfittree <- cv.tree(fittree, FUN = prune.misclass) #cross-validated
#plot(cvfittree)
#prune.misclass(fittree, best=SIZE) # based on some value of complexity parameter or misclassification error
# with the satisfactory tree now, let's score it on the trainingData
trainpredz <- predict(fittree, trainingData, type="class")
## Warning in pred1.tree(object, tree.matrix(newdata)): NAs introduced by
## coercion
table(trainpredz,trainingData_response) #confusion matrix
## trainingData_response
## trainpredz Fail Pass Pass w/ Conditions
## Fail 0 0 0
## Pass 7 238 121
## Pass w/ Conditions 5 39 190
mean(trainpredz==trainingData_response) #classification rate
## [1] 0.7133333
mean(trainpredz!=trainingData_response) #misclassification rate
## [1] 0.2866667
table(trainingData_response)/sum(table(trainingData_response)) #comparing to testing prior proportions
## trainingData_response
## Fail Pass Pass w/ Conditions
## 0.0200000 0.4616667 0.5183333
# with the satisfactory tree now, let's score it on the testingData
testpredz <- predict(fittree, testingData, type="class")
## Warning in pred1.tree(object, tree.matrix(newdata)): NAs introduced by
## coercion
table(testpredz,testingData_response) #confusion matrix
## testingData_response
## testpredz Fail Pass Pass w/ Conditions
## Fail 0 0 0
## Pass 1 83 44
## Pass w/ Conditions 1 18 53
mean(testpredz==testingData_response) #classification rate
## [1] 0.68
mean(testpredz!=testingData_response) #misclassification rate
## [1] 0.32
table(testingData_response)/sum(table(testingData_response)) #comparing to testing prior proportions
## testingData_response
## Fail Pass Pass w/ Conditions
## 0.010 0.505 0.485
For the data examples, we will use only the second part out of 20 total parts of the 2015 US Natality Data. To do a valid data analysis, we would need to import all 20 parts, combine them into a single data frame, and subset from that entire dataset.
rm(list = ls())
library(tidyverse)
birth1 <- read_csv("https://uofi.box.com/shared/static/kr2s5wp3jpxdlpcjmq2qs15oo4797u4y.csv")
## Parsed with column specification:
## cols(
## .default = col_double(),
## MAGE_IMPFLG = col_logical(),
## MAGE_REPFLG = col_logical(),
## MAR_P = col_character(),
## MAR_IMP = col_logical(),
## FAGERPT_FLG = col_logical(),
## WIC = col_character(),
## CIG_REC = col_character(),
## RF_PDIAB = col_character(),
## RF_GDIAB = col_character(),
## RF_PHYPE = col_character(),
## RF_GHYPE = col_character(),
## RF_EHYPE = col_character(),
## RF_PPTERM = col_character(),
## RF_INFTR = col_character(),
## RF_REDRG = col_character(),
## RF_ARTEC = col_character(),
## RF_CESAR = col_character(),
## IP_GON = col_character(),
## IP_SYPH = col_character(),
## IP_CHLAM = col_character()
## # ... with 42 more columns
## )
## See spec(...) for full column specifications.
colnames(birth1) <- tolower(colnames(birth1))
dim(birth1)
## [1] 200000 241
#removing some unknown or NAs based on data key
rs1 <- which(birth1$dbwt==9999); sum(birth1$dbwt==9999)
## [1] 59
rs1 <- c(rs1,which(birth1$apgar5==99)); sum(birth1$apgar5==99)
## [1] 957
rs1 <- c(rs1,which(birth1$precare==99)); sum(birth1$precare==99)
## [1] 3078
rs1 <- c(rs1,which(birth1$cig_0==99)); sum(birth1$cig_0==99)
## [1] 1103
rs1 <- c(rs1,which(birth1$wtgain==99)); sum(birth1$wtgain==99)
## [1] 8789
#mager has no unknowns
rs1 <- c(rs1,which(birth1$dmeth_rec==9)); sum(birth1$dmeth_rec==9)
## [1] 0
rs2 <- unique(rs1)
birth11 <- birth1[-rs2,]
set.seed(448)
rs <- sample(nrow(birth11),667)
birth01 <- birth11[rs,]
rm(birth1,birth11)
#Suppose we know these variables to be collinear: precare, previs, cig_0, cig_1, cig_2, cig_3, pwgt_r, dwgt_r, wtgain, apgar5 , apgar10
colnames(birth01)
## [1] "dob_yy" "dob_mm" "dob_tt" "dob_wk"
## [5] "bfacil" "f_facility" "bfacil3" "mage_impflg"
## [9] "mage_repflg" "mager" "mager14" "mager9"
## [13] "mbstate_rec" "restatus" "mrace31" "mrace6"
## [17] "mrace15" "mbrace" "mraceimp" "mhisp_r"
## [21] "f_hisp" "mracehisp" "mar_p" "dmar"
## [25] "mar_imp" "f_mar_p" "meduc" "f_meduc"
## [29] "fagerpt_flg" "fagecomb" "fagerec11" "frace31"
## [33] "frace6" "frace15" "fhisp_r" "f_fhisp"
## [37] "fracehisp" "feduc" "f_feduc" "priorlive"
## [41] "priordead" "priorterm" "lbo_rec" "tbo_rec"
## [45] "illb_r" "illb_r11" "ilop_r" "ilop_r11"
## [49] "ilp_r" "ilp_r11" "precare" "f_mcb"
## [53] "precare5" "previs" "previs_rec" "f_tpcv"
## [57] "wic" "f_wic" "cig_0" "cig_1"
## [61] "cig_2" "cig_3" "cig0_r" "cig1_r"
## [65] "cig2_r" "cig3_r" "f_cigs" "f_cigs_1"
## [69] "f_cigs_2" "f_cigs_3" "cig_rec" "f_tobaco"
## [73] "m_ht_in" "f_m_ht" "bmi" "bmi_r"
## [77] "pwgt_r" "f_pwgt" "dwgt_r" "f_dwgt"
## [81] "wtgain" "wtgain_rec" "f_wtgain" "rf_pdiab"
## [85] "rf_gdiab" "rf_phype" "rf_ghype" "rf_ehype"
## [89] "rf_ppterm" "f_rf_pdiab" "f_rf_gdiab" "f_rf_phyper"
## [93] "f_rf_ghyper" "f_rf_eclamp" "f_rf_ppb" "rf_inftr"
## [97] "rf_redrg" "rf_artec" "f_rf_inft" "f_rf_inf_drg"
## [101] "f_rf_inf_art" "rf_cesar" "rf_cesarn" "f_rf_cesar"
## [105] "f_rf_ncesar" "no_risks" "ip_gon" "ip_syph"
## [109] "ip_chlam" "ip_hepb" "ip_hepc" "f_ip_gonor"
## [113] "f_ip_syph" "f_ip_chlam" "f_ip_hepatb" "f_ip_hepatc"
## [117] "no_infec" "ob_ecvs" "ob_ecvf" "f_ob_succ"
## [121] "f_ob_fail" "ld_indl" "ld_augm" "ld_ster"
## [125] "ld_antb" "ld_chor" "ld_anes" "f_ld_indl"
## [129] "f_ld_augm" "f_ld_ster" "f_ld_antb" "f_ld_chor"
## [133] "f_ld_anes" "no_lbrdlv" "me_pres" "me_rout"
## [137] "me_trial" "f_me_pres" "f_me_rout" "f_me_trial"
## [141] "rdmeth_rec" "dmeth_rec" "f_dmeth_rec" "mm_mtr"
## [145] "mm_plac" "mm_rupt" "mm_uhyst" "mm_aicu"
## [149] "f_mm_mtr" "f_mm_plac" "f_mm_rupt" "f_mm_uhyst"
## [153] "f_mm_aicu" "no_mmorb" "attend" "mtran"
## [157] "pay" "pay_rec" "f_pay" "f_pay_rec"
## [161] "apgar5" "apgar5r" "f_apgar5" "apgar10"
## [165] "apgar10r" "dplural" "imp_plur" "setorder_r"
## [169] "sex" "imp_sex" "dlmp_mm" "dlmp_yy"
## [173] "compgst_imp" "obgest_flg" "combgest" "gestrec10"
## [177] "gestrec3" "lmpused" "oegest_comb" "oegest_r10"
## [181] "oegest_r3" "dbwt" "bwtr12" "bwtr4"
## [185] "ab_aven1" "ab_aven6" "ab_nicu" "ab_surf"
## [189] "ab_anti" "ab_seiz" "f_ab_vent" "f_ab_vent6"
## [193] "f_ab_niuc" "f_ab_surfac" "f_ab_antibio" "f_ab_seiz"
## [197] "no_abnorm" "ca_anen" "ca_mnsb" "ca_cchd"
## [201] "ca_cdh" "ca_omph" "ca_gast" "f_ca_anen"
## [205] "f_ca_menin" "f_ca_heart" "f_ca_hernia" "f_ca_ompha"
## [209] "f_ca_gastro" "ca_limb" "ca_cleft" "ca_clpal"
## [213] "ca_down" "ca_disor" "ca_hypo" "f_ca_limb"
## [217] "f_ca_cleftlp" "f_ca_cleft" "f_ca_downs" "f_ca_chrom"
## [221] "f_ca_hypos" "no_congen" "itran" "ilive"
## [225] "bfed" "f_bfed" "ubfacil" "urf_diab"
## [229] "urf_chyper" "urf_phyper" "urf_eclam" "ume_forcp"
## [233] "ume_vac" "uop_induc" "uld_breech" "uca_anen"
## [237] "uca_spina" "uca_ompha" "uca_celftlp" "uca_hernia"
## [241] "uca_downs"
# partitioning the data - 70% training, 30% testing
set.seed(448)
ids<-sample(nrow(birth01),floor(0.70*nrow(birth01)))
#Suppose we use the following variables as predictors: apgar5, precare, cig_0, wtgain, mager9, meduc
birth01$cigb4preg <- ifelse(birth01$cig0_r==0,0,1)
birth02 <- select(mutate(birth01, newprecare=factor(precare), newcigb4preg=factor(cigb4preg), newmager9=factor(mager9), newdmeth_rec=factor(dmeth_rec)), dbwt, apgar5, wtgain, newprecare, newcigb4preg, newmager9, newdmeth_rec)
colnames(birth02)
## [1] "dbwt" "apgar5" "wtgain" "newprecare"
## [5] "newcigb4preg" "newmager9" "newdmeth_rec"
trainingData <- birth02[ids,]
testingData <- birth02[-ids,]
#assuming the response is results of the inspection
trainingData_response <- trainingData$dbwt
testingData_response <- testingData$dbwt
trainingData_predictors <- trainingData[,-1]
testingData_predictors <- testingData[,-1]
rm(birth01)
# R code (borrowed from James Le at DataCamp https://www.datacamp.com/community/tutorials/decision-trees-R)
library(tree)
response <- trainingData_response
hist(response)
fittree <- tree(response ~ apgar5+ wtgain+ newprecare+ newcigb4preg+ newmager9+ newdmeth_rec, data = trainingData)
summary(fittree)
##
## Regression tree:
## tree(formula = response ~ apgar5 + wtgain + newprecare + newcigb4preg +
## newmager9 + newdmeth_rec, data = trainingData)
## Variables actually used in tree construction:
## [1] "wtgain" "apgar5"
## Number of terminal nodes: 3
## Residual mean deviance: 271400 = 125600000 / 463
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -2347.00 -288.00 16.37 0.00 306.50 1402.00
plot(fittree)
text(fittree, pretty=0)
#cvfittree <- cv.tree(fittree, FUN = prune.misclass) #cross-validated
#plot(cvfittree)
#prune.misclass(fittree, best=SIZE) # based on some value of complexity parameter or misclassification error
# with the satisfactory tree now, let's score it on the trainingData
trainpredz <- predict(fittree, trainingData, type="vector")
res <- trainpredz-trainingData_response #residuals
sqrt(mean( (res)^2))#rmse
## [1] 519.2347
var(trainpredz)/var(trainingData_response) #rsquare
## [1] 0.05493916
# R code
# with the satisfactory tree now, let's score it on the testingData
testpredz <- predict(fittree, testingData, type="vector")
rez <- testpredz-testingData_response #residuals
sqrt(mean( (rez)^2))#rmse
## [1] 560.196
var(testpredz)/var(testingData_response) #rsquare
## [1] 0.06227451
# as a rough comparison to linear regression
td<-trainingData[sample(nrow(trainingData),201),]
md <- lm(dbwt ~ apgar5 + wtgain, data=td)
predict.lm(lm(dbwt~ apgar5 + wtgain, data=td), newdata=testingData)
## 1 2 3 4 5 6 7 8
## 3266.741 3180.064 3125.891 3255.906 3125.891 3018.748 3288.410 3310.079
## 9 10 11 12 13 14 15 16
## 3461.764 3125.891 3721.795 3418.425 3331.748 3277.575 2963.371 3320.914
## 17 18 19 20 21 22 23 24
## 3505.102 3200.529 3538.811 3353.418 3223.402 3461.764 3006.710 2963.371
## 25 26 27 28 29 30 31 32
## 3613.449 3341.379 3255.906 3537.606 3505.102 3147.560 3602.614 3396.756
## 33 34 35 36 37 38 39 40
## 3494.268 3429.260 3115.056 3288.410 3407.591 3234.237 3115.056 2963.371
## 41 42 43 44 45 46 47 48
## 3104.221 3526.772 3342.583 3450.929 3494.268 3298.040 3310.079 3342.583
## 49 50 51 52 53 54 55 56
## 3505.102 3136.725 3277.575 3526.772 3331.748 3103.017 3038.009 3266.741
## 57 58 59 60 61 62 63 64
## 3136.725 3416.017 3483.433 3266.741 2963.371 3396.756 3342.583 3375.087
## 65 66 67 68 69 70 71 72
## 2963.371 3255.906 3287.206 3385.922 3580.945 3104.221 3223.402 3277.575
## 73 74 75 76 77 78 79 80
## 3396.756 3375.087 3288.410 3093.387 3234.237 3155.986 3342.583 3288.410
## 81 82 83 84 85 86 87 88
## 3147.560 3515.937 3234.237 3071.717 3189.694 3342.583 2963.371 3266.741
## 89 90 91 92 93 94 95 96
## 2985.040 3678.457 3331.748 3234.237 3407.591 3180.064 3548.441 3060.883
## 97 98 99 100 101 102 103 104
## 3483.433 3028.379 3450.929 3180.064 3125.891 3396.756 3071.717 3331.748
## 105 106 107 108 109 110 111 112
## 3559.276 3158.394 3093.387 3299.245 3505.102 3396.756 3093.387 3288.410
## 113 114 115 116 117 118 119 120
## 3190.898 3039.213 3320.914 3450.929 3602.614 3212.568 3158.394 3526.772
## 121 122 123 124 125 126 127 128
## 3613.449 3212.568 3104.221 3364.252 3180.064 3082.552 3288.410 2761.122
## 129 130 131 132 133 134 135 136
## 3147.560 3212.568 3147.560 3189.694 3125.891 3234.237 3158.394 3071.717
## 137 138 139 140 141 142 143 144
## 3201.733 3472.599 3320.914 3201.733 3255.906 3233.033 3310.079 3830.141
## 145 146 147 148 149 150 151 152
## 3310.079 3212.568 3450.929 3170.433 3180.064 3342.583 3385.922 3093.387
## 153 154 155 156 157 158 159 160
## 3331.748 3255.906 3580.945 3266.741 3190.898 3495.472 3212.568 3255.906
## 161 162 163 164 165 166 167 168
## 3223.402 3461.764 3472.599 3180.064 3201.733 3233.033 3407.591 3385.922
## 169 170 171 172 173 174 175 176
## 3190.898 3395.552 2935.681 3320.914 3418.425 3299.245 3310.079 3407.591
## 177 178 179 180 181 182 183 184
## 3201.733 3299.245 3310.079 3310.079 3331.748 3233.033 3277.575 3396.756
## 185 186 187 188 189 190 191 192
## 3364.252 3234.237 3093.387 3245.071 3234.237 3288.410 3397.960 3472.599
## 193 194 195 196 197 198 199 200
## 3320.914 3233.033 3396.756 3234.237 3223.402 3613.449 3201.733 3060.883
## 201
## 3069.309
modelr::rmse(md, data=testingData)
## [1] 582.8404
rsq::rsq(md,adj=FALSE,data=testingData)
## [1] 0.07416619