Skip to content

Instantly share code, notes, and snippets.

@aboucaud
Last active September 24, 2020 12:28
Show Gist options
  • Select an option

  • Save aboucaud/28d20c4bf08eb94eef960573624c59be to your computer and use it in GitHub Desktop.

Select an option

Save aboucaud/28d20c4bf08eb94eef960573624c59be to your computer and use it in GitHub Desktop.
[GAN model in keras]
import tensorflow as tf
class GAN(tf.keras.model):
def __init__(self, discriminator, generator, latent_dim):
super(GAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, d_optimizer, g_optimizer, loss_fn):
super(GAN, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
# Sample random points in the latent space
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(
shape=(batch_size, self.latent_dim)
)
# Decode them to fake images
generated_images = tf.concat([generated_images, real_images], axis=0)
# Assemble labels discriminating real from fake images
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick !
labels += 0.05 * tf.random.uniform(tf.shape(labels))
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(
shape=(batch_size, self.latent_dim)
)
# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = self.discriminator(
self.generator(random_latent_vectors)
)
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(
zip(grads, self.generator.trainable_weights)
)
return {"d_loss": d_loss, "g_loss": g_loss}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment