Decision Trees in Python

From Sustainability Methods

THIS ARTICLE IS STILL IN EDITING MODE

In short: Decision tree learning is a supervised machine learning method that is both used for classification and regression problems. It is a very powerful method that is also easy to interpret and understand. However, decision trees can easily encounter overfitfing and therefore not generalize well. Often, methods such as pruning have to be applied.

Motivation

Decision tree learning falls into the category of supervised machine learning algorithms. The final predictive model is referred to as decision tree. Decision trees can be used for both regression (regression trees) and classification tasks (classification trees). Thus, they can be used whenever the dataset that contains features and a target variable. For instance, consider the following dataset from Mitchell (1997) (Table 1). Outlook, temperature, humidity, and wind are the features, and the target variable is whether we should play tennis or not. Can we find a pattern in the data that allows us to determine this on a given day?

Table 1. Training examples for the target concept PlayTennis

Day Outlook Temperature Humidity Wind PlayTennis
Dl Sunny Hot High Weak No
D2 Sunny Hot High Strong No
D3 Overcast Hot High Weak Yes
D4 Rain Mild High Weak Yes
D5 Rain Cool Normal Weak Yes
D6 Rain Cool Normal Strong No
D7 Overcast Cool Normal Strong Yes
D8 Sunny Mild High Weak No
D9 Sunny Cool Normal Weak Yes
D10 Rain Mild Normal Weak Yes
D11 Sunny Mild Normal Strong Yes
D12 Overcast Mild High Strong Yes
D13 Overcast Hot Normal Weak Yes
D14 Rain Mild High Strong No

What the Method does

A decision tree breaks down this complex decision of whether one should or should not play tennis into a set of logical rules using disjunctions (OR) and conjunctions (AND). A decision tree for this dataset may look as follows: I will play tennis whenever…

  • the outlook today is overcast;
  • OR if the outlook is sunny AND the humidity is normal;
  • OR the outlook is rain AND the wind is weak.

The algorithm first selects the most informative feature for the root node of the tree and splits the data into subsets that correspond to the branches of the feature. There are different metrics on how to calculate the most informative feature (e.g., entropy or Gini impurity for classification, variance for regression), though generally, they all measure the homogeneity of the target variable in the respective subset of the data and choose the feature with the highest homogeneity.

In our example, outlook was selected as the most informative feature and the subsets would be all the data where outlook is sunny, overcast or rain. The algorithm then repeats the same procedure on all the subsets that were created. It stops once the target variable in a subset is homogeneous, meaning that all the examples belong to the same class.

Interpretability

One big advantage of decision trees is that they are considered to be human interpretable. Compared to other models, such as neural networks, decision trees are white box models. Any prediction of the tree can be explained by Boolean logic. On top of that, we can also visualize the final model and arrive at the same conclusion as the decision tree by following the branches of the tree. The following picture (Figure 1) shows the same logical rules that were previously mentioned.

Figure 1. PlayTennis DecisionTree

Simple Regression Tree Model in Python

In this example, we are going to use the car price prediction dataset from Kaggle. The goal is to predict the price of a car based on its certain properties, such as car length, cylinder size, etc.

First, we have to import the libraries and modules that we will use for this demonstration.

import pandas as pd
import numpy as np 
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

Note that this dataset is a toy dataset, meaning that luckily for us, not a lot of data-cleaning and preprocessing will have to be done. The accuracy of our model could potentially be improved by including more preprocessing steps or model simplification. However, the purpose of this example is not necessarily how to achieve the highest accuracy (then we would use more sophisticated models, such as random forest regression). Instead, the goal is to showcase how one can employ a simple decision tree model. We are now going to import our data using pandas.

df = pd.read_csv('CarPrice_Assignment.csv')
df.columns

Index(['car_ID', 'symboling', 'CarName', 'fueltype', 'aspiration', 'doornumber', 'carbody', 'drivewheel', 'enginelocation', 'wheelbase', 'carlength', 'carwidth', 'carheight', 'curbweight', 'enginetype', 'cylindernumber', 'enginesize', 'fuelsystem', 'boreratio', 'stroke', 'compressionratio', 'horsepower', 'peakrpm', 'citympg', 'highwaympg', 'price'], dtype='object')

We are going to remove the car_ID column, as it is basically the same as the index. Let us do some modifications of the CarName column, in order to display only the car`s brand, avoiding typos.

# Drop useless column
df = df.drop(columns='car_ID',axis=1)

def car_brand(car_name):
    # This function takes an entry of the 'carName' column and splits the words into a list
    # The first word will be the car brand, which is then returned in lowercase
    car_brand = car_name.split(" ")[0].lower()
    return car_brand

# We apply the car_brand function to the 'CarBrand' column
df['CarBrand']=df['CarName'].apply(car_brand)

# Drop the old 'CarName' column from the dataframe
df.drop(columns='CarName',axis=1,inplace=True)

# Fixing some typos 
df.loc[df['CarBrand']=='vokswagen','CarBrand']='vw'
df.loc[df['CarBrand']=='volkswagen','CarBrand']='vw'
df.loc[df['CarBrand']=='porcshce','CarBrand']='porsche'
df.loc[df['CarBrand']=='maxda','CarBrand']='mazda'
df.loc[df['CarBrand']=='toyouta','CarBrand']='toyota'
df.columns

Even though categorical data works for decision trees, the Scikit-Learn implementation does not support this kind of data. Since our categorical columns do not contain ordinal data, we have to one-hot encode the categorical columns. We could also standardize the numeric columns (except our target variable) to have a mean 0 and a standard deviation of 1. However, decision trees do not require this assuption and can handle unstandardized data.

# Collect the categoric and numeric columns in a list
categoric = df.select_dtypes(include=object).columns.to_list()
# One-hot encode the categoric columns
df_cat = pd.get_dummies(df[categoric], dtype=float)
# First drop the old categoric columns and then concat the new categoric columns into the original dataframe
df = df.drop(columns=categoric)
df = pd.concat([df, df_cat], axis= 1)

We now split the data into training and test set and fit our decision tree model. We then simply have to initialize the DecisionTreeRegressor (DecisionTreeClassifier for classification) class. We use a random state for reproducibility.

# Convert data into features X and labels y in a numpy array
X = np.asarray(df.drop(columns='price', axis=1))
y = np.asarray(df['price'])

# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=0)

# Set the minimum number of samples in each leaf to 5 to pre-prune the tree
reg_tree = tree.DecisionTreeRegressor(random_state=0)

# Fit the model
reg_tree.fit(X_train, y_train)

After the DecisionTreeRegressor instance was fitted, we can plot the tree as follows:

from matplotlib import pyplot as plt

tree.plot_tree(reg_tree, max_depth=1)
plt.show() # Figure 2

Plot tree.png
Figure 2. Decision tree plot

Further explanation and other parameters of the plot_tree method can be found here.

Lastly, we calculate the root mean squared error (RSME) for our train and test predictions and also calculate the coefficient of determination to evaluate how well the model fits the data.

y_pred_train = reg_tree.predict(X_train)

# Train RMSE 
RMSE_train = np.sqrt(mean_squared_error(y_train, y_pred_train))
print("RMSE for Training Data: ", RMSE_train) # Out: 291.4381969264448

# Train R_squared
R_squared_train = r2_score(y_train, y_pred_train)
print("Training R_squared for Decision Tree Regression Model: ", R_squared_train) # Out: 0.9986795498364933

We see that our model can almost perfectly explain the training data. The next step will be to check the measurements for the test data:

y_pred_test = reg_tree.predict(X_test)

# Test RMSE 
RMSE_test = np.sqrt(mean_squared_error(y_test, y_pred_test))
print("RMSE for Testing Data: ", RMSE_test)

# Test R_squared:
R_squared_test= r2_score(y_test, y_pred_test) # Out: 2839.9804482339855
print("Testing R_squared for Decision Tree Regression Model: ", R_squared_test) # Out: 0.8692375508177494

The results show that Test RMSE is much more than for the training data, in turn, Test R_squared is less. We may suggest that it is overfitting. This is a disadvantage of decision trees. A decision tree can potentially split the data until each subset is completely homogeneous, which basically means that our tree has memorized the data and won’t extrapolate to unseen data.

Pre- and Post-Pruning

In order to counteract the problem of overfitting, we can prune our decision tree. Pruning the tree will reduce the depth of the tree which hinders it to perfectly match all training examples.

Pre-Pruning

One way to prune the tree is to set a limit on the its depth beforehand. This method is referred to as pre-pruning or early stopping. However, this can be a bit tricky, because we do not know how to set the parameters in advance. In order to implement this, the max_depth parameter can be used, when initializing the tree. Alternatively, we can also require the tree to have a certain number of examples in the leaf nodes through the min_samples_leaf parameter. Let us see how well our model performs on the test set, using a pre-pruning method.

# Set the minimum number of samples in each leaf to pre-prune the tree
reg_tree = tree.DecisionTreeRegressor(random_state=0, min_samples_leaf=4)

# Train the model 
reg_tree.fit(X_train, y_train)

# Collect the predictions of the train and test data
y_pred_train = reg_tree.predict(X_train)
y_pred_test = reg_tree.predict(X_test)

# Train RMSE 
RMSE_train = np.sqrt(mean_squared_error(y_train, y_pred_train))
print("RMSE for Training Data: ", RMSE_train) # Out: 1464.1947006424898

# Test RMSE
RMSE_test = np.sqrt(mean_squared_error(y_test, y_pred_test))
print("RMSE for Testing Data: ", RMSE_test) # Out: 2541.4558310183947

# Train R_squared:
R_squared_train = r2_score(y_train, y_pred_train)
print("Training R_squared for Decision Tree Regression Model: ", R_squared_train) # Out: 0.9666706584900479

# Test R_squared:
R_squared_test = r2_score(y_test, y_pred_test)
print("Testing R_squared for Decision Tree Regression Model: ", R_squared_test) # Out: 0.8952829308308532

We can see that the model accuracy on the training data has decreased. However, we can generalize better on unseen data than before without any pruning.

Post-Pruning

Post-pruning is a method where first we let the decision tree overfit the data. Then, using a validation set, we prune the tree and remove the nodes, where deleting them would increase the validation performance the most. It is a bit more complex mathematically. An example can be found here.

Strength & Weaknesses of Decision Trees

Weaknesses

  • Overfitting is an often issue. Therefore, pruning methods need to be applied.
  • Decision trees are sensitive to noise. More sophisticated models, such as random forests, can help to overcome this.
  • If the dataset is unbalanced, decision trees can create a biased model.

Strengths

  • As it was previously pointed out, decision trees are white box models. As long as they are not too large, they can be visualized and interpreted easily by humans.
  • They are relatively simple to use.
  • Can be used for both categorical and numerical data.
  • Require little data-preprocessing, e.g., decision trees can handle unstandardized data and missing values.

References

1. 1.10. Decision Trees. (n.d.). Scikit-learn. https://scikit-learn.org/stable/modules/tree.html

2. Post pruning decision trees with cost complexity pruning. (n.d.). Scikit-learn. https://scikit-learn.org/stable/auto_examples/tree/plot_cost_complexity_pruning.html

3. Mitchell, T. (1997). Machine Learning.

4. Shalev‐Shwartz, S., & Ben-David, S. (2014). Understanding machine learning. https://doi.org/10.1017/cbo9781107298019

The author of this entry is Moritz Burmester. Edited by Evgeniya Zakharova.