A batch too large: Finding the batch size that fits on GPUs

A simple function to identify the batch size for your PyTorch model that can fill the GPU memory

Bryan M. Li
Towards Data Science

--

I am sure many of you had the following painful experience: you start multiple ML experiments on your GPUs to train overnight and when you came back to check 10 hours later, you realize that the progress bar has barely moved due to hardware underutilization, or worse, all the experiments failed due to out-of-memory (OOM) error. In this mini-guide, we will implement an automated method to find the batch size for your PyTorch model that can utilize the GPU memory sufficiently without causing OOM!

Photo by Nana Dua on Unsplash

On top of the model architecture and the number of parameters, the batch size is the most effective hyperparameter to control the amount of GPU memory an experiment uses. The proper method to find the optimal batch size that can fully utilize the accelerator is via GPU profiling, a process to monitor processes on the computing device. Both TensorFlow and PyTorch provide detailed guides and tutorials on how to perform profiling in their framework. In addition, the batch size can greatly affect the performance of the model. For instance, a large batch size can lead to poor generalization, check out this blog post by Kevin Shen on the effect of batch size on training dynamics if you are interested in this topic.

Nevertheless, if you simply want to train a model to test an idea, profiling or performing a hyperparameter search to find the best batch size might be overkill, especially in the early stage of the project. A common approach to find the value that allows you to fit your model without OOM is to train the model with a small batch size while monitoring the GPU utilization using tools like nvidia-smi or nvitop. You then increase the value if the model is underutilizing the GPU memory, and repeat the process until you hit the memory capacity. However, this manual process can be time-consuming. More annoyingly, when you have to run experiments on different GPUs with varying memory sizes then you have to repeat the same process for each device. Luckily, we can convert this tedious iterative process into code and run it before the actual experiment so that you know your model won’t cause an OOM.

The idea is very simple:

  1. Initialize your model.
  2. Set batch size to 2 (for BatchNorm)
  3. Create dummy data that has the sample shape as the real data.
  4. Train the model for n steps (both forward and backward passes).
  5. If the model ran without an error, then increase the batch size and go to Step 3. If OOM is raised (i.e. RuntimeError in PyTorch) then set the batch size to the previous value and terminate.
  6. Return the final batch size.

To put this into code

As you can see, this function has 7 arguments:

  • model — the model you want to fit, note that the model will be deleted from memory at the end of the function.
  • devicetorch.device which should be a CUDA device.
  • input_shape — the input shape of the data.
  • output_shape — the expected output shape of the model.
  • dataset_size — the size of your dataset (we wouldn’t want to continue the search when the batch size is already larger than the size of the dataset).
  • max_batch_size — an optional argument to set the maximum batch size to use.
  • num_iterations — the number of iterations to update the model before increasing the batch size, default to 5.

Let’s quickly go through what’s happening in the function. We first load the model to the GPU, initialize Adam optimizer, and set the initial batch size to 2 (you can start with a batch size of 1 if you are not using BatchNorm). We can then begin the iterative process. First, we check if the current batch size is larger than the size of the dataset or the maximum desired batch size, if so, we break the loop. Otherwise, we create dummy inputs and targets, move them to GPU and fit the model. We train the model for 5 steps to ensure neither forward nor backward pass causes OOM. If everything is fine, we multiply the batch size by 2 and re-fit the model. If OOM occurs during the above steps, then we reduce the batch size by a factor of 2 and exit the loop. Finally, we clear the model and optimizer from memory and return the final batch size. That’s it!

Note that, instead of simply dividing the batch size by 2 if the case of OOM, one could continue to search for the optimal value (i.e. binary search the batch size, set batch size to the mid-point between the breaking and last working value, and continue to Step 3.) to find the batch size that fit perfectly to the GPU. However, keep in mind that PyTorch/TensorFlow or other processes might request more GPU memory in the middle of an experiment and you risk OOM, I hence prefer having some wiggle room.

Now let’s put this function into use. Here we fit the ResNet50 on 1,000 train synthetic images of size (3, 224, 224) generated by FakeData Datasets. Briefly, we first call get_batch_size=(model=ResNet(), input_shape=IMAGE_SHAPE, output_shape=(NUM_CLASSES,), dataset_size=DATASET_SIZE) to get the batch size that can fill the GPU memory sufficiently. Then we can initialize the model and DataLoaders, and train the model like you normally do!

The GIF below is a screen recording of running the example code on the Nvidia RTX 2080 8GB. I added some print statements in the find_batch_size function to show the batch size that it is testing, and notice the increase in GPU memory usage as the function increases the batch size. Our script identified that a batch size of 16 would cause OOM and ran the rest of the training code with a batch size of 8 with a GPU memory utilization of ~66.8%.

Running the example code on the Nvidia RTX 2080 8GB, the script identifies that a batch size of 16 would cause OOM. [GIF by author]

When we run the exact same code on the Nvidia RTX 2080Ti 11GB, we are able to run with a batch size of 16 and a GPU memory utilization of 90.3%.

Running the example code on the Nvidia RTX 2080Ti 11GB, the script identifies that a batch size of 32 would cause OOM. [Image by author]

There you have it! A simple function you can add to the beginning of your training script to find a batch size that can utilize the GPU memory sufficiently without worrying about OOM error.

--

--