How to customize distributed training when using the TensorFlow Estimator API

Lak Lakshmanan
Towards Data Science
6 min readApr 9, 2018

--

TensorFlow’s Estimator API provides an easy, high-level API to train machine learning models. You can use the train(), evaluate() or predict() methods on a Estimator. However, most often, training is carried out in a loop, in a distributed way, with evaluation done periodically during the training process. To do this, you will use the train_and_evaluate loop:

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

In this article, I will talk about what kinds of things you might want to specify in the train_spec and the eval_spec. Taken together, these options allow train_and_evaluate to provide a powerful, customizable way to do distributed training.

Here’s a complete train_and_evaluate (full code on GitHub). I’ll be picking out different parts of this to talk about:

def train_and_evaluate(output_dir):
EVAL_INTERVAL = 300 # seconds
run_config = tf.estimator.RunConfig(save_checkpoints_secs = EVAL_INTERVAL,
keep_checkpoint_max = 3)
estimator = tf.estimator.DNNLinearCombinedRegressor(
model_dir = output_dir,
...
config = run_config)

estimator = tf.contrib.estimator.add_metrics(estimator, my_rmse)
train_spec = tf.estimator.TrainSpec(
input_fn = read_dataset('train', tf.estimator.ModeKeys.TRAIN, BATCH_SIZE),
max_steps = TRAIN_STEPS)
exporter = tf.estimator.LatestExporter('exporter', serving_input_fn, exports_to_keep=None) eval_spec = tf.estimator.EvalSpec(
input_fn = read_dataset('eval', tf.estimator.ModeKeys.EVAL, 2**15), # no need to batch in eval
steps = None,
start_delay_secs = 60, # start evaluating after N seconds
throttle_secs = EVAL_INTERVAL, # evaluate every N seconds
exporters = exporter)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
train_and_evaluate() is highly customizable

RunConfig

The RunConfig allows you to control how often checkpoints are written out. Checkpoints are how Estimator supports fault-tolerance. If the chief (or master node) of the training cluster fails, training will resume from the checkpoint. The more often you checkpoint, the less you will lose from machine failure. Of course, checkpointing itself consumes CPU and storage, so it’s a tradeoff. In my example, I am asking that checkpointing be done every 300 seconds (to limit CPU overhead) and that only the last 3 checkpoints are saved (to limit storage overhead):

run_config = tf.estimator.RunConfig(save_checkpoints_secs = 300, 
keep_checkpoint_max = 3)

You could also specify the checkpoint interval in terms of number of training steps — initially, this appears simpler and more appealing. However, if you recognize that checkpointing is about fault recovery, you’ll quickly recognize that specifying this by time is a better option. Estimator is smart enough to not write a checkpoint unless the training job has actually progressed.

Note that the run_config is passed as a parameter to the Estimator:

estimator = tf.estimator.DNNLinearCombinedRegressor(
model_dir = output_dir,
...
config = run_config)

Eval Metrics

By default, prebuilt estimators such as LinearRegressor and DNNClassifier pre-specify the metrics that they will evaluate. In the case of LinearRegressor, for example, that’s the average_loss, which would be mean-squared-error. A common need is to evaluate other metrics too. You can do that with add_metrics:

estimator = tf.contrib.estimator.add_metrics(estimator, my_rmse)

This is a common pattern — the way to extend Estimator is to wrap it up and add extra functionality to it. In this case, there is a contributed function that is capable of adding metrics. The my_rmse is a function that returns a dictionary of such metrics:

def my_rmse(labels, predictions):
pred_values = predictions['predictions']
return {
'rmse': tf.metrics.root_mean_squared_error(labels, pred_values)
}

Train batch size

Gradients are computed on your training examples one batch at a time. The size of this batch is controlled by the training input function:

train_spec = tf.estimator.TrainSpec(
input_fn = read_dataset('train', ModeKeys.TRAIN, 40),
max_steps = TRAIN_STEPS)

Your input function probably uses the TensorFlow Dataset API, where the batching is simply a call to the Dataset:

dataset = dataset.repeat(num_epochs).batch(batch_size)

Note, however, that num_epochs behaves differently in distributed training. The time that you train for is controlled by the number of training steps and not by repeating the input data. So, it is important to pass the mode to the input function — you will want to shuffle the training examples and read the training data indefinitely:

if mode == tf.estimator.ModeKeys.TRAIN:
num_epochs = None # indefinitely
dataset = dataset.shuffle(buffer_size = NWORKERS * batch_size)

Why shuffle? This is because, in distributed training, each of the workers computes a gradient on a batch and then the gradient update is averaged across the workers. There is, obviously, no point if all your workers are processing the same data. So, you ask the Dataset to shuffle the data so that each worker sees a different batch.

num_epochs will be 1 during evaluation, but you could evaluate on a fraction of the dataset by setting the number of eval steps appropriately (this is covered below).

Train Steps

The second arg to the TrainSpec is something that indicates how many steps to train for:

train_spec = tf.estimator.TrainSpec(
input_fn = read_dataset('train', ModeKeys.TRAIN, BATCH_SIZE),
max_steps = TRAIN_STEPS)

Remember that a training step involves gradient computation on a batch of training examples. The TrainSpec requires the maximum number of steps for which to train the model. Note that this is max_steps, and not steps. If you have a checkpoint corresponding to step#1800 and you specify max_steps=2000, then training will resume at 1800 and go on for just 200 steps! The ability to resume training is an important property, but make sure that you understand exactly what this means: to train from scratch, you will want to clean out the output directory so that there are no preexisting checkpoints.

Also note that this is what controls how long you are training for — it essentially overrides num_epochs in the input function.

Exporter

Checkpoints are not the same as exporting. Checkpoints are about fault-recovery and involve saving the complete training state (weights, global step number, etc.). Exporting is about creating a serving signature. Most commonly, you’ll export the prediction head or the output node of a neural network, but sometimes, you will want to export an embedding node too.

In my example, I am exporting at the end of training:

exporter = tf.estimator.LatestExporter('exporter', serving_input_fn, exports_to_keep=None)

The exported model will written to a directory called “exporter” and the serving input function specifies what the end-user will be expected to provide to the prediction service. You can specify exports_to_keep=5 if you want to retain exports corresponding to the last 5 checkpoints.

Eval batch size

The training batch size is about the number of examples over which the gradient is computed. During evaluation, though, no gradients are being computed. During evaluation, the only reason to read the data in batches to avoid over-allocating memory buffers. So, specify a much larger batch size:

input_fn = read_dataset('eval', tf.estimator.ModeKeys.EVAL, 2**15)

Whereas training batch size is on the order of about 100, evaluation batch size can be set so that the overall memory use of a batch is on the order of 32MB or so — obviously, the exact number depends on how much space your training data takes in memory.

Eval Steps

When evaluating, you can evaluate on the entire dataset by specifying steps=None in the EvalSpec:

eval_spec = tf.estimator.EvalSpec(
input_fn = ...,
steps = None

However, if your datasets are large, this can impose quite a bit of overhead. The only reason to evaluate during training is for monitoring, and you want monitoring to be somewhat lightweight. So, a good compromise is to specify a large enough number for steps that the evaluation is statistically robust. This depends on your data, but 100 tends to be a common choice:

eval_spec = tf.estimator.EvalSpec(
input_fn = ...,
steps = 100

Because our batch_size was 2**15 (or 32k), 100 steps would correspond to about 3.2 million examples and the idea is that evaluating on this large a number is stable enough.

Eval throttle

By default, Estimator will evaluate every time it checkpoints. You can specify a a throttle on the evaluation so that you checkpoint more often (for fault-recovery) but evaluate less often (for monitoring). The reason to do this is that checkpoint is relatively fast, but evaluation can be slow, especially if you evaluate on the entire dataset:

eval_spec = tf.estimator.EvalSpec(
input_fn = ...,
steps = EVAL_STEPS,
start_delay_secs = 60, # start evaluating after N seconds
throttle_secs = 600, # evaluate every N seconds
exporters = exporter)

As you can see, train_and_evaluate provides a powerful, customizable way to do distributed training.

--

--