Things Covered in this Week’s Notes


Classification Trees

A Procedure for Classification Trees

  • Remove any collinearity in the predictors and any extreme outliers.
  • Split the data into training and testing sets.
  • Build a classification tree to classify on the training data set.
  • Evaluate its performance on the training dataset.
  • Modify the tree model, if needed, and re-evaluate its performance on the training dataset.
  • Score the tree model on the testing dataset and evaluate its performance.
  • Interpret the tree model and its resulting prediction performance.

Remove any collinearity in the predictors and any extreme outliers.

We begin by importing the Chicago Food Inspections Data.

rm(list=ls())
library(tidyverse)
## -- Attaching packages --------------------------------------------- tidyverse 1.2.1 --
## v ggplot2 3.2.1     v purrr   0.3.2
## v tibble  2.1.3     v dplyr   0.8.3
## v tidyr   1.0.0     v stringr 1.4.0
## v readr   1.3.1     v forcats 0.4.0
## -- Conflicts ------------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
food <- read_csv("https://uofi.box.com/shared/static/5637axblfhajotail80yw7j2s4r27hxd.csv", 
    col_types = cols(Address = col_skip(), 
        `Census Tracts` = col_skip(), City = col_skip(), 
        `Community Areas` = col_skip(), `Historical Wards 2003-2015` = col_skip(), 
        `Inspection Date` = col_date(format = "%m/%d/%Y"), 
        Location = col_skip(), State = col_skip(), 
        Wards = col_skip(), `Zip Codes` = col_skip()))

dim(food)
## [1] 187787     13
colnames(food) <- tolower(colnames(food))

#pairs(foodi) #if we had numeric variables, but there are none.

food <- arrange(food, desc(`inspection date`)) #sorts data by newest observations first
head(food$`inspection date`);tail(food$`inspection date`)
## [1] "2019-05-31" "2019-05-31" "2019-05-31" "2019-05-31" "2019-05-31"
## [6] "2019-05-31"
## [1] "2010-01-04" "2010-01-04" "2010-01-04" "2010-01-04" "2010-01-04"
## [6] "2010-01-04"
food2 <- distinct(food, `license #`, .keep_all=TRUE) # keeps the distinct (unique) businesses (where only first entries are kept)
dim(food2)
## [1] 36444    13
food2$totalviolations <- str_count(food2$violations, "\\|") +1 # we could create a new variable that counts the number of violations

food3 <- filter(food2, results == "Pass" | results == "Fail" | results == "Pass w/ Conditions", `license #` >= 1)
table(food3$results)
## 
##               Fail               Pass Pass w/ Conditions 
##                688              13328               6868
summary(food3$totalviolations)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max.    NA's 
##    1.00    2.00    3.00    4.26    6.00   30.00    6541
plot(density(food3$totalviolations, na.rm=TRUE))

hist(food3$totalviolations)

length(food3$totalviolations)-sum(is.na(food3$totalviolations))
## [1] 14343
#3 sigma rule
mn<-mean(food3$totalviolations, na.rm=TRUE)
sg<-sd(food3$totalviolations, na.rm=TRUE)
tsr <- which(abs(food3$totalviolations-mn) > 3*sg)
length(tsr)
## [1] 272
food3$totalviolations[tsr]
##   [1] 21 15 16 16 16 20 16 16 18 17 15 16 14 15 15 15 18 18 15 17 15 17 15
##  [24] 14 19 18 17 14 15 20 15 14 20 17 15 16 15 19 14 17 21 23 17 20 14 15
##  [47] 14 15 15 14 15 14 14 14 21 16 15 15 18 18 14 16 17 16 16 18 14 19 20
##  [70] 14 14 14 18 19 14 20 15 14 15 20 14 17 20 17 14 18 17 19 16 15 14 16
##  [93] 14 23 17 17 14 16 14 20 14 14 14 14 15 14 19 19 18 15 14 18 15 18 30
## [116] 15 15 17 17 20 17 20 14 16 16 19 14 14 14 17 17 14 18 14 25 14 18 16
## [139] 16 14 14 16 15 16 14 18 15 15 18 15 14 17 17 14 26 14 17 15 15 16 15
## [162] 15 18 16 17 14 14 14 14 19 18 18 19 18 15 14 17 18 14 15 15 15 14 16
## [185] 16 16 15 19 17 16 14 17 14 14 15 19 15 18 15 19 18 15 18 15 15 14 20
## [208] 17 15 15 16 18 15 18 14 15 15 20 16 14 17 15 14 14 15 20 15 14 14 14
## [231] 17 15 14 15 15 25 21 17 20 14 17 15 17 15 16 14 15 15 17 16 14 17 21
## [254] 15 15 19 20 18 14 20 16 24 14 20 16 14 24 14 20 16 16 14
# box plot rule
q1<-as.vector(quantile(food3$totalviolations,1/4, na.rm=TRUE))
q3<-as.vector(quantile(food3$totalviolations,3/4, na.rm=TRUE))
iqr <- as.vector(q3-q1)
lwr <- which(food3$totalviolations < q1-1.5*iqr)
upr <- which(food3$totalviolations > q3+1.5*iqr)
length(lwr);length(upr)
## [1] 0
## [1] 373
food3$totalviolations[upr]
##   [1] 21 13 15 16 16 16 20 13 16 13 13 13 16 18 17 15 16 14 15 15 15 18 18
##  [24] 13 15 17 13 15 17 15 13 14 19 13 13 18 13 17 14 13 15 20 15 14 20 17
##  [47] 13 15 16 13 15 19 14 13 13 13 13 17 21 23 17 20 14 15 14 13 13 15 15
##  [70] 14 15 14 14 14 13 21 16 13 15 15 18 18 13 14 16 17 16 16 18 13 13 14
##  [93] 19 20 14 14 14 18 19 13 13 13 14 13 20 13 15 14 15 13 20 13 14 17 20
## [116] 17 13 13 14 18 17 19 16 13 13 15 14 16 14 23 17 17 14 13 13 16 14 20
## [139] 13 14 14 14 14 15 14 19 13 19 13 13 18 15 14 13 18 15 18 30 15 15 17
## [162] 17 20 17 20 14 16 16 13 19 14 14 14 17 17 14 18 14 25 13 14 18 13 16
## [185] 16 14 14 16 15 13 16 14 18 15 13 15 18 15 14 13 17 17 14 26 13 14 17
## [208] 13 13 13 15 15 16 15 15 18 13 16 17 13 13 14 14 13 14 13 14 19 18 18
## [231] 19 18 15 13 14 13 13 17 18 14 15 15 15 14 13 13 13 13 16 16 13 16 15
## [254] 13 13 19 13 13 17 16 14 13 17 14 14 15 19 13 15 18 15 13 13 13 19 18
## [277] 15 18 15 15 14 13 13 20 13 17 15 15 16 13 18 13 15 18 13 14 15 13 13
## [300] 15 20 13 13 16 14 13 13 17 13 15 14 14 15 20 15 13 14 14 14 17 13 15
## [323] 14 13 13 13 15 15 25 13 21 13 13 17 20 14 17 15 17 15 16 14 15 13 15
## [346] 17 16 14 17 21 15 15 19 20 18 14 20 16 24 14 20 16 14 24 14 20 16 13
## [369] 16 14 13 13 13
# hampel identifier
md<-median(food3$totalviolations, na.rm=TRUE)
sg2<-1.4826*(median(abs(food3$totalviolations-md), na.rm=TRUE))
hi <- which(abs(food3$totalviolations-md) > 3*sg2)
length(hi)
## [1] 547
food3$totalviolations[hi]
##   [1] 21 12 13 15 16 16 16 20 13 16 12 13 13 13 16 18 17 12 15 12 16 14 12
##  [24] 15 15 15 18 12 12 12 18 13 15 17 13 15 17 12 15 13 14 19 13 13 18 13
##  [47] 12 17 14 13 15 20 15 14 20 17 13 15 16 12 12 13 15 19 14 13 13 13 13
##  [70] 12 17 12 21 23 17 20 14 15 14 13 13 15 15 14 15 14 12 14 12 14 13 21
##  [93] 12 16 13 15 12 15 12 18 12 18 13 12 12 14 12 12 12 16 12 17 16 16 18
## [116] 13 13 14 19 20 14 14 14 18 12 19 13 13 13 14 13 20 12 12 13 15 14 12
## [139] 12 15 12 12 13 20 13 14 17 20 17 13 13 12 14 18 12 17 12 19 16 12 13
## [162] 13 15 14 12 16 14 12 23 17 17 12 14 13 13 16 12 14 12 20 12 13 14 12
## [185] 14 14 12 12 14 15 14 19 12 13 12 19 13 13 18 12 15 14 13 12 18 15 18
## [208] 30 12 15 12 15 12 17 12 17 20 17 20 12 14 16 16 13 12 12 12 12 19 14
## [231] 12 14 12 14 17 12 17 14 18 14 25 12 12 13 12 14 18 13 16 16 12 14 14
## [254] 16 12 15 13 16 14 18 15 13 12 15 18 12 12 15 12 14 13 17 17 14 12 12
## [277] 12 26 12 13 12 14 17 13 12 13 13 15 15 16 15 15 18 12 13 16 12 12 12
## [300] 12 17 12 13 13 12 12 12 14 14 13 14 13 12 12 12 12 14 12 19 18 18 19
## [323] 18 15 13 14 13 12 13 12 17 18 14 15 15 12 15 14 12 13 13 13 13 16 12
## [346] 16 13 12 16 15 13 13 12 12 19 12 13 13 17 12 12 16 14 13 17 14 14 15
## [369] 12 19 12 13 12 12 15 18 12 15 13 12 12 12 13 12 12 13 19 18 15 12 18
## [392] 15 15 12 12 12 14 13 13 20 13 12 12 17 15 15 12 16 12 13 12 18 13 15
## [415] 18 12 13 14 15 13 13 12 12 12 12 15 20 12 13 13 16 14 12 13 12 12 13
## [438] 12 17 13 12 15 14 14 15 20 12 15 12 13 14 14 14 12 17 13 12 12 15 12
## [461] 14 13 13 13 12 15 12 12 15 25 12 12 13 12 12 12 12 21 13 13 17 20 12
## [484] 12 12 14 12 17 15 17 15 12 16 12 14 12 15 13 12 15 12 17 16 14 17 12
## [507] 21 15 12 15 19 20 12 12 18 14 20 16 24 12 14 20 16 12 12 12 14 24 14
## [530] 12 20 16 12 12 12 13 12 12 12 16 14 12 13 12 12 13 13
#removing missing values for simplification and clarity
sum(is.na(food3$results)==TRUE)
## [1] 0
sum(is.na(food3$risk)==TRUE)
## [1] 8
sum(is.na(food3$totalviolations)==TRUE)
## [1] 6541
miss <- c(which((is.na(food3$risk)==TRUE)),which(is.na(food3$totalviolations)==TRUE))

length(miss)
## [1] 6549
food4 <- select(food3[-miss,], results, risk, totalviolations)
sum(is.na(food4))
## [1] 0
set.seed(448)
rs <- sample(nrow(food4),800)
food5 <- food4[rs,]

rm(food,food2,food3,food4)

Partition the data into training and testing sets.

# R code (borrowed from Dalpiaz's textbook)

# partitioning the data - 75% training, 25% testing
set.seed(448)
ids<-sample(nrow(food5),floor(0.75*nrow(food5)))
trainingData <- food5[ids,]
testingData <- food5[-ids,]

colnames(food5)
## [1] "results"         "risk"            "totalviolations"
#assuming the response is results of the inspection
trainingData_response <- trainingData$results
testingData_response <- testingData$results

trainingData_predictors <- trainingData[,-1]
testingData_predictors <- testingData[,-1]

Build a classification tree to classify on the training data set.

Classification Algorithm (described by Wei-Yin Loh http://pages.stat.wisc.edu/~loh/treeprogs/guide/wires11.pdf)

  1. Start at the root node. The tree typically begins with all observations in one node, then split into new generations as either leaves or internal nodes.

  2. For each \(X\), find the set \(S\) that minimizes the sum of the node impurities in the two child nodes and choose the split \(\{X^* \in S^*\}\) that gives the minimum overall \(X\) and \(S\).

  3. If a stopping criterion is reached, exit. Otherwise, apply step 2 to each child node in turn.

  • node impurity: when observations are classified, each resulting class should have observations that make sense for that class; when this doesn’t happen, there is node impurity. We want to have low node impurity; high node purity.

  • stopping criterion: to simplify the tree structures, we often specify a maximum tree size (controlling the number of terminal nodes i.e. leaves) or setting a value for the cost-complexity of the tree

# R code (borrowed from James Le at DataCamp https://www.datacamp.com/community/tutorials/decision-trees-R)

library(tree)
## Registered S3 method overwritten by 'tree':
##   method     from
##   print.tree cli
response <- factor(trainingData_response)
fittree <- tree(response ~ risk + totalviolations , data = trainingData)
## Warning in tree(response ~ risk + totalviolations, data = trainingData):
## NAs introduced by coercion
summary(fittree)
## 
## Classification tree:
## tree(formula = response ~ risk + totalviolations, data = trainingData)
## Variables actually used in tree construction:
## [1] "totalviolations"
## Number of terminal nodes:  4 
## Residual mean deviance:  1.262 = 752.1 / 596 
## Misclassification error rate: 0.2867 = 172 / 600
plot(fittree)
text(fittree, pretty=0)

#cvfittree <- cv.tree(fittree, FUN = prune.misclass) #cross-validated
#plot(cvfittree)
#prune.misclass(fittree, best=SIZE) # based on some value of complexity parameter or misclassification error

Evaluate its performance on the training dataset.

# with the satisfactory tree now, let's score it on the trainingData
trainpredz <- predict(fittree, trainingData, type="class")
## Warning in pred1.tree(object, tree.matrix(newdata)): NAs introduced by
## coercion
table(trainpredz,trainingData_response) #confusion matrix
##                     trainingData_response
## trainpredz           Fail Pass Pass w/ Conditions
##   Fail                  0    0                  0
##   Pass                  7  238                121
##   Pass w/ Conditions    5   39                190
mean(trainpredz==trainingData_response) #classification rate
## [1] 0.7133333
mean(trainpredz!=trainingData_response) #misclassification rate
## [1] 0.2866667
table(trainingData_response)/sum(table(trainingData_response)) #comparing to testing prior proportions
## trainingData_response
##               Fail               Pass Pass w/ Conditions 
##          0.0200000          0.4616667          0.5183333

Score the tree model on the testing dataset and evaluate its performance.

# with the satisfactory tree now, let's score it on the testingData
testpredz <- predict(fittree, testingData, type="class")
## Warning in pred1.tree(object, tree.matrix(newdata)): NAs introduced by
## coercion
table(testpredz,testingData_response) #confusion matrix
##                     testingData_response
## testpredz            Fail Pass Pass w/ Conditions
##   Fail                  0    0                  0
##   Pass                  1   83                 44
##   Pass w/ Conditions    1   18                 53
mean(testpredz==testingData_response) #classification rate
## [1] 0.68
mean(testpredz!=testingData_response) #misclassification rate
## [1] 0.32
table(testingData_response)/sum(table(testingData_response)) #comparing to testing prior proportions
## testingData_response
##               Fail               Pass Pass w/ Conditions 
##              0.010              0.505              0.485

Interpret the tree model and its resulting prediction performance.

  • For data analysis though, it really depends on the questions you want to answer. In other words, we should consider how anything we do is helping the audience understand the data better and what kind of insights can be gleaned from our investigations. ?roc

Regression Trees

For the data examples, we will use only the second part out of 20 total parts of the 2015 US Natality Data. To do a valid data analysis, we would need to import all 20 parts, combine them into a single data frame, and subset from that entire dataset.

Remove any collinearity in the predictors and any extreme outliers.

rm(list = ls())
library(tidyverse)
birth1 <- read_csv("https://uofi.box.com/shared/static/kr2s5wp3jpxdlpcjmq2qs15oo4797u4y.csv")
## Parsed with column specification:
## cols(
##   .default = col_double(),
##   MAGE_IMPFLG = col_logical(),
##   MAGE_REPFLG = col_logical(),
##   MAR_P = col_character(),
##   MAR_IMP = col_logical(),
##   FAGERPT_FLG = col_logical(),
##   WIC = col_character(),
##   CIG_REC = col_character(),
##   RF_PDIAB = col_character(),
##   RF_GDIAB = col_character(),
##   RF_PHYPE = col_character(),
##   RF_GHYPE = col_character(),
##   RF_EHYPE = col_character(),
##   RF_PPTERM = col_character(),
##   RF_INFTR = col_character(),
##   RF_REDRG = col_character(),
##   RF_ARTEC = col_character(),
##   RF_CESAR = col_character(),
##   IP_GON = col_character(),
##   IP_SYPH = col_character(),
##   IP_CHLAM = col_character()
##   # ... with 42 more columns
## )
## See spec(...) for full column specifications.
colnames(birth1) <- tolower(colnames(birth1))
dim(birth1)
## [1] 200000    241
#removing some unknown or NAs based on data key
rs1 <- which(birth1$dbwt==9999); sum(birth1$dbwt==9999)
## [1] 59
rs1 <- c(rs1,which(birth1$apgar5==99)); sum(birth1$apgar5==99)
## [1] 957
rs1 <- c(rs1,which(birth1$precare==99)); sum(birth1$precare==99)
## [1] 3078
rs1 <- c(rs1,which(birth1$cig_0==99)); sum(birth1$cig_0==99)
## [1] 1103
rs1 <- c(rs1,which(birth1$wtgain==99)); sum(birth1$wtgain==99)
## [1] 8789
#mager has no unknowns
rs1 <- c(rs1,which(birth1$dmeth_rec==9)); sum(birth1$dmeth_rec==9)
## [1] 0
rs2 <- unique(rs1)
birth11 <- birth1[-rs2,]

set.seed(448)
rs <- sample(nrow(birth11),667)
birth01 <- birth11[rs,]
rm(birth1,birth11)

#Suppose we know these variables to be collinear: precare, previs, cig_0, cig_1, cig_2, cig_3, pwgt_r, dwgt_r, wtgain, apgar5 , apgar10


colnames(birth01)
##   [1] "dob_yy"       "dob_mm"       "dob_tt"       "dob_wk"      
##   [5] "bfacil"       "f_facility"   "bfacil3"      "mage_impflg" 
##   [9] "mage_repflg"  "mager"        "mager14"      "mager9"      
##  [13] "mbstate_rec"  "restatus"     "mrace31"      "mrace6"      
##  [17] "mrace15"      "mbrace"       "mraceimp"     "mhisp_r"     
##  [21] "f_hisp"       "mracehisp"    "mar_p"        "dmar"        
##  [25] "mar_imp"      "f_mar_p"      "meduc"        "f_meduc"     
##  [29] "fagerpt_flg"  "fagecomb"     "fagerec11"    "frace31"     
##  [33] "frace6"       "frace15"      "fhisp_r"      "f_fhisp"     
##  [37] "fracehisp"    "feduc"        "f_feduc"      "priorlive"   
##  [41] "priordead"    "priorterm"    "lbo_rec"      "tbo_rec"     
##  [45] "illb_r"       "illb_r11"     "ilop_r"       "ilop_r11"    
##  [49] "ilp_r"        "ilp_r11"      "precare"      "f_mcb"       
##  [53] "precare5"     "previs"       "previs_rec"   "f_tpcv"      
##  [57] "wic"          "f_wic"        "cig_0"        "cig_1"       
##  [61] "cig_2"        "cig_3"        "cig0_r"       "cig1_r"      
##  [65] "cig2_r"       "cig3_r"       "f_cigs"       "f_cigs_1"    
##  [69] "f_cigs_2"     "f_cigs_3"     "cig_rec"      "f_tobaco"    
##  [73] "m_ht_in"      "f_m_ht"       "bmi"          "bmi_r"       
##  [77] "pwgt_r"       "f_pwgt"       "dwgt_r"       "f_dwgt"      
##  [81] "wtgain"       "wtgain_rec"   "f_wtgain"     "rf_pdiab"    
##  [85] "rf_gdiab"     "rf_phype"     "rf_ghype"     "rf_ehype"    
##  [89] "rf_ppterm"    "f_rf_pdiab"   "f_rf_gdiab"   "f_rf_phyper" 
##  [93] "f_rf_ghyper"  "f_rf_eclamp"  "f_rf_ppb"     "rf_inftr"    
##  [97] "rf_redrg"     "rf_artec"     "f_rf_inft"    "f_rf_inf_drg"
## [101] "f_rf_inf_art" "rf_cesar"     "rf_cesarn"    "f_rf_cesar"  
## [105] "f_rf_ncesar"  "no_risks"     "ip_gon"       "ip_syph"     
## [109] "ip_chlam"     "ip_hepb"      "ip_hepc"      "f_ip_gonor"  
## [113] "f_ip_syph"    "f_ip_chlam"   "f_ip_hepatb"  "f_ip_hepatc" 
## [117] "no_infec"     "ob_ecvs"      "ob_ecvf"      "f_ob_succ"   
## [121] "f_ob_fail"    "ld_indl"      "ld_augm"      "ld_ster"     
## [125] "ld_antb"      "ld_chor"      "ld_anes"      "f_ld_indl"   
## [129] "f_ld_augm"    "f_ld_ster"    "f_ld_antb"    "f_ld_chor"   
## [133] "f_ld_anes"    "no_lbrdlv"    "me_pres"      "me_rout"     
## [137] "me_trial"     "f_me_pres"    "f_me_rout"    "f_me_trial"  
## [141] "rdmeth_rec"   "dmeth_rec"    "f_dmeth_rec"  "mm_mtr"      
## [145] "mm_plac"      "mm_rupt"      "mm_uhyst"     "mm_aicu"     
## [149] "f_mm_mtr"     "f_mm_plac"    "f_mm_rupt"    "f_mm_uhyst"  
## [153] "f_mm_aicu"    "no_mmorb"     "attend"       "mtran"       
## [157] "pay"          "pay_rec"      "f_pay"        "f_pay_rec"   
## [161] "apgar5"       "apgar5r"      "f_apgar5"     "apgar10"     
## [165] "apgar10r"     "dplural"      "imp_plur"     "setorder_r"  
## [169] "sex"          "imp_sex"      "dlmp_mm"      "dlmp_yy"     
## [173] "compgst_imp"  "obgest_flg"   "combgest"     "gestrec10"   
## [177] "gestrec3"     "lmpused"      "oegest_comb"  "oegest_r10"  
## [181] "oegest_r3"    "dbwt"         "bwtr12"       "bwtr4"       
## [185] "ab_aven1"     "ab_aven6"     "ab_nicu"      "ab_surf"     
## [189] "ab_anti"      "ab_seiz"      "f_ab_vent"    "f_ab_vent6"  
## [193] "f_ab_niuc"    "f_ab_surfac"  "f_ab_antibio" "f_ab_seiz"   
## [197] "no_abnorm"    "ca_anen"      "ca_mnsb"      "ca_cchd"     
## [201] "ca_cdh"       "ca_omph"      "ca_gast"      "f_ca_anen"   
## [205] "f_ca_menin"   "f_ca_heart"   "f_ca_hernia"  "f_ca_ompha"  
## [209] "f_ca_gastro"  "ca_limb"      "ca_cleft"     "ca_clpal"    
## [213] "ca_down"      "ca_disor"     "ca_hypo"      "f_ca_limb"   
## [217] "f_ca_cleftlp" "f_ca_cleft"   "f_ca_downs"   "f_ca_chrom"  
## [221] "f_ca_hypos"   "no_congen"    "itran"        "ilive"       
## [225] "bfed"         "f_bfed"       "ubfacil"      "urf_diab"    
## [229] "urf_chyper"   "urf_phyper"   "urf_eclam"    "ume_forcp"   
## [233] "ume_vac"      "uop_induc"    "uld_breech"   "uca_anen"    
## [237] "uca_spina"    "uca_ompha"    "uca_celftlp"  "uca_hernia"  
## [241] "uca_downs"

Partition the data into training and testing sets.

# partitioning the data - 70% training, 30% testing
set.seed(448)
ids<-sample(nrow(birth01),floor(0.70*nrow(birth01)))

#Suppose we use the following variables as predictors: apgar5, precare, cig_0, wtgain, mager9, meduc
birth01$cigb4preg <- ifelse(birth01$cig0_r==0,0,1)

birth02 <- select(mutate(birth01, newprecare=factor(precare), newcigb4preg=factor(cigb4preg), newmager9=factor(mager9), newdmeth_rec=factor(dmeth_rec)), dbwt, apgar5, wtgain, newprecare, newcigb4preg, newmager9, newdmeth_rec)
colnames(birth02)
## [1] "dbwt"         "apgar5"       "wtgain"       "newprecare"  
## [5] "newcigb4preg" "newmager9"    "newdmeth_rec"
trainingData <- birth02[ids,]
testingData <- birth02[-ids,]

#assuming the response is results of the inspection
trainingData_response <- trainingData$dbwt
testingData_response <- testingData$dbwt

trainingData_predictors <- trainingData[,-1]
testingData_predictors <- testingData[,-1]

rm(birth01)

Build a regression tree to predict on the training data set.

# R code (borrowed from James Le at DataCamp https://www.datacamp.com/community/tutorials/decision-trees-R)

library(tree)
response <- trainingData_response
hist(response)

fittree <- tree(response ~ apgar5+ wtgain+ newprecare+ newcigb4preg+ newmager9+ newdmeth_rec, data = trainingData)
summary(fittree)
## 
## Regression tree:
## tree(formula = response ~ apgar5 + wtgain + newprecare + newcigb4preg + 
##     newmager9 + newdmeth_rec, data = trainingData)
## Variables actually used in tree construction:
## [1] "wtgain" "apgar5"
## Number of terminal nodes:  3 
## Residual mean deviance:  271400 = 125600000 / 463 
## Distribution of residuals:
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## -2347.00  -288.00    16.37     0.00   306.50  1402.00
plot(fittree)
text(fittree, pretty=0)

#cvfittree <- cv.tree(fittree, FUN = prune.misclass) #cross-validated
#plot(cvfittree)
#prune.misclass(fittree, best=SIZE) # based on some value of complexity parameter or misclassification error

Evaluate its performance on the training dataset.

# with the satisfactory tree now, let's score it on the trainingData
trainpredz <- predict(fittree, trainingData, type="vector")
res <- trainpredz-trainingData_response #residuals
sqrt(mean( (res)^2))#rmse
## [1] 519.2347
var(trainpredz)/var(trainingData_response) #rsquare
## [1] 0.05493916

Score the tree model on the testing dataset and evaluate its performance.

# R code
# with the satisfactory tree now, let's score it on the testingData
testpredz <- predict(fittree, testingData, type="vector")
rez <- testpredz-testingData_response #residuals
sqrt(mean( (rez)^2))#rmse
## [1] 560.196
var(testpredz)/var(testingData_response) #rsquare
## [1] 0.06227451
# as a rough comparison to linear regression
td<-trainingData[sample(nrow(trainingData),201),]
md <- lm(dbwt ~ apgar5 + wtgain, data=td)
predict.lm(lm(dbwt~ apgar5 + wtgain, data=td), newdata=testingData)
##        1        2        3        4        5        6        7        8 
## 3266.741 3180.064 3125.891 3255.906 3125.891 3018.748 3288.410 3310.079 
##        9       10       11       12       13       14       15       16 
## 3461.764 3125.891 3721.795 3418.425 3331.748 3277.575 2963.371 3320.914 
##       17       18       19       20       21       22       23       24 
## 3505.102 3200.529 3538.811 3353.418 3223.402 3461.764 3006.710 2963.371 
##       25       26       27       28       29       30       31       32 
## 3613.449 3341.379 3255.906 3537.606 3505.102 3147.560 3602.614 3396.756 
##       33       34       35       36       37       38       39       40 
## 3494.268 3429.260 3115.056 3288.410 3407.591 3234.237 3115.056 2963.371 
##       41       42       43       44       45       46       47       48 
## 3104.221 3526.772 3342.583 3450.929 3494.268 3298.040 3310.079 3342.583 
##       49       50       51       52       53       54       55       56 
## 3505.102 3136.725 3277.575 3526.772 3331.748 3103.017 3038.009 3266.741 
##       57       58       59       60       61       62       63       64 
## 3136.725 3416.017 3483.433 3266.741 2963.371 3396.756 3342.583 3375.087 
##       65       66       67       68       69       70       71       72 
## 2963.371 3255.906 3287.206 3385.922 3580.945 3104.221 3223.402 3277.575 
##       73       74       75       76       77       78       79       80 
## 3396.756 3375.087 3288.410 3093.387 3234.237 3155.986 3342.583 3288.410 
##       81       82       83       84       85       86       87       88 
## 3147.560 3515.937 3234.237 3071.717 3189.694 3342.583 2963.371 3266.741 
##       89       90       91       92       93       94       95       96 
## 2985.040 3678.457 3331.748 3234.237 3407.591 3180.064 3548.441 3060.883 
##       97       98       99      100      101      102      103      104 
## 3483.433 3028.379 3450.929 3180.064 3125.891 3396.756 3071.717 3331.748 
##      105      106      107      108      109      110      111      112 
## 3559.276 3158.394 3093.387 3299.245 3505.102 3396.756 3093.387 3288.410 
##      113      114      115      116      117      118      119      120 
## 3190.898 3039.213 3320.914 3450.929 3602.614 3212.568 3158.394 3526.772 
##      121      122      123      124      125      126      127      128 
## 3613.449 3212.568 3104.221 3364.252 3180.064 3082.552 3288.410 2761.122 
##      129      130      131      132      133      134      135      136 
## 3147.560 3212.568 3147.560 3189.694 3125.891 3234.237 3158.394 3071.717 
##      137      138      139      140      141      142      143      144 
## 3201.733 3472.599 3320.914 3201.733 3255.906 3233.033 3310.079 3830.141 
##      145      146      147      148      149      150      151      152 
## 3310.079 3212.568 3450.929 3170.433 3180.064 3342.583 3385.922 3093.387 
##      153      154      155      156      157      158      159      160 
## 3331.748 3255.906 3580.945 3266.741 3190.898 3495.472 3212.568 3255.906 
##      161      162      163      164      165      166      167      168 
## 3223.402 3461.764 3472.599 3180.064 3201.733 3233.033 3407.591 3385.922 
##      169      170      171      172      173      174      175      176 
## 3190.898 3395.552 2935.681 3320.914 3418.425 3299.245 3310.079 3407.591 
##      177      178      179      180      181      182      183      184 
## 3201.733 3299.245 3310.079 3310.079 3331.748 3233.033 3277.575 3396.756 
##      185      186      187      188      189      190      191      192 
## 3364.252 3234.237 3093.387 3245.071 3234.237 3288.410 3397.960 3472.599 
##      193      194      195      196      197      198      199      200 
## 3320.914 3233.033 3396.756 3234.237 3223.402 3613.449 3201.733 3060.883 
##      201 
## 3069.309
modelr::rmse(md, data=testingData)
## [1] 582.8404
rsq::rsq(md,adj=FALSE,data=testingData)
## [1] 0.07416619

Interpret the tree model and its resulting prediction performance.

  • For data analysis though, it really depends on the questions you want to answer. In other words, we should consider how anything we do is helping the audience understand the data better and what kind of insights can be gleaned from our investigations.