restoration gains with GANs

Rrohan.Arrora
Towards Data Science
10 min readNov 21, 2019

--

In this blog post, I will discuss the GANs and how it is used for restoring the images. I will use the fastai library for the practical. I would suggest the readers to explore fastai as well. Let us start.

Image by Thomas B. from Pixabay

What is Image Restoration?

Image restoration is the process of restoring the original image from the distorted one. Image restoration is of various kinds like:

  • take a low-resolution image and convert it to the high-resolution image
  • convert the black and white image to the coloured image
  • restore the broken part of the image

and many more.

Let us deal with the first kind of it, i.e. convert the low-resolution image with some unwanted text on the image to the high-resolution clear image. We need a dataset of images to restore, and fastai provides us with the dataset of images. Let us use it.

import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
path = untar_data(URLs.PETS)
path_hr = path/'images'
path_lr = path/'crappy'

path_hr — PosixPath(‘/root/.fastai/data/oxford-iiit-pet/images’)
path_lr — PosixPath(‘/root/.fastai/data/oxford-iiit-pet/crappy’)

Now, we need to crappify the images, and we can create any function for distorting the images. Below, I will use a function in which I have distorted the image to low resolution with written text over it. Let us understand it.

from fastai.vision import *
from PIL import Image, ImageDraw, ImageFont
class crappifier(object):
def __init__(self, path_lr, path_hr):
self.path_lr = path_lr
self.path_hr = path_hr
def __call__(self, fn, i):
dest = self.path_lr/fn.relative_to(self.path_hr)
dest.parent.mkdir(parents=True, exist_ok=True)
img_open = PIL.Image.open(fn)
targ_sz = resize_to(img_open, 96, use_min=True)
img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
w,h = img.size
q = random.randint(10,70)
ImageDraw.
Draw(img).
text((random.randint(0,w//2),random.randint(0,h//2)), str(q), fill=(255,255,255))
img.save(dest, quality=q)
  • dest — we are initialising the destination path for the crappy images to store.
  • img_open — opening the image using the PIL library
  • targ_sz — initialising the size for the crappy image.
  • img — resized image
  • w, h — width and height of the image
  • q — choosing any random number to display over the image
  • Draw() — draws the text over the image.
  • text() — finds the text to be drawn over the image. The first parameter declared the dimensions where we have to place the text. The second parameter is the number that is displayed over the image.
  • Finally, we are saving the image to the destination folder with quality q.

So, this is how we can crappify the images. You can distort the images in any fashion.

Memorize one thing, anything you do not include in the crappifier() , the model won’t learn to fix that.

il = ImageList.from_folder(path_hr)
parallel(crappify, il.items)

The process to crappify the images can take a while, but fast.ai has a function called parallel. If you pass parallel a function name and a list of things to run that function on, it will run that function on them all in parallel. So, it will save a lot of time.

Now let’s pre-train the generator. This is the usual stuff that we do after we have processed the data.

bs,size = 32, 128
arch = models.resnet34
src = ImageImageList.from_folder(path_lr).split_by_rand_pct(0.1, seed=42)
def get_data(bs,size):
data = (src.label_from_func(lambda x: path_hr/x.name)
.transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
.databunch(bs=bs).normalize(imagenet_stats, do_y=True))

return data

Let use the get_data() for the pre-processed dataset of images.

data_gen = get_data(bs,size)
data_gen.show_batch(2)

Now, we have created the data bunch, and we have to use this data bunch to create the leaner.

Here, I would expect the readers to know about the UNets and why do we use them. If you have little or no knowledge about it, please refer here.

What we are doing is restoring the image from learning from the original image and restoration is what the UNet performs. We need to pass the UNet, our data. Let us use UNets and build the need for GANs.

The obtained to us is just the list of images from different folders, normalised and transformed as per the requirement. Now, we are using the ImageNet stats above in the normalise method normalize(imagenet_stats, do_y=True) because we will use the pre-trained model. Now, why will we use the pre-trained model? We want to restore the distorted image. It is always better to train the model using the model who at least knows about the animals(not all but a lot) than to train the model on something which knows nothing about the animals. Moreover, we want to restore the image, i.e. remove the unwanted text from the images that our model should in general.

A word of advice is that transfer learning helps in almost every kind of computer vision problem.

Let us declare the parameters.

wd = 1e-3
y_range = (-3.,3.)
loss_gen = MSELossFlat()
def create_gen_learner():
return unet_learner(data_gen,
arch, wd=wd,
blur=True,
norm_type=NormType.Weight,
self_attention=True,
y_range=y_range,
loss_func=loss_gen)
learn_gen = create_gen_learner()
  • wd — the weight decay for the model to describe the regularisation.
  • y_range — it is the sigmoid function applied to the activations obtained at the last step. The values are defined for such type of problems.
  • loss_gen — defined the loss function. Since we are restoring from the original image, therefore we need to compare the output with the original image and MSE loss works great for it. MSE loss basically finds loss between the two input vectors. In our case, the input is images. So, we need to flatten the images before they could be put into the loss function. MSELossFlat() basically does the same kind of thing.
  • If you want to know more about it, learn it here.

So, we have created the UNet learner using a predefined and pre-trained model, ResNet34. This whole process is known as generative learning, where we are generating the images using the unet learner. This is not the exact definition but more of the broad definition. Let us now fit the model.

learn_gen.fit_one_cycle(2, pct_start=0.8)
  • pct_start — stands for the number of iterations in the single epoch, the learning rate would rise and the number of iterations for which the learning rate would decrease. Let us understand it using the above example only.

Let the number of iterations per epoch = 100
Number of iterations for which the learning rate would increase per epoch = (100 * 1) * 0.8 = 80
Number of iterations for which the learning rate would decrease = 100–80 = 20

The above fitting of the model happens with the freezing of the encoder in the UNet part, i.e. ResNet part of the UNet. But, since we are using the transfer learning, so we can unfreeze the pre-trained part of the UNet (the pre-trained part of a U-Net is the downsampling part.) That’s where the ResNet is.

learn_gen.unfreeze()learn_gen.fit_one_cycle(3, slice(1e-6,1e-3))learn_gen.show_results(rows=2)
  • We have obtained above relatively a good prediction model. Being the fast that our image is not as clear is due to the loss function. We are using the MSE loss that the pixel difference between the obtained image and the actual image is very less. If removing the text is the only task, then we have accomplished our objective. But, we are also working on to improve the image resolution.
  • Basically, the model is doing an excellent job in the upsampling but not working fine in downsampling.
  • All that we have to improve is the loss function. A better loss function would give us better results. This is what defines the Neural Networks. This establishes the need for the GANs.

Generative Adversarial Network

image source — fastai

Let us understand the semantics behind the GANs.

  • Till now, we have a model which moderately predicts the images that are not so much different from the original images. As per the above image, we have already created the crappy images and also the generator which is generating not such terrible images. We are then using the Pixel MSE to compare the predicted images and Hi-res images.
  • Imagine if we could come with something that rather than comparing the pixels between the images, actually classified the predicted image between the hi-resolution image and low-resolution image. And it would even be more interesting if we could fool that binary classifier in a way that it starts classifying the generated images to the high-resolution images.
  • Once, we start fooling the classifier, then we will train the classifier more to predict the actual class of the image, i.e. if the image is generated one, then it should predict it correctly, and if the image is the high-resolution image, then it should predict it high-resolution image.
  • Once, the classifier trains to predict the class of the predicted image correctly, that means that we cannot fool the classifier any more. In that situation, we will more train the generator so that it could generate images that are more close to the high-res image. Once we train the generator much enough, then we could again fool the classifier.
  • Once, we start fooling the classifier again, and we will begin training the classifier more this time. This process of training generator and classifier more and more actually sums up to the GANs.
  • So, basically, the loss function in the GANs calls us another model and that model itself has attained state of the art result. All game is to get a better and better loss function.

Believe me! This is all GAN.

We have already created the generator. Let us now create the classifier. But before we create the classifier, we need to store our predictions somewhere because we need to classify between the predicted images and high-resolution images. So, let us store the predicted images somewhere.

name_gen = 'image_gen'
path_gen = path/name_gen
path_gen.mkdir(exist_ok=True)
  • We have created a path PosixPath(‘/root/.fastai/data/oxford-iiit-pet/image_gen’)] where we want to store the generated images.
  • We already have a path PosixPath(‘/root/.fastai/data/oxford-iiit-pet/images’) where we store the high-resolution images.
def save_preds(dl):
i=0
names = dl.dataset.items

for b in dl:
preds = learn_gen.pred_batch(batch=b, reconstruct=True)
for o in preds:
o.save(path_gen/names[i].name)
i += 1
save_preds(data_gen.fix_dl)
  • We have saved the predictions in the folder.
  • data_gen.fix_dl will generate the fixed-size data loader.
  • Then we iterate over the data loader and extract a particular size of data and pass it to the pred_batch to predict.
  • We then store the predicted images to the folder under the image name itself.

Now, we need to code the classifier. If you need to know more about the classifiers, read here.

def get_crit_data(classes, bs, size):
src = ImageList.from_folder(path, include=classes)
.split_by_rand_pct(0.1, seed=42)
ll = src.label_from_folder(classes=classes)
data = (ll.transform(get_transforms(max_zoom=2.), size=size)
.databunch(bs=bs).normalize(imagenet_stats))
return data
data_crit = get_crit_data([name_gen, 'images'], bs=bs, size=size)
  • from_folder — extract the images from the folder located in path . We want to include the data of only those folders whose name are mentioned in the include=classes .
  • label_from_folder — label the images using the classes which are basically the folder names itself.
  • Then we transformed the data, created the data bunch and finally normalised the data.
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)

Let us define the learner for the classifier.

loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())def create_critic_learner(data, metrics):
return Learner(data, gan_critic(), metrics=metrics, loss_func=loss_critic, wd=wd)
learn_critic = create_critic_learner(data_crit, accuracy_thresh_expand)
  • if you say gan_critic, fast.ai will give you a binary classifier suitable for GANs.
  • Because we have this slightly different architecture and slightly different loss function, we did a slightly different metric. accuracy_thresh_expand is the equivalent GAN version of accuracy for critics.
  • Finally, we are calling the create_critic_learner to create the learner for GANs. Fastai data block API is beneficial in creating the learners.
learn_critic.fit_one_cycle(6, 1e-3)
  • Now, we have a learner who is pretty much good at differentiating the predicted images from the high-resolution images.
  • This is the predicted behaviour also because we already have images well differentiable.

Now, we have a generator, classifier/critic. Let us move to the last part of the section.

Finishing up GAN

learn_crit=None
learn_gen=None
data_crit = get_crit_data(['crappy', 'images'], bs=bs, size=size)learn_crit = create_critic_learner(data_crit, metrics=None).load('critic-pre2')learn_gen = create_gen_learner().load('gen-pre2')
  • Now that we have pre-trained the generator and pre-trained the critic, we now need to get it to kind of ping pong between training a little bit of each.
  • The amount of time you spend on each of those things and the learning rates you use is still a little bit on the fuzzy side, so fastai provides a GANLearner , you just pass in your generator and your critic (which we've just simply loaded here from the ones we just trained) and it will go-ahead
  • When you go learn.fit, it will do that for you - it'll figure out how much time to train generator and then when to switch to training the discriminator/critic, and it'll go backwards and forwards.
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen,
learn_crit,
weights_gen=(1.,50.),
show_img=False,
switcher=switcher,
opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
  • GANs hate momentum when you’re training them. It doesn’t make sense to train them with momentum because you keep switching between generator and critic, so it’s kind of tough. So this number here (betas=(0.,...)) when you create an Adam optimiser is where the momentum goes, so you should set that to zero.
  • Basically, the above-defined hyperparameters generally work for the GAns. You may pretty much use these parameters for every GAN problem.
lr = 1e-4
learn.fit(40,lr)
  • As the generator gets better, it gets harder for the discriminator (i.e. the critic) and then as the critic gets better, it’s harder for the generator. One of the tough things about training GANs is it’s hard to know how are they doing. The only way to know how are they doing is to actually take a look at the results from time to time.
  • If you put show_img=True in the GANLearner, then It’ll actually print out a sample after every epoch.

After all this, you may check the results below.

learn.show_results(rows=1)

Finally, You can compare the previous results with the now results. So, that is where GANs come into play.

Shortcoming of GANs

We are using a critic which is not using any a pre-trained model like ResNet34, rather we are using the gan_critic(). So, we need a pre-trained classifier model that is not only trained on ResNet34 but also compatible with GANs.

So, that is my take on GANs. Thanks for reading it and keep exploring fastai.

--

--