When I started with TFRecords, it took me a while to understand its concept. There were so many new things. To save others from this hassle, I have created a hands-on walkthrough based on the Mnist dataset.
Note: this blog post is now available in a more general version which includes more up-to-date concepts. Further, also note that the TFRecord format is not that hard once you’ve got started, which is why I’ve created a hands-on introduction to it. To learn more, after this tutorial I recommend you to consult these two resources.
Overview
The MNIST dataset consists of digitized handwritten digits in black and white. With 28x28x1 per image, they are pretty small. The memory footprint of the complete dataset is only 32 MB.

Imports and helper functions
Let us start with the necessary imports; two libraries, os and TensorFlow. Additionally, we set a global variable, AUTOTUNE, which we use later.
First, we download the MNIST dataset to our local machine. Then, we set two options to True, _shufflefiles and _assupervised.
The first option is used when we create our Tfrecord dataset; the second option allows slightly more comfortable iterating.
We can also check how much example each split holds by calling .cardinality().
The following four functions improve the readability and are used to create the single examples we write into the TFrecord files.
Writing TFRecords
Next, we go over all splits (here, only "train" and "test").
For each split, we create a TFRecordWriter, which writes the parsed examples to file. Note that we add the currently processed split to the filename – this allows us to glob the files by a string pattern later. The additional index we use is for counting how many samples we wrote to disk. This little trick is helpful for custom datasets, where the .cardinality() operation won’t return the number of elements.
Since we set as_supervised earlier, we can now iterate over the (example, label) pairs.
The primary example generation happens during the creation of the temporary dictionary data. First, we transform each data field we want to use later to a tf.train.Feature. For the height and the width of our images, we use an __int64feature (since these numbers are integers); for the actual image data, we first serialize the array, then transform it into a _byteslist. This conversion is required for storing non-scalar data.
With all our features defined, we can now create a single Example and write it to the TFRecord file.
We proceed until all examples of the current split are processed, then we repeat the process with the following subset.
After both subsets are processed, we have created our first two TFRecord files (yeah!), one holding the training and one holding the test data. The index we incremented for each record per subset is useful when training the model:
For datasets shipped with tensorflow_datasets, you query the cardinality, that is, the number of examples, by simply calling .cardinality(). This won’t report the actual size for custom datasets like ours but returns -1, meaning that the number of examples is unknown.
However, when training on such a dataset, we have to know how many batches our dataset can deliver. Else, we could run in an infinite loop; see further below for clarification.
Short recap until here: We used the MNIST dataset and wrote all examples to TFRecord files.
Reading TFRecords
After creation, we want to read them back into memory. This process is similar to the above, but in reverse:
We create a function that reads the examples from the TFRecord file. Herein we create a dictionary with all the fields we want to use from the example; the dictionary is similar to the one we used to write our data. For each key, we define a placeholder. Note the last field: It is of type tf.string (even though we stored it as a bytes list); all other fields are initialized with the same type as before.
With our dictionary ready, we can extract our examples from the TFRecord file. Lastly, we obtain the original image data. Note that we set uint8 as the datatype. If our image contained floats, we would set float64 as the datatype. Since the MNIST data is scaled between 0 and 255, we are okay with integers:
Creating a dataset
With the following function, we create a dataset around our TFRecord files. Previously, we only have defined a function to get us a single example; now we create a TFRecordDataset to map all examples to this function.
We do this in the middle statement. Here, we use the AUTOTUNE optimizer we created above. During training, later on, it automatically determines how many examples we can process in parallel. This can reduce GPU idle times.
Afterward, we shuffle our data, set a batch size, and set repeat with no argument; this means repeating it endlessly. This requires us to remember how many examples our dataset has (as written above).
[As a hacky alternative, we can here set repeat to the number of epochs we want to train later on, and set the number of epochs in the fit()-function to 1. This makes the dataset parse "epoch" times through our network (because we set repeat() with the number of epochs), but only once (because we set fit() with epochs=1).]
Last, we let the AUTOTUNE optimizer determine the most appropriate number of examples to prefetch.
Until now, we have only created a dataset and mapped a data-generating function to it. As a sanity check, let’s look at one sample the dataset gives us.
It returns two Tensors. The first Tensor is of shape (32, 28, 28, 1) (because we took one batch, and the batch size is 32), the second Tensor is of shape (32,) (since we have 32 labels, one per example in our batch).
To recap until here:
We created two TFRecord files, one for the training data, one for the test data. We did this by iterating over all (image, label) pairs in the original MNIST dataset, transforming them to a tf.train.Feature. All features together form an example. We then created a function that reverses this; it pulls the features from the examples stored in a TFRecord file. Finally, we mapped the function to our dataset and did a sanity check to see if all works correctly.
Training
Our next step is training a network on our TFRecordDataset.
We use a simple Convolutional Neural Network; feel free to experiment with different architectures.
To maintain readability, we write a function that returns our network and check in the model’s summary if our output layer fits the label shape,
and then fit the network on our train dataset. We set the number of epochs (the number of times the network sees the complete dataset) to 2.
Now, we need to know how many examples our train dataset holds. Since we set the dataset to repeat infinitely, we need to tell the network how many batches to query until one epoch is done; this is the parameter _steps_perepoch:
After training is completed, which takes around 30 seconds on Colab with GPUs enabled, we will test the model. Since we also wrote our test data to TFRecord files, we can quickly create another TFRecordDataset from these files, using our _getdataset() function.
Since we set the second parameter in the function call to "test", our dataset won’t repeat; we don’t need to determine the number of steps.
We then call model.evaluate(), which returns an array of two values.
The first number is the loss; the second number is the one we are interested in: the accuracy. After only two epochs, it hovers around 95%, a good start.
Summary
We focused on the MNIST dataset as our ongoing example. With it, we created two TFRecord files, one for the training data and one for the testing data. Next, we covered reading the data back into memory to finally train a CNN.
This blog post is now available in a more general version which includes more up-to-date concepts. Further, also note that the TFRecord format is not that hard once you’ve got started, which is why I’ve created a hands-on introduction to it. To learn more, after this tutorial I recommend you to consult these two resources.
A Colab notebook covering the whole process is available here.
This post covered TFRecords in the context of an existing dataset. If you are interested in seeing this file format used for a custom dataset, have a look at this code
for this post