Capturing a Training State in TensorFlow

How to debug occurrences of a NaN loss in your training project

Chaim Rand
Towards Data Science

--

Photo by Alan Emery on Unsplash

In a previous post, we attempted to offer some support in the — often difficult, sometimes impossible, and always maddening — task of debugging in TensorFlow. The blog includes a description of, what I believe to be, the ultimate example of the potential suffering of the modern day machine learning developer — the sudden appearance of the NaN loss. In the cursed scenario of which we speak, after a (possibly extended) period of successful model convergence, the model loss is suddenly reported to be NaN — “Not a Number”, and the model state (weights) become disastrously and irreparably demolished. We proposed an innovative method for capturing the precise state of the model training just before the weights become corrupted. Under the assumption that the training step does not include randomness (e.g. no dropout layers) the method allows for reproducing and investigating the issue that led to the NaN in a reliable and efficient manner. (We will return to the possibility of randomness in the training step below.)

Since then, we developed an additional technique for capturing the training state of a model that extends to scenarios not addressed by the original method. The intention of this post is to share this technique and to compare it with its predecessor.

We will start by briefly reviewing the problem and previous solution. For the sake of brevity, I will refrain from going into the details of the unique challenges of debugging the NaN loss and the original solution we proposed, and instead recommend that you read the previous post. Next, we will describe a few of the limitations of the original solution and present a new technique that addresses these limitations. Finally, we will briefly demonstrate a sample debugging flow (based on a true story) using this technique.

The post will include a few blocks of code that will demonstrate how to program the solution we discuss. The code is based on TensorFlow version 2.6. As always, you should verify both the code, as well as the the general arguments that we make, against the most up to date version at the time that you read this.

Debugging a NaN Loss can be Hard

While debugging in general is hard, there are a number of reasons that make debugging an occurrence of a NaN loss in TensorFlow especially hard.

The use of a symbolic computation graph

TensorFlow includes two modes of execution, eager execution and graph execution. Although it is easier to debug when training in eager execution, graph mode has a number of benefits that make training significantly more efficient. The disadvantage of graph mode is that it is hard to debug. Specifically, in graph mode, you build the model up front and send it off to the training device (e.g. GPU). During the training, you cannot freely access and measure tensors on the computation graph. The obvious conclusion is that we should always train in graph mode and debug in eager execution mode. However, a prerequisite of this strategy is the ability to easily capture a precise training state for reproduction and debugging.

The challenge of reproducing a training state

In order to fully reproduce a training step we need to know three things: the state of the model, the precise batch of data input, and the seeds that control whatever stochasticity is included in the train step.

Capturing model state: TensorFlow provides tools for saving the values of its model weights and optimizer variables to file (e.g. tf.keras.Model.save_weights). It is a good practice to capture these regularly during training (e.g. using tf.keras ModelCheckpoint callback).
Ideally, we would like to be able to capture the precise model state that led to the NaN loss. This would simplify and shorten the reproduction. Unfortunately, in a typical scenario, once the NaN loss is reported the model state has already been ruined. In our previous post we demonstrated a way to discover a NaN loss before the weights are irreversibly obfuscated. We will adopt a similar technique in this post as well.

Capturing the data input: Due to the sequential nature of TensorFlow tf.data.Datasets once a training batch passes through the train step we no longer have access to it. In our previous post we demonstrated a mechanism for capturing a training batch for reproduction.

Capturing the seeds that control the stochasticity: Stochasticity is an inherent property of machine learning training algorithms, often a key to their success. At the same time, stochasticity can be a nightmare for reproducing bugs. One way to overcome this challenge is to explicitly set and capture the specific seeds that determine the stochasticity (e.g. tf.random.Generator.from_seed).

Limitations of our previous proposal

In our previous post we proposed a method that involved customizing the train_step and make_train_function routines of the tf.keras.model class so as to:

  1. save the input batch as an eager (accessible) tensor,
  2. test the gradients for NaN values before applying them to the model weights, and,
  3. terminate the training and save the last input batch and current weights.

There are two notable limitations to this method:

Capturing state in distributed training: In its current form, the method does not support capturing the state of a training session that is running on multiple cores using a tf.distribute distribution strategy.

Capturing state on TPUs: The method requires part of the train step to run outside of a tf.function scope so as to capture the input data. This scheme is not supported on TPUs.

In the following section we describe a new technique that overcomes these limitations.

New Proposal for Capturing Training State

Our new proposal is comprised of four components:

1. A tf.keras.models.Model with a customized train_step.

2. A customized TensorFlow input data capturing layer which we append to the end of the input data pipeline.

3. Setting the dataset prefetch value to 0 (to make sure that the last data batch is not overwritten).

4. The standard tf.keras.callbacks.TerminateOnNaN callback.

Custom Train Step

We customize the train step to test for NaN gradients before applying them to the model weights. If a NaN gradient is discovered we explicitly set the loss metric to NaN and do not update the weights. The NaN loss will be identified by the tf.keras.callbacks.TerminateOnNaN callback and the training will be halted. Contrary to our previous solution, the custom train step is fully within the tf.function scope. The code block below demonstrates the custom train step. The main changes from the default train step are highlighted.

import tensorflow as tf
from tensorflow.python.keras.engine import data_adapter
class DebugModel(tf.keras.models.Model):
def train_step(self, data):
data=data_adapter.expand_1d(data)
x, y, sample_weight=data_adapter.unpack_x_y_sample_weight(data)
# Run forward pass.
with tf.GradientTape() as tape:
y_pred=self(x, training=True)
loss=self.compiled_loss(
y, y_pred,
sample_weight, regularization_losses=self.losses)
grads_and_vars=self.optimizer._compute_gradients(
loss, var_list=self.trainable_variables, tape=tape)
grads_and_vars=self.optimizer._aggregate_gradients(
grads_and_vars)
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return
return_metrics = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
return_metrics.update(result)
else:
return_metrics[metric.name] = result
def t_fn():
# if any of the gradients are non, set loss metric to NaN
return tf.constant(float('NaN'))
def f_fn():
# if all gradients are valid apply them
self.optimizer.apply_gradients(
grads_and_vars,
experimental_aggregate_gradients=False)
return return_metrics['loss']
grad_nan=[tf.reduce_any(tf.math.is_nan(g)) for g,_ in
grads_and_vars if g is not None]
grad_nan=tf.reduce_any(grad_nan)
return_metrics['loss'] = tf.cond(grad_nan,
t_fn,
f_fn)
return return_metrics

Data Input Capture Layer

In a previous post we introduced a unique method for capturing the values of graph tensor using a custom tf.keras.layers.Layer that assigns tensor values to accessible eager tensors. We apply such tensor capturing layers to the end of the data input pipeline (before prefetch) in order to capture input data batches right before they are entered into the training step.

class TensorCaptureLayer(tf.keras.layers.Layer):
def __init__(self, shape, dtype, **kwargs):
super(TensorCaptureLayer, self).__init__(dtype=dtype, **kwargs)
self.record_tensor_var = tf.Variable(shape=shape,
initial_value=tf.zeros(shape=shape, dtype=dtype),
validate_shape=False,
dtype=dtype,
trainable=False)
def call(self, inputs, **kwargs):
self.record_tensor_var.assign(inputs)
return inputs
# Here we demonstrate a case where the input consists of a
# single batch of 32 frames and a single batch of 32 labels
class DataInputCaptureLayer(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(DataInputCaptureLayer, self).__init__(**kwargs)
self.frame_input = TensorCaptureLayer([32,512,512,3],tf.float32)
self.label_input = TensorCaptureLayer([32],tf.int32)
def call(self, x, y, **kwargs):
self.frame_input(x)
self.label_input(y)
return x,y

Set Dataset Prefetch to Zero

The tf.data.Dataset.prefetch routine is a common technique that is used to streamline the training flow and maximize resource utilization (see here). However, if we set the prefetch to a value greater than zero, we run the risk of populating our DataInputCaptureLayer with an upcoming batch rather than the current one, undermining our ability to capture the input batch that caused the NaN loss.

Use the TerminateOnNaN Callback

We use the standard TerminateOnNaN callback to monitor the loss and terminate the training if it is NaN.

The code block below demonstrates how we put it all together.

dataset=...
capture_layer=DataInputCaptureLayer()
dataset=dataset.map(lambda x,y:capture_layer(x,y))
dataset=dataset.prefetch(0)
strategy=...
with strategy.scope():
model = DebugModel(inputs=..., outputs=...)
model.compile(loss=..., optimizer=...)
model.fit(dataset, epochs=..., steps_per_epoch=...,
callbacks=[tf.keras.callbacks.TerminateOnNaN()])
if model.stop_training: # set by callback in case of Nan
# capture the model state
model.save_weights('model_weights.ckpt')
# capture the last data batch
np.save('frm.npz',capture_layer.frame_input.record_tensor.numpy())
np.save('lbl.npz',capture_layer.label_input.record_tensor.numpy())

Reproducing the NaN loss

With the training state now in hand we can go about reproducing the NaN loss. The model state is easily restored using the tf.keras.Model.load_weights routine, and the input data is injected into the pipeline using a custom layer as demonstrated below:

class InputInjectionLayer(Layer):
def __init__(self, **kwargs):
super(InputInjectionLayer, self).__init__(**kwargs)
self.frames=tf.convert_to_tensor(np.load(frm.npz))
self.labels=tf.convert_to_tensor(np.load(lbl.npz))
def call(self, x, y, **kwargs):
return self.frames,self.labels
dataset=...
inject_layer=InputInjectionLayer()
dataset=dataset.map(lambda x,y:inject_layer(x,y))
dataset=dataset.prefetch(0)
strategy=...
with strategy.scope():
model = DebugModel(inputs=..., outputs=...)
model.compile(loss=..., optimizer=...)
model.fit(dataset, epochs=1, steps_per_epoch=1)

Comparison to Previous Solution

Contrary to the previous solution we proposed, this solution can be used to debug a NaN loss encountered in distributed training setting and can also be used on TPUs.

An additional difference worth pointing out is the resource requirements required by the two solutions. In the first solution the data batches were stored as part of the training step in the memory of the training worker (e.g. GPU). In the second solution the data batches are stored as part of the data input pipeline which runs on the host CPUs. Depending on your workload, the data batches might require a significant amount of memory and the resource on which it is stored might have implications on how you build your model (e.g. on the maximum batch size).

Lastly, the different solutions might incur different penalties on your training step time. The performance penalty of each solution is very much dependent on the properties of your model. For example, if you have a CPU bottleneck, then adding the data input capture layer is likely to have a higher penalty than our original solution. On the other hand, if your CPU is underutilized, the penalty is likely to be lower. For our models (which are typically relatively large) we did not notice a significant difference in the penalties of the two solutions, both of which were low (<10%).

Solution Extensions

A number of enhancements can be made to this solution to extract more debug information and/or extend its coverage to additional use cases. Here we will note a few of them.

Extract more information using tf.print
When detecting a NaN gradient, the custom train step we implemented above does nothing more than set the loss to NaN. However, we could enhance our function to collect and print additional diagnostics using the tf.print utility. For example, we could print the full list of gradient tensors and whether they include NaN values. This might provide some hints to the source of the problem. It is worth noting that (as of the time of this writing) tf.print is not supported on TPU, so while capturing more information is still possible, it requires a bit more creativity (e.g. using custom metrics).

Support Alternative Data Loading Schemes
In the example code we shared we assumed that the data distribution is managed by a tf distribution strategy and that the default data loading mechanism is used. However, the solution can be easily extended to cover other data loading schemes. For example, if Horovod is used for data distribution then each training core (e.g. GPU) has its own individual dataset and all that is required is to have each Horovod process save (and inject) its input data to a unique file path. If tf distribution strategy is used along with distribute_datasets_from_function, a single dataset pipeline may feed multiple training cores. In this case, the InputDataCaptureLayer could be enhanced with a queue mechanism for capturing multiple data batches (according to the number of training cores that a single dataset pipeline feeds).

Support Modifying Batch Size During Injection
Sometimes you may wish to attempt to reproduce and debug in an environment with less available memory than on your training device. To enable this, you may need to run with a smaller input data batch size. The Injection Layer we described above can be easily modified to support breaking the captured data of the training batch size into smaller batches. Of course, there is always the possibility that the issue is caused by the full batch and will not reproduce on any of the smaller batches. If you do adopt this technique, make sure that you test each batch with the original model weights, i.e. that you do not update the weights between batches.

Reproducing Random Operations
Some models include layers that are controlled by random variables. A classic example is the dropout layer, a regularization technique that randomly sets a portion of the input units to zero. In some cases you may suspect that your NaN loss is caused by an unfortunate combination of the random variables. In its current form, the solution we described does not address this situation. You could, of course, run the train_step many times and hope that the issue reproduces, but that is not a recipe for success, especially not if you have a large number of random variables. Here are two approaches for solving this issue. The first is to capture the random variables of each random layer. For example, in the case of a dropout layer, you would capture the units that were set to zero. The second option is to explicitly set the RNG seed for each random layer and for each training step (e.g. as a function of the iteration). The options are both simple in theory but require a bit of mastery and creativity to implement in TensorFlow and are outside the scope of this post.

Debugging Flow Example

In order to give you a better idea of how this solution can be used, I will describe a typical debugging flow that is based on a real use case.

We were recently tasked with training a fairly large computer vision model on TPUs. (The model did not include random layers such as dropout.) The training did not go as planned. Early on, we found that nearly half of our training experiments were failing on NaN losses. At that point we started running our experiments in the capture mode we described above, using the custom DebugModel instead of the default keras model and using the DataInputCaptureLayer. Fortunately, this enabled us to capture a training state that lead to the NaN loss, quite easily. The batch size of the captured data was 1024. Since we were interested in debugging in a CPU environment, we worked on identifying a sub batch of size 1 that reproduced the NaN loss. Once on the CPU, we used tf.gradients to explicitly calculate the gradients of specific variables and tf.print to print out diagnostics. Typically this was done inside a customized layer. (Sometimes we would extend a standard layer and overwrite its call function so as to include these operations.) We also took advantage of eager execution mode, to traverse the graph and analyze tensor values. In this manner we were able to identify the precise operation that was the source of the NaN gradients. We had implemented a custom loss function that included a call to tf.norm. It turned out that, on specific data samples and with the captured model state, the value of the norm operation was zero. Since the norm operation is not differentiable at zero, we were, very appropriately, getting NaNs. (One way to fix this problem is to break down the norm op and add a a small epsilon before performing the square root.)

In another situation we used the same technique to investigate NaN losses we were getting when using mixed precision on GPUs. We discovered that the use of tf.float16s was causing an overflow in our loss calculation. We promptly set the loss function to use tf.float32 and the problem was solved.

This method does not, by any means, guarantee success. In the anecdote above, we could have gotten unlucky at several steps. We could have failed at reproducing the issue on a smaller batch size, we could have found that the issue reproduces only on TPUs (e.g. due to its use of bfloat16) or only in graph mode, or the model could have been dependent on random layers. In any of these cases it would have been back to the drawing board to see how to improve our method to address the challenges we faced.

Summary

This post discusses one of the more difficult challenges of training in TensorFlow, the advent of a NaN loss. We proposed a new technique for reproducing and debugging this issue. As we demonstrated, the technique can be implemented without much difficulty. While it does not address all possible scenarios, it is well worth a try when all else fails. Don’t you think?

--

--

I am a Machine Learning Algorithm Developer working on Autonomous Vehicle technologies at Mobileye. The views expressed in my posts are my own.