The world’s leading publication for data science, AI, and ML professionals.

How to Train a Decision Tree Classifier… In SQL

SQL can now replace Python for most supervised ML tasks. Should you make the switch?

Photo by Resource Database on Unsplash
Photo by Resource Database on Unsplash

When it comes to machine learning, I’m an avid fan of attacking data where it lives. 90%+ of the time, that’s going to be a relational database, assuming we’re talking about supervised machine learning.

Python is amazing, but pulling dozens of GB of data whenever you want to train a model is a huge bottleneck, especially if you need to retrain them frequently. Eliminating data movement makes a lot of sense. SQL is your friend.

For this article, I’ll use an always-free Oracle Database 21c provisioned on Oracle Cloud. I’m not sure if you can translate the logic to other database vendors. Oracle works like a charm, and the database you provision won’t cost you a dime – ever.


Dataset Loading and Preprocessing

I’ll leave the Python vs. Oracle for machine learning on huge dataset comparison for some other time. Today, it’s all about getting back to basics.

I’ll use the following dataset today:

So download it to follow along, and make sure you have a connection established to your database instance. Tools like SQL Developer or Visual Studio Code can do that for you.

How to Create a Database Table

The following code snippet will create the iris table. The id column is mandatory, as Oracle will need it behind the scenes:

create sequence seq_iris;

create table iris(
    id number default seq_iris.nextval,
    sepal_length number(2, 1),
    sepal_width number(2, 1),
    petal_length number(2, 1),
    petal_width number(2, 1),
    species varchar2(15)
);

Once the table is created, it’s time to load the data.

How to Load CSV Data Into the Table

The source data is in CSV format, so you’ll have to use a tool like SQL Developer to load it into a table. Right-click on your table name, select Import Data and provide a path to your CSV file.

You’ll then need to map the CSV column names to your table column names:

Image 1 - Loading a CSV file to a table (image by author)
Image 1 – Loading a CSV file to a table (image by author)

Once done, issue a simple select statement to verify the data was loaded successfully:

select * from iris;

Here’s what you should get back:

Image 2 - Iris dataset stored in a database (image by author)
Image 2 – Iris dataset stored in a database (image by author)

All available data is now in one table. It’s a great starting point, but not what you want in machine learning.

Train/Test Split

Splitting your data into two subsets makes a lot of sense when building machine learning models. You train the model on one subset (larger) and evaluate it on previously unseen data.

Oracle SQL allows you to sample a portion of your data, let’s say 75% with a random seed. You can use this technique to create a training set, and then use the minus set operator to create a test set:

-- Train set
create or replace view v_iris_train as
select * from iris
sample(75)
seed(42);

-- Test set
create or replace view v_iris_test as
select * from iris 
minus 
select * from v_iris_train;

To see how the data was split, you can print the number of rows from each:

select
    (select count(*) from v_iris_train) as n_train_instances,
    (select count(*) from v_iris_test) as n_test_instances
from dual;

Not a perfect 75:25 split, but close enough:

Image 3 - Number of training/testing instances (image by author)
Image 3 – Number of training/testing instances (image by author)

You now have everything needed to create a machine learning model in SQL. Let’s get right to it!

How to Train a Decision Tree Classifier in SQL

_Note: I’m demonstrating how to train a Decision Tree classifier, but that’s not the only model available. You can choose simple linear models, neural networks, and everything in between. Refer to the official docs for all available algorithms._

There are a couple of things you need to know before training your first ML model in SQL:

  • You must name your model – It’s an arbitrary string, but you can’t create a model if one with a given name already exists. I’ll show you how to drop it first.
  • Each model has its settings – All are passed into dbms_data_mining.setting_list. The ones you’re about to see are specific to a decision tree classifier model. Refer to the official docs for more settings, or settings for different models.
  • Training function – It’s recommended to use the dbms_data_mining.create_model2() function to train the model. The 2 in the name is unfortunately named revision of the original function, and it now allows you to pass in model settings directly instead of saving them into a table first.

As for the model, I’ll modify the settings to use entropy for the metric and to have a max depth of 5. The model will be named ml_model_iris, and will first be dropped if exists.

The rest of the code mostly has to do with error handling:

declare
    v_setlist dbms_data_mining.setting_list;
begin
    -- 1. Decision tree specific settings
    v_setlist(dbms_data_mining.algo_name) := dbms_data_mining.algo_decision_tree;
    v_setlist(dbms_data_mining.tree_impurity_metric) := dbms_data_mining.tree_impurity_entropy;
    v_setlist(dbms_data_mining.tree_term_max_depth) := '5';

    -- 2. Drop the model if it exists
    begin 
        dbms_data_mining.drop_model('ml_model_iris');
    exception
        when others then
            -- Model does not exist error code
            if sqlcode = -40284 then null;
            else 
                dbms_output.put_line('ERROR training the ML model');
                dbms_output.put_line('Reason: Unable to delete previous ML model.');
                dbms_output.put_line('Message: ' || sqlerrm);
                dbms_output.put_line('--------------------------------------------------');
            end if;
    end;

    -- 3. Create the model
    begin
        dbms_data_mining.create_model2(
            model_name => 'ml_model_iris',
            mining_function => 'CLASSIFICATION',
            data_query => 'select * from v_iris_train',
            set_list => v_setlist,
            case_id_column_name => 'id',
            target_column_name => 'species'
        );
    exception
        when others then
            dbms_output.put_line('ERROR training the ML model');
            dbms_output.put_line('Reason: Unable to train the model.');
            dbms_output.put_line('Message: ' || sqlerrm);
            dbms_output.put_line('--------------------------------------------------');
    end;

    -- 4. Other potential exceptions
    exception
        when others then
            dbms_output.put_line('ERROR training the ML model');
            dbms_output.put_line('Reason: Other.');
            dbms_output.put_line('Message: ' || sqlerrm);
            dbms_output.put_line('--------------------------------------------------');
end;
/

Running the above PL/SQL snippet will create a bunch of tables. In my opinion, these should be hidden from the table view as they introduce a lot of clutter, especially if you have multiple model versions:

Image 4 - Tables created by Oracle (image by author)
Image 4 – Tables created by Oracle (image by author)

But, the tables carry a lot of useful information. Let’s go over it next.

How to Extract Model Details

For example, the DM$PX<model-name> table shows you the "decisions made" by the decision tree model. In other words, it shows which features have to be split where to make the best prediction possible:

Image 5 - Model details for decision tree classifier (image by author)
Image 5 – Model details for decision tree classifier (image by author)

Other tables are less relevant, so I’ll leave their investigation up to you.

How to Make Predictions on the Test Set

Onto the predictions now. The below code snippet creates a table named iris_predictions which contains per-class predictions and prediction probability for every row in the test set.

The syntax is somewhat specific to Machine Learning tasks on the Oracle database, as you need to call the prediction_set() function and pass in the model:

create table iris_predictions as (
    select 
        s.id, 
        t.prediction, 
        t.probability
    from v_iris_test s, 
    table(prediction_set(ml_model_iris using *)) t
);

This will create the table for you, so let’s take a peek what’s inside:

select * from iris_predictions;

The predictions are per-class based, meaning you’ll have as many records for a single instance as there are classes. The Iris dataset has 3 classes, so the iris_predictions table will hold 3 * 33 records:

Image 6 - Per class prediction probabilities (image by author)
Image 6 – Per class prediction probabilities (image by author)

With a bit of SQL magic, you can extract only the row in which the probability is the highest. The analytical row_number() function combined with descending sorting does most of the heavy lifting:

with ranked_predictions as (
    select 
        id, 
        prediction, 
        probability,
        row_number() over(partition by id order by probability desc) as rn
    from iris_predictions
)
select
    p.id,
    a.species as actual,
    p.prediction,
    p.probability,
    case
        when a.species = p.prediction then 1
        else 0
    end as is_correct
from ranked_predictions p
join v_iris_test a on a.id = p.id
where p.rn = 1;

I’ve also added an is_correct flag, so you can immediately know if the predicted value matches the actual one:

Image 7 - Predicted class with is_correct flag (image by author)
Image 7 – Predicted class with is_correct flag (image by author)

And as you can see, that’s not always the case:

Image 8— Predicted class with is_correct flag (2) (image by author)
Image 8— Predicted class with is_correct flag (2) (image by author)

You now have a prediction result set, but what about model accuracy on the test set? What about other metrics, such as the confusion matrix? That’s what you’ll implement next.

How to Evaluate Your Decision Tree Classifier Model

Unfortunately, computing model accuracy on the test set and confusion matrix takes much more code than doing the same in Python.

You get the accuracy value for free when calculating the confusion matrix, so that’s a plus. The dbms_data_mining.confusion_matrix() function requires you to pass in a dozen parameters, all specifying different table names and table column names. The confusion matrix itself will then be saved into a table specified by theconfusion_matrix_table_name parameter.

Overall, a lot of work:

declare 
    cm_accuracy number;
begin
    execute immediate 'drop table iris_confusion_matrix';
    exception
        when others then
            if sqlcode = -942 then null;
            else
                dbms_output.put_line('ERROR making evaluations');
                dbms_output.put_line('Reason: Unable to delete previous table containing confusion matrix.');
                dbms_output.put_line('Message: ' || sqlerrm);
                dbms_output.put_line('--------------------------------------------------');
            end if;

    begin
        dbms_data_mining.compute_confusion_matrix(
            accuracy => cm_accuracy,
            apply_result_table_name => 'iris_predictions',
            target_table_name => 'v_iris_test',
            case_id_column_name => 'id',
            target_column_name => 'species',
            confusion_matrix_table_name => 'iris_confusion_matrix',
            score_column_name => 'prediction',
            score_criterion_column_name => 'probability',
            cost_matrix_table_name => null,
            apply_result_schema_name => null,
            target_schema_name => null,
            score_criterion_type => 'PROBABILITY'
        );
        dbms_output.put_line('Accuracy: ' || round(cm_accuracy * 100, 2) || '%');
    exception
        when others then
            dbms_output.put_line('ERROR making evaluations');
            dbms_output.put_line('Reason: Unable to create a confusion matrix table.');
            dbms_output.put_line('Message: ' || sqlerrm);
            dbms_output.put_line('--------------------------------------------------');
    end;
end;
/

Once you run the above PL/SQL block, you’ll get the following output:

Image 9 - Decision tree classifier accuracy (image by author)
Image 9 – Decision tree classifier accuracy (image by author)

Which means the model is over 90% accurate on the test set. Nice!

Now about the confusion matrix:

select * from iris_confusion_matrix;

You can see there are two Virginica instances predicted as Versicolor, and one case going the other way around:

Image 10 - Confusion matrix (image by author)
Image 10 – Confusion matrix (image by author)

Overall, that’s your confusion matrix. It may not be presented in the format you’re familiar with, but all the data is there.


Summing up Machine Learning with SQL

To conclude, SQL can now do most things Python can when it comes to supervised machine learning. The important thing to remember is you’re trading ease of development for ease of data access, at least if you’re coming from a Python background.

The current version of Oracle SQL implementation requires more code to do the same, has significantly inferior community support, and could benefit from making the documentation more developer-friendly.

But, if transferring data from database to memory is a deal-breaker, it’s a completely viable solution for building highly accurate models.

What are your thoughts on shifting ML systems to the database? Do you have hands-on experience with a different database vendor? Make sure to share in the comment section below.

Read next:

How to Use OpenAI ChatGPT API… In SQL


Related Articles