Using Diffusion to generate images

You must have heard of Dall-E 2. Published by Open AI, which is a model that generates realistic looking images from a given text prompt. You can check out a smaller version of the model here.
Ever wondered how it works under the hood? Well… it uses a new class of generative technique, called ‘diffusion’. The idea was proposed by Sohl-Dickstein, et al in 2015, where essentially, a model generates an image from Noise.
But why use Diffusion Models when there are GANs around?
GANs are great at generating high fidelity images. But, as outlined in this paper by Open AI: Diffusion models beat GANs on Image Synthesis, diffusion models are much better at image synthesis by being more faithful to the image. GANs have to produce an image in one go and generally don’t have any options for refinement during the generation of the image. Diffusion on the other hand is a slow and iterative process, during which, noise is converted into image, step by step. This allows diffusion models to have better options for guiding the image towards the desired result.
In this article we will be looking at how to create our own diffusion model based on Denoising Diffusion Probabilistic Models (Ho et al, 2021)(DDPM) and Denoising Diffusion Implicit Models (Song et al, 2021)(DDIM) using Keras and TensorFlow. So lets get started…
The process behind diffusion models is divided into two parts:
- Forward Noising process, and
- Backward Denoising process.
Forward Noising:
The concept of diffusion models is based on the well researched concept of diffusion in Physics.
In Physics, diffusion is defined as a process in which an isolated system tries to attain homogeneity by by altering the potential gradient in response to the introduction of a new element.

Using diffusion models, we try to reverse this process of homogenization by predicting the movements of the new element one step at a time.
Consider the series of images given below. Here we see that we gradually add small amounts of random noise to the image till it becomes indistinguishable. Our diffusion model, will try to figure out how to reverse this process of adding noise.

For the forward noising process q, we define a Markov Chain for a predefined amounts of steps, say T. Which takes an image and adds small amounts of Gaussian Noise to the image according to a variance schedule: _β_₀, _β_₁, … _βt. Where β₀ < β_₁< … < βt.
We then train a model that learns to remove this small amounts of noise at every timestep(given that the added noise is in small increments). We will explore this in the backward denoising section.
But first, what is a Markov Chain??
A Markov chain is a chain of events in which an event is only determined by the previous event.

Here, the state x1 is only determined by using x0, x2 by x1, and so on till we reach xT. So for our purpose, x0 state is our normal image, and as we move forward on our Markov chain, the image gets noisier till we reach the state xT.
Addition of Noise:
According to our Markov chain, the state xt is only determined by the state xt-1. For this, we need to calculate the probability q(xt|xt-1) to generate a slightly noisier image at the time-step t compared to t-1. This ‘slightly’ noisier image is generated by sampling small amount of noise using the Gaussian Distribution ‘N’ and adding it to the image. Noise sampled from Gaussian distribution is only determined by the mean and standard deviation. Here’s where we use the variance schedule: _β_₀, _β_₁, … βt. We make the mean value depended on βt and the input image xt. So finally, q(xt|xt-1) **** can be defined as:

And according to principle of Markov chains, the probability that a chain from x1 to xT occurs, for a given initial state x0 is given by:

Reparameterization:
The role of our model is to undo the added noise at every timestamp. To generate the noisy image at the said timestamp, we need to iterate through the Markov chain till we obtain the desired noisy image. This process is very inefficient. As a workaround, we use a reparameterization trick, which uses an approximation to generate the noise at the required timestamp. This trick works because the sum of two Gaussian samples is also a Gaussian sample. Here’s the reparameterization formula:

Therefore, we can pre-calculate the values for α and α bar, using the formula for q(xt|x0), and obtain the noised image xt at the timestep t given the original image x0.
Enough theory, let’s code this…
Here are the dependencies we will need to build our model.
!pip install tensorflow
!pip install tensorflow_datasets
!pip install tensorflow_addons
!pip install einops
Let’s start with the imports
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
from PIL import Image
import tensorflow as tf
from tensorflow import keras, einsum
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Layer
import tensorflow.keras.layers as nn
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
from einops import rearrange
from einops.layers.tensorflow import Rearrange
from functools import partial
from inspect import isfunction
# Suppressing tf.hub warnings
tf.get_logger().setLevel("ERROR")
# configure the GPU
gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.8)
config = tf.compat.v1.ConfigProto(gpu_options=gpu_options)
session = tf.compat.v1.Session(config=config)
For this implementation, we will use the MNIST digits dataset.
# set the parameters for dataset
target_size = (32, 32)
channels = 1
BATCH_SIZE=64
# Normalization helper
def preprocess(x, y):
return tf.image.resize(tf.cast(x, tf.float32) / 127.5 - 1, (32, 32))
def get_datasets():
# Load the MNIST dataset
train_ds = tfds.load('mnist', as_supervised=True, split="train")
# Normalize to [-1, 1], shuffle and batch
train_ds = train_ds.map(preprocess, tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(5000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
# Return numpy arrays instead of TF tensors while iterating
return tfds.as_numpy(train_ds)
dataset = get_datasets()
As per the description of the forward diffusion process, we need to create a fixed beta schedule. Along with that let us also set up the forward noising process and timestamp generation.
timesteps = 200
# create a fixed beta schedule
beta = np.linspace(0.0001, 0.02, timesteps)
# this will be used as discussed in the reparameterization trick
alpha = 1 - beta
alpha_bar = np.cumprod(alpha, 0)
alpha_bar = np.concatenate((np.array([1.]), alpha_bar[:-1]), axis=0)
sqrt_alpha_bar = np.sqrt(alpha_bar)
one_minus_sqrt_alpha_bar = np.sqrt(1-alpha_bar)
now let’s visualize the forward noising process.
# this function will help us set the RNG key for Numpy
def set_key(key):
np.random.seed(key)
# this function will add noise to the input as per the given timestamp
def forward_noise(key, x_0, t):
set_key(key)
noise = np.random.normal(size=x_0.shape)
reshaped_sqrt_alpha_bar_t = np.reshape(np.take(sqrt_alpha_bar, t), (-1, 1, 1, 1))
reshaped_one_minus_sqrt_alpha_bar_t = np.reshape(np.take(one_minus_sqrt_alpha_bar, t), (-1, 1, 1, 1))
noisy_image = reshaped_sqrt_alpha_bar_t * x_0 + reshaped_one_minus_sqrt_alpha_bar_t * noise
return noisy_image, noise
# this function will be used to create sample timestamps between 0 & T
def generate_timestamp(key, num):
set_key(key)
return tf.random.uniform(shape=[num], minval=0, maxval=timesteps, dtype=tf.int32)
# Let us visualize the output image at a few timestamps
sample_mnist = next(iter(dataset))[0]
fig = plt.figure(figsize=(15, 30))
for index, i in enumerate([10, 100, 150, 199]):
noisy_im, noise = forward_noise(0, np.expand_dims(sample_mnist, 0), np.array([i,]))
plt.subplot(1, 4, index+1)
plt.imshow(np.squeeze(np.squeeze(noisy_im, -1), 0), cmap='gray')
plt.show()

Backward Denoising:
Let’s understand what exactly will our model do…
We want an image-generating model that will predict what noise was added to the image at a given timestamp. This model should take in an input of a noised image along with the timestamp and predict what noise was added to the image at that time step. A U-Net-style model is perfect for this job. We can make some changes to the base architecture by changing the Convolutional layers to ResNet layers, adding mechanisms to consider timestamp encodings, and also having attention layers. The U-Net model was first proposed for biomedical image segmentation but since its inception, it has been modified and used for a lot of different applications.

Let’s code up our U-Net
1) Helper modules
# helpers functions
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# We will use this to convert timestamps to time encodings
class SinusoidalPosEmb(Layer):
def __init__(self, dim, max_positions=10000):
super(SinusoidalPosEmb, self).__init__()
self.dim = dim
self.max_positions = max_positions
def call(self, x, training=True):
x = tf.cast(x, tf.float32)
half_dim = self.dim // 2
emb = math.log(self.max_positions) / (half_dim - 1)
emb = tf.exp(tf.range(half_dim, dtype=tf.float32) * -emb)
emb = x[:, None] * emb[None, :]
emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
return emb
# small helper modules
class Identity(Layer):
def __init__(self):
super(Identity, self).__init__()
def call(self, x, training=True):
return tf.identity(x)
class Residual(Layer):
def __init__(self, fn):
super(Residual, self).__init__()
self.fn = fn
def call(self, x, training=True):
return self.fn(x, training=training) + x
def Upsample(dim):
return nn.Conv2DTranspose(filters=dim, kernel_size=4, strides=2, padding='SAME')
def Downsample(dim):
return nn.Conv2D(filters=dim, kernel_size=4, strides=2, padding='SAME')
class LayerNorm(Layer):
def __init__(self, dim, eps=1e-5, **kwargs):
super(LayerNorm, self).__init__(**kwargs)
self.eps = eps
self.g = tf.Variable(tf.ones([1, 1, 1, dim]))
self.b = tf.Variable(tf.zeros([1, 1, 1, dim]))
def call(self, x, training=True):
var = tf.math.reduce_variance(x, axis=-1, keepdims=True)
mean = tf.reduce_mean(x, axis=-1, keepdims=True)
x = (x - mean) / tf.sqrt((var + self.eps)) * self.g + self.b
return x
class PreNorm(Layer):
def __init__(self, dim, fn):
super(PreNorm, self).__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def call(self, x, training=True):
x = self.norm(x)
return self.fn(x)
class SiLU(Layer):
def __init__(self):
super(SiLU, self).__init__()
def call(self, x, training=True):
return x * tf.nn.sigmoid(x)
def gelu(x, approximate=False):
if approximate:
coeff = tf.cast(0.044715, x.dtype)
return 0.5 * x * (1.0 + tf.tanh(0.7978845608028654 * (x + coeff * tf.pow(x, 3))))
else:
return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(1.4142135623730951, x.dtype)))
class GELU(Layer):
def __init__(self, approximate=False):
super(GELU, self).__init__()
self.approximate = approximate
def call(self, x, training=True):
return gelu(x, self.approximate)
2) Building blocks of the U-Net model: Here we are incorporating time embedding by scaling and shifting the input passed to the resnet block. This scale and shift factor comes by passing the time embeddings through a Multi Layer Perceptron(MLP) module within the resnet block. This MLP will convert the fixed sized time embeddings into a vector that is complient with the compatible dimensions of the blocks in the resnet layer. Scale and Shift is written as ‘Gamma’ and ‘Beta’ in the code below.
# building block modules
class Block(Layer):
def __init__(self, dim, groups=8):
super(Block, self).__init__()
self.proj = nn.Conv2D(dim, kernel_size=3, strides=1, padding='SAME')
self.norm = tfa.layers.GroupNormalization(groups, epsilon=1e-05)
self.act = SiLU()
def call(self, x, gamma_beta=None, training=True):
x = self.proj(x)
x = self.norm(x, training=training)
if exists(gamma_beta):
gamma, beta = gamma_beta
x = x * (gamma + 1) + beta
x = self.act(x)
return x
class ResnetBlock(Layer):
def __init__(self, dim, dim_out, time_emb_dim=None, groups=8):
super(ResnetBlock, self).__init__()
self.mlp = Sequential([
SiLU(),
nn.Dense(units=dim_out * 2)
]) if exists(time_emb_dim) else None
self.block1 = Block(dim_out, groups=groups)
self.block2 = Block(dim_out, groups=groups)
self.res_conv = nn.Conv2D(filters=dim_out, kernel_size=1, strides=1) if dim != dim_out else Identity()
def call(self, x, time_emb=None, training=True):
gamma_beta = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b 1 1 c')
gamma_beta = tf.split(time_emb, num_or_size_splits=2, axis=-1)
h = self.block1(x, gamma_beta=gamma_beta, training=training)
h = self.block2(h, training=training)
return h + self.res_conv(x)
class LinearAttention(Layer):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.hidden_dim = dim_head * heads
self.attend = nn.Softmax()
self.to_qkv = nn.Conv2D(filters=self.hidden_dim * 3, kernel_size=1, strides=1, use_bias=False)
self.to_out = Sequential([
nn.Conv2D(filters=dim, kernel_size=1, strides=1),
LayerNorm(dim)
])
def call(self, x, training=True):
b, h, w, c = x.shape
qkv = self.to_qkv(x)
qkv = tf.split(qkv, num_or_size_splits=3, axis=-1)
q, k, v = map(lambda t: rearrange(t, 'b x y (h c) -> b h c (x y)', h=self.heads), qkv)
q = tf.nn.softmax(q, axis=-2)
k = tf.nn.softmax(k, axis=-1)
q = q * self.scale
context = einsum('b h d n, b h e n -> b h d e', k, v)
out = einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c (x y) -> b x y (h c)', h=self.heads, x=h, y=w)
out = self.to_out(out, training=training)
return out
class Attention(Layer):
def __init__(self, dim, heads=4, dim_head=32):
super(Attention, self).__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2D(filters=self.hidden_dim * 3, kernel_size=1, strides=1, use_bias=False)
self.to_out = nn.Conv2D(filters=dim, kernel_size=1, strides=1)
def call(self, x, training=True):
b, h, w, c = x.shape
qkv = self.to_qkv(x)
qkv = tf.split(qkv, num_or_size_splits=3, axis=-1)
q, k, v = map(lambda t: rearrange(t, 'b x y (h c) -> b h c (x y)', h=self.heads), qkv)
q = q * self.scale
sim = einsum('b h d i, b h d j -> b h i j', q, k)
sim_max = tf.stop_gradient(tf.expand_dims(tf.argmax(sim, axis=-1), axis=-1))
sim_max = tf.cast(sim_max, tf.float32)
sim = sim - sim_max
attn = tf.nn.softmax(sim, axis=-1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b x y (h d)', x = h, y = w)
out = self.to_out(out, training=training)
return out
3) U-Net model
class Unet(Model):
def __init__(self,
dim=64,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
resnet_block_groups=8,
learned_variance=False,
sinusoidal_cond_mlp=True
):
super(Unet, self).__init__()
# determine dimensions
self.channels = channels
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = nn.Conv2D(filters=init_dim, kernel_size=7, strides=1, padding='SAME')
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
# time embeddings
time_dim = dim * 4
self.sinusoidal_cond_mlp = sinusoidal_cond_mlp
self.time_mlp = Sequential([
SinusoidalPosEmb(dim),
nn.Dense(units=time_dim),
GELU(),
nn.Dense(units=time_dim)
], name="time embeddings")
# layers
self.downs = []
self.ups = []
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append([
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else Identity()
])
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append([
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else Identity()
])
default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)
self.final_conv = Sequential([
block_klass(dim * 2, dim),
nn.Conv2D(filters=self.out_dim, kernel_size=1, strides=1)
], name="output")
def call(self, x, time=None, training=True, **kwargs):
x = self.init_conv(x)
t = self.time_mlp(time)
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for block1, block2, attn, upsample in self.ups:
x = tf.concat([x, h.pop()], axis=-1)
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
x = tf.concat([x, h.pop()], axis=-1)
x = self.final_conv(x)
return x
Once, we have defined our U-Net model, we can now create an instance of it along with a checkpoint manager to save checkpoints during training. While we are at it, lets also create our optimizer. We will use the Adam optimizer with a learning rate of 1e-4.
# create our unet model
unet = Unet(channels=1)
# create our checkopint manager
ckpt = tf.train.Checkpoint(unet=unet)
ckpt_manager = tf.train.CheckpointManager(ckpt, "./checkpoints", max_to_keep=2)
# load from a previous checkpoint if it exists, else initialize the model from scratch
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
start_interation = int(ckpt_manager.latest_checkpoint.split("-")[-1])
print("Restored from {}".format(ckpt_manager.latest_checkpoint))
else:
print("Initializing from scratch.")
# initialize the model in the memory of our GPU
test_images = np.ones([1, 32, 32, 1])
test_timestamps = generate_timestamp(0, 1)
k = unet(test_images, test_timestamps)
# create our optimizer, we will use adam with a Learning rate of 1e-4
opt = keras.optimizers.Adam(learning_rate=1e-4)
Training our model:
The backward denoising step for our model is defined by p, where p is:

Here we want our model, i.e., our U-Net model, to predict the noise in the input image xt at a given timestep t by essentially predicting the value of µ(xt, t) and Σ(xt, t), i.e., the mean and variance for xt at the timestep t. We calculate the loss for the predicted noise between the predicted noise _Єθ and the original noise Є by the following formula:

The formula may look intimidating to a few folks, but we are going to be essentially calculating the loss value using Mean Squared Error between the predicted noise and the real noise. So let’s code this up!
def loss_fn(real, generated):
loss = tf.math.reduce_mean((real - generated) ** 2)
return loss
For the training process, we will use the following algorithm: 1) Generate a random number for the generation of timestamps and noise. 2) Create a list of random timestamps according to the batch size 3) Run the input image through the forward noising process along with the timestamps. 4) Get the predictions from the U-Net model using the noised image and the timestamps. 5) Calculate the loss between the predicted noise and real noise. 6) Update the trainable variables in the U-Net model. 7) Repeat for all training batches.
rng = 0
def train_step(batch):
rng, tsrng = np.random.randint(0, 100000, size=(2,))
timestep_values = generate_timestamp(tsrng, batch.shape[0])
noised_image, noise = forward_noise(rng, batch, timestep_values)
with tf.GradientTape() as tape:
prediction = unet(noised_image, timestep_values)
loss_value = loss_fn(noise, prediction)
gradients = tape.gradient(loss_value, unet.trainable_variables)
opt.apply_gradients(zip(gradients, unet.trainable_variables))
return loss_value
epochs = 10
for e in range(1, epochs+1):
# this is cool utility in Tensorflow that will create a nice looking progress bar
bar = tf.keras.utils.Progbar(len(dataset)-1)
losses = []
for i, batch in enumerate(iter(dataset)):
# run the training loop
loss = train_step(batch)
losses.append(loss)
bar.update(i, values=[("loss", loss)])
avg = np.mean(losses)
print(f"Average loss for epoch {e}/{epochs}: {avg}")
ckpt_manager.save(checkpoint_number=e)
Now that our model is trained, lets run it in inference mode. In the DDPM paper, the authors had outlined an algorithm for inference.

Here xt is a random sample, which we pass through our U-Net model and obtain _Єθ, then we calculate xt-1 according to the formula:

Before we code this, let’s create a helper function that will create and save a GIF file from a list of images.
# Save a GIF using logged images
def save_gif(img_list, path="", interval=200):
# Transform images from [-1,1] to [0, 255]
imgs = []
for im in img_list:
im = np.array(im)
im = (im + 1) * 127.5
im = np.clip(im, 0, 255).astype(np.int32)
im = Image.fromarray(im)
imgs.append(im)
imgs = iter(imgs)
# Extract first image from iterator
img = next(imgs)
# Append the other images and save as GIF
img.save(fp=path, format='GIF', append_images=imgs,
save_all=True, duration=interval, loop=0)
Now let’s make our backward denoising algorithm using the DDPM approach.
def ddpm(x_t, pred_noise, t):
alpha_t = np.take(alpha, t)
alpha_t_bar = np.take(alpha_bar, t)
eps_coef = (1 - alpha_t) / (1 - alpha_t_bar) ** .5
mean = (1 / (alpha_t ** .5)) * (x_t - eps_coef * pred_noise)
var = np.take(beta, t)
z = np.random.normal(size=x_t.shape)
return mean + (var ** .5) * z
now for the inference, let’s create a random image using the function defined above.
x = tf.random.normal((1,32,32,1))
img_list = []
img_list.append(np.squeeze(np.squeeze(x, 0),-1))
for i in tqdm(range(timesteps-1)):
t = np.expand_dims(np.array(timesteps-i-1, np.int32), 0)
pred_noise = unet(x, t)
x = ddpm(x, pred_noise, t)
img_list.append(np.squeeze(np.squeeze(x, 0),-1))
if i % 25==0:
plt.imshow(np.array(np.clip((x[0] + 1) * 127.5, 0, 255), np.uint8), cmap="gray")
plt.show()
save_gif(img_list + ([img_list[-1]] * 100), "ddpm.gif", interval=20)
plt.imshow(np.array(np.clip((x[0] + 1) * 127.5, 0, 255), np.uint8))
plt.show()
Here’s an example GIF generated by using the DDPM inference algorithm:

There’s one problem with the inference algorithm proposed in the DDPM paper. The process is very slow since we have to loop through all 200 timesteps to get the result. To make this process faster, an improved inference loop was proposed in the DDIM paper. Let’s discuss that…
DDIM:
In the DDIM paper, the authors proposed a non-markovian method for the backward denoising process, therefore removing the constraint that the order of the chain has to depend on the previous image. The paper proposed a modification to the DDPM objective by making the loss function more general:

From this loss function, we can infer that the loss value is only dependent on q(xt|x0) and not the joint probability of q(x1:T|x0). Along with this, the authors also proposed that we can explore a different inference approach which is non-markovian. Complicated-looking math coming up:

The above changes make the forward process non-Markovian as well where σ controls the stochasticity of the forward process. When σ→0, we reach a case where xt−1 becomes known and fixed. For the generative process with a fixed prior pθ(xT)=N(0,I):

Finally, the formula for inference is given by:

Here, if we set σ=0 ∀ t then the forward process becomes deterministic. The above formulae are taken from [1].
Enough mathematics, let’s code this up.
def ddim(x_t, pred_noise, t, sigma_t):
alpha_t_bar = np.take(alpha_bar, t)
alpha_t_minus_one = np.take(alpha, t-1)
pred = (x_t - ((1 - alpha_t_bar) ** 0.5) * pred_noise)/ (alpha_t_bar ** 0.5)
pred = (alpha_t_minus_one ** 0.5) * pred
pred = pred + ((1 - alpha_t_minus_one - (sigma_t ** 2)) ** 0.5) * pred_noise
eps_t = np.random.normal(size=x_t.shape)
pred = pred+(sigma_t * eps_t)
return pred
Now let’s use a similar backward denoising process as DDPM. Note that we are using only 10 steps for this inference loop, instead of the 200 steps of DDPM
# Define number of inference loops to run
inference_timesteps = 10
# Create a range of inference steps that the output should be sampled at
inference_range = range(0, timesteps, timesteps // inference_timesteps)
x = tf.random.normal((1,32,32,1))
img_list = []
img_list.append(np.squeeze(np.squeeze(x, 0),-1))
# Iterate over inference_timesteps
for index, i in tqdm(enumerate(reversed(range(inference_timesteps))), total=inference_timesteps):
t = np.expand_dims(inference_range[i], 0)
pred_noise = unet(x, t)
x = ddim(x, pred_noise, t, 0)
img_list.append(np.squeeze(np.squeeze(x, 0),-1))
if index % 1 == 0:
plt.imshow(np.array(np.clip((np.squeeze(np.squeeze(x, 0),-1) + 1) * 127.5, 0, 255), np.uint8), cmap="gray")
plt.show()
plt.imshow(np.array(np.clip((x[0] + 1) * 127.5, 0, 255), np.uint8), cmap="gray")
plt.show()
Here’s a sample gif from the ddim inference:

This model can be trained on a different dataset as well, and the code given in this post is robust enough to support higher resolution and rgb images. For example, I trained a model on the celebA dataset to generated 64×64 rgb images, here are some of the results:


With that we can conclude with this topic. There is a lot of related literature that has propped up from the concept of diffusion models. Here are some interesting reads: 1) GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models 2) Image Super-Resolution via Iterative Refinement 3) Diffusion Models Beat GANs on Image Synthesis 4) Imagen 5) Dall-E 2
[1]Exploring Diffusion Models with JAX by Darshan Deshpande. link.
- Unless otherwise noted, all images are made by me.
You can also read on the follow up of this story where I discuss on how to generate images from class labels. link.
What to connect? Please write to me at [email protected]