Skip to content

Instantly share code, notes, and snippets.

@pcaressa
Created November 18, 2021 15:13
Show Gist options
  • Select an option

  • Save pcaressa/be67ff30dfd9ca4959de9458944c1f29 to your computer and use it in GitHub Desktop.

Select an option

Save pcaressa/be67ff30dfd9ca4959de9458944c1f29 to your computer and use it in GitHub Desktop.
toy_gan.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "toy_gan.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyNhkejhEdfQQ2me3lPw6Fcj",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/pcaressa/be67ff30dfd9ca4959de9458944c1f29/toy_gan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 282
},
"id": "D4Jr5pSoRT9e",
"outputId": "c8bbe874-3be9-4d7e-f153-eb0223af61ce"
},
"source": [
"\"\"\"\n",
" Generatore di curve sinusoidali\n",
"\n",
" (una curva è descritta da un insieme di punti (x,y) del piano)\n",
"\"\"\"\n",
"\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Dropout, LeakyReLU\n",
"\n",
"import numpy as np\n",
"from numpy.random import randint, uniform\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"TRAIN_LEN = 64\n",
"TRAIN_SIZE = 8192\n",
"BATCH = 128\n",
"X_MIN, X_MAX = -5.0, 5.0\n",
"Y_MIN, Y_MAX = -1.0, 1.0\n",
"\n",
"X_COORDS = np.linspace(X_MIN , X_MAX, TRAIN_LEN)\n",
"\n",
"print(\"Genera il training set campionando delle sinusoidi\")\n",
"\n",
"X_TRAIN = np.zeros((TRAIN_SIZE, TRAIN_LEN))\n",
"for i in range(0, TRAIN_SIZE):\n",
" scale = uniform(0.5, 2.0)\n",
" phase = uniform(np.math.pi)\n",
" X_TRAIN[i] = np.array([np.sin(scale*x + phase) for x in X_COORDS])\n",
"\n",
"print(\"Crea il discriminatore\")\n",
"\n",
"DIS_DROPOUT = 0.4\n",
"discriminator = Sequential()\n",
"discriminator.add(Dense(TRAIN_LEN, activation = \"relu\"))\n",
"discriminator.add(Dropout(DIS_DROPOUT))\n",
"discriminator.add(Dense(SAMPLE_LEN, activation = \"relu\"))\n",
"discriminator.add(Dropout(DIS_DROPOUT))\n",
"discriminator.add(Dense(1, activation = \"sigmoid\"))\n",
"discriminator.compile(optimizer = \"adam\",\n",
" loss = \"binary_crossentropy\",\n",
" metrics = [\"accuracy\"])\n",
"\n",
"print(\"Addestra il discriminatore\")\n",
"\n",
"BATCH_SIZE = 128\n",
"EPOCHS = 32\n",
"\n",
"ONES = np.ones((BATCH_SIZE//2)) # vettore di label 1\n",
"ZEROS = np.zeros((BATCH_SIZE//2)) # vettore di label 0\n",
"ONEZEROS = (ONES, ZEROS)\n",
"NOISE = uniform(Y_MIN, Y_MAX, size = (TRAIN_SIZE, TRAIN_LEN))\n",
"\n",
"print(\"-------+-------------\")\n",
"print(\" epoca | accuratezza \")\n",
"print(\"-------+-------------\")\n",
"\n",
"for i in range(EPOCHS):\n",
" # Sceglie BATCH_SIZE//2 indici e li torna in n\n",
" n = randint(0, TRAIN_SIZE, size = BATCH_SIZE//2)\n",
" # Ora prepara un batch di training record, metà\n",
" # training (label 1) metà a caso (label 0)\n",
" x = np.concatenate((X_TRAIN[n], NOISE[n]))\n",
" y = np.concatenate(ONEZEROS)\n",
" dummy, acc = discriminator.train_on_batch(x, y)\n",
" print(f\" {i:3} | {acc}\")\n",
"\n",
"print(\"Crea il generatore\")\n",
"\n",
"generator = Sequential()\n",
"generator.add(Dense(TRAIN_LEN, activation = \"relu\"))\n",
"generator.add(Dense(256, activation = \"relu\"))\n",
"generator.add(Dense(TRAIN_LEN, activation = \"tanh\"))\n",
"generator.compile(optimizer = \"adam\", loss = \"mse\", metrics = [\"accuracy\"])\n",
"\n",
"print(\"Crea la GAN\")\n",
"\n",
"gan = Sequential()\n",
"gan.add(generator)\n",
"discriminator.trainable = False\n",
"gan.add(discriminator)\n",
"gan.compile(optimizer = \"adam\", loss = \"binary_crossentropy\", metrics = [\"accuracy\"])\n",
"\n",
"print(\"Addestra la GAN\")\n",
"\n",
"EPOCHS = 64\n",
"\n",
"print(\"-------+------------------------+------------------------\")\n",
"print(\" epoca | accuratezza discrimin. | accuratezza generatore \")\n",
"print(\"-------+------------------------+------------------------\")\n",
"\n",
"for e in range(EPOCHS):\n",
" for k in range(TRAIN_SIZE//BATCH):\n",
" n = randint(0, TRAIN_SIZE, size = BATCH//2)\n",
"\n",
" x = np.concatenate((X_TRAIN[n], generator.predict(NOISE[n])))\n",
" y = np.concatenate(ONEZEROS)\n",
" discriminator.trainable = True\n",
" d_loss, d_acc = discriminator.train_on_batch(x, y)\n",
" discriminator.trainable = False\n",
"\n",
" g_loss, g_acc = gan.train_on_batch(NOISE[n], ONES)\n",
" \n",
" print(f\" {e:03n} | {d_acc:.5f} | {g_acc:.5f}\")\n",
"\n",
"x = uniform(Y_MIN, Y_MAX, size = (1, TRAIN_LEN))\n",
"y = generator.predict(x)[0]\n",
"\n",
"plt.plot(X_COORDS, y)\n",
"plt.show()\n"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Genera il training set campionando delle sinusoidi\n",
"Crea il discriminatore\n"
]
},
{
"output_type": "error",
"ename": "NameError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-1-95e4b138f36d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mdiscriminator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mTRAIN_LEN\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mactivation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"relu\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0mdiscriminator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDIS_DROPOUT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m \u001b[0mdiscriminator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mSAMPLE_LEN\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mactivation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"relu\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 38\u001b[0m \u001b[0mdiscriminator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDIS_DROPOUT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mdiscriminator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mactivation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"sigmoid\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'SAMPLE_LEN' is not defined"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zaNrR3ceRfiZ",
"outputId": "32c6fdff-8765-4392-8fda-192ec231e28b"
},
"source": [
"print(\" CREAZIONE DEL DISCRIMINATORE\")\n",
"\n",
"# Preveniamo la stampa di warning di tensorflow\n",
"import os\n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' \n",
"\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Dropout, LeakyReLU\n",
"\n",
"DROPOUT = Dropout(0.4) # Empirical hype rparameter\n",
"LEAKY_RELU = LeakyReLU(0.2) # Empirical hyperparameter\n",
"\n",
"DROPOUT = 0.4\n",
"LEAKY_RELU = 0.2\n",
"\n",
"discriminator = Sequential()\n",
"discriminator.add(Dense(SAMPLE_LEN, activation=\"relu\"))\n",
"#discriminator.add(LeakyReLU(LEAKY_RELU))\n",
"discriminator.add(Dropout(DROPOUT))\n",
"discriminator.add(Dense(SAMPLE_LEN, activation=\"relu\"))\n",
"#discriminator.add(LeakyReLU(LEAKY_RELU))\n",
"discriminator.add(Dropout(DROPOUT))\n",
"discriminator.add(Dense(1, activation = \"sigmoid\"))\n",
"discriminator.compile(optimizer = \"adam\", loss = \"binary_crossentropy\", metrics = [\"accuracy\"])"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" CREAZIONE DEL DISCRIMINATORE\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AAThoN9DDAgo",
"outputId": "e8f9bd61-7bb4-4d63-e9ae-fbde0d19bb78"
},
"source": [
"print(\" ADDESTRAMENTO DEL DISCRIMINATORE\")\n",
"\n",
"BATCH_SIZE = 128\n",
"BATCH_SIZE_2 = BATCH_SIZE // 2\n",
"EPOCHS = 16\n",
"\n",
"ONES = np.ones((SAMPLE_SIZE)) # vettore di label 1\n",
"ZEROS = np.zeros((SAMPLE_SIZE)) # vettore di label 0\n",
"NOISE = uniform(Y_MIN, Y_MAX, size = (SAMPLE_SIZE, SAMPLE_LEN))\n",
"\n",
"print(\"-------+-------------\")\n",
"print(\" epoca | accuratezza \")\n",
"print(\"-------+-------------\")\n",
"\n",
"for i in range(EPOCHS):\n",
" # Sceglie BATSH_SIZE_2 indici e li torna nel vettore n\n",
" n = randint(0, SAMPLE_SIZE, size = BATCH_SIZE_2)\n",
" # Ora prepara un batch di training record per il discriminatore:\n",
" # metà sono esempi corretti (con label 1) l'altra metà sono\n",
" # punti a caso (con label 0)\n",
" x = np.concatenate((SAMPLE[n], NOISE[n]))\n",
" y = np.concatenate((ONES[n], ZEROS[n]))\n",
" dummy, acc = discriminator.train_on_batch(x, y)\n",
" print(f\" {i:3} | {acc}\")\n",
"\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" ADDESTRAMENTO DEL DISCRIMINATORE\n",
"-------+-------------\n",
" epoca | accuratezza \n",
"-------+-------------\n",
" 0 | 0.4140625\n",
" 1 | 0.4921875\n",
" 2 | 0.4765625\n",
" 3 | 0.53125\n",
" 4 | 0.484375\n",
" 5 | 0.546875\n",
" 6 | 0.6171875\n",
" 7 | 0.5859375\n",
" 8 | 0.5546875\n",
" 9 | 0.6171875\n",
" 10 | 0.5625\n",
" 11 | 0.6640625\n",
" 12 | 0.671875\n",
" 13 | 0.671875\n",
" 14 | 0.75\n",
" 15 | 0.71875\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "sGUXW8cJRpOR",
"outputId": "09fce87f-05f3-47df-f068-ae6cd3bc64ca"
},
"source": [
"\n",
"print(\" CREAZIONE DEL GENERATORE\")\n",
"\n",
"# Il generatore prende una sequenza a caso di SAMPLE_LEN interi e la trasforma in una sinusoide\n",
"generator = Sequential()\n",
"generator.add(Dense(SAMPLE_LEN))\n",
"generator.add(LeakyReLU(LEAKY_RELU))\n",
"#generator.add(Dropout(DROPOUT))\n",
"generator.add(Dense(512))\n",
"generator.add(LeakyReLU(LEAKY_RELU))\n",
"#generator.add(Dropout(DROPOUT))\n",
"generator.add(Dense(SAMPLE_LEN, activation = \"tanh\"))\n",
"generator.compile(optimizer = \"adam\", loss = \"mse\", metrics = [\"accuracy\"])"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" CREAZIONE DEL GENERATORE\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HEmLLc_5RsuH",
"outputId": "982ed204-07d5-48e8-844e-e894335ef408"
},
"source": [
"print(\" COLLEGAMENTO DEL GENERATORE E DEL DISCRIMINATORE\")\n",
"\n",
"gan = Sequential()\n",
"gan.add(generator)\n",
"discriminator.trainable = False\n",
"gan.add(discriminator)\n",
"gan.compile(optimizer = \"adam\", loss = \"binary_crossentropy\", metrics = [\"accuracy\"])"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" COLLEGAMENTO DEL GENERATORE E DEL DISCRIMINATORE\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 599
},
"id": "z3GY0ygSRvZO",
"outputId": "e393c58b-2b1a-4ff0-a624-755e3210d1de"
},
"source": [
"print(\" ADDESTRAMENTO DELLA GAN\")\n",
"\n",
"BATCH_SIZE = 128\n",
"EPOCHS = 4 #64\n",
"\n",
"NOISE = uniform(-1.0, 1.0, size = (SAMPLE_SIZE, SAMPLE_LEN))\n",
"#NOISE = uniform(X_MIN, X_MAX, size = (SAMPLE_SIZE, SAMPLE_LEN))\n",
"ONES = np.ones((SAMPLE_SIZE))\n",
"ZEROS = np.zeros((SAMPLE_SIZE))\n",
"\n",
"print(\"-------+------------------------+------------------------\")\n",
"print(\" epoca | accuratezza discrimin. | accuratezza generatore \")\n",
"print(\"-------+------------------------+------------------------\")\n",
"\n",
"\n",
"fig = plt.figure(figsize = (8, 12))\n",
"ax_index = 1\n",
"for e in range(EPOCHS):\n",
" for k in range(SAMPLE_SIZE//BATCH_SIZE):\n",
" # Addestra il discriminatore a riconoscere le sinusoidi vere da quelle prodotte dal generatore\n",
" # Sceglie BATCH_SIZE indici e li torna nel vettore n\n",
" n = randint(0, SAMPLE_SIZE, BATCH_SIZE)\n",
" # Ora prepara un batch di training record per il discriminatore\n",
" p = generator.predict(NOISE[n])\n",
" x = np.concatenate((SAMPLE[n], p))\n",
" y = np.concatenate((ONES[n], ZEROS[n]))\n",
"\n",
" d_loss, d_acc = discriminator.train_on_batch(x, y)\n",
"\n",
" # Ora addestra la gan a riconoscere quelli classificati dal discriminatore\n",
" discriminator.trainable = False\n",
" g_loss, g_acc = gan.train_on_batch(NOISE[n], ONES[n])\n",
" discriminator.trainable = True\n",
" print(f\" {e:03n} | {d_acc:.5f} | {g_acc:.5f}\")\n",
" # At 0, 10, 20, ... plots the last generator prediction\n",
" #if e % 10 == 0:\n",
" ax = fig.add_subplot(8, 1, ax_index)\n",
" plt.plot(X_COORDS, p[-1])\n",
" ax.xaxis.set_visible(False)\n",
" plt.ylabel(f\"Epoch: {e}\")\n",
" ax_index += 1\n",
"\n",
"# Plots a curve generated by the GAN\n",
"y = generator.predict(uniform(-1, 1, size = (1, SAMPLE_LEN)))[0]\n",
"ax = fig.add_subplot(8, 1, ax_index)\n",
"plt.plot(X_COORDS, y)\n",
"\n",
"plt.show()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" ADDESTRAMENTO DELLA GAN\n",
"-------+------------------------+------------------------\n",
" epoca | accuratezza discrimin. | accuratezza generatore \n",
"-------+------------------------+------------------------\n",
" 000 | 0.84766 | 0.11719\n",
" 001 | 0.82812 | 0.07812\n",
" 002 | 0.81250 | 0.08594\n",
" 003 | 0.82422 | 0.06250\n"
]
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x864 with 5 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment