Learn RCNNs with this Toy Dataset
Here’s a dataset that is designed to help showcase when a Recurrent Convolutional Neural Network (RCNN) will outperform its’ non-recurrent counterpart, the Convolutional Neural Network (CNN).
A little primer
Recurrent models are models that are specially designed to use a sequence of data in making their predictions (e.g a stock market predictor that uses a sequence of data points from the past 3 days).
Convolutional models are models that are specially designed to work well with image data.
So a Recurrent Convolutional model is a model that is specially designed to make predictions using a sequence of images (more commonly also know as video).
The dataset
Download it here from Kaggle, also can be found on my Github.
The task
Predict the next position of the square!
The input
As a CNN you only get a single frame as input:
As a RCNN you get multiple frames as input:
The intuition
Why should the CNN under-perform on this dataset?
Take a look at the input for the CNN above (the single frame). Can you confidently predict what the next position of the Sliding Square is going to be? I’ll assume you’re answer to that is “No”. That’s because it’s perfectly ambiguous i.e it could be moving to the right or the left, there is no way of knowing based on a single image.
Why should the RCNN perform well on this dataset?
Try predicting the next position using the RCNN input from above (the two frames). Pretty easy now? The direction that the square is moving is now no longer ambiguous.
Performance testing CNN vs RCNN
Check out this notebook to see the code for building and training the simple CNN and RCNN (Keras models) on this dataset.
To evaluate the performance we use a random sample of frames from the dataset and get each model to try to predict the next position of the square.
Below is a frequency chart of the errors each model made in their predictions. E.g you can see that for 7 tests the CNN got the squares position exactly right (zero error).
What’s going on with the CNNs error scores?
You may notice that the CNN is making a lot of errors that are +/- 4 . Well all is revealed when we inspect one of these predictions:
The CNN seems to always predict that the square is in the same position as it is in the input image. You may think this is strange, because the next position of the square can only be to the left or the right. Why would it predict it to be in a place where it defiantly isn’t going to be!? One word… consistency! Generally for machine learning models it’s better to be a little wrong all of the time, rather than very wrong some of the time. This preference is achieved in most learning algorithms by squaring the error as shown below:
The CNN has learned to guess the same position as the input because then it’s always going to be at most +/-5 away from the correct answer. If it always guessed that the square was moving to the right: half of the time it would be exactly right, but the the other half of it’s predictions will be off by 10! (which wouldn’t be very consistent)
Naturally the RCNN is in its’ element
You can see from the error chart that the RCNN doesn’t make anywhere near the amount of errors as the CNN. That’s because it has no issues in figuring out which direction that the square is moving.
Hopefully this has given you a bit more intuition around when to use RCNNs instead of standard CNNs.
All of the code used to generate the toy dataset along with the code for the test models can be found on my Github: