Last active
September 24, 2020 12:28
-
-
Save aboucaud/28d20c4bf08eb94eef960573624c59be to your computer and use it in GitHub Desktop.
[GAN model in keras]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| class GAN(tf.keras.model): | |
| def __init__(self, discriminator, generator, latent_dim): | |
| super(GAN, self).__init__() | |
| self.discriminator = discriminator | |
| self.generator = generator | |
| self.latent_dim = latent_dim | |
| def compile(self, d_optimizer, g_optimizer, loss_fn): | |
| super(GAN, self).compile() | |
| self.d_optimizer = d_optimizer | |
| self.g_optimizer = g_optimizer | |
| self.loss_fn = loss_fn | |
| def train_step(self, real_images): | |
| if isinstance(real_images, tuple): | |
| real_images = real_images[0] | |
| # Sample random points in the latent space | |
| batch_size = tf.shape(real_images)[0] | |
| random_latent_vectors = tf.random.normal( | |
| shape=(batch_size, self.latent_dim) | |
| ) | |
| # Decode them to fake images | |
| generated_images = tf.concat([generated_images, real_images], axis=0) | |
| # Assemble labels discriminating real from fake images | |
| labels = tf.concat( | |
| [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0 | |
| ) | |
| # Add random noise to the labels - important trick ! | |
| labels += 0.05 * tf.random.uniform(tf.shape(labels)) | |
| with tf.GradientTape() as tape: | |
| predictions = self.discriminator(combined_images) | |
| d_loss = self.loss_fn(labels, predictions) | |
| grads = tape.gradient(d_loss, self.discriminator.trainable_weights) | |
| self.d_optimizer.apply_gradients( | |
| zip(grads, self.discriminator.trainable_weights) | |
| ) | |
| # Sample random points in the latent space | |
| random_latent_vectors = tf.random.normal( | |
| shape=(batch_size, self.latent_dim) | |
| ) | |
| # Assemble labels that say "all real images" | |
| misleading_labels = tf.zeros((batch_size, 1)) | |
| # Train the generator (note that we should *not* update the weights | |
| # of the discriminator)! | |
| with tf.GradientTape() as tape: | |
| predictions = self.discriminator( | |
| self.generator(random_latent_vectors) | |
| ) | |
| g_loss = self.loss_fn(misleading_labels, predictions) | |
| grads = tape.gradient(g_loss, self.generator.trainable_weights) | |
| self.g_optimizer.apply_gradients( | |
| zip(grads, self.generator.trainable_weights) | |
| ) | |
| return {"d_loss": d_loss, "g_loss": g_loss} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment