Reconstruct Source Galaxy Images from Strong Gravitational Lens Images Using U-Net

How to generate un-lensed images of galaxies far, far away

Madhumita Dange
Towards Data Science

--

Gravitational pull bends the light on distant galaxies, which causes them to appear lensed. The following image shows us how light from a distant quasar bends due to strong gravitational pull of a lensing galaxy which acts similar to an optical lens in between the Quasar (source) and the observer. This creates a distorted image of quasar shown in right side of image.

A simple illustration of strong gravitational lensing | Diagram by Author, Galaxy Images : Mandelbaum, Rachel, Lackner, Claire, Leauthaud, Alexie, & Rowe, Barnaby. (2012). COSMOS real galaxy dataset [Data set]. Zenodo. http://doi.org/10.5281/zenodo.3242143

This project tries to reconstruct the original image of the distant sources distorted by gravitational lenses.

Using Convolutional Neural Networks, U-Net, and Mean Squared Error with Structural Similarity Index, we were able to generate un-lensed images of these galaxies. i.e. Undoing effect of strong gravitational lensing and generating source image . Also, we are able to predict lensing parameters of the strong gravitational lens which causes this distortion. We worked on developing an algorithm to undo this visual effect so that we can detect other possible unidentified gravitational lenses. i.e. reconstructing images of the source galaxy.

Undoing Strong Gravitational lensing Effect | Diagram by Author, Galaxy Images : Mandelbaum, Rachel, Lackner, Claire, Leauthaud, Alexie, & Rowe, Barnaby. (2012). COSMOS real galaxy dataset [Data set]. Zenodo. http://doi.org/10.5281/zenodo.3242143

Dataset

As not many images of strong gravitational lenses are available, We decided to simulate data using the Cosmos dataset. Cosmos dataset has 60K galaxy images. We used redshift properties to filter out galaxies to use for our project. After filtering galaxies, we used Lenstronomy and randomly generated lens parameters to simulate lensed images of the source galaxies. We also used rotation data augmentation to generate more simulated data. Our dataset is made up of generated lensed images and its source real galaxy pairs. These Images are 64x64 grayscale images. They have been distributed into test, validation, and training datasets based on galaxy’s distinct source id to avoid data leak between the sets. The lensed images are used to predict the lensing parameters of gravitational lens for a galaxy. The un-lensed images are the correct depiction of the same galaxy used to predict the un-lensed galaxy given the lensed image.

Data Augmentation to create simulated strong gravitational lensing data | Diagram by Author, Galaxy Images : Mandelbaum, Rachel, Lackner, Claire, Leauthaud, Alexie, & Rowe, Barnaby. (2012). COSMOS real galaxy dataset [Data set]. Zenodo. http://doi.org/10.5281/zenodo.3242143

Approach

We broke our problem down into two steps.

First, predict 8 lensing parameters:

  • Einstein radius (theta_E)
  • Exponent of the lens’s power-law mass distribution (gamma)
  • Ellipticity components of the lens (e1,e2)
  • Location of the lens in the image (center_x,center_y) in arcseconds
  • External shear components (gamma1,gamma2) of lensed image.

Second, U-net generates the un-lensed source galaxy image from its lensed pair image.

Architecture Part 1: Predicting lensing parameters

Predicting lensing parameters | Diagram by Author, Galaxy Images : Mandelbaum, Rachel, Lackner, Claire, Leauthaud, Alexie, & Rowe, Barnaby. (2012). COSMOS real galaxy dataset [Data set]. Zenodo. http://doi.org/10.5281/zenodo.3242143

After we preprocessed our images, we used our training dataset in a convolutional neural network with three blocks, using the corresponding lensing parameters as the output of the last layer, which produced our model for part 1. In each block, we applied 2D Convolution, Max Pooling, and Batch Normalization, and a Relu activation function. Finally, we flattened the layers and passed the input into two dense layers, one Relu and the other a Linear function. We assigned all 8 parameters to be predicted as our labels in each dataset, then ran the function on all the images in the datasets. It is able to predict the lensing parameters with good metrics across all three datasets.

Hyperparameters of Part 1: We tuned the learning rate for Adam, number of layers for our CNN, mini batch size, and epochs.

model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3),input_shape=(64, 64, 1),use_bias=False))
model.add(layers.MaxPooling2D((2, 2))) model.add(layers.BatchNormalization()) model.add(layers.Activation("relu"))
model.add(layers.Conv2D(64, (3, 3) ,use_bias=False)) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.BatchNormalization()) model.add(layers.Activation("relu"))
model.add(layers.Conv2D(64, (3, 3),use_bias=False)) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.BatchNormalization()) model.add(layers.Activation("relu")) model.add(layers.Flatten()) model.add(layers.Dense(64, activation='relu')) model.add(layers.Dense(8,activation='linear'))
CNN to Predict Lensing parameters

Metric of “Root Mean squared error” Output per Lensed Parameter Class for Test Set:

RMS per Output Class Parameter, Diagram by Author,

This shows that theta_E (Einstein’s radius) is easier to predict than lensed center coordinates and gamma, but harder than the lens ellipticity components (e1, e2).

Architecture Part 2: Generating source galaxy image

lensed image to an un-lensed source galaxy image using U-Net | Diagram by Author, Galaxy Images : Mandelbaum, Rachel, Lackner, Claire, Leauthaud, Alexie, & Rowe, Barnaby. (2012). COSMOS real galaxy dataset [Data set]. Zenodo. http://doi.org/10.5281/zenodo.3242143

In part 2, we converted a lensed image to an un-lensed source galaxy image. For this we used our U-net encoder decoder network and image similarity for our error/loss function.

U-Net Experiments and hyperparameter tuning

We started a simple U-Net with multiple Conv2D, max pooling, transpose convolution layers, and linear activation for the last layer with MSE loss and Adam optimizer. After that, we added more hidden blocks in U-Net that improved its performance. Then we experimented with different activation functions: relu, elu, etc. Relu gave us better results for intermediate layers. Then we added dropouts which did not improve results much so removed it. Also, adding batch normalization helped U-Net. Still it was not generating better images for spiral galaxies from manual error analysis, so we decided to use different image similarity loss functions.

Initial Model
Intermediate model with mse generating correct shapes |Galaxy Images : Mandelbaum, Rachel, Lackner, Claire, Leauthaud, Alexie, & Rowe, Barnaby. (2012). COSMOS real galaxy dataset [Data set]. Zenodo. http://doi.org/10.5281/zenodo.3242143

Metrics and Custom Loss functions :

We experimented with image structural similarity SSIM index, Peak signal to noise ratio (PSNR), and mean squared error (MSE) as loss functions. SSIM gave us better results. We further improved our results by writing a custom loss function which is a combination of SSIM and MSE.

U-Net with Custom loss function

inputs = Input((64, 64, 1))s = Lambda(lambda x: x / 255) (inputs)
c1 = Conv2D(16, (3, 3), padding='same',use_bias=False) (s)
c1 = BatchNormalization()(c1)
c1 = Activation('relu')(c1)
c1 = Conv2D(16, (3, 3), padding='same',use_bias=False) (c1)
c1 = BatchNormalization()(c1)
c1 = Activation('relu')(c1)
p1 = MaxPooling2D((2, 2)) (c1)
c2 = Conv2D(32, (3, 3), padding='same',use_bias=False) (p1)
c2 = BatchNormalization()(c2)
c2 = Activation('relu')(c2)
c2 = Conv2D(32, (3, 3), padding='same',use_bias=False) (c2)
c2 = BatchNormalization()(c2)
c2 = Activation('relu')(c2)
p2 = MaxPooling2D((2, 2)) (c2)
c3 = Conv2D(64, (3, 3), padding='same',use_bias=False) (p2)
c3 = BatchNormalization()(c3)
c3 = Activation('relu')(c3)
c3 = Conv2D(64, (3, 3), padding='same',use_bias=False) (c3)
c3 = BatchNormalization()(c3)
c3 = Activation('relu')(c3)
c3 = Conv2D(64, (3, 3), padding='same',use_bias=False) (c3)
c3 = BatchNormalization()(c3)
c3 = Activation('relu')(c3)
p3 = MaxPooling2D((2, 2)) (c3)
c4 = Conv2D(128, (3, 3), padding='same', use_bias=False) (p3)
c4 = BatchNormalization()(c4)
c4 = Activation('relu')(c4)
c4 = Conv2D(128, (3, 3), padding='same',use_bias=False) (c4)
c4 = BatchNormalization()(c4)
c4 = Activation('relu')(c4)
p4 = MaxPooling2D(pool_size=(2, 2)) (c4)
c5 = Conv2D(256, (3, 3), padding='same',use_bias=False) (p4)
c5 = BatchNormalization()(c5)
c5 = Activation('relu')(c5)
c5 = Conv2D(256, (3, 3), padding='same',use_bias=False) (c5)
c5 = BatchNormalization()(c5)
c5 = Activation('relu')(c5)
u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same') (c5)
u6 = concatenate([u6, c4])
c6 = Conv2D(128, (3, 3), padding='same',use_bias=False) (u6)
c6 = BatchNormalization()(c6)
c6 = Activation('relu')(c6)
c6 = Conv2D(128, (3, 3), padding='same',use_bias=False) (c6)
c6 = BatchNormalization()(c6)
c6 = Activation('relu')(c6)
u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (c6)
u7 = concatenate([u7, c3])
c7 = Conv2D(64, (3, 3), padding='same',use_bias=False) (u7)
c7 = BatchNormalization()(c7)
c7 = Activation('relu')(c7)
c7 = Conv2D(64, (3, 3), padding='same',use_bias=False) (c7)
c7 = BatchNormalization()(c7)
c7 = Activation('relu')(c7)
c7 = Conv2D(64, (3, 3), padding='same',use_bias=False) (c7)
c7 = BatchNormalization()(c7)
c7 = Activation('relu')(c7)
u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (c7)
u8 = concatenate([u8, c2])
c8 = Conv2D(32, (3, 3), padding='same',use_bias=False) (u8)
c8 = BatchNormalization()(c8)
c8 = Activation('relu')(c8)
c8 = Conv2D(32, (3, 3), padding='same',use_bias=False) (c8)
c8 = BatchNormalization()(c8)
c8 = Activation('relu')(c8)
u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same') (c8)
u9 = concatenate([u9, c1], axis=3)
c9 = Conv2D(16, (3, 3), padding='same',use_bias=False) (u9)
c9 = BatchNormalization()(c9)
c9 = Activation('relu')(c9)
c9 = Conv2D(16, (3, 3), padding='same',use_bias=False) (c9)
c9 = BatchNormalization()(c9)
c9 = Activation('relu')(c9)
out = Conv2D(1, (1, 1), use_bias=False) (c9)
out = BatchNormalization()(out)
outputs = Activation('linear')(out)
# Optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model = Model(inputs=[inputs], outputs=[outputs])model.compile(loss=mse_ssim_loss, optimizer=optimizer, metrics=['mse',ssim_loss,psnr_loss])model.summary()

Custom loss function which is a combination of SSIM and MSE

def mse_ssim_loss(y_true, y_pred):
return tf.reduce_mean(tf.math.squared_difference(y_true, y_pred)) - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))

This loss function is generating better results for spiral galaxies as well.

val_mse: 0.0023 — val_ssim_loss: 0.0468 — val_psnr_loss: 30.6308. Loss curves.

U-Net Results on Test Set Improved Model: more results below

Improved model with custom loss function generating patterns as well | Galaxy Images : Mandelbaum, Rachel, Lackner, Claire, Leauthaud, Alexie, & Rowe, Barnaby. (2012). COSMOS real galaxy dataset [Data set]. Zenodo. http://doi.org/10.5281/zenodo.3242143

Insights and Future work:

We tried to generate a source from our trained model for original strong gravitational lensed images from the internet as well, but we need more preprocessing and noise removal for that data. We will also try gaussian noise removal as well as augmenting more data (Data based approach). We will also experiment with RGB images and continue to refine our models for both parts. So we can get closer to predict how these source galaxies actually look like. Also, we will make simulated data publicly available.

Contributions

Git Repository : Generate-Source-Galaxy-Images-from-Strong-Gravitational-Lens-Images

Data : lensingData

Acknowledgements

My team member Anand Bhavsar and I would like to thank our mentor Mr. Jelle Aalbers, Kavli Postdoctoral Fellow at Stanford University, domain expert in astrophysics for the idea, helping us set up our data, providing guidance and advice to build our models, and giving feedback on our test results.

More Results:

Sample Output Images | Galaxy Images : Mandelbaum, Rachel, Lackner, Claire, Leauthaud, Alexie, & Rowe, Barnaby. (2012). COSMOS real galaxy dataset [Data set]. Zenodo. http://doi.org/10.5281/zenodo.3242143

References:

COSMOS real galaxy dataset

Structural Similarity (SSIM)

Peak Signal to Noise Ratio (PSNR)

--

--

Works as Data Scientist and scpd student for AI at Standard University. Loves astrophysics and learning.