Building a Skin Lesion Classification Web App

Using Keras and TensorFlow.js to classify seven types of skin lesions

Alex Yu
Towards Data Science

--

After doing research on Convolutional Neural Networks, I became interested in developing an end-to-end machine learning solution. I decided to use the HAM10000 dataset to build a web app to classify skin lesions. In this article, I’ll provide some background information and explain some of the important concepts I learned while working on this project including Transfer Learning, Data Augmentation, Keras Callbacks, and TensorFlow.js.

Photo by rawpixel on Unsplash

Artificial intelligence is shaping the world around us. We interact with things that have been touched by machine learning on a daily basis. From our song and video recommendations to the smart assistants in our phones. But these are both consumer applications of AI, what about AI on a larger scale?

“Just as electricity transformed almost everything 100 years ago, today I actually have a hard time thinking of an industry that I don’t think AI will transform in the next several years.” — Andrew Ng

Personally, I believe that one field with huge potential for deep learning is healthcare. Even though our technology has advanced significantly, there’s still one problem that still remains a significant issue today. It’s been reported that about 10 percent of deaths and 6 to 17 percent of hospital complications are due to diagnostic issues. Imagine reducing that number down to less than 5% by helping medical professionals diagnose patients with the help of machine learning models. The impact would be huge!

In my last article on Convolutional Neural Networks, I touched on how Computer Vision is being applied in various industries. I highly recommend you check it out here.

A Bit of Background Information on Skin Cancer

  • More people are diagnosed with skin cancer each year in the U.S. than all other cancers combined.
  • One in five Americans will develop skin cancer by the age of 70.
  • Actinic keratosis affects more than 58 million Americans.

I wanted to build a solution that leverages a Convolutional Neural Network to help people classify different types of skin cancers quickly and accurately. My main goal was to create a project that is easily accessible and effective. I ultimately decided on building a web app.

For this project, I used the publicly available HAM10000 dataset which contains approximately 10,000 different images of skin lesions.

The categories of skin lesions include:

  • Actinic keratoses and intraepithelial carcinoma (akiec): common non-invasive variants of squamous cell carcinomas. They are sometimes seen as precursors that may progress to invasive squamous cell carcinoma.
  • Basal cell carcinoma (bcc): a common version of epithelial skin cancer that rarely metastasizes but grows if it isn’t treated.
  • Benign keratosis (bkl): contains three subgroups (seborrheic keratoses, solar lentigo, and lichen-planus like keratoses (LPLK)). These groups may look different but are biologically similar.
  • Dermatofibroma (df): a benign skin lesion that is regarded as a benign proliferation or an inflammatory reaction to minimal trauma.
  • Melanoma (mel): a malignant neoplasm that can appear in different variants. Melanomas are usually, but not always, chaotic, and some criteria depend on the site location.
  • Melanocytic Nevi (nv): these variants can differ significantly from a dermatoscopic point of view but are usually symmetric in terms of distribution of color and structure.
  • Vascular Lesions (vasc): generally categorized by a red or purple color and solid, well-circumscribed structures known as red clods or lacunes.

For more information on the dataset or the skin cancer classifications please refer to this paper.

Using Transfer Learning for a Convolutional Neural Network Model

Photo by Mika Baumeister on Unsplash

If you’ve been working with any kind of data, you’ll know that data is the most important thing when you’re developing deep learning models. However, most of the time your datasets probably won’t be large enough for optimal performance. By large, we’re talking about at least 50,000 images. Networks with a ton of layers are also incredibly expensive to train. It can take a super long time if you don’t have an amazing GPU (or several). 😢

The entire idea behind Transfer Learning is that you can take a model that has already been pre-trained on a large dataset, modify it and retrain it on the dataset you’re currently working with.

As I explained in my previous article, Convolutional Neural Networks look for different features in images such as edges and shapes. We can take a neural network with millions of connections that has already been trained at identifying different features and retrain part of it by “freezing” the first few layers. After adding a fully connected layer and through only training the last few layers, we obtain a model that can effectively identify basic features but make predictions that generalize well with relevant data.

Image result for transfer learning
A diagram depicting the use of transfer learning to retrain layers of a network. Source.

The Keras Applications library includes several deep learning models including VGG16, VGG19, ResNet50, MobileNet, and a few others. All of them have been trained on the ImageNet dataset which includes approximately 14 million images. That’s a pretty noticeable difference when compared to our 10,000-image dataset.

For this project, I chose to use the MobileNet architecture, which is optimized for mobile applications with less computing power. This architecture makes use of depth-wise separable convolutions which essentially helps to reduce the number of training parameters, making the model more lightweight. For more information on MobileNet, check out this paper.

Here’s how we can do it in Keras.

Preprocessing the Images of Skin Lesions

One nice thing about the HAM10000 dataset is that all of the images are the same size, 600x450. However, after looking at the distribution of the images, we see that a significant majority of the images belong to the class of melanocytic nevi.

Left: total number of images in each class. Right: number of training images in each class.

Augmenting the Training Data

Data augmentation is super useful when it comes to increasing the number of training examples we can work with. We can augment the training data and For this, we use the Keras ImageDataGenerator class from the Keras Preprocessing library, which generates batches of tensor image data with real-time augmentation by looping through the data in batches. Some of the parameters that we pass through are:

  • rotation_range: which is the degree range for random rotations
  • width_shift_range: this represents a fraction of the total width that the image can be shifted by
  • height_shift_range: this represents a fraction of the total height that the image can be shifted by
  • zoom_range=0.1: the fraction of the image that can be zoomed in or out
  • horizontal_flip=True: randomly flips the input horizontally
  • vertical_flip=True: randomly flips the input vertically
  • fill_mode='nearest': the specification for filling in points outside of the input boundaries

We can declare an augmented data generator by running the following code. Our target size is 224x224 because those are the dimensions that are needed for the MobileNet input layer.

Compiling the Model

The Keras Callbacks library provides a bunch of useful functions that can be applied at several stages during the training process of the model. These functions can be used to learn more about the internal states of the model. Two of the callbacks that are used in this program are ReduceLROnPlateau and ModelCheckpoint.

ReduceLROnPlateau is used to reduce the learning rate when one of the model metrics has stopped improving. It’s been shown that models often benefit when the learning rate is reduced by a factor of 2–10 once the model stops improving after several iterations. Some important parameters are:

  • monitor: the metric that will be used to evaluate whether or not the model is improving
  • factor=0.5: the factor by which the learning rate will be reduced
  • patience=2: the number of epochs with the same accuracy after which the learning rate is reduced
  • mode='max': reduces the learning rate when the model stops improving

ModelCheckpoint is used to save the model after every epoch. save_best_only=True makes sure that the best model isn’t overwritten.

Plotting a Confusion Matrix of the Predictions

We can see that our model has acceptable performance and that a significant number of test examples for label nv were classified correctly. The model mostly predicts akiec, bcc, bkl, nv, and vasc correctly, but struggles with df. The model sometimes confuses Melanoma (mel) with Melanocytic nevi (nv), as well as nv with Benign keratosis (bkl). The model still has much more room for improvement and finer tuning of hyperparameters may help.

Saving the Model and Converting it to TensorFlow.js

After training the model, we can find the Keras model in the local directory as model.h5. We can convert it to a TensorFlow.js file by running the following code.

Running Machine Learning in the Browser

TensorFlow.js is the JavaScript version of Google’s popular deep learning framework TensorFlow. It consists of a low-level core API and a high-level layers API. There are two main reasons why I think TensorFlow.js is pretty cool.

  1. TensorFlow.js with WebGL runs on any kind of GPU, including Nvidia, AMD, and phone GPUs as well.
  2. You can convert existing models into TensorFlow.js models and repurpose them easily.
  3. Models run locally in the browser, meaning that the user’s data never leaves their computer.

The last point is particularly important, as you can obviously imagine how important this would be in the future if self-diagnosing online becomes widespread. Training and inference on the client side can help make sure solutions are privacy-friendly!

You can find the code for my project here and a live version of the model here.

Key Takeaways:

  • Transfer learning is useful when you don’t necessarily have a lot of data or computing power to work with
  • Data augmentation is also another method for making sure you have enough training data to make sure your model performs well
  • TensorFlow.js lets you run machine learning models in the browser easily and on the client side

Thanks for reading! If you enjoyed it, please:

  • Add me on LinkedIn and follow my Medium to stay updated with my journey
  • Leave some feedback or send me an email (alex@alexyu.ca)
  • Share this article with your network

--

--

High school student passionate about tech, business, and making cool things.