Glossar der KI-Begriffe: Generative Adversarial Network (GAN)

Kategorien:
No items found.
Freigegeben:
June 16, 2024

Generative Adversarial Network (GAN)

Einleitung

Generative Adversarial Networks (GANs) sind eine Klasse von maschinellen Lernmodellen, die sich in der künstlichen Intelligenz durch ihre Fähigkeit, Daten zu generieren und zu verbessern, hervorgetan haben. Erstmalig von Ian Goodfellow und seinem Team im Jahr 2014 vorgestellt, bestehen GANs aus zwei miteinander konkurrierenden neuronalen Netzwerken: einem Generator und einem Diskriminator. Dieses Zusammenspiel führt zu einem einzigartigen Trainingsprozess, bei dem GANs auf eine "adversariale" Weise arbeiten, die als ein überwacht-lernendes Problem formuliert ist.

Grundlegendes Prinzip

Ein Generative Adversarial Network besteht aus zwei neuralen Netzwerken, dem Generator und dem Diskriminator, die gleichzeitig durch adversariales Training trainiert werden.







Während des Trainings versucht der Generator, Daten zu erzeugen, die der Diskriminator nicht von echten Daten unterscheiden kann, während der Diskriminator versucht, immer besser zwischen echten und gefälschten Daten zu unterscheiden.

Architektur eines GANs

Die Architektur besteht aus zwei neuralen Netzwerken:







Training eines GANs

Das Training eines GANs erfolgt, indem beide Netzwerke gleichzeitig trainiert werden:







Während des Trainingsprozesses versucht der Generator, Daten zu erzeugen, die der Diskriminator nicht als gefälscht erkennen kann. Der Diskriminator hingegen versucht, immer besser darin zu werden, zwischen echten und gefälschten Daten zu unterscheiden. Dieser Prozess führt dazu, dass der Generator immer realistischere Daten erzeugt.

Verlustfunktionen

Generative Adversarial Networks verwenden Verlustfunktionen, um sowohl den Generator als auch den Diskriminator während des Trainings zu optimieren. Beide Netzwerke verwenden die binäre Kreuzentropieverlustfunktion:







Anwendungen von GANs

GANs haben eine Vielzahl von Anwendungen in verschiedenen Bereichen gefunden:









Herausforderungen und Lösungen

Obwohl GANs leistungsfähig sind, gibt es einige Herausforderungen beim Training:







Implementierung eines GANs

Die folgende einfache Implementierung zeigt, wie ein GAN zur Erzeugung von Bildern, die den MNIST-Datensatz-Zahlen ähneln, verwendet werden kann:



import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Flatten, Reshape, Conv2D, Conv2DTranspose, LeakyReLU, BatchNormalization
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

# Generator
def make_generator_model():
model = keras.Sequential()
model.add(Dense(7 * 7 * 256, activation='relu', input_shape=(100,)))
model.add(Reshape((7, 7, 256)))
model.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', activation='tanh'))
return model

# Discriminator
def make_discriminator_model():
model = keras.Sequential()
model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
model.add(LeakyReLU())
model.add(BatchNormalization())
model.add(Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(LeakyReLU())
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(1))
return model

def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss

def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)

def train_step(images):
noise = tf.random.normal([256, 100])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)

def generate_and_save_images(model, epoch, test_input):
predictions = model(test_input, training=False)
predictions = tf.reshape(predictions, (-1, 28, 28))
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i + 1)
plt.imshow(predictions[i] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig(f'image_at_epoch_{epoch:04d}.png')
plt.show()

if __name__ == "main":
(train_images, _), (_, _) = mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5
BUFFER_SIZE = 60000
BATCH_SIZE = 256
generator = make_generator_model()
discriminator = make_discriminator_model()
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
EPOCHS = 50
train(train_dataset, EPOCHS)
test_input = tf.random.normal([16, 100])
generate_and_save_images(generator, EPOCHS, test_input)

Fazit

GANs sind ein leistungsfähiges Werkzeug im Bereich des maschinellen Lernens, das in vielen Anwendungen von der Bilderzeugung bis zur Videoproduktion eingesetzt wird. Trotz der Herausforderungen beim Training bieten GANs immense Möglichkeiten zur Generierung realistischer Daten und zur Verbesserung bestehender Modelle.

Was bedeutet das?