Optimising your input pipeline performance with tf.data (part 1)
Improve your input pipeline efficiency and GPU utilisation
Concepts of tf.data
The
tf.data
API enables you to build complex input pipelines from simple, reusable pieces.tf.data
also makes it possible to handle large amount of data, reading from different data formats, and perform complex transformation
It is not a big thing to know that GPUs and TPUs can significantly reduce the time required to train a model. However, as a deep learning developer, one of the worst things to experience is seeing your GPU capacity not fully utilised with the bottleneck on the CPU — especially if you are paying loads of money for these services on the different cloud platforms.
Hence, it is crucial to ensure that we achieve optimal performance and efficiency in our input pipeline. The tf.data
API directly deals with this — and this is why I love it so much.
In this part 1 article, I will explain the different concepts as to how tf.data
achieves optimal behaviour and in part 2, I will compare the performance of tf.data
and Keras ImageDataGenerator
for reading your input data.
There are several ways how tf.data
reduce computational overhead which can be easily implemented into your pipeline:
- Prefetching
- Parallelising data extraction
- Parallelising data transformation
- Caching
- Vectorised mapping
Naive approach
Before we start on these concepts, we will have to first understand how the naive approach works when a model is being trained.
This diagram shows that a training step includes opening a file, fetching a data entry from the file and then using the data for training. We can see clear inefficiencies here as when our model is training, the input pipeline is idle and when the input pipeline is fetching the data, our model is idle.
tf.data
solves this issue by using prefetching
.
Prefetching
Prefetching solves the inefficiencies from naive approach as it aims to overlap the preprocessing and model execution of the training step. In other words, when the model is executing training step n, the input pipeline will be reading the data for step n+1.
The
tf.data
API provides thetf.data.Dataset.prefetch
transformation. It can be used to decouple the time when data is produced from the time when data is consumed. In particular, the transformation uses a background thread and an internal buffer to prefetch elements from the input dataset ahead of the time they are requested.
There is an argument that prefetch transformation requires — the number of elements to prefetch. However, we could simply make use of tf.data.AUTOTUNE
— provided by tensorflow, which prompts tf.data
runtime to tune the value dynamically at runtime.
Parallelising data extraction
There exists computational overhead when raw bytes are loaded into memory when reading data, as it may be necessary to deserialise and decrypt the read data. This overhead exists irrespective of whether data is stored locally or remotely.
To deal with this overhead,
tf.data
providestf.data.Dataset.interleave
transformation to parallelise the data loading step, interleaving the contents of other datasets.
Similarly, this interleave transformation supports tf.data.AUTOTUNE
which will again delegate the decision of the level of parallelism during tf.data
runtime.
Parallelising data transformation
In most scenarios, you will have to do some preprocessing to your dataset before passing it to the model for training. The tf.data
API takes care of this by offering the tf.data.Dataset.map
transformation — which applies a user-defined function to each element of the input dataset.
Because input elements are independent of one another, the pre-processing can be parallelised across multiple CPU cores.
To utilise multiple CPU cores, you will have to pass in num_parallel_calls
argument to specify the level of parallelism you want. Similarly, the map transformation also supports tf.data.AUTOTUNE
which will again delegate the decision of the level of parallelism during tf.data
runtime.
Caching
tf.data
also have caching abilities with tf.data.Dataset.cache
transformation. You can either cache a dataset in memory or in local storage. The rule of thumb will be to cache a small dataset in memory and a large dataset in local storage. This thus saves operation like file opening and data reading from being executed during each epoch — next epochs will reuse the data cached by the cache
transformation.
One thing to note — you should cache after preprocessing (especially when these preprocessing functions are computational expensive) and before augmentation, as you would not want to store any randomness from your augmentations.
Vectorised mapping
When using the tf.data.Dataset.map
transformation as mentioned previously under ‘parallelising data transformation’, there is certain overhead related to scheduling and executing the user-defined function. Vectorising this user-defined function — have it operate over a batch of inputs at once — and applying the batch
transformation before the map
transformation helps to improve on this overhead.
As seen from the images, the overhead appears only once, improving the overall time performance. Hence, invoking the map
transformation on to batches of sample has better performance as compared to when you invoke the map
transformation onto each sample.
In conclusion
tf.data
by Tensorflow greatly takes into account of input pipelines performance, with several ways to optimise efficiency.
In summary, you can use prefetch
transformation to overlap the work done by pipeline (producer) and the model (consumer), interleave
transformation to parallelise data reading, map
transformation to parallelise data transformation, cache
transformation to cache data in memory or local storage and also vectorising your map
transformations with the batch
transformation.
As mentioned at the start, one of the worst things to experience is seeing your GPU capacity not fully utilised with the bottleneck on the CPU. With tf.data
, you will most probably be happy with your GPU utilisation!
In part 2, I will demonstrate using tf.data
for your input pipeline and also measure the performance between tf.data
and Keras ImageDataGenerator