Image Super-Resolution using Convolution Neural Networks and Auto-encoders

A guide to enhancing image quality with Deep Learning!

Harshil Patel
Towards Data Science

--

The problem statement is quite familiar. You may all have faced problems with distorted images at some point and hence would have tried to enhance the image quality. Well, due to the advances in deep learning techniques, we’ll try to enhance the resolution of images by training a convolution neural network and using auto-encoders here!

Prerequisites

  1. A basic understanding of Convolution Neural Networks(CNNs)
  2. Working of TensorFlow, Keras and some other mandatory python libraries.

What are Auto-encoders?

Auto-encoders are a type of generative models used for unsupervised learning.

In layman terms it can be said that these models take some input x, tries to learn some latent features and then reconstructs input x to give some desired output X with the help of these learned features.

Here image input is been reconstructed (Source)

We will be using the concept of auto-encoder model to increase the resolution of an image. For detailed understanding about auto-encoders click here.

Implementation:

Library Imports

Let’s open Jupyter Notebook and import some required libraries.

import numpy as np
import cv2
import glob
import tensorflow as tf
from tensorflow.keras import Model, Input, regularizers
from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, UpSampling2D, Add, Dropout
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.preprocessing import image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import pickle

Download Dataset

We will be working on ‘Labeled Faces in the Wild Home’ dataset. This dataset contains a database of labelled faces, generally used for face recognition and detection. However, our aim is not to detect faces but to make a model to enhance image resolution.

Download Dataset by clicking here.

The dataset comprises of multiple sub directories containing various images of that person. Hence, it is important to capture image paths from these directories.

face_images = glob.glob('lfw/lfw/**/*.jpg') #returns path of images
print(len(face_images)) #contains 13243 images

Load and Preprocess Images

The size of original images are of 250 x 250 pixels. However, it would take a lot computation power to process these images on normal computer. Therefore, we will reduce the size of all images to 80 x 80 pixels.

As there are around 13,000 images, it would take lot of time if we process it individually. Hence, we take advantage of multiprocessing library provided in python for ease of execution.

tqdm is a progress library that we use to get a progress bar of the work done.

from tqdm import tqdm
from multiprocessing import Pool
progress = tqdm(total= len(face_images), position=0)
def read(path):
img = image.load_img(path, target_size=(80,80,3))
img = image.img_to_array(img)
img = img/255.
progress.update(1)
return img
p = Pool(10)
img_array = p.map(read, face_images)

In order to save time in future, let’s store our img_array (contains images) with the help of pickle library:

with open('img_array.pickle','wb') as f:
pickle.dump(img_array, f)
print(len(img_array))

Data preparation for Model Training

Now, we will split our dataset to train and validation set. We will use train data to train our model and validation data will be used to evaluate the model.

all_images = np.array(img_array)#Split test and train data. all_images will be our output images
train_x, val_x = train_test_split(all_images, random_state = 32, test_size=0.2)

As this is an image resolution enhancement task we will distort our images and take it as an input images. The original images will be added as our output images.

#now we will make input images by lowering resolution without changing the size
def pixalate_image(image, scale_percent = 40):
width = int(image.shape[1] * scale_percent / 100)
height = int(image.shape[0] * scale_percent / 100)
dim = (width, height)
small_image = cv2.resize(image, dim, interpolation = cv2.INTER_AREA)

# scale back to original size
width = int(small_image.shape[1] * 100 / scale_percent)
height = int(small_image.shape[0] * 100 / scale_percent)
dim = (width, height)
low_res_image = cv2.resize(small_image, dim, interpolation = cv2.INTER_AREA) return low_res_image

The idea is to take these distorted images and feed it to our model and make model learn to get the original image back.

train_x_px = []for i in range(train_x.shape[0]):
temp = pixalate_image(train_x[i,:,:,:])
train_x_px.append(temp)
train_x_px = np.array(train_x_px) #Distorted images# get low resolution images for the validation set
val_x_px = []
for i in range(val_x.shape[0]):
temp = pixalate_image(val_x[i,:,:,:])
val_x_px.append(temp)
val_x_px = np.array(val_x_px) #Distorted images
Input Image
Original Image

Model building

Let's define the structure of model. Moreover, to overcome the possibility of over-fitting, we are using l1 regularization technique in our convolution layer.

Input_img = Input(shape=(80, 80, 3))  

#encoding architecture
x1 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(Input_img)
x2 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x1)
x3 = MaxPool2D(padding='same')(x2)
x4 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x3)
x5 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x4)
x6 = MaxPool2D(padding='same')(x5)
encoded = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x6)
#encoded = Conv2D(64, (3, 3), activation='relu', padding='same')(x2)
# decoding architecture
x7 = UpSampling2D()(encoded)
x8 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x7)
x9 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x8)
x10 = Add()([x5, x9])
x11 = UpSampling2D()(x10)
x12 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x11)
x13 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l1(10e-10))(x12)
x14 = Add()([x2, x13])
# x3 = UpSampling2D((2, 2))(x3)
# x2 = Conv2D(128, (3, 3), activation='relu', padding='same')(x3)
# x1 = Conv2D(256, (3, 3), activation='relu', padding='same')(x2)
decoded = Conv2D(3, (3, 3), padding='same',activation='relu', kernel_regularizer=regularizers.l1(10e-10))(x14)
autoencoder = Model(Input_img, decoded)
autoencoder.compile(optimizer='adam', loss='mse', metrics=['accuracy'])

You can modify this model as per your choice and requirement to get better results. You can change number of layers, number of units or some regularization techniques too. For the time being, let’s move forward and see what our model looks like!

autoencoder.summary()
Screenshot of the model summary

Model Training

We will first define some callbacks so that it would be easy for model visualization and evaluation in future.

early_stopper = EarlyStopping(monitor='val_loss', min_delta=0.01, patience=50, verbose=1, mode='min')model_checkpoint =  ModelCheckpoint('superResolution_checkpoint3.h5', save_best_only = True)

Let's train our model:

history = autoencoder.fit(train_x_px,train_x,
epochs=500,
validation_data=(val_x_px, val_x),
callbacks=[early_stopper, model_checkpoint])

The execution time was around 21 seconds per epoch on 12GB NVIDIA Tesla K80 GPU. EarlyStopping was achieved at 65th epoch.

Now, let's evaluate our model on our test dataset:

results = autoencoder.evaluate(val_x_px, val_x)
print('val_loss, val_accuracy', results)

val_loss, val_accuracy [0.002111854264512658, 0.9279356002807617]

We are getting some pretty good results from our model with around 93% validation accuracy and a validation loss of 0.0021.

Make Predictions

predictions = autoencoder.predict(val_x_px)n = 4
plt.figure(figsize= (20,10))
for i in range(n):
ax = plt.subplot(3, n, i+1)
plt.imshow(val_x_px[i+20])
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(3, n, i+1+n)
plt.imshow(predictions[i+20])
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
1st row — Input Images & 2nd row — Output Images

Ending Notes

In this story, we learned about basic functionality of auto-encoders and implemented an Image Super-Resolution enhancement task. This task could have multiple use cases in daily lifestyles. For example, we can use this technique to enhance the quality of low-resolution videos as well. So, even without labels, we can work with the image data and solve several real-world problems.

If you have any other use case or technique to work with image data and also, if you find more improved model for image enhancement, do share in the response block below!

The entire code for this article is available here.

Reference

  1. A course by Snehan Kekre on Coursera.

--

--

A Deep Learning enthusiast with a profound background in Computer Science. Loves learning new and creative concepts about programming, science and life.