Teaching Computers to See
Classifying Flowers with CNNs and Transfer Learning
If I asked you what type of flower is pictured above, you would probably know it’s a sunflower. But what if I asked a computer the same question? Wouldn’t it be pretty impressive if it answered correctly? Remarkably, computers can be trained to classify not only sunflowers but a multitude of other flowers as well. And the person training the computer doesn’t even have to know what a sunflower is!
The Secret to the Magic: Convolutional Neural Networks
To identify types of flowers, I developed a Convolutional Neural Network (CNN) that can classify dandelions, daisies, tulips, sunflowers, and roses. Check out the full code HERE
What are neural networks? Artificial neural networks are modeled after the biological neural networks of the human brain and can make predictions based on patterns recognized in past data. These networks have three types of layers: an input layer where initial data is provided, a hidden layer where weights and biases are used to perform computations, and an output layer where an activation function is applied to give the final results. Read a more in-depth description of neural networks here.
A convolutional neural network is a type of neural network that typically includes convolutional layers and max-pooling layers. A convolution is simply applying a filter onto the collection of pixels that make up the input image. This results in an activation. Repeated applications of this filter result in a map of activations called a feature map which essentially tells the computer about an image. Following the convolutional layer is the max-pooling layer. In the max-pooling layer, the filter on the image checks for the greatest pixel value in each section (the size of the section is specified by the programmer) and then uses the maximum pixel values to create a new, smaller image. These smaller images help the computer run the model much faster. Check this video out for a deeper description of CNNs
When the convolutional layers and max-pooling layers are connected to the input and output layers of a neural network, the model is able to use past labeled data to make predictions of what future images contain. Now that we understand what a CNN is, let’s look at the steps to build one.
Gathering Data
We can code this project using Python and the TensorFlow library. The flowers dataset (containing labeled images of the 5 classes of flowers) is already provided in TensorFlow Datasets so it can simply be downloaded from there. Yet, the dataset must be refined before it can be passed in the model. We split 70% of the data into the training set and the remaining 30% of the data into the validation set. The training set is a set of examples used by the model to fit the parameters of the classifier and the validation set is used to further tune these parameters so that the classifier can work on data not seen before. Since the images of this dataset are of different sizes, we resize all of them to a standard size.
Transfer Learning to Build Model
After gathering the data, we can begin to train our model. Transfer learning is the process of reusing parts of an already trained model and changing the final layer of the model; the final layer is then retrained on the flowers dataset to give the outputs we want. Transfer learning is implemented because it can improve the accuracy of our model. Using the MobileNet v2 model and changing only the final layer makes the code for the actual model very short.
URL = 'https://tfhub.dev/google/tf2preview/mobilenet_v2/feature_vector/4'feature_extractor = hub.KerasLayer(URL, input_shape = (IMAGE_RES, IMAGE_RES, 3))feature_extractor.trainable = Falsemodel = tf.keras.Sequential([feature_extractor, layers.Dense(num_classes, activation='softmax')])
Training Model
After training the model, we can plot its accuracy and loss to see how it is doing. The x-axis on the graph represents the number of epochs (number of times the model runs through the entire training set and updates the weights).
Making Predictions
Let’s look at the results after we run the image batch through the model!
As you can see, the model performed very well on the testing set as a result of transfer learning and the CNN architecture.
So how exactly is this useful? While classifying flowers may only be helpful to botanists, CNNs can have life-saving applications such as detecting pneumonia from MRIs and making self-driving cars a reality.
Don’t leave yet!
I’m Roshan, a 16 year old passionate about artificial intelligence, specifically its applications in finance. If you liked reading this, check out my other articles such as “Fighting Financial Fraud with Artificial Intelligence”, where I describe how I used an autoencoder to detect accounting journal entry anomalies.
Reach out to me on LinkedIn