Colorizing B/W Images With GANs in TensorFlow

An experiment with GANs to unveil their superpowers

🤖 Deep Learning

GANs are one of the most interesting topics in machine learning today. They have been used in a number of problems (and not just to generate MNIST digits!) and performed very well in each case. A GAN (General Adversarial Network) consists of a generator and a discriminator, which compete against each other to produce mind-blowing results. Here, we’ll take a mathematical approach towards understanding the GAN and its loss functions. As the idea behind training a GAN comes from game theory, we’ll have a quick look at the Minimax Optimization Strategy too.

In this article, we’ll explore GANs for colourizing B/W images and also learn the loss functions required for our model. So, get ready for some GANs!

GANs have some of the most amazing applications, like turning a horse into a zebra, as seen below.

Here’s a basic architecture of GANs used for generated realistic human faces:

This story will only be able to give you a glimpse of how GANs work, as we’ll focus more on the use-case rather than a complete explanation of how GANs work. In order to gain in-depth knowledge on the topic, refer to these blogs:

I have tried image-colorization with AutoEncoders before, but the results were not up to the mark. A single color appeared in the whole image with different shades or tints. The code for the project is available here ->

The Data and The Code

Our dataset will consist of 3000 RGB images from various domains (mountains, forests, cities, etc.,). You may download it from here. In the Colab notebook, we’ll convert these RGB images to grayscale using PIL which will act as labels for our model.

The TensorFlow implementation of this project can be found in this Colab notebook.

The Generator

The first thing our GAN will require is a generator. This generator will take in grayscale or B/W image, and output an RGB image. Our generator will have an encoder-decoder structure with layers placed symmetrically, just like a UNet.

  • The encoder will take a grayscale image and produce a latent representation of it (also called the bottleneck representation). The decoder’s job is to produce an RGB image by enlarging this latent representation. This approach is used by most autoencoders as well as other encoder-decoder structures.
  • While constructing the RGB image from the latent representation, some finer details might be missing. It would be interesting to observe the results if the information could come directly from the encoder to the decoder. Here’s where skip-connections come into the picture.
  • Skip connections bring outputs of the convolution layer (present in the encoder) to the decoder, where they are concatenated with previous outputs of the decoder itself.
def get_generator_model():

    inputs = tf.keras.layers.Input( shape=( img_size , img_size , 1 ) )

    conv1 = tf.keras.layers.Conv2D( 16 , kernel_size=( 5 , 5 ) , strides=1 , dilation_rate=4 )( inputs )
    conv1 = tf.keras.layers.LeakyReLU()( conv1 )
    conv1 = tf.keras.layers.Conv2D( 32 , kernel_size=( 3 , 3 ) , strides=1 , dilation_rate=2 )( conv1 )
    conv1 = tf.keras.layers.LeakyReLU()( conv1 )
    conv1 = tf.keras.layers.Conv2D( 32 , kernel_size=( 3 , 3 ) , strides=1)( conv1 )
    conv1 = tf.keras.layers.LeakyReLU()( conv1 )

    conv2 = tf.keras.layers.Conv2D( 32 , kernel_size=( 5 , 5 ) , strides=1)( conv1 )
    conv2 = tf.keras.layers.LeakyReLU()( conv2 )
    conv2 = tf.keras.layers.Conv2D( 64 , kernel_size=( 3 , 3 ) , strides=1 )( conv2 )
    conv2 = tf.keras.layers.LeakyReLU()( conv2 )
    conv2 = tf.keras.layers.Conv2D( 64 , kernel_size=( 3 , 3 ) , strides=1 )( conv2 )
    conv2 = tf.keras.layers.LeakyReLU()( conv2 )

    conv3 = tf.keras.layers.Conv2D( 64 , kernel_size=( 5 , 5 ) , strides=1 )( conv2 )
    conv3 = tf.keras.layers.LeakyReLU()( conv3 )
    conv3 = tf.keras.layers.Conv2D( 128 , kernel_size=( 3 , 3 ) , strides=1 )( conv3 )
    conv3 = tf.keras.layers.LeakyReLU()( conv3 )
    conv3 = tf.keras.layers.Conv2D( 128 , kernel_size=( 3 , 3 ) , strides=1 )( conv3 )
    conv3 = tf.keras.layers.LeakyReLU()( conv3 )

    bottleneck = tf.keras.layers.Conv2D( 128 , kernel_size=( 3 , 3 ) , strides=1 , activation='relu' , padding='same' )( conv3 )

    concat_1 = tf.keras.layers.Concatenate()( [ bottleneck , conv3 ] )
    conv_up_3 = tf.keras.layers.Conv2DTranspose( 128 , kernel_size=( 3 , 3 ) , strides=1 , activation='relu' )( concat_1 )
    conv_up_3 = tf.keras.layers.Conv2DTranspose( 128 , kernel_size=( 3 , 3 ) , strides=1 , activation='relu' )( conv_up_3 )
    conv_up_3 = tf.keras.layers.Conv2DTranspose( 64 , kernel_size=( 5 , 5 ) , strides=1 , activation='relu' )( conv_up_3 )

    concat_2 = tf.keras.layers.Concatenate()( [ conv_up_3 , conv2 ] )
    conv_up_2 = tf.keras.layers.Conv2DTranspose( 64 , kernel_size=( 3 , 3 ) , strides=1 , activation='relu' )( concat_2 )
    conv_up_2 = tf.keras.layers.Conv2DTranspose( 64 , kernel_size=( 3 , 3 ) , strides=1 , activation='relu' )( conv_up_2 )
    conv_up_2 = tf.keras.layers.Conv2DTranspose( 32 , kernel_size=( 5 , 5 ) , strides=1 , activation='relu' )( conv_up_2 )

    concat_3 = tf.keras.layers.Concatenate()( [ conv_up_2 , conv1 ] )
    conv_up_1 = tf.keras.layers.Conv2DTranspose( 32 , kernel_size=( 3 , 3 ) , strides=1 , activation='relu')( concat_3 )
    conv_up_1 = tf.keras.layers.Conv2DTranspose( 32 , kernel_size=( 3 , 3 ) , strides=1 , activation='relu' , dilation_rate=2 )( conv_up_1 )
    conv_up_1 = tf.keras.layers.Conv2DTranspose( 3 , kernel_size=( 5 , 5 ) , strides=1 , activation='relu' , dilation_rate=4 )( conv_up_1 )

    model = tf.keras.models.Model( inputs , conv_up_1 )
    return model
  • You may observe the skip connections at lines 28, 33 and 38 in the snippet above.

The Discriminator

Our discriminator will be a standard CNN which we use for classification. It will take an image and output a probability of whether the given image is an original or if it is generated (by the generator).

def get_discriminator_model():
    layers = [
        tf.keras.layers.Conv2D( 32 , kernel_size=( 7 , 7 ) , strides=1 , activation='relu' , input_shape=( 120 , 120 , 3 ) ),
        tf.keras.layers.Conv2D( 32 , kernel_size=( 7, 7 ) , strides=1, activation='relu'  ),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D( 64 , kernel_size=( 5 , 5 ) , strides=1, activation='relu'  ),
        tf.keras.layers.Conv2D( 64 , kernel_size=( 5 , 5 ) , strides=1, activation='relu'  ),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D( 128 , kernel_size=( 3 , 3 ) , strides=1, activation='relu'  ),
        tf.keras.layers.Conv2D( 128 , kernel_size=( 3 , 3 ) , strides=1, activation='relu'  ),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D( 256 , kernel_size=( 3 , 3 ) , strides=1, activation='relu'  ),
        tf.keras.layers.Conv2D( 256 , kernel_size=( 3 , 3 ) , strides=1, activation='relu'  ),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense( 512, activation='relu'  )  ,
        tf.keras.layers.Dense( 128 , activation='relu' ) ,
        tf.keras.layers.Dense( 16 , activation='relu' ) ,
        tf.keras.layers.Dense( 1 , activation='sigmoid' ) 
    ]
    model = tf.keras.models.Sequential( layers )
    return model

The Math

I know math becomes scary, especially in machine learning, but you need not worry, I’ll keep things as simple as possible. Suppose we have a sample ( x, y ) from our dataset. Here x would represent a grayscale image and y will be the same image but with colors, i.e., in RGB format. The shapes of x and y are shown below.

We represent the generator as G and the discriminator as D. For a single step, we’ll run the generator once and the discriminator twice.

Here y_p represents the generated image. P( real | y) is the probability that the given image y is the one from the data. Here, “real” indicates that the image is not generated.

We could consider y as real and y_p as a generated/fake image from the generator. We may use this terminology as you’ll find in most resources explaining GANs.

Now let’s take a look at the loss functions. First, for the generator, we’ll use the MSE loss function. It is generally used in regression, but we can also use it in our case.

In the above equation, y_p is the generated image and x is the input image. The loss functions for the discriminator are shown below. We use the binary cross-entropy loss for both the outputs of the discriminator.

We add these loss functions to get the final expression for the loss function of the discriminator:

  • Minimax refers to an optimization strategy in two-player turn-based games for minimizing the loss or cost for the worst case of the other player. Here, our generator and discriminator are the two players competing against each other.
  • For the discriminator, maximizing its loss would mean classifying generated images (y_p) accurately as well as producing a good probability (closer to 1.0) for images ( y ) from the dataset.
  • The generator, by minimizing its loss, improves itself to such an extent that it can fool the discriminator. Fooling the discriminator means that the discriminator will produce probabilities (closer to 1.0) even for generated images (y_p).

We’ll train the discriminator in such a manner that it will output probabilities closer to 1.0 for real images (from our dataset) and output probabilities closer to 0.0 for images coming from the generator.

If the discriminator is “smart” enough, it will output probabilities closer to 1.0 for real images (coming from our dataset). So, we are training our generator to forge such realistic images which will make the discriminator output probabilities closer to 1.0 even when the images are forged (not from our dataset, but from the generator).

Our final loss function will be:

That’s all! We’ll now head towards the code for training the GAN.

The Code

We saw the Keras implementation of the generator and the discriminator in snippets 1 and 2. Now let us take a look at the implementation of the loss functions.

cross_entropy = tf.keras.losses.BinaryCrossentropy()
mse = tf.keras.losses.MeanSquaredError()

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output) - tf.random.uniform( shape=real_output.shape , maxval=0.1 ) , real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output) + tf.random.uniform( shape=fake_output.shape , maxval=0.1  ) , fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output , real_y):
    real_y = tf.cast( real_y , 'float32' )
    return mse( fake_output , real_y )

generator_optimizer = tf.keras.optimizers.Adam( 0.001 )
discriminator_optimizer = tf.keras.optimizers.Adam( 0.001 )

Noticed something different in the snippet above, at lines 5 and 6?

We’re are adding/subtracting small random values from tf.ones and tf.zeros. So, instead of using hard labels like 1s and 0s, we use noisy labels like 0.12 or 0.99. This helps the discriminator learn better or else it would approach 1 or 0 at initial epochs and no learning will occur.

We use Adam optimizer for both the generator and the discriminator with a learning rate of 0.001.

Next comes the training loop. The training loop will generate predictions, both from the generator and discriminator, calculate the losses, and optimize both the models.

@tf.function
def train_step( input_x , real_y ):
   
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate an image -> G( x )
        generated_images = generator( input_x , training=True)
        # Probability that the given image is real -> D( x )
        real_output = discriminator( real_y, training=True)
        # Probability that the given image is the one generated -> D( G( x ) )
        generated_output = discriminator(generated_images, training=True)
        
        # MSE
        gen_loss = generator_loss( generated_images , real_y )
        # Log loss for the discriminator
        disc_loss = discriminator_loss( real_output, generated_output )
    
    #tf.keras.backend.print_tensor( tf.keras.backend.mean( gen_loss ) )
    #tf.keras.backend.print_tensor( gen_loss + disc_loss )

    # Compute the gradients
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    # Optimize with Adam
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

We are now ready to start training. The loss values are being printed for every forward pass.

num_epochs = 60

for e in range( num_epochs ):
    print( 'Epoch ' , e )
    for ( x , y ) in dataset:
        # Here ( x , y ) represents a batch from our training dataset.
        train_step( x , y )

The Results

The results are quite good and showcase the amazing power of GANs. But you’ll see some disturbance (the black/yellow colored patches distinct from their background) in the images below.

We also got some surprising results in which daytime was transformed to evening — we hadn’t trained the model for this!

Conclusion

Really, GANs have the power to change the face of machine learning. As I said earlier, they are flexible and can be used to solve various problems. Wondering where else can we use GANs?

GANs have been used for super-resolution of images. Here, we convert a low-resolution image to a high-resolution image as illustrated below,

GANs (more precisely CycleGANs) can even create human-like paintings and artwork, as shown below.

If you liked this blog, consider some further reading, both on ML in general and on Android.

If you have more ideas on how we can improve this GAN to yield better results, make sure you leave them in the comments below! Thank you and have a wonderful ML journey ahead.

Fritz

Our team has been at the forefront of Artificial Intelligence and Machine Learning research for more than 15 years and we're using our collective intelligence to help others learn, understand and grow using these new technologies in ethical and sustainable ways.

Comments 0 Responses

Leave a Reply

Your email address will not be published. Required fields are marked *

wix banner square