The world’s leading publication for data science, AI, and ML professionals.

Boosting performance by combining trees with GLM: A benchmarking analysis

How much of an improvement is gained by combining trees with GLM, and how does it compare to additive models?

Photo by Simon Berger on Unsplash
Photo by Simon Berger on Unsplash

A common pitfall of statistical modeling is ensuring the modeling method is appropriate to the structure of the data. linear models like logistic regression assume the existence of a linear relationship between the likelihood of the predicted event, and of the independent variables. While researching the topic recently, I came across this article on StackExchange discussing the use of shallow decision trees as a feature engineering step for logistic regression. Briefly, this strategy attempts to cope with nonlinear data by using trees as a feature engineering step to transform it into dummy variables that can then be used in the logistic regression specification.

After researching this strategy, my impression was that performing this introductory step may introduce the drawback of reducing the flexibility of the model while offering a marginal increase in model results. It may be more useful to either transform the data to remove the nonlinearities, or to account for them in the specification of the model. However, I was surprised to see that there weren’t many articles or posts comparing it to other methods!

To see how decision trees combined with logistic regression (tree+GLM) performs, I’ve tested the method on three data sets and benchmarked the results against standard logistic regression and a generalized additive model (GAM) to see if there is a consistent performance difference between the two methods.

The Tree + GLM Methodology

Logistic regression and decision trees are generally the first two classification models one is introduced to. Each has its own pitfall. Regression models assume the dependent variable can be explained using a set of linear functions applied to the independent variables and have an equation of the form:

Decision trees don’t make an assumption on the distribution of variables, they simply create a new branch in the decision tree based on criteria like Gini impurity. However, decision trees are prone to overfitting and can be unstable to changes in the training data.

The chart below illustrates the results for linear regression and a decision tree models over the range of a dependent variable in a one-variable model:

Image by author
Image by author

From the chart above, we can see that the predictions from the decision tree are discontinuous in contrast to the smooth function generated by the linear regression model. While this will be problematic for simple linear data, the ability of the decision tree strategy to change in a nonlinear fashion provides justification for its use on nonlinear data.

To try to remedy the downsides of these two methods, several sources have suggested using a decision tree as an intermediate step which helps remove potential nonlinearity in the model. In its most simple form, the process is as follows:

  1. Fit a shallow decision tree, T(X) for the training data, X. This tree will have N terminal nodes.
  2. N categorical variables denoted C_n are included as features in the logistic regression specification.
  3. The logistic regression is fit using the modified set of data.

There are two simple alternatives to handling nonlinearity in data. The first choice is to use some other transformation step on the raw data to give it a linear relationship with the dependent variable. This strategy isn’t always an option though, and its appropriateness varies based on the domain. The second way to eliminate it is to change the model specification to a method that can cope with nonlinear data.

Generalized additive models

There are several methodologies that can be used to cope with nonlinear data, but the one I chose for this exercise is the generalized additive model. Generalized additive models are a framework originally introduced in "Generalized Additive Models" by Trevor Hastie and Robert Tibshirani. The authors went on to write Elements of Statistical Learning which is where I first encountered it.

It takes the concept of the generalized linear model that logistic regression is based off, and relaxes the assumption of linear basis functions. In the case of the logistic regression specification, the GAM equivalent for logistic regression would be:

In this case, the notation above substitutes

for

in the original logistic regression equation. This is because we have changed the term operating on each dependent variable to an arbitrary smooth function. Alternatively, we could think of a logistic regression model as an additive model where:

For some regression coefficient

It is clear that this offers us additional flexibility in how we treat the variables that we are modeling, so we will use this specification in our numerical experiment as a challenger to the tree+logistic regression specification.

Data

Three data sets were used for the comparisons performed in this post These were:

  1. Synthetic data set – An artificial data set containing 10,000 observations created to include linear and nonlinear variables.
  2. Banking data – A data set generated by a banking campaign. The data utilized here was actually cleaned and formatted by Andrzej Szymanski, Ph.D. who wrote an article on using logistic regression with decision trees. I wanted to include this as it could provide a valuable comparison to results that had been generated before. The original data is drawn from Andrzej’s GitHub.
  3. Adult data – Census data used to attempt to predict whether income exceeds $50,000/year. It is available at the UCI Machine Learning Repository here.

In the below code chunk, we collect the data that’s going to be used in our three tests and examine it. We begin by generating the synthetic data, and then downloading and formatting the Banking and Adult data sets. After this, we perform some data cleaning steps to make sure that the variables we are using are not represented as characters. Finally, we break the data into training and testing data:

rm(list = ls())
library(pacman)
p_load(data.table,caret,ggplot2,plotly,mgcv,rpart,magrittr,precrec,MLmetrics,partykit,gam,rmarkdown,knitr,broom,rpart.plot,reactable)
set.seed(13)
with.nonlinear <- data.table(twoClassSim(10000,
                                         linearVars = 6,ordinal = F ) )
data.list <- list("Synthetic" = with.nonlinear,
                  "Banking" =  fread("https://raw.githubusercontent.com/AndrzejSzymanski/TDS-LR-DT/master/banking_cleansed.csv")  ,
                  "Adult" = fread("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data")  )
names(data.list[[3]]) <- c("age","workclass","fnlwgt","education","education-num","marital-status","occupation", "relationship", "race","sex","capital-gain","capital-loss","hours-per-week","native-country","Income-class")
# Change variables to factors for banking data-set
banking.names.for.factors <- names(data.list$Banking)[apply(data.list$Banking,MARGIN = 2, function(x){length(unique(x))})==2]
data.list$Banking[,
                  names(data.list$Banking)[apply(data.list$Banking,MARGIN = 2, function(x){
                    length(unique(x))})==2] := lapply(X = .SD,FUN = factor),
                  .SDcols = banking.names.for.factors]
# Change variables to factors for Adult data-set:
adult.names.for.factors <- names(data.list$Adult)[sapply(X = 1:ncol(data.list$Adult),function(x){is.character( data.list$Adult[[x]]  )})]
data.list$Adult[,names(data.list$Adult)[sapply(X = 1:ncol(data.list$Adult),
                                               function(x){is.character( data.list$Adult[[x]]  )})]:= lapply(X = .SD,FUN = factor),
                .SDcols = adult.names.for.factors]
data.list$Adult[,names(data.list$Adult)[sapply(X = 1:ncol(data.list$Adult),
                                               function(x){is.integer( data.list$Adult[[x]]  )})]:= lapply(X = .SD,FUN = as.numeric ),
                .SDcols = names(data.list$Adult)[sapply(X = 1:ncol(data.list$Adult),
                                                        function(x){is.integer( data.list$Adult[[x]]  )})]]
training.data <- list()
test.data <- list()
for( i in 1:length(data.list)){
  train_inds <- sample(x = 1:nrow(data.list[[i]]) ,size = .8*nrow(data.list[[i]]))
  training.data[[i]] <- data.list[[i]][train_inds]
  test.data[[i]] <-   data.list[[i]][-train_inds]  }
names(training.data)<- names(data.list)
names(test.data)<- names(data.list)

Before engaging in our benchmark analysis, it’s useful to understand which variables have a nonlinear relationship with the dependent variable. A simple and intuitive way to do this is to chart the relationship between the dependent and independent variables to see if a nonlinear relationship is obvious from a visual inspection. In the case of the implementation here, I am using a GAM with one variable to describe the general shape of the relationship between the two variables. You can see the implementation below which generates assesses the nonlinearity using charts:

nonlinear.viz <- function(dt,dep.var,indep.var,bins = 100){
  dt$y <- as.numeric(as.character(dt$y))
  return.plot <- ggplot(dt, 
                        aes_string( x = indep.var,y = dep.var) ) + stat_smooth(method = "gam", 
                                                                               method.args = list(family = "binomial"),) + theme_bw() + theme(axis.title = element_text(size = 16),axis.text = element_text(size = 11),plot.title = element_text(size = 20)   ) +ylab(indep.var)+
    xlab(dep.var)+
    ggtitle(paste("Relationship between ",dep.var," and ", indep.var,sep = ""))
  return(return.plot)}
training.data$Synthetic[,Class := ifelse(Class == "Class1", 1,0)]
test.data$Synthetic[,Class := ifelse(Class == "Class1", 1,0)]
training.data$Adult[,"Class" := ifelse(`Income-class` == ">50K", 1,0)]
test.data$Adult[,"Class" := ifelse(`Income-class` == ">50K", 1,0)]
synthetic.plot <- lapply(X = names(training.data$Synthetic)[names(training.data$Synthetic)!="Class"],
                         function(x){nonlinear.viz(dt = training.data$Synthetic,dep.var = "Class",x)})
banking.plot <- lapply(X = c("age","previous","euribor3m","cons_conf_idx","cons_price_idx","nr_employed", "emp_var_rate"   ,   "pdays"),
                       function(x){nonlinear.viz(dt = training.data$Banking,dep.var = "y",x)})
names(training.data$Adult) <- gsub(names(training.data$Adult),pattern = "-",replacement = "_")
names(test.data$Adult) <- names(training.data$Adult)
adult.plot <- lapply(X = c("age","education_num","capital_gain","capital_loss","hours_per_week"),
                     function(x){nonlinear.viz(dt = training.data$Adult,
                                               dep.var = "Class",x )})

The code above will generate charts for all variable relationships. For the sake of brevity, I concentrate only on nonlinear data in the sections below:

Synthetic data-set

The synthetic data set that’s being used is generated from the twoClassSim function from the caret library. This function is straightforward: A data-set is generated with a binary outcome and a set of variables that are either linearly related, or non-linearly related. This data set is useful for our tests because it allows us to compare the algorithms without being concerned with information regarding the domain of the problem. The relationship between variables is summarized in the table below:

Image by author
Image by author

I’ve plotted the nonlinear variables for the synthetic data against the likelihood of the class in the data-set using the below:

Image by author
Image by author

We can see that the variables are what we would expect: In all cases, there is a generally negative nonlinear relationship exists between the likelihood of an instance being a positive class, and the likelihood of a negative class.

Banking data-set

The banking data set contains 27 independent variables. Of these, 7 of these variables are continuous variables with the remaining being binary variables. The feature the relationship to the dependent variable are shown in the table below:

Image by author
Image by author

As with the synthetic data, I’ve also produced a set of charts showing the relationship between the continuous feature variables and the dependent variable (in this case, y):

Image by author
Image by author

We can see that in each of the example charts above, there is a pronounced nonlinear relationship between the dependent and independent variable that would be difficult to compensate for using a linear basis function.

Adult data-set

As with the banking data-set, the adult data-set has many variables which are binary variables along with a set of continuous variables (in this case, five). They are summarized in the table below:

Image by author
Image by author

I’ve also created the chart below which illustrates the relationship between each variable and the dependent variable:

Image by author
Image by author

The relationship between the continuous independent variables and the dependent variable similar to the banking data-set. The sole exception is the education_num variable which represents the number of years in school. We can see that it has a clearly positive relationship and might be approximated with a linear function, despite the concavity.

Balance of data

It’s helpful to review the balance between classes in the data sets before fitting and experimenting with models. Class imbalance in data can introduce bias that may have to be addressed using resampling techniques.

We’ll investigate the imbalance using the code displayed below:

sum.func <- function(x,col,model.name){ 
  dims <- data.table( "Measure" = c("Observations","Factors")  , "Count" = dim(x))
  factors <- data.table( "Minority Class Percent",  min(x[,.N,by = col]$N)/sum(x[,.N,by = col]$N)  )
  names(factors) <- c("Measure","Count")
  for.return <- data.table(rbind(dims,factors)   )
  names(for.return)[2] <- model.name
  # factors$Measure <- paste("Class = ",factors$Measure,sep = "")
  return(  for.return    )  }
dep.vars <- c("Class", "y", "Income_class")
summaries <- lapply(X = 1:length(training.data),
                    FUN = function(x){sum.func(training.data[[x]], dep.vars[x] ,
                                               model.name = names(training.data)[x])   })
summaries[[1]]$Banking <- summaries[[2]]$Banking
summaries[[1]]$Adult <- summaries[[3]]$Adult
kable(summaries[[1]],digits = 3)

I’ve formatted the results of the table to show the size of the minority class of the dependent variable compared to the rest of the data set below.

Image by author
Image by author

We can see from the table above that the Banking and Adult data sets contain a significant imbalance for the prediction classes. To correct this, we will use the synthetic minority oversampling technique and review the results.

A quick note: The SMOTE step may take some time.

banking.smote <- RSBID::SMOTE_NC(data = training.data$Banking , outcome = "y" )
adult.smote <- RSBID::SMOTE_NC(data = training.data$Adult, outcome = "Income_class" )
p_load(data.table,rmarkdown,knitr)
resampled_data <- list(training.data$Synthetic,banking.smote,adult.smote)
dep.vars <- c("Class", "y", "Income_class")
names(resampled_data) <- names(training.data)
summaries <- lapply(X = 1:length(training.data),
                    FUN = function(x){sum.func(resampled_data[[x]], 
                                               dep.vars[x] ,
                                               model.name = names(training.data)[x])   })
summaries[[1]]$Banking <- summaries[[2]]$Banking
summaries[[1]]$Adult <- summaries[[3]]$Adult
kable(summaries[[1]],digits = 3)

Building Decision Trees

We begin by building simple decision trees to use as a factor instead of the original categorical variables. These trees will have to be small in order to avoid overfitting. I’ve tested models using trees with maximum depths of two, three, and four and chosen on using a maximum depth of four in each model based on its improved accuracy in ROC and PR AUC. In the code chunk below, CART decision trees are fitted with the leaf node prediction then being used to create a new factor variable in the data. After that, the new data is used to fit a logistic regression model.

We first fit the initial decision trees (another note: I tested this with several iterations of depth and other details. The choices below seemed to perform the best):

synthetic.tree.2 <- rpart(data = training.data$Synthetic,
                          formula = Class~Linear1+Linear2+Linear3+Linear4+Linear5+Linear6+Nonlinear1 +Nonlinear2 +Nonlinear3,
                          control = rpart.control(  maxdepth = 4 ))
banking.tree.3 <- rpart(data = banking.smote,
                        formula = y~.-V1, control = rpart.control(  maxdepth = 4)  )
adult.tree.3 <- rpart(data = adult.smote,
                      formula = Class~age+workclass+fnlwgt+education+education_num +marital_status+occupation+relationship+race+sex+capital_gain +  capital_loss + hours_per_week+ native_country,
                      control = rpart.control(  maxdepth = 4))
synth.models <- list(synthetic.tree.2)
banking.models <- list(banking.tree.3)
adult.models <- list(adult.tree.3)
tree.to.feature <- function(tree.model,dt){
  require(partykit)
  tree_labels <- factor( predict(as.party(tree.model), dt ,type = "node") )
  return(tree_labels)}

Next, we take these results and use them to fit the logistic regression models:

synth.train.preds <- lapply(X = synth.models,FUN = function(x){   tree.to.feature(tree.model = x,dt = training.data$Synthetic)  }) %>% data.frame
banking.train.preds <-lapply(X = banking.models,FUN = function(x){   tree.to.feature(tree.model = x,dt = training.data$Banking)  }) %>% data.frame
adult.train.preds <-lapply(X = adult.models,FUN = function(x){   tree.to.feature(tree.model = x,dt = training.data$Adult)  }) %>% data.frame
names(synth.train.preds) <- c("three.nodes")
names(banking.train.preds) <- c("four.nodes")
names(adult.train.preds) <- c("four.nodes")
training.data$Synthetic <- cbind( training.data$Synthetic,synth.train.preds )
training.data$Banking <- cbind(training.data$Banking ,banking.train.preds )
training.data$Adult <- cbind(training.data$Adult ,adult.train.preds )
synth.model.three.deep<- glm(formula = Class~Linear1+Linear2+Linear3+Linear4+Linear5+Linear6+Nonlinear1 +Nonlinear2 +Nonlinear3 + three.nodes,family = "binomial",data = training.data$Synthetic)
banking.mode.four.deep <- glm(formula = y~.-V1 ,family = "binomial",data = training.data$Banking)
adult.mode.four.deep <- glm(formula = Class~age+workclass+fnlwgt+education+education_num +marital_status+occupation+relationship+race+sex+capital_gain +  capital_loss + hours_per_week+ native_country + four.nodes,
                            family = "binomial",
                            data = training.data$Adult)

Finally, we fit the GAM and GLM models which we will use to benchmark the GLM+Tree model results:

# Create GAM models and the GLM models:
synth.gam <- gam(data = training.data$Synthetic,formula = Class~s(Linear1)+s(Linear2)+s(Linear3)+s(Linear4)+s(Linear5)+s(Linear6)+s(Nonlinear1) +s(Nonlinear2) +s(Nonlinear3),family = binomial)
banking.gam <- gam(formula = y~s(age)+s(previous)+s(euribor3m)+s(cons_conf_idx)+s(cons_price_idx)+s(nr_employed)+s(emp_var_rate)+s(pdays)+`job_blue-collar`+
                     job_management+`job_other 1`+`job_other 2`+job_services+job_technician+marital_married+marital_single+ education_high.school+education_professional.course+education_university.degree+education_unknown+default_unknown+housing_unknown+
                     housing_yes+loan_unknown+loan_yes+poutcome_nonexistent+poutcome_success,family = "binomial",data = training.data$Banking)
adult.gam <- gam(formula = Class~s(age)+(workclass)+s(fnlwgt)+(education)+s(education_num) +(marital_status)+(occupation)+(relationship)+(race)+(sex)+s(capital_gain) +  s(capital_loss) + s(hours_per_week)+ (native_country),
                 family = "binomial",
                 data = training.data$Adult)
# Create GLM Models
synth.model.glm<- glm(formula = Class~Linear1+Linear2+Linear3+Linear4+Linear5+Linear6+Nonlinear1 +Nonlinear2 +Nonlinear3 ,family = "binomial",data = training.data$Synthetic)
banking.mode.glm <- glm(formula = y~.-V1-four.nodes ,family = "binomial",data = training.data$Banking)
adult.mode.glm <- glm(formula = Class~age+workclass+fnlwgt+education+education_num +marital_status+occupation+relationship+race+sex+capital_gain +  capital_loss + hours_per_week+ native_country ,
                      family = "binomial",
                      data = training.data$Adult)Finally, we fit the GAM and GLM models which we will be using to benchmark the GLM+Tree model results:
# Create GAM models and the GLM models:
synth.gam <- gam(data = training.data$Synthetic,formula = Class~s(Linear1)+s(Linear2)+s(Linear3)+s(Linear4)+s(Linear5)+s(Linear6)+s(Nonlinear1) +s(Nonlinear2) +s(Nonlinear3),family = binomial)

banking.gam <- gam(formula = y~s(age)+s(previous)+s(euribor3m)+s(cons_conf_idx)+s(cons_price_idx)+s(nr_employed)+s(emp_var_rate)+s(pdays)+`job_blue-collar`+
                     job_management+`job_other 1`+`job_other 2`+job_services+job_technician+marital_married+marital_single+ education_high.school+education_professional.course+education_university.degree+education_unknown+default_unknown+housing_unknown+
                     housing_yes+loan_unknown+loan_yes+poutcome_nonexistent+poutcome_success,family = "binomial",data = training.data$Banking)

adult.gam <- gam(formula = Class~s(age)+(workclass)+s(fnlwgt)+(education)+s(education_num) +(marital_status)+(occupation)+(relationship)+(race)+(sex)+s(capital_gain) +  s(capital_loss) + s(hours_per_week)+ (native_country),
                           family = "binomial",
                           data = training.data$Adult)
# Create GLM Models

synth.model.glm<- glm(formula = Class~Linear1+Linear2+Linear3+Linear4+Linear5+Linear6+Nonlinear1 +Nonlinear2 +Nonlinear3 ,family = "binomial",data = training.data$Synthetic)

banking.mode.glm <- glm(formula = y~.-V1-four.nodes ,family = "binomial",data = training.data$Banking)

adult.mode.glm <- glm(formula = Class~age+workclass+fnlwgt+education+education_num +marital_status+occupation+relationship+race+sex+capital_gain +  capital_loss + hours_per_week+ native_country ,
                           family = "binomial",
                           data = training.data$Adult)

Model Testing

Now that we’ve fit the models, we can compare their results. To do this, I will rely on ROC AUC and PR AUC. The code to evaluate the results is presented below:

synth.test.preds <- lapply(X = synth.models,FUN = function(x){  
  tree.to.feature(tree.model = x,dt = test.data$Synthetic)  }) %>% data.frame
banking.test.preds <-lapply(X = banking.models,FUN = function(x){ 
  tree.to.feature(tree.model = x,dt = test.data$Banking)  }) %>% data.frame
adult.test.preds <- lapply(X = adult.models,FUN = function(x){ 
  tree.to.feature(tree.model = x,dt = test.data$Adult)  }) %>% data.frame
names(synth.test.preds) <- c( "three.nodes")
names(banking.test.preds) <- c( "four.nodes")
names(adult.test.preds) <- c( "four.nodes")
test.data$Synthetic <- cbind( test.data$Synthetic,synth.test.preds )
test.data$Banking <- cbind(test.data$Banking ,banking.test.preds )
test.data$Adult <- cbind(test.data$Adult ,adult.test.preds )
training.for.mmdata <- data.frame(predict(banking.mode.glm,newdata = training.data$Banking, type = "response" ),
                                  predict(banking.mode.four.deep, newdata = training.data$Banking,type = "response" ),
                                  predict(banking.gam,newdata = training.data$Banking, type = "response" )  ) 
training.mdat <- mmdata(scores = training.for.mmdata,labels = training.data$Banking$y,
                        modnames = c("Logistic Regression", "Tree w/ GLM", "GAM"))
testing.for.mmdata <- data.frame(predict(synth.model.glm,newdata = test.data$Synthetic, type = "response" ),
                                 predict(synth.model.three.deep, newdata = test.data$Synthetic,type = "response" ),
                                 predict(synth.gam,newdata = test.data$Synthetic, type = "response" )  ) 
testing_mdat <- mmdata(scores = testing.for.mmdata,labels = test.data$Synthetic$Class,
                       modnames = c("Logistic Regression", "Tree w/ GLM", "GAM"))
training_ROC <- autoplot(evalmod(training.mdat),curvetype = c("ROC"))+theme(legend.position = "bottom") +ggtitle("ROC Curve - Training Data")
training_PR <- autoplot(evalmod(training.mdat),curvetype = c("PR"))+theme(legend.position = "bottom") +ggtitle("PR Curve - Training Data")
testing_ROC <- autoplot(evalmod(testing_mdat),curvetype = c("ROC"))+theme(legend.position = "bottom") +ggtitle("ROC Curve - Testing Data")
testing_PR <- autoplot(evalmod(testing_mdat),curvetype = c("PR"))+theme(legend.position = "bottom") +ggtitle("PR Curve - Testing Data")
Image by author
Image by author
Image by author
Image by author
Image by author
Image by author

Each model performs well in both the training and testing data sets and there is little to no deterioration of model accuracy when switching to the out-of-sample data. In all cases, there is a small improvement of the performance metrics from using decision trees as a feature engineering step. We also see that in all cases the GAM model outperforms either model.

In several cases, the GAM outperforms the GLM+Tree model by a greater margin than the GLM+Tree outperforms the logistic regression model (see PR AUC for Banking and Synthetic and the Adult data-set). From this, we can see that including the structure of the nonlinear data in the model specification instead of transforming it allows for improved performance in all cases. This is significant given that using GAM also eliminates the investigation and comparison of potential decision trees for the model.

Conclusion:

I compared the results from two different approaches to coping with nonlinear data. The first was used a decision tree model as a feature engineering step to create a set of one-hot encoders which are then used in a logistic regression model. The second method was a generalized additive model which incorporates the potential nonlinearity of the data into the model specification. This eliminates the assumption of a linear relationship between the dependent and independent variables. These two methodologies were compared on three data sets to see if there was a consistent difference in their performances. The tests show that GAM outperformed the GLM+Tree models in every instead. Both models outperformed simple logistic regression models.

The fact that GAMs outperformed the Tree+GLM models is significant since it also eliminates the feature engineering step and the ambiguity about the depth of the tree being used to generate the processed variables.

References:

[1] T. Hastie, R. Tibshirani, J. Friedman, Elements of Statistical Learning (2009)

[2] T. Hastie, R. Tibshirani, Generalized Additive Models, Statistical Science (1986)

[3]Syzmanski, Andrzej, Combining logistic regression and decision tree, Towards Data Science (2020)


Related Articles