Image credit: Gimages

GAN 3

MNIST Linear GAN

Image credit: Gimages

GAN 3

MNIST Linear GAN

We saw an Intro to GANs and the Theory of Game between Generator and Discriminator in the previous posts. In this post we are going to implement and learn about how to train GANs in PyTorch. We will start with MNIST dataset and in the future posts we will implement different applications of GANs and also my research paper on one of the application of GANs.

So the task is to use the MNIST dataset to generate new MNIST alike data samples with GANs.

Let’s Code GAN

Get the Data

Import all the necessary libraries like Numpy, Matplotlib, torch, torchvision.

import numpy as np
import torch
import matplotlib.pyplot as plt

from torchvision import datasets
import torchvision.transforms as transforms

Now lets get the MNIST data from the torchvision datasets.

transform = transforms.ToTensor()
data = datasets.MNIST(root='data', train=True,
                                   download=True, transform=transform)
data_loader = torch.utils.data.DataLoader(data, batch_size=1024)

The Model

As we have already seen in Theory of Game between Generator and Discriminator, the GAN models generally have 2 networks Discriminator D and Generator G. We will code both of these network as seperate classes in PyTorch.

Discriminator

The discriminator is a just a classifier , which takes input images and classifies the images as real or fake generated images. So lets make a classifier network in PyTorch.

import torch.nn as nn
import torch.nn.functional as F

class D(nn.Module):

    def __init__(self, input_size, hidden_dim, output_size):
        super(D, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_dim*4)
        self.fc2 = nn.Linear(hidden_dim*4, hidden_dim*2)
        self.fc3 = nn.Linear(hidden_dim*2, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, output_size)
        self.dropout = nn.Dropout(0.3)      
        
    def forward(self, x):
        # flatten image
        x = x.view(-1, 28*28)
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = self.dropout(x)
        out = F.log_softmax(self.fc4(x))
        return out

The D network has 4 linear layers with leaky relu and dropout layers in between.

Here the input size will be 28281 (size of MNIST image)\
hidden dim can be anything of your choice.\
output_size = 2 (real or fake)

I am also adding a log softmax in the end for computation purpose.

Lets make a Discriminator object

D_network = D(28*28*1, 50, 2)
print(D_network)

output :

D(
  (fc1): Linear(in_features=784, out_features=200, bias=True)
  (fc2): Linear(in_features=200, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=50, bias=True)
  (fc4): Linear(in_features=50, out_features=2, bias=True)
  (dropout): Dropout(p=0.3)
)

Generator

The Generator takes a random vector(z)(also called latent vector) and generates a sample image with a distribution close to the training data distribution. We want to upsample z to an image of size 12828. Tanh was used as activation in the output layer(as used in the original paper) , but feel free to try other activations and check which gives good result.

class G(nn.Module):

    def __init__(self, input_size, hidden_dim, output_size):
        super(G, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim*2)
        self.fc3 = nn.Linear(hidden_dim*2, hidden_dim*4)
        self.fc4 = nn.Linear(hidden_dim*4, output_size) 
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = self.dropout(x)
        out = F.tanh(self.fc4(x))
        return out

The G network architecture is same as D’s architecture except now we upsample the z to 28281 size image.

G_network = G(100, 50, 1*28*28)
print(G_network)
G(
  (fc1): Linear(in_features=100, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=200, bias=True)
  (fc4): Linear(in_features=200, out_features=784, bias=True)
  (dropout): Dropout(p=0.3)
)

Loss

The discriminator wants the probability of fake images close to 0 and the generator wants the probability of the fake images generated by it to be close to 1.

So we define 2 losses

  • Real Loss (loss btw p and 1)
  • Fake loss (loss btw p and 0)

p is the probability of image to be real.

  • For Generator : minimize real_loss(p) or p to be closer to 1. ie: fool generator by making realistic images.

  • For Discriminator : minimize real_loss + fake loss. ie: p of real image close to 1 and p of fake image close to 0.

def real_loss(D_out, smooth=False):
    batch_size = D_out.size(0)
    # label smoothing
    if smooth:
        # smooth, real labels = 0.9
        labels = torch.ones(batch_size)*0.9
    else:
        labels = torch.ones(batch_size) # real labels = 1
    criterion = nn.NLLLoss()
    loss = criterion(D_out.squeeze(), labels.long().cuda())
    return loss

def fake_loss(D_out):
    batch_size = D_out.size(0)
    labels = torch.zeros(batch_size) # fake labels = 0
    criterion = nn.NLLLoss()
    loss = criterion(D_out.squeeze(), labels.long().cuda())
    return loss

label smoothing is also done for better convergence.

Training

We will use 2 optimizers

  • One for Generator, which optimizes the real_loss of fake images. ie: it tries to make the classification prediction of fake images equal to 1.
  • Next is discriminator, which tries to optimize real+fake loss. ie: it tries to make the prediciton of fake images to 0 and real images to 1.

Adjust the no of epochs, latent vector size, optimizer parameters, dimensions etc.

num_epochs = 100
print_every = 400

# train the network
D.train()
G.train()
for epoch in range(num_epochs):
    for batch_i, (images, _) in enumerate(train_loader):         
        batch_size = images.size(0)
        
        ## Important rescaling step ## 
        real_images = images*2 - 1  
        # rescale input images from [0,1) to [-1, 1)

        d_optimizer.zero_grad()
        D_real = D(real_images)
        d_real_loss = real_loss(D_real, smooth=True)
        
        
        z = np.random.uniform(-1, 1, size=(batch_size, z_size))
        z = torch.from_numpy(z).float()
        fake_images = G(z)
      
        D_fake = D(fake_images)
        d_fake_loss = fake_loss(D_fake)

        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        

        g_optimizer.zero_grad()
        z = np.random.uniform(-1, 1, size=(batch_size, z_size))
        z = torch.from_numpy(z).float()
        fake_images = G(z)
        
        D_fake = D(fake_images)
        g_loss = real_loss(D_fake) 
        g_optimizer.step()

        if batch_i % print_every == 0:
            print('Epoch {:5d}/{:5d}\td_loss: {:6.4f}\tg_loss: {:6.4f}'.format(
                    epoch+1, num_epochs, d_loss.item(), g_loss.item()))
Epoch     1/  100 d_loss: 1.3925  g_loss: 0.6747
Epoch     2/  100 d_loss: 1.2275  g_loss: 0.6837
Epoch     3/  100 d_loss: 1.0829  g_loss: 0.6959
Epoch     4/  100 d_loss: 1.0295  g_loss: 0.7128
Epoch     5/  100 d_loss: 1.0443  g_loss: 0.7358
Epoch     6/  100 d_loss: 1.0362  g_loss: 0.7625
Epoch     7/  100 d_loss: 0.9942  g_loss: 0.8000
Epoch     8/  100 d_loss: 0.9445  g_loss: 0.8455
Epoch     9/  100 d_loss: 0.9005  g_loss: 0.9073
Epoch    10/  100 d_loss: 0.8604  g_loss: 0.9908
...

Generate new MNIST Samples

def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        img = img.detach()
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')

sample_size=16
rand_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
rand_z = torch.from_numpy(rand_z).float()

G.eval()
rand_images = G(rand_z)
view_samples(0, [rand_images])

Linear GAN Model does a decent job in generating MNIST images. In next post we will look into DCGAN(Deep Convolutional GAN), to use CNNs for generating new samples.

Check this Awesome Repo on comparing Linear GAN and DCGAN for MNIST. Also this notebook for pytorch implementation of vanilla GAN(Linear).

Avatar
Shangeth Rajaa
Researcher at Vernacular.ai

Machine Learning Researcher at Vernacular.ai

Related

comments powered by Disqus