Pushing down ML in SQL Engines: A Exploration

Demonstration of cross-domain optimization to improve XGBoost model inference

exploration
Author

Daniel Mesejo

Published

March 24, 2024

Modified

May 24, 2024

SQL + XGBoost

Welcome once again to LETSQL’s exploration series!

In the previous blog posts, we showed how we optimize ML preprocessing and leveraging User Defined Aggregate Functions (UDAF) for training. In this post, we go a step further and experiment database-style optimizations to an end-to-end machine learning inference pipeline. We also show a working example of compiling a simple XGBoost into SQL and benefit from database style optimizations e.g. predicate pushdowns, projection pushdowns and constant folding etc.

You can find the complete code on GitHub.

Introduction

Machine Learning is transforming every industry by drawing insights from high-value data stored in databases, data warehouses, and lakes.

Data analysts and data scientists use complex analytics queries with pre-trained models to predict new data outcomes. Often, these queries involve data processing operators and relational-specific operators, like filters or joins, due to the data being stored in a relational DB.

While real-time ML inference is gaining traction, most of the Enterprise use-cases are done via batch workflows. We make an observation that the batch use-cases can benefit from the SQL’s relational machinery and show an example of compiling an XGBoost model into CASE statements in the trees.

Problem Statement

We will demonstrate how cross-domain optimization can enhance UX and performance of prediction queries using Microsoft’s Length of Stay Dataset. The task is to find which patients that have low number of re-admissions within the last 180 days (rcount < 2) are likely to have a long stay (lengthofstay > 1)1. We already have a trained simple XGBRegressor model ("model.json") and the data is stored inside a PostgreSQL database.

The dataset and the full description of the features can be found here.

A normal workflow

The usual solution would involve downloading all data from Postgres into a pandas dataframe, applying some preprocessing, running the XGBRegressor on top of the dataframe and finally filter it.

import pandas as pd
import xgboost as xgb
from sqlalchemy import create_engine


model = xgb.XGBRegressor()
model.load_model("model.json")

engine = create_engine('postgresql://user:pass@localhost:5432/patients')
query = "SELECT * FROM patients;"
df = pd.read_sql_query(query, engine)

columns_to_drop = [col for col in df.columns if col.startswith('facid_')]
X = df.set_index('eid').drop(columns=columns_to_drop)

predictions = model.predict(X)
df_with_predictions = X.assign(lengthofstay=predictions)
mask = df_with_predictions["rcount"].lt(2) & df_with_predictions["lengthofstay"].gt(1)
filtered = df_with_predictions[mask]
result = filtered.reset_index()[['eid']]

What’s wrong with this picture?

The problems with the previous workflow can be put into two major categories: UX and Performance.

Expressivity Problems

A non exhaustive list of UX problems are:

  1. The data scientist has to manually specify which columns to drop, or select even though the model already has encoded this information.
  2. It requires the user to know how to connect to Postgres and the SQL dialect associated with it.
  3. There is no easy way to pipeline the above computations together and express them in one dialect/API

Performance Problems

Regarding performance, the workflow has two significant issues that need attention:

  1. It fetches all columns of the table, only to remove some of the columns later, wasting bandwidth.
  2. It predicts all rows of the table, only to filter some of the rows later, wasting compute.

Some of these areas are subject of research and I encourage the users to check out academic papers here and here. Therefore, it is essential that a good tool abstracts the user from such details while providing a performant solution that takes advantage of the relational algebra and many years of research. There must be a better way!

A better way

Ideally we would like a tool that allows us to declaratively specify our query and results in the following SQL query:

select p.eid
from patients
where predict_xgb('model.json', patients.*) > 1
  and patients.rcount < 2;

The core idea is that we read and parse the model.json file and convert it into an SQL query that can be executed by Postgres. Here is what a single tree looks like when compiled to SQL2:

CASE
  WHEN patients.rcount < 2
  AND (
    patients.psychologicaldisordermajor IS NULL
    OR patients.psychologicaldisordermajor >= 1
  )
  AND (
    patients.hematocrit IS NULL OR patients.hematocrit >= 9.69999981
  )
  THEN 1.24346876
  WHEN patients.rcount < 2
  AND (
    patients.psychologicaldisordermajor IS NULL
    OR patients.psychologicaldisordermajor >= 1
  )
  AND patients.hematocrit < 9.69999981
  THEN 1.66323793
  WHEN patients.rcount < 2
  AND patients.psychologicaldisordermajor < 1
  AND (
    patients.hemo IS NULL OR patients.hemo >= 1
  )
  THEN 1.3903867
  WHEN patients.rcount < 2 AND patients.psychologicaldisordermajor < 1 AND patients.hemo < 1
  THEN 0.698453844
END

Additionally such tool could beneficially also be able to do the following optimizations out-of-the-box3:

  1. projection pushdowns: the condition rcount < 2 is pushed upward and into the XGBRegressor, resulting in multiple subtrees being pruned resulting in fewer features.
  2. model inlining: small models could be transpiled into SQL and executed at the source.
  3. tree-level pushdowns: We have an opportunities to do further predicate based pushdowns into the trees, since we only required lengthofstay above 1. This is a topic of research and left for future implementation.

Show me the code

In the rest of the post we will create a function transpile_predict that will take a query such as the one above and convert it to a query that can be executed by Postgres. For this we will use sqlglot, a sql transpiler. From the README on GitHub:

SQLGlot is a no-dependency SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between 21 different dialects like DuckDB, Presto / Trino, Spark / Databricks, Snowflake, and BigQuery. It aims to read a wide variety of SQL inputs and output syntactically and semantically correct SQL in the targeted dialects.

The full transpile_predict function code is:

def transpile_predict(sql):

    node = sg.parse_one(sql, dialect=Postgres)
    predict_expression = _extract_predict_function(node)
    if predict_expression is not None:
        model = _get_model(predict_expression)
        threshold = _extract_prediction_threshold(predict_expression)
        trees = _get_trees(model, threshold=threshold)
        if len(trees) < 5:  # the model is simple enough
            case_expressions = _transform_to_case_expressions(trees)
            case_expressions = _prune_branches(case_expressions, predict_expression)
            return _inline_trees(node, case_expressions).sql(pretty=True)
        else:
            pass  # TODO implement model splitting
    else:
        return sql

The first step will be to parse the query. For this sqlglot provides the function parse_one that parses the given SQL string and returns a syntax tree (AST). Once we have the AST, we extract the node corresponding to the predict_xgb function and load the model from it:

node = sg.parse_one(sql, dialect=Postgres)
predict_expression = _extract_predict_function(node)
if predict_expression is not None:
    model = _get_model(predict_expression)

The second step is to get the trees from the model:

trees = _get_trees(model, threshold=threshold)

If the model is simple enough, the next step would be to inline it. To achieve this, first, we need to transform the reduced list of trees into case statements. Considering the rcount < 2, we could also prune some branches:

case_expressions = _transform_to_case_expressions(trees)
case_expressions = _prune_branches(case_expressions, predict_expression)

And finally replace the predict_xgb function with the inlined trees:

return _inline_trees(node, case_expressions).sql(pretty=True)

The code for each of the functions can be found in the GitHub repo.

Ok, but is it performant?

A simple benchmark shows that the Postgres only solution is about 5 times faster:

postgres + pandas took:     671 ms
postgres took:      141 ms

Conclusions

The main takeaway should not be the improved performance 4 but rather the better UX gained by integrating relational algebra and machine learning. With the added benefit of increasing security by transmitting only necessary data, like eid.

Challenges

In the example we were lucky enough that the model was small enough to be inlined, but: What happens if the model is not small enough to be inlined? If you were paying attention you probably caught the line:

pass  # TODO implement model splitting

One solution is to split the model in two, a cheap model (bmi > 35) to be inlined and a more complex one (bmi <= 35) to be run on by a specialized XGBoost inference operator (like gbdt-rs).

Another challenge is how to do predicate pushdown into the model, since we only required predict_xgb('model.json', patients.*) > 1 we may not need to examine all the trees.

Future work: LETSQL

The LETSQL team is working hard to build a multi-engine scheduler and optimizer that can take advantage of the relational algebra and machine learning. But, we can’t do it alone. That’s why we need your feedback and suggestions on how to improve your experience. Share your thoughts with us in the comments section or on our community forum, and let’s redefine the boundaries of data science and machine learning integration together. To stay updated with our latest developments, sign up for our newsletter or visit letsql.dev.

Resources

Glossary

  • Relational Algebra: A set of operations used in relational database systems to manipulate data stored in tables. It forms the theoretical foundation of SQL.
  • Machine Learning (ML): A branch of artificial intelligence (AI) that focuses on building systems that learn from data, identifying patterns, and making decisions with minimal human intervention.
  • Model Inference: The process of making predictions using a trained machine learning model.
  • XGBRegressor: A part of the XGBoost library, the XGBRegressor is used for regression problems. It implements the gradient boosting decision tree algorithm.
  • PostgreSQL: An open-source relational database management system (RDBMS) known for its robustness, scalability, and support for advanced data types and SQL standards.
  • Pandas DataFrame: A two-dimensional, size-mutable, and potentially heterogeneous tabular data structure with labeled axes (rows and columns) in Python. It is part of the Pandas library, widely used for data manipulation and analysis.
  • SQLAlchemy: A SQL toolkit and Object-Relational Mapping (ORM) library for Python. It provides a full suite of well-known enterprise-level persistence patterns and is designed for efficient and high-performing database access.
  • Predicate-based Model Pruning: An optimization technique where conditions (predicates) are used to eliminate unnecessary calculations or data fetching in a model, improving performance.
  • Model Inlining: The process of translating a machine learning model into SQL statements or other executable code that can be run directly within a database engine, eliminating the need to move data outside the database for processing.
  • Predicate Pushdown: An optimization technique used in databases where the filtering conditions (predicates) are “pushed down” to the data retrieval operations, reducing the amount of data that needs to be loaded and processed.
  • SQLGlot: A no-dependency SQL parser, transpiler, optimizer, and engine capable of formatting SQL or translating between different dialects like DuckDB, Presto/Trino, Spark/Databricks, Snowflake, and BigQuery.
  • Transpiler: A type of compiler that transforms source code written in one programming language into another programming language.
  • Benchmarking: The process of comparing the performance of various systems or components, typically by running a series of tests or benchmarks. Feel free to refer back to this glossary as you read through the blog post to better understand the key concepts and terminology used.

Acknowledgements

Thanks to Dan Lovell and Hussain Sultan for the comments and the thorough review.

Footnotes

  1. The task is a toy problem to showcase the gap between relational algebra and machine learning.↩︎

  2. Perhaps, this type of stress-testing is use-full when you have a lot of case statements. Of course, that’s just for union statements but similar concept may apply for CASE statements.↩︎

  3. The optimizations listed can be found in the papers listed in the Resources section.↩︎

  4. Further improvements could be achieved by tuning the DB for this specific type of workload.↩︎