Learn RCNNs with this Toy Dataset

Zack Akil
Towards Data Science
4 min readMar 18, 2018

--

Here’s a dataset that is designed to help showcase when a Recurrent Convolutional Neural Network (RCNN) will outperform its’ non-recurrent counterpart, the Convolutional Neural Network (CNN).

A little primer

Recurrent models are models that are specially designed to use a sequence of data in making their predictions (e.g a stock market predictor that uses a sequence of data points from the past 3 days).

Convolutional models are models that are specially designed to work well with image data.

So a Recurrent Convolutional model is a model that is specially designed to make predictions using a sequence of images (more commonly also know as video).

The dataset

The aptly named `Sliding Square` toy dataset

Download it here from Kaggle, also can be found on my Github.

The task

Predict the next position of the square!

The input

As a CNN you only get a single frame as input:

Input to CNN (single image)

As a RCNN you get multiple frames as input:

Input to RCNN (multiple sequencial images)

The intuition

Why should the CNN under-perform on this dataset?

Take a look at the input for the CNN above (the single frame). Can you confidently predict what the next position of the Sliding Square is going to be? I’ll assume you’re answer to that is “No”. That’s because it’s perfectly ambiguous i.e it could be moving to the right or the left, there is no way of knowing based on a single image.

Ambiguity in trying to predict the next position based on a single frame.

Why should the RCNN perform well on this dataset?

Try predicting the next position using the RCNN input from above (the two frames). Pretty easy now? The direction that the square is moving is now no longer ambiguous.

Showing how using multiple frames disambiguates the next position.

Performance testing CNN vs RCNN

Check out this notebook to see the code for building and training the simple CNN and RCNN (Keras models) on this dataset.

To evaluate the performance we use a random sample of frames from the dataset and get each model to try to predict the next position of the square.

Below is a frequency chart of the errors each model made in their predictions. E.g you can see that for 7 tests the CNN got the squares position exactly right (zero error).

frequency chart of prediction errors for a sample of test data

What’s going on with the CNNs error scores?

You may notice that the CNN is making a lot of errors that are +/- 4 . Well all is revealed when we inspect one of these predictions:

CNN prediction (always predicts it to be the same position as the input)

The CNN seems to always predict that the square is in the same position as it is in the input image. You may think this is strange, because the next position of the square can only be to the left or the right. Why would it predict it to be in a place where it defiantly isn’t going to be!? One word… consistency! Generally for machine learning models it’s better to be a little wrong all of the time, rather than very wrong some of the time. This preference is achieved in most learning algorithms by squaring the error as shown below:

How consistent errors are encouraged by squaring the original error. See how both models have the same Sum(Error) but Model B has a significantly higher Sum(Error x Error)

The CNN has learned to guess the same position as the input because then it’s always going to be at most +/-5 away from the correct answer. If it always guessed that the square was moving to the right: half of the time it would be exactly right, but the the other half of it’s predictions will be off by 10! (which wouldn’t be very consistent)

Naturally the RCNN is in its’ element

You can see from the error chart that the RCNN doesn’t make anywhere near the amount of errors as the CNN. That’s because it has no issues in figuring out which direction that the square is moving.

RCNN not having any trouble in predicting the next position of the square.

Hopefully this has given you a bit more intuition around when to use RCNNs instead of standard CNNs.

All of the code used to generate the toy dataset along with the code for the test models can be found on my Github:

--

--