Saltar navegación

3.4 Generative Adversarial Networks

Introduction to GANs

Generative Adversarial Networks (GANs) are a class of artificial intelligence models used for generative tasks, meaning they can create new data instances similar to those in a given dataset. Introduced by Ian Goodfellow in 2014, GANs consist of two neural networks, a generator and a discriminator, that compete in a game-like scenario. The generator learns to create realistic data, while the discriminator learns to distinguish real data from fake ones.

GANs have revolutionized fields such as computer vision, art generation, and data augmentation. They can generate high-resolution images, enhance low-quality images, and even create deepfake videos. Understanding how GANs work is essential for leveraging them in AI-driven applications.

Architecture of a GAN

A GAN consists of two main components:

  • Generator: This neural network takes random noise as input and transforms it into synthetic data that resembles the real dataset.
  • Discriminator: This neural network receives both real and generated data and learns to classify them correctly.

These two networks are trained together in an adversarial process:

  1. The generator creates a batch of synthetic data.
  2. The discriminator evaluates this data along with real samples.
  3. The discriminator provides feedback on whether the generated data is real or fake.
  4. The generator updates its parameters to create more realistic samples.

Mathematically, the training process is modeled as a minimax game, where the generator tries to minimize the probability of being detected as fake, while the discriminator maximizes its accuracy in distinguishing real from fake data.

Training Process

During training, the generator improves by tricking the discriminator, while the discriminator refines its ability to detect generated samples. This dynamic continues until the generator produces samples indistinguishable from real ones.

Applications of GANs

GANs have a wide range of applications across different fields, including:

  • Image Generation: GANs are widely used to generate realistic human faces, landscapes, and artwork.
  • Super-Resolution: GANs can enhance low-resolution images to high-resolution versions (e.g., ESRGAN).
  • Style Transfer: Applications like DeepArt use GANs to apply artistic styles to images.
  • Data Augmentation: GANs help create synthetic data for training AI models, improving performance in scenarios with limited real data.
  • Deepfake Technology: GANs are behind the generation of realistic fake videos and images, raising ethical concerns.
  • Medical Imaging: GANs assist in generating synthetic medical images for training AI models in healthcare.
  • 3D Model Generation: GANs contribute to generating realistic 3D objects for VR/AR applications.

Despite their impressive capabilities, GANs come with challenges, including instability during training and the potential for generating biased or unethical content. Researchers continue to explore methods for improving GAN performance and ensuring ethical AI development.

Vídeo

Practice: Generating Handwritten Digits with GANs

Generating Handwritten Digits with GANs

In this notebook, we will train a Generative Adversarial Network (GAN) to generate images of handwritten digits similar to those in the MNIST dataset.

GANs are powerful for generating new content based on real data. The Generator creates digit images that do not exist but look real. The Discriminator tries to distinguish between real and fake digits. Over time, the Generator improves until it can fool the Discriminator into believing its digits are real.

This technique is widely used in AI Generative models for applications like deepfake generation, digital art creation, and image enhancement.

Import Required Libraries

We need TensorFlow and Keras for building and training the neural networks. NumPy will handle numerical data processing, and Matplotlib will allow us to visualize the generated digits.


import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

Load and Prepare the MNIST Dataset

The MNIST dataset contains thousands of handwritten digit images. We load the images, normalize pixel values between -1 and 1, and reshape them for compatibility with our model.


(train_images, _), (_, _) = keras.datasets.mnist.load_data()
train_images = (train_images.astype("float32") - 127.5) / 127.5
train_images = np.expand_dims(train_images, axis=-1)

Build the Generator

The Generator takes a random noise vector as input and transforms it into an image of a handwritten digit. Initially, the generated digits will look like noise, but they will improve over time.


latent_dim = 100

def build_generator():
    model = keras.Sequential([
        keras.Input(shape=(latent_dim,)),
        layers.Dense(256, activation="relu"),
        layers.BatchNormalization(),
        layers.Dense(28 * 28 * 1, activation="tanh"),
        layers.Reshape((28, 28, 1))
    ])
    return model

generator = build_generator()

Build the Discriminator

The Discriminator learns to differentiate between real MNIST digits and fake ones generated by the GAN. If the Discriminator becomes too good, the Generator must improve to keep up.


def build_discriminator():
    model = keras.Sequential([
        keras.Input(shape=(28, 28, 1)),
        layers.Flatten(),
        layers.Dense(256, activation="relu"),
        layers.Dense(1, activation="sigmoid")
    ])
    model.compile(optimizer=keras.optimizers.Adam(0.0002), loss="binary_crossentropy")
    return model

discriminator = build_discriminator()

Train the GAN

Training consists of selecting real numbers from MNIST, generating fake numbers, training the Discriminator to distinguish real from fake, and training the Generator to improve.


epochs = 3000
batch_size = 128

for epoch in range(epochs):
    real_images = train_images[np.random.randint(0, train_images.shape[0], batch_size)]
    real_labels = np.ones((batch_size, 1))

    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    fake_images = generator.predict(noise, verbose=0)
    fake_labels = np.zeros((batch_size, 1))

    discriminator.trainable = True
    d_loss_real = discriminator.train_on_batch(real_images, real_labels)
    d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
    d_loss = 0.5 * (d_loss_real + d_loss_fake)

    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    misleading_labels = np.ones((batch_size, 1))
    discriminator.trainable = False
    g_loss = discriminator.train_on_batch(generator.predict(noise, verbose=0), misleading_labels)

    if epoch % 500 == 0:
        print(f'Epoch {epoch}: D Loss = {d_loss:.4f}, G Loss = {g_loss:.4f}')

Generate AI-Created Digits

Now that our GAN is trained, we can use it to generate new handwritten digits that do not exist in MNIST.


noise = np.random.normal(0, 1, (5, latent_dim))
fake_images = generator.predict(noise)

plt.figure(figsize=(10, 2))
for i in range(5):
    plt.subplot(1, 5, i+1)
    plt.imshow(fake_images[i, :, :, 0], cmap="gray")
    plt.axis("off")
plt.show()

Summary

In this section, we've covered the basics of Generative Adversarial Networks, including the architecture of the generator and discriminator networks. We also explored practical examples of building and training GANs for image generation. These concepts and techniques are fundamental to understanding and working with GANs in various generative modeling applications.

Feito con eXeLearning (Nova xanela)