One of the most invigorating aspects of the data science community is just the sheer number of available algorithms that you can add to your analytics toolkit. With all of these techniques at your disposal, you gain much flexibility in your data work and you can be confident that you will use an appropriate method when the scenario calls for it.
If your background is like mine (a more traditional inferential statistics education), you’ve likely encountered just a fraction of possible ways to model your data. But rest assured, if you’ve already learned the basic fundamentals of statistics, you should be able to ease your way into more complicated methods fairly seamlessly.
To start, it’s important to acknowledge the most common statistical techniques are just extensions of the general linear model (GLM). These include logistic regression, linear regression, and ANOVA (which itself is just a special case of linear regression). These methods are popular and for good reason; when a relationship between X and Y is truly linear (which it often is), they tend to perform extremely well.
But there are many situations in which X and Y are related, but in a non-linear way. What is there to do in such cases? Well, one technique is to simply modify the terms of the linear regression equation (using splines, segmenting the regression, adding polynomial terms, including interaction terms, etc.). But these models often end up being convoluted and lacking in interpretability (it’s not uncommon to catch yourself saying things like, “this models shows there is a 1.8-unit increase in the logged Y for every 1-unit increase in the quadradicness of X.”; NO ONE wants to hear this except for stats geeks!) And compounding the problem, when multiple predictors are in the equation, the interpretation of that one coefficient gets even more complicated to explain.
Lucky for us, there is an algorithm that often outperforms the GLM, can deal with non-linearity, and also has the added benefit of being highly interpretable. I introduce to you, the decision tree.
What is a decision tree?
A decision tree is a purely data-driven (non-parametric) model that one can use for both regression and classification. It works by partitioning the data into nested subsets using a sequence of binary splits. The algorithm declares a root node attribute (at the top of the ‘tree’) and then progresses downward, making binary split decisions for how to subset the data until it is all optimally partitioned. At this point the splitting stops and the data resides in a terminal node (sometimes called a ‘leaf node’). The resulting structure is tree-like (really an inverted tree), with the root node at the top, decision nodes in the middle, and the leaf nodes at the bottom.
In this example above, the objective is to classify whether a customer will purchase a computer vs not. The predictor space is three dimensional (age, credit rating, and student status). In this case, all predictors are nominal, but decision trees can also work with continuous predictors. After a decision tree is built one simply feeds some new data into the tree, and the algorithm will trickle the data down until it lands in a leaf node. At this point, the computed prediction will be the modal or mean response of that particular leaf node. In the tree above, if a customer walks in that is a senior and has an excellent credit rating, you would predict that this customer is likely to purchase a computer.
How does a decision tree decide where to split?
So now that we know what a decision tree is, it’s natural to wonder how the tree is built and how all of this splitting actually works. In the example above you may also be wondering what determined age to be the root node and not credit rating or student status. Though specific details may vary depending on your implementation (there are a few), in essence, all trees will choose an attribute and split criterion that partitions the data into maximally homogeneous segments. Basically, at each step, the algorithm will consider every possible variable to split on and every possible ‘cutpoint’ value within the range of that variable. And the particular split criterion the algorithm chooses will produce a subset of data that is the most homogeneous.
How is homogeneity measured? Well, it depends on the structure of your predictors. In the example above, the predictors are all nominal variables and the response is also a nominal variable (purchase vs not). In such cases, a common splitting criterion is the value that minimizes entropy or analogously increases information gain, both of which you can read about here). Another algorithm that uses this basic idea (CHAID algorithm) will choose an attribute split that maximizes the test statistic in a chi-squared test of independence. A high test statistic here indicates the attribute and its splitting criterion are doing a good job of separating purchasers from non-purchasers. In the case of regression, the process is similar, but instead of homogeneity, the algorithm will compare predicted to actual responses, square them, and then choose a split that produces the greatest reduction in the sums of squared residuals (RSS).
Whichever criteria the algorithm uses (chi-square, RSS, entropy, information gain, etc.), the process recursively continues until there are no more splits. So after the tree makes its first split, it partitions the two resulting subsets of data again by splitting on some other value within that subset, and so on. The stopping criterion to determine no more splits is often set by the user, and can include metrics such as how many observations are assigned to each terminal node, how many splits are made, etc.
How do decision trees compare to linear models?
When a set of X’s is truly related to Y in some non-linear way, your standard least squares regression line will be biased (meaning a straight line doesn’t approximate the true nature of the relationship in the data). In such cases you can go ahead and fit a linear regression any way. It may even produce a respectable p-value and explain a significant portion of the overall variability in Y. But even in such cases when a linear model predicts decently, if the relationship is non-linear decision trees will typically outperform.
So, decision trees tend to perform better than linear models when there is a true non-linear relationship between X and Y. Another big advantage over the linear model is in its ease of interpretation. Though many linear regressions are fairly interpretable (it measures change in Y in units of X), things get much fuzzier when you start dealing with multiple predictors, interaction terms, polynomials, etc. In contrast, decision trees are extremely straightforward and can be readily understood by almost anybody (this is especially useful for communicating with those who may lack the proper training in analytics).
So picking up non-linearities and also ease of interpretation are the two main advantages of decision trees. Like with all algorithms, however, there is no free lunch. One disadvantage with decision trees is that they tend to model the signal so well that they actually often pick up a lot of the noise (this is called ‘overfitting’). Remember, there is always a random component in our data, and so really specific trees that can pick up peculiarities of the data structure may be modeling some of that random component and treating it like a systematic pattern. Luckily there are many methods to handle this pitfall of decision trees (ensemble methods, bagging, random forests, etc.). We won’t focus on these in this blogpost, but just know that there are ways to treat this scenario with decision trees.
How do I get started with decision trees?
Now that you have learned about decision trees, the best way to get better at using them is to go test them out! There are many great packages available in R and Python to get you started. The rpart package in R I would personally recommend. You can run it on the iris dataset in R and have a decision tree built in minutes!
In a future blog I will do a tutorial in R for building a decision tree and visualizing the results. It’s really quite simple. In the meantime, I welcome you to comment with any thoughts or questions!