Skip to content

Instantly share code, notes, and snippets.

@dkohlsdorf
Created March 15, 2020 19:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dkohlsdorf/16d044e85c385401dbd7e5a8326a708a to your computer and use it in GitHub Desktop.
Save dkohlsdorf/16d044e85c385401dbd7e5a8326a708a to your computer and use it in GitHub Desktop.
Recursive Auto Encoder With Nodes
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import os\n",
"import pandas as pd\n",
"import pickle as pkl"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
"def merge_encoder(n_in):\n",
" a = tf.keras.layers.Input(n_in)\n",
" b = tf.keras.layers.Input(n_in)\n",
" c = tf.keras.layers.Concatenate()([a,b])\n",
" h = tf.keras.layers.Dense(n_in, activation='relu')(c)\n",
" o = tf.keras.layers.Dense(n_in * 2)(h)\n",
" merge = tf.keras.models.Model(inputs=[a, b], outputs=[h, c, o])\n",
" merge.summary()\n",
" return merge\n",
"\n",
"class Node:\n",
" \n",
" def __init__(self, i, embedding, score, payload, l = None, r = None):\n",
" self.i = i\n",
" self.score = score\n",
" self.embedding = embedding\n",
" self.left = l\n",
" self.right = r\n",
" self.payload = payload\n",
" \n",
" def print(self, offset=\"\"):\n",
" print(\"{} {} {} {}\".format(offset, self.i, self.score, np.mean(self.embeding)))\n",
" if self.left is not None and self.right is not None:\n",
" self.left.print(offset + \"\\t\")\n",
" self.right.print(offset + \"\\t\")\n",
"\n",
" def merge(self, other, merger):\n",
" merged = merger([self.embedding, other.embedding])\n",
" h = merged[0]\n",
" c = merged[1]\n",
" y = merged[2]\n",
" #score = tf.nn.l2_loss(y - c) + self.score + other.score\n",
" score = tf.nn.softmax_cross_entropy_with_logits(c, y) + self.score + other.score\n",
" return Node(-1, h, score, self, other)\n",
"\n",
"def ts2leafs(df):\n",
" sequence = []\n",
" for i, row in df.iterrows():\n",
" node = Node(i, row['token'], tf.constant(0.0), row)\n",
" sequence.append(node)\n",
" return sequence\n",
"\n",
"def merge(x, m):\n",
" while len(x) > 1: \n",
" min_loss = float('inf')\n",
" min_node = None\n",
" min_i = 0\n",
" min_j = 0\n",
" for i in range(len(x)):\n",
" for j in range(len(x)):\n",
" if i < j:\n",
" node = x[i].merge(x[j], m)\n",
" if node.score < min_loss:\n",
" min_node = node\n",
" min_loss = node.score\n",
" min_i = i\n",
" min_j = j\n",
" print(\"Merge: {} {}\".format(min_i, min_j))\n",
" x[min_i] = min_node\n",
" x = [x[idx] for idx in range(0, len(x)) if idx != min_j]\n",
" return x[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Merge: 10 65\n",
"Merge: 10 14\n",
"Merge: 9 85\n",
"Merge: 76 101\n",
"Merge: 6 95\n",
"Merge: 10 85\n",
"Merge: 19 31\n",
"Merge: 37 54\n",
"Merge: 13 52\n",
"Merge: 4 35\n",
"Merge: 47 73\n",
"Merge: 25 53\n",
"Merge: 2 93\n",
"Merge: 16 96\n",
"Merge: 3 73\n",
"Merge: 11 26\n",
"Merge: 14 89\n",
"Merge: 64 78\n",
"Merge: 43 50\n",
"Merge: 30 78\n",
"Merge: 5 53\n",
"Merge: 40 73\n",
"Merge: 21 72\n",
"Merge: 70 74\n",
"Merge: 6 29\n",
"Merge: 9 58\n",
"Merge: 37 55\n",
"Merge: 78 83\n",
"Merge: 35 38\n",
"Merge: 52 71\n",
"Merge: 24 33\n",
"Merge: 62 74\n",
"Merge: 28 30\n",
"Merge: 49 58\n",
"Merge: 38 43\n",
"Merge: 41 45\n",
"Merge: 66 74\n",
"Merge: 30 57\n",
"Merge: 55 62\n",
"Merge: 27 56\n",
"Merge: 7 17\n",
"Merge: 67 69\n",
"Merge: 1 22\n",
"Merge: 12 29\n",
"Merge: 40 42\n",
"Merge: 48 54\n",
"Merge: 60 61\n",
"Merge: 2 17\n",
"Merge: 18 55\n",
"Merge: 44 52\n",
"Merge: 46 60\n",
"Merge: 4 41\n",
"Merge: 22 23\n",
"Merge: 16 51\n",
"Merge: 39 43\n",
"Merge: 45 50\n",
"Merge: 8 29\n",
"Merge: 17 34\n",
"Merge: 20 49\n",
"Merge: 0 13\n",
"Merge: 10 34\n",
"Merge: 3 14\n",
"Merge: 11 29\n",
"Merge: 25 34\n",
"Merge: 27 34\n",
"Merge: 9 30\n",
"Merge: 13 17\n",
"Merge: 29 36\n",
"Merge: 5 25\n",
"Merge: 27 36\n",
"Merge: 18 36\n",
"Merge: 22 38\n",
"Merge: 21 26\n",
"Merge: 20 28\n",
"Merge: 6 29\n",
"Merge: 16 34\n",
"Merge: 12 33\n",
"Merge: 30 32\n",
"Merge: 7 32\n",
"Merge: 4 29\n",
"Merge: 1 2\n",
"Merge: 22 29\n",
"Merge: 7 9\n",
"Merge: 0 15\n",
"Merge: 16 25\n",
"Merge: 12 13\n",
"Merge: 20 21\n",
"Merge: 8 11\n",
"Merge: 2 9\n",
"Merge: 4 12\n",
"Merge: 18 19\n",
"Merge: 13 15\n",
"Merge: 5 14\n",
"Merge: 9 11\n",
"Merge: 6 16\n",
"Merge: 3 13\n",
"Merge: 1 7\n",
"Merge: 0 10\n",
"Merge: 9 11\n",
"Merge: 4 7\n",
"Merge: 2 10\n",
"Merge: 5 9\n",
"Merge: 6 7\n",
"Merge: 1 3\n",
"Merge: 0 6\n",
"Merge: 2 3\n",
"Merge: 3 4\n",
"Merge: 0 1\n",
"Merge: 1 2\n",
"Merge: 0 1\n",
"done merging: [1285.6406]\n",
"Epoch: 5\n",
"Merge: 10 65\n",
"Merge: 10 14\n",
"Merge: 9 85\n",
"Merge: 76 101\n",
"Merge: 6 95\n",
"Merge: 10 85\n",
"Merge: 19 31\n",
"Merge: 37 54\n",
"Merge: 13 52\n",
"Merge: 4 35\n",
"Merge: 47 73\n",
"Merge: 25 53\n",
"Merge: 2 93\n",
"Merge: 16 96\n",
"Merge: 3 73\n",
"Merge: 11 26\n",
"Merge: 14 89\n",
"Merge: 64 78\n",
"Merge: 43 50\n",
"Merge: 30 78\n",
"Merge: 5 53\n",
"Merge: 40 73\n",
"Merge: 6 29\n",
"Merge: 21 71\n",
"Merge: 69 73\n",
"Merge: 9 58\n",
"Merge: 37 55\n",
"Merge: 78 83\n",
"Merge: 35 38\n",
"Merge: 52 71\n",
"Merge: 24 33\n",
"Merge: 62 74\n",
"Merge: 28 30\n",
"Merge: 49 58\n",
"Merge: 38 43\n",
"Merge: 41 45\n",
"Merge: 66 74\n",
"Merge: 30 57\n",
"Merge: 55 62\n",
"Merge: 27 56\n",
"Merge: 7 17\n",
"Merge: 67 69\n",
"Merge: 1 22\n",
"Merge: 12 29\n",
"Merge: 40 42\n",
"Merge: 48 54\n",
"Merge: 60 61\n",
"Merge: 2 17\n",
"Merge: 18 55\n",
"Merge: 44 52\n",
"Merge: 46 60\n",
"Merge: 4 41\n",
"Merge: 22 23\n",
"Merge: 16 51\n",
"Merge: 39 43\n",
"Merge: 45 50\n",
"Merge: 8 29\n",
"Merge: 17 34\n",
"Merge: 20 49\n",
"Merge: 0 13\n",
"Merge: 10 34\n",
"Merge: 3 14\n",
"Merge: 11 29\n",
"Merge: 25 34\n",
"Merge: 27 34\n",
"Merge: 9 30\n",
"Merge: 13 17\n",
"Merge: 29 36\n",
"Merge: 5 25\n",
"Merge: 27 36\n",
"Merge: 18 36\n",
"Merge: 22 38\n",
"Merge: 21 26\n"
]
}
],
"source": [
"df = pd.read_csv('models/v2_lstm_v5/seq_clustering_log_06281101C.csv', names=[\"start\", \"stop\", \"file\", \"cluster\"], header=None)\n",
"tokens = dict([(c, i) for i, c in enumerate(sorted(list(set(df['cluster']))))])\n",
"bits = int(np.ceil(np.log(len(tokens)) / np.log(2)))\n",
"for c, i in tokens.items():\n",
" tokens[c] = np.float32([int(c) for c in np.binary_repr(i, width = bits)]).reshape(1, bits)\n",
"df['token'] = df['cluster'].apply(lambda x : tokens[x])\n",
"\n",
"m = merge_encoder(bits)\n",
"optimizer = tf.keras.optimizers.Adam()\n",
"x = ts2leafs(df)\n",
"\n",
"print(\"Start Merging\")\n",
"node = None\n",
"for epoch in range(0, 25):\n",
" with tf.GradientTape(watch_accessed_variables=True) as tape:\n",
" print(\"Epoch: {}\".format(epoch))\n",
" tape.watch(m.variables) \n",
" node = merge(x, m)\n",
" print(\"done merging: {}\".format(node.score))\n",
" g = tape.gradient(node.score, m.variables)\n",
" optimizer.apply_gradients(zip(g, m.variables))\n",
" pkl.dump(node, open('epoch_{}_merged_{}.pkl'.format(epoch, \"seq_clustering_log_06281101C\"), \"wb\"))\n",
"m.save('dolphin_merger.h5')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment