Created
March 15, 2020 19:28
-
-
Save dkohlsdorf/16d044e85c385401dbd7e5a8326a708a to your computer and use it in GitHub Desktop.
Recursive Auto Encoder With Nodes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import tensorflow as tf\n", | |
| "import numpy as np\n", | |
| "import os\n", | |
| "import pandas as pd\n", | |
| "import pickle as pkl" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 95, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def merge_encoder(n_in):\n", | |
| " a = tf.keras.layers.Input(n_in)\n", | |
| " b = tf.keras.layers.Input(n_in)\n", | |
| " c = tf.keras.layers.Concatenate()([a,b])\n", | |
| " h = tf.keras.layers.Dense(n_in, activation='relu')(c)\n", | |
| " o = tf.keras.layers.Dense(n_in * 2)(h)\n", | |
| " merge = tf.keras.models.Model(inputs=[a, b], outputs=[h, c, o])\n", | |
| " merge.summary()\n", | |
| " return merge\n", | |
| "\n", | |
| "class Node:\n", | |
| " \n", | |
| " def __init__(self, i, embedding, score, payload, l = None, r = None):\n", | |
| " self.i = i\n", | |
| " self.score = score\n", | |
| " self.embedding = embedding\n", | |
| " self.left = l\n", | |
| " self.right = r\n", | |
| " self.payload = payload\n", | |
| " \n", | |
| " def print(self, offset=\"\"):\n", | |
| " print(\"{} {} {} {}\".format(offset, self.i, self.score, np.mean(self.embeding)))\n", | |
| " if self.left is not None and self.right is not None:\n", | |
| " self.left.print(offset + \"\\t\")\n", | |
| " self.right.print(offset + \"\\t\")\n", | |
| "\n", | |
| " def merge(self, other, merger):\n", | |
| " merged = merger([self.embedding, other.embedding])\n", | |
| " h = merged[0]\n", | |
| " c = merged[1]\n", | |
| " y = merged[2]\n", | |
| " #score = tf.nn.l2_loss(y - c) + self.score + other.score\n", | |
| " score = tf.nn.softmax_cross_entropy_with_logits(c, y) + self.score + other.score\n", | |
| " return Node(-1, h, score, self, other)\n", | |
| "\n", | |
| "def ts2leafs(df):\n", | |
| " sequence = []\n", | |
| " for i, row in df.iterrows():\n", | |
| " node = Node(i, row['token'], tf.constant(0.0), row)\n", | |
| " sequence.append(node)\n", | |
| " return sequence\n", | |
| "\n", | |
| "def merge(x, m):\n", | |
| " while len(x) > 1: \n", | |
| " min_loss = float('inf')\n", | |
| " min_node = None\n", | |
| " min_i = 0\n", | |
| " min_j = 0\n", | |
| " for i in range(len(x)):\n", | |
| " for j in range(len(x)):\n", | |
| " if i < j:\n", | |
| " node = x[i].merge(x[j], m)\n", | |
| " if node.score < min_loss:\n", | |
| " min_node = node\n", | |
| " min_loss = node.score\n", | |
| " min_i = i\n", | |
| " min_j = j\n", | |
| " print(\"Merge: {} {}\".format(min_i, min_j))\n", | |
| " x[min_i] = min_node\n", | |
| " x = [x[idx] for idx in range(0, len(x)) if idx != min_j]\n", | |
| " return x[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Merge: 10 65\n", | |
| "Merge: 10 14\n", | |
| "Merge: 9 85\n", | |
| "Merge: 76 101\n", | |
| "Merge: 6 95\n", | |
| "Merge: 10 85\n", | |
| "Merge: 19 31\n", | |
| "Merge: 37 54\n", | |
| "Merge: 13 52\n", | |
| "Merge: 4 35\n", | |
| "Merge: 47 73\n", | |
| "Merge: 25 53\n", | |
| "Merge: 2 93\n", | |
| "Merge: 16 96\n", | |
| "Merge: 3 73\n", | |
| "Merge: 11 26\n", | |
| "Merge: 14 89\n", | |
| "Merge: 64 78\n", | |
| "Merge: 43 50\n", | |
| "Merge: 30 78\n", | |
| "Merge: 5 53\n", | |
| "Merge: 40 73\n", | |
| "Merge: 21 72\n", | |
| "Merge: 70 74\n", | |
| "Merge: 6 29\n", | |
| "Merge: 9 58\n", | |
| "Merge: 37 55\n", | |
| "Merge: 78 83\n", | |
| "Merge: 35 38\n", | |
| "Merge: 52 71\n", | |
| "Merge: 24 33\n", | |
| "Merge: 62 74\n", | |
| "Merge: 28 30\n", | |
| "Merge: 49 58\n", | |
| "Merge: 38 43\n", | |
| "Merge: 41 45\n", | |
| "Merge: 66 74\n", | |
| "Merge: 30 57\n", | |
| "Merge: 55 62\n", | |
| "Merge: 27 56\n", | |
| "Merge: 7 17\n", | |
| "Merge: 67 69\n", | |
| "Merge: 1 22\n", | |
| "Merge: 12 29\n", | |
| "Merge: 40 42\n", | |
| "Merge: 48 54\n", | |
| "Merge: 60 61\n", | |
| "Merge: 2 17\n", | |
| "Merge: 18 55\n", | |
| "Merge: 44 52\n", | |
| "Merge: 46 60\n", | |
| "Merge: 4 41\n", | |
| "Merge: 22 23\n", | |
| "Merge: 16 51\n", | |
| "Merge: 39 43\n", | |
| "Merge: 45 50\n", | |
| "Merge: 8 29\n", | |
| "Merge: 17 34\n", | |
| "Merge: 20 49\n", | |
| "Merge: 0 13\n", | |
| "Merge: 10 34\n", | |
| "Merge: 3 14\n", | |
| "Merge: 11 29\n", | |
| "Merge: 25 34\n", | |
| "Merge: 27 34\n", | |
| "Merge: 9 30\n", | |
| "Merge: 13 17\n", | |
| "Merge: 29 36\n", | |
| "Merge: 5 25\n", | |
| "Merge: 27 36\n", | |
| "Merge: 18 36\n", | |
| "Merge: 22 38\n", | |
| "Merge: 21 26\n", | |
| "Merge: 20 28\n", | |
| "Merge: 6 29\n", | |
| "Merge: 16 34\n", | |
| "Merge: 12 33\n", | |
| "Merge: 30 32\n", | |
| "Merge: 7 32\n", | |
| "Merge: 4 29\n", | |
| "Merge: 1 2\n", | |
| "Merge: 22 29\n", | |
| "Merge: 7 9\n", | |
| "Merge: 0 15\n", | |
| "Merge: 16 25\n", | |
| "Merge: 12 13\n", | |
| "Merge: 20 21\n", | |
| "Merge: 8 11\n", | |
| "Merge: 2 9\n", | |
| "Merge: 4 12\n", | |
| "Merge: 18 19\n", | |
| "Merge: 13 15\n", | |
| "Merge: 5 14\n", | |
| "Merge: 9 11\n", | |
| "Merge: 6 16\n", | |
| "Merge: 3 13\n", | |
| "Merge: 1 7\n", | |
| "Merge: 0 10\n", | |
| "Merge: 9 11\n", | |
| "Merge: 4 7\n", | |
| "Merge: 2 10\n", | |
| "Merge: 5 9\n", | |
| "Merge: 6 7\n", | |
| "Merge: 1 3\n", | |
| "Merge: 0 6\n", | |
| "Merge: 2 3\n", | |
| "Merge: 3 4\n", | |
| "Merge: 0 1\n", | |
| "Merge: 1 2\n", | |
| "Merge: 0 1\n", | |
| "done merging: [1285.6406]\n", | |
| "Epoch: 5\n", | |
| "Merge: 10 65\n", | |
| "Merge: 10 14\n", | |
| "Merge: 9 85\n", | |
| "Merge: 76 101\n", | |
| "Merge: 6 95\n", | |
| "Merge: 10 85\n", | |
| "Merge: 19 31\n", | |
| "Merge: 37 54\n", | |
| "Merge: 13 52\n", | |
| "Merge: 4 35\n", | |
| "Merge: 47 73\n", | |
| "Merge: 25 53\n", | |
| "Merge: 2 93\n", | |
| "Merge: 16 96\n", | |
| "Merge: 3 73\n", | |
| "Merge: 11 26\n", | |
| "Merge: 14 89\n", | |
| "Merge: 64 78\n", | |
| "Merge: 43 50\n", | |
| "Merge: 30 78\n", | |
| "Merge: 5 53\n", | |
| "Merge: 40 73\n", | |
| "Merge: 6 29\n", | |
| "Merge: 21 71\n", | |
| "Merge: 69 73\n", | |
| "Merge: 9 58\n", | |
| "Merge: 37 55\n", | |
| "Merge: 78 83\n", | |
| "Merge: 35 38\n", | |
| "Merge: 52 71\n", | |
| "Merge: 24 33\n", | |
| "Merge: 62 74\n", | |
| "Merge: 28 30\n", | |
| "Merge: 49 58\n", | |
| "Merge: 38 43\n", | |
| "Merge: 41 45\n", | |
| "Merge: 66 74\n", | |
| "Merge: 30 57\n", | |
| "Merge: 55 62\n", | |
| "Merge: 27 56\n", | |
| "Merge: 7 17\n", | |
| "Merge: 67 69\n", | |
| "Merge: 1 22\n", | |
| "Merge: 12 29\n", | |
| "Merge: 40 42\n", | |
| "Merge: 48 54\n", | |
| "Merge: 60 61\n", | |
| "Merge: 2 17\n", | |
| "Merge: 18 55\n", | |
| "Merge: 44 52\n", | |
| "Merge: 46 60\n", | |
| "Merge: 4 41\n", | |
| "Merge: 22 23\n", | |
| "Merge: 16 51\n", | |
| "Merge: 39 43\n", | |
| "Merge: 45 50\n", | |
| "Merge: 8 29\n", | |
| "Merge: 17 34\n", | |
| "Merge: 20 49\n", | |
| "Merge: 0 13\n", | |
| "Merge: 10 34\n", | |
| "Merge: 3 14\n", | |
| "Merge: 11 29\n", | |
| "Merge: 25 34\n", | |
| "Merge: 27 34\n", | |
| "Merge: 9 30\n", | |
| "Merge: 13 17\n", | |
| "Merge: 29 36\n", | |
| "Merge: 5 25\n", | |
| "Merge: 27 36\n", | |
| "Merge: 18 36\n", | |
| "Merge: 22 38\n", | |
| "Merge: 21 26\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "df = pd.read_csv('models/v2_lstm_v5/seq_clustering_log_06281101C.csv', names=[\"start\", \"stop\", \"file\", \"cluster\"], header=None)\n", | |
| "tokens = dict([(c, i) for i, c in enumerate(sorted(list(set(df['cluster']))))])\n", | |
| "bits = int(np.ceil(np.log(len(tokens)) / np.log(2)))\n", | |
| "for c, i in tokens.items():\n", | |
| " tokens[c] = np.float32([int(c) for c in np.binary_repr(i, width = bits)]).reshape(1, bits)\n", | |
| "df['token'] = df['cluster'].apply(lambda x : tokens[x])\n", | |
| "\n", | |
| "m = merge_encoder(bits)\n", | |
| "optimizer = tf.keras.optimizers.Adam()\n", | |
| "x = ts2leafs(df)\n", | |
| "\n", | |
| "print(\"Start Merging\")\n", | |
| "node = None\n", | |
| "for epoch in range(0, 25):\n", | |
| " with tf.GradientTape(watch_accessed_variables=True) as tape:\n", | |
| " print(\"Epoch: {}\".format(epoch))\n", | |
| " tape.watch(m.variables) \n", | |
| " node = merge(x, m)\n", | |
| " print(\"done merging: {}\".format(node.score))\n", | |
| " g = tape.gradient(node.score, m.variables)\n", | |
| " optimizer.apply_gradients(zip(g, m.variables))\n", | |
| " pkl.dump(node, open('epoch_{}_merged_{}.pkl'.format(epoch, \"seq_clustering_log_06281101C\"), \"wb\"))\n", | |
| "m.save('dolphin_merger.h5')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment