The world’s leading publication for data science, AI, and ML professionals.

Stepping into the magical world of GANs

A step by step tutorial to GAN

https://unsplash.com/photos/6dN9l-gopyo
https://unsplash.com/photos/6dN9l-gopyo

A Generative Model can potentially do magic, if trained properly it may write poetry, generate music, draw images like an expert. The objective of GAN is to generate synthetic samples after which are very realistic. GAN models learn the trick in an adversarial setting. There are two multi-layer neural networks one acting as a generator another as an adversary of the same, called the discriminator. Both are trained using the regular backpropagation method, although with different and conflicting loss functions (Adversarial Setup). Once the training is done, the discriminator is removed and the generator is used to produce the samples.

This is illustrated using the following diagram. The task here is to be able to create handwritten digits like MNIST.

Fig 1a: Traning Phase of GAN ( Image Source: Author)
Fig 1a: Traning Phase of GAN ( Image Source: Author)

The diagram has two neural networks, the first one is a generator and the second one is a discriminator.

The generator takes a random noise vector as an input and generates an image of the same dimension of real data. Why noise as input? This makes sure, the generator does not end up producing a replica of what was existing in the real-world data.

The discriminator is a simple classifier doing a binary classification. Class ‘0’ or fake class coming from the generator and class ‘1’ coming from the real images in this case MNIST.

The discriminator uses a regular loss function of cross-entropy, whereas the generator trains through the discriminator, keeping its weights constant (Else convergence will be like a moving target) and here the loss function measures the fake images having the probability of class 1. This is where the conflict comes. The discriminator wants the real class to be classified as real and fake to be classified as fake and the generator wants the fake class to be classified as real.

Once the training is done, the discriminator can be discarded and the generator can be used to produce realistic samples as shown below diagram.

Fig 1b: Generating Phase of GAN
Fig 1b: Generating Phase of GAN

Now the implementation part, the coding is done following the given blog by Jason Brownlee.

Step1: Defining the Generator

def define_generator(latent_dim):
    model = Sequential()
    # foundation for 7x7 image
    n_nodes = 128 * 7 * 7
    model.add(Dense(n_nodes, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7, 7, 128)))
    # upsample to 14x14
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    # upsample to 28x28
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(1, (7,7), activation='sigmoid', padding='same'))
    return model

The generator will take a noise vector as input and generate a 28 X 28 Image. Below is the code for testing the generator

noise = tf.random.normal([1, 100])
generator = define_generator(100)
generated_image = generator(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')

The image generated is as follows

Fig 2: Random Noise Generated (Image Source: Author)
Fig 2: Random Noise Generated (Image Source: Author)

Step 2: Defining the discriminator model

It’s a simple binary classification network.

def define_discriminator(in_shape=(28,28,1)):
    model = Sequential()
    model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same', input_shape=in_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

Step 3: Creating a combined Model

The purpose here is to train the Generator model and the discriminator weights will not change or are not trainable.

def define_gan(g_model, d_model):
    # make weights in the discriminator not trainable
    d_model.trainable = False
    # connect them
    model = Sequential()
    # add generator
    model.add(g_model)
    # add the discriminator
    model.add(d_model)
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

Step 4: Setting the stage.

  • Load Real Data: – Gets the data from MNIST
  • Generate Real Data: – Gets a sample from the real dataset, and append class label ‘1’
  • Generate Noise Data:- Creates random noise as input to the generator
  • Generate Fake Date:- Gets a sample from the generator and append class label ‘0’
def load_real_samples():
    # load mnist dataset
    (trainX, _), (_, _) = mnist.load_data()
    # expand to 3d, e.g. add channels dimension
    X = expand_dims(trainX, axis=-1)
    # convert from unsigned ints to floats
    X = X.astype('float32')
    # scale from [0,255] to [0,1]
    X = X / 255.0
    return X
def generate_real_samples(dataset, n_samples):
    # choose random instances
    ix = randint(0, dataset.shape[0], n_samples)
    # retrieve selected images
    X = dataset[ix]
    # generate 'real' class labels (1)
    y = ones((n_samples, 1))
    return X, y
def generate_latent_points(latent_dim, n_samples):
    # generate points in the latent space
    x_input = randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

def generate_fake_samples(g_model, latent_dim, n_samples):
    # generate points in latent space
    x_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    X = g_model.predict(x_input)
    # create 'fake' class labels (0)
    y = zeros((n_samples, 1))
    return X, y

Step 5: Training method GAN

The below code is the important piece, there are two loops, the outer one for Epochs and the inner one for batches. Two models are being trained one the Discriminator model another the combined model with discriminator weight constant. The number of Epoch is set to 5, which will not give a good result but will give you some idea whether it’s working or not.

def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=5, n_batch=256):
    bat_per_epo = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    # manually enumerate epochs
    for i in range(n_epochs):
        # enumerate batches over the training set
        for j in range(bat_per_epo):
            # get randomly selected 'real' samples
            X_real, y_real = generate_real_samples(dataset, half_batch)
            # generate 'fake' examples
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            # create training set for the discriminator
            X, y = vstack((X_real, X_fake)), vstack((y_real, y_fake))
            # update discriminator model weights
            d_loss, _ = d_model.train_on_batch(X, y)
            # prepare points in latent space as input for the generator
            X_gan = generate_latent_points(latent_dim, n_batch)
            # create inverted labels for the fake samples
            y_gan = ones((n_batch, 1))
            # update the generator via the discriminator's error
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            # summarize loss on this batch
            print('>%d, %d/%d, d=%.3f, g=%.3f' % (i+1, j+1, bat_per_epo, d_loss, g_loss))

Step 6: Running the training method with parameters

# size of the latent space
latent_dim = 100
# create the discriminator
d_model = define_discriminator()
# create the generator
g_model = define_generator(latent_dim)
# create the gan
gan_model = define_gan(g_model, d_model)
# load image data
dataset = load_real_samples()
# train model
train(g_model, d_model, gan_model, dataset, latent_dim)

Generation Phase:

seed = tf.random.normal([num_examples_to_generate, noise_dim])
predictions = g_model(seed, training=False)

fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] , cmap='gray')
      plt.axis('off')

Below is the generated image after 5 epochs and certainly if we train for more epochs this will give something meaningful.

Fig 3: Images produced by Generator (Image Source: Author)
Fig 3: Images produced by Generator (Image Source: Author)

Conclusion:

GAN is one of the coolest additions with a lot of potential and active development, this primer is just to get you started. There are some issues on convergence and strictness of the discriminator, the interested readers can further go through.

You can get started with our video lecture as well.


Related Articles