How to do batch predictions of TensorFlow models directly in BigQuery

BigQuery ML now supports TensorFlow SavedModel

Lak Lakshmanan
Towards Data Science

--

If you have trained a model in TensorFlow and exported it as a SavedModel, you can now use the ML.PREDICT SQL function in BigQuery to make predictions. This is very useful if you want to make batch predictions (e.g., to make predictions for all the data collected in the past hour), since any SQL query can be scheduled in BigQuery.

Steps:

  • Train and export a saved model in TensorFlow
  • In BigQuery, create a Model, passing in the location of the saved model
  • Use ML.EVALUATE, ML.PREDICT, etc. just as if you had trained the model in BigQuery using its other (built-in) model types.

Note: this feature is currently in public alpha. Contact your GCP representative to get whitelisted.

1. Train and export SavedModel in TensorFlow/Keras

I’ll demonstrate using the text classification model that I describe in this blog post and for which the Keras code is on GitHub. I trained the model on Cloud ML Engine, but you can train it wherever and however you want. The important bit is the line to export the model as a Saved Model into Google Cloud Storage:

exporter = tf.estimator.LatestExporter('exporter', serving_input_fn)

2. Create Model

Creating the model in BigQuery is simply a matter of specifying a different model_type and pointing it at the model_path where the SavedModel was exported (note the wildcard at the end to pick up the assets, vocabulary, etc.):

CREATE OR REPLACE MODEL advdata.txtclass_tf
OPTIONS (model_type='tensorflow',
model_path='gs://cloud-training-demos/txtclass/export/exporter/1549825580/*')

I’ve made the bucket above public, so you can try out the query above as-is (create a dataset named advdata first). This creates a model in BigQuery that works like any built-in model:

In particular, the schema indicates that the required input to the model is called “input” and is a string.

3. Predicting with model

Predicting with model is quite straightforward. For example, we can pull some data from our BigQuery table and make sure that in our select, we name our columns depending on what the TensorFlow requires. In this case, my TensorFlow model’s serving_input_fn specifies that the model expects a single input string called “input”.

Given that, we can now do a prediction:

WITH extracted AS (
SELECT source, REGEXP_REPLACE(LOWER(REGEXP_REPLACE(title, '[^a-zA-Z0-9 $.-]', ' ')), " ", " ") AS title FROM
(SELECT
ARRAY_REVERSE(SPLIT(REGEXP_EXTRACT(url, '.*://(.[^/]+)/'), '.'))[OFFSET(1)] AS source,
title
FROM
`bigquery-public-data.hacker_news.stories`
WHERE
REGEXP_CONTAINS(REGEXP_EXTRACT(url, '.*://(.[^/]+)/'), '.com$')
AND LENGTH(title) > 10
)
)
, input_data AS (
SELECT title AS input FROM extracted limit 5
)
SELECT *
FROM ML.PREDICT(MODEL advdata.txtclass_tf,
(SELECT * FROM input_data))

This provides the result:

Knowing the actual labels, we can make the actual query better:

SELECT
input,
(SELECT AS STRUCT(p, ['github', 'nytimes', 'techcrunch'][ORDINAL(s)]) prediction FROM
(SELECT p, ROW_NUMBER() OVER() AS s FROM
(SELECT * FROM UNNEST(dense_1) AS p))
ORDER BY p DESC LIMIT 1).*
FROM ML.PREDICT(MODEL advdata.txtclass_tf,
(
SELECT 'Unlikely Partnership in House Gives Lawmakers Hope for Border Deal' AS input
UNION ALL SELECT "Fitbit\'s newest fitness tracker is just for employees and health insurance members"
UNION ALL SELECT "Show HN: Hello, a CLI tool for managing social media"
))

Note that this one provides the values directly in the SQL. The result:

That’s it!

--

--