Generating Handwritten Digits with GANs
In this notebook, we will train a Generative Adversarial Network (GAN) to generate images of handwritten digits similar to those in the MNIST dataset.
GANs are powerful for generating new content based on real data. The Generator creates digit images that do not exist but look real. The Discriminator tries to distinguish between real and fake digits. Over time, the Generator improves until it can fool the Discriminator into believing its digits are real.
This technique is widely used in AI Generative models for applications like deepfake generation, digital art creation, and image enhancement.
Import Required Libraries
We need TensorFlow and Keras for building and training the neural networks. NumPy will handle numerical data processing, and Matplotlib will allow us to visualize the generated digits.
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
Load and Prepare the MNIST Dataset
The MNIST dataset contains thousands of handwritten digit images. We load the images, normalize pixel values between -1 and 1, and reshape them for compatibility with our model.
(train_images, _), (_, _) = keras.datasets.mnist.load_data()
train_images = (train_images.astype("float32") - 127.5) / 127.5
train_images = np.expand_dims(train_images, axis=-1)
Build the Generator
The Generator takes a random noise vector as input and transforms it into an image of a handwritten digit. Initially, the generated digits will look like noise, but they will improve over time.
latent_dim = 100
def build_generator():
model = keras.Sequential([
keras.Input(shape=(latent_dim,)),
layers.Dense(256, activation="relu"),
layers.BatchNormalization(),
layers.Dense(28 * 28 * 1, activation="tanh"),
layers.Reshape((28, 28, 1))
])
return model
generator = build_generator()
Build the Discriminator
The Discriminator learns to differentiate between real MNIST digits and fake ones generated by the GAN. If the Discriminator becomes too good, the Generator must improve to keep up.
def build_discriminator():
model = keras.Sequential([
keras.Input(shape=(28, 28, 1)),
layers.Flatten(),
layers.Dense(256, activation="relu"),
layers.Dense(1, activation="sigmoid")
])
model.compile(optimizer=keras.optimizers.Adam(0.0002), loss="binary_crossentropy")
return model
discriminator = build_discriminator()
Train the GAN
Training consists of selecting real numbers from MNIST, generating fake numbers, training the Discriminator to distinguish real from fake, and training the Generator to improve.
epochs = 3000
batch_size = 128
for epoch in range(epochs):
real_images = train_images[np.random.randint(0, train_images.shape[0], batch_size)]
real_labels = np.ones((batch_size, 1))
noise = np.random.normal(0, 1, (batch_size, latent_dim))
fake_images = generator.predict(noise, verbose=0)
fake_labels = np.zeros((batch_size, 1))
discriminator.trainable = True
d_loss_real = discriminator.train_on_batch(real_images, real_labels)
d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
d_loss = 0.5 * (d_loss_real + d_loss_fake)
noise = np.random.normal(0, 1, (batch_size, latent_dim))
misleading_labels = np.ones((batch_size, 1))
discriminator.trainable = False
g_loss = discriminator.train_on_batch(generator.predict(noise, verbose=0), misleading_labels)
if epoch % 500 == 0:
print(f'Epoch {epoch}: D Loss = {d_loss:.4f}, G Loss = {g_loss:.4f}')
Generate AI-Created Digits
Now that our GAN is trained, we can use it to generate new handwritten digits that do not exist in MNIST.
noise = np.random.normal(0, 1, (5, latent_dim))
fake_images = generator.predict(noise)
plt.figure(figsize=(10, 2))
for i in range(5):
plt.subplot(1, 5, i+1)
plt.imshow(fake_images[i, :, :, 0], cmap="gray")
plt.axis("off")
plt.show()