Last active
December 10, 2019 16:54
-
-
Save rsk97/e2f2f35f069ad2046506c01dc8e3fe9e to your computer and use it in GitHub Desktop.
CGM-Text-embedding.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "CGM-Text-embedding.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/rsk97/e2f2f35f069ad2046506c01dc8e3fe9e/cgm-text-embedding.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QbbEqZaxliDr", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import tensorflow as tf\n", | |
"# import cv2 as cv\n", | |
"import os\n", | |
"import numpy as np\n", | |
"import glob\n", | |
"# # from random import shuffle\n", | |
"from google.colab import drive\n", | |
"import matplotlib.pyplot as plt\n", | |
"import csv\n", | |
"import re\n", | |
"import json\n", | |
"from nltk.corpus import stopwords \n", | |
"from nltk.tokenize import word_tokenize \n", | |
"from keras.preprocessing.text import Tokenizer\n", | |
"from keras.preprocessing.sequence import pad_sequences\n", | |
"from keras.layers import Embedding,Input\n", | |
"from keras.models import Model\n", | |
"from keras.optimizers import Adam\n", | |
"from keras.layers import Bidirectional,LSTM,GlobalMaxPool1D,Dense\n", | |
"from tensorflow.contrib import rnn\n", | |
"from sklearn.utils import shuffle\n", | |
"import pandas as pd\n", | |
"import tensorflow_hub as hub\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"import statistics as st\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "rFAHbwFHmXKB", | |
"colab_type": "code", | |
"outputId": "0fd82942-b9d0-4ec0-9846-5d50594966b2", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 126 | |
} | |
}, | |
"source": [ | |
"drive.mount(\"/content/drive\")" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n", | |
"\n", | |
"Enter your authorization code:\n", | |
"··········\n", | |
"Mounted at /content/drive\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "eoeNCC1SHory", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"df=pd.read_csv(\"/content/drive/My Drive/Charades_v1_train.csv\")" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "8NJluwraknsj", | |
"colab_type": "code", | |
"outputId": "400a2d96-77eb-4a19-9ee7-a5d01df734d6", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 189 | |
} | |
}, | |
"source": [ | |
"df.head()" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>id</th>\n", | |
" <th>subject</th>\n", | |
" <th>scene</th>\n", | |
" <th>quality</th>\n", | |
" <th>relevance</th>\n", | |
" <th>verified</th>\n", | |
" <th>script</th>\n", | |
" <th>objects</th>\n", | |
" <th>descriptions</th>\n", | |
" <th>actions</th>\n", | |
" <th>length</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>46GP8</td>\n", | |
" <td>HR43</td>\n", | |
" <td>Kitchen</td>\n", | |
" <td>6.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>Yes</td>\n", | |
" <td>A person cooking on a stove while watching som...</td>\n", | |
" <td>food;stove;window</td>\n", | |
" <td>A person cooks food on a stove before looking ...</td>\n", | |
" <td>c092 11.90 21.20;c147 0.00 12.60</td>\n", | |
" <td>24.83</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>N11GT</td>\n", | |
" <td>0KZ7</td>\n", | |
" <td>Stairs</td>\n", | |
" <td>6.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>Yes</td>\n", | |
" <td>One person opens up a folded blanket, then sne...</td>\n", | |
" <td>blanket;broom;floor</td>\n", | |
" <td>Person at the bottom of the staircase shakes a...</td>\n", | |
" <td>c098 8.60 14.20;c075 0.00 11.70;c127 0.00 15.2...</td>\n", | |
" <td>18.33</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0IH69</td>\n", | |
" <td>6RE8</td>\n", | |
" <td>Bedroom</td>\n", | |
" <td>6.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>Yes</td>\n", | |
" <td>A person is seen leaving a cabinet. They then ...</td>\n", | |
" <td>book;box;cabinet;shelf</td>\n", | |
" <td>A person is standing in a bedroom. They walk o...</td>\n", | |
" <td>NaN</td>\n", | |
" <td>30.25</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>KRF68</td>\n", | |
" <td>YA10</td>\n", | |
" <td>Laundry room</td>\n", | |
" <td>6.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>Yes</td>\n", | |
" <td>A person runs into their laundry room. They gr...</td>\n", | |
" <td>clothes;door;phone</td>\n", | |
" <td>A person runs in and shuts door. The person gr...</td>\n", | |
" <td>c018 22.60 27.80;c141 4.10 9.60;c148 10.30 25....</td>\n", | |
" <td>30.33</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>MJO7C</td>\n", | |
" <td>6RE8</td>\n", | |
" <td>Kitchen</td>\n", | |
" <td>6.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>Yes</td>\n", | |
" <td>A person runs into their pantry holding a bott...</td>\n", | |
" <td>cup;phone</td>\n", | |
" <td>A person runs in place while holding a bottle ...</td>\n", | |
" <td>c015 0.00 32.00;c107 0.00 32.00</td>\n", | |
" <td>31.38</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" id subject ... actions length\n", | |
"0 46GP8 HR43 ... c092 11.90 21.20;c147 0.00 12.60 24.83\n", | |
"1 N11GT 0KZ7 ... c098 8.60 14.20;c075 0.00 11.70;c127 0.00 15.2... 18.33\n", | |
"2 0IH69 6RE8 ... NaN 30.25\n", | |
"3 KRF68 YA10 ... c018 22.60 27.80;c141 4.10 9.60;c148 10.30 25.... 30.33\n", | |
"4 MJO7C 6RE8 ... c015 0.00 32.00;c107 0.00 32.00 31.38\n", | |
"\n", | |
"[5 rows x 11 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 180 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "eqhdcgDdNnKs", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"scripts=df['script'].values.tolist()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Z76gQx6Nt-nf", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def embed_useT(module):\n", | |
" with tf.Graph().as_default():\n", | |
" sentences = tf.placeholder(tf.string)\n", | |
" embed = hub.Module(module)\n", | |
" embeddings = embed(sentences)\n", | |
" session = tf.train.MonitoredSession()\n", | |
" return lambda x: session.run(embeddings, {sentences: x})" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Uat9ZQOJvm1d", | |
"colab_type": "code", | |
"outputId": "58429a25-090b-4628-b7a8-c17614541cb0", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 88 | |
} | |
}, | |
"source": [ | |
"embed_fn = embed_useT('/content/drive/My Drive/USE/')" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n", | |
"INFO:tensorflow:Graph was finalized.\n", | |
"INFO:tensorflow:Running local_init_op.\n", | |
"INFO:tensorflow:Done running local_init_op.\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4pK3Nf6kuA7P", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"sentence_encoding=embed_fn(scripts)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Xc5jJei8uA_m", | |
"colab_type": "code", | |
"outputId": "5505896f-edc3-4f3b-f04e-e20e7ec4ece3", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
} | |
}, | |
"source": [ | |
"sentence_encoding.shape" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(7985, 512)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 186 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "NaeUTaoEuA5M", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0RW4yJWtIGDZ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"vocab=set()\n", | |
"for i in df['script']:\n", | |
" v=i.split()\n", | |
" vocab.update(v)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "FZoNFkcHI1aZ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"vocab_len=len(vocab)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2Pd9xgNwNhbE", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"tok = Tokenizer(num_words=30000,lower=True)\n", | |
"tok.fit_on_texts(vocab)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HoqwC_FpNnnU", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"seq_script=tok.texts_to_sequences(scripts)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pLJrwhVZNnHt", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"max_seq_len=0\n", | |
"a=[]\n", | |
"for i in seq_script:\n", | |
" a.append(len(i))\n", | |
" if len(i)>max_seq_len:\n", | |
" max_seq_len=len(i)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HoKgukEofC2U", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"word2idx = tok.word_index" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wKiU2f6LRc-D", | |
"colab_type": "code", | |
"outputId": "84f53589-9e51-4a4e-83ea-27627502880e", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
} | |
}, | |
"source": [ | |
"max_seq_len" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"84" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 194 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "OH8c_pZfRxlt", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"mn=st.mean(a)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0oq8LOdpShyX", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"std=st.stdev(a)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "u5kk4pkeUFru", | |
"colab_type": "code", | |
"outputId": "cc83b34a-a94c-4279-f3e6-b9f1e5a153b6", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
} | |
}, | |
"source": [ | |
"thresh=int((mn+(3*std)))\n", | |
"thresh" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"43" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 197 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wm-aY8TRcOc3", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EkmBYGsVbz3Q", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"for ind,val in enumerate(seq_script):\n", | |
" seq_script[ind]=seq_script[ind][:thresh]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xXh_FkbnUR0G", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"data_feed_script= pad_sequences(seq_script, maxlen=thresh+1)\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "NCjN07iOVy9N", | |
"colab_type": "code", | |
"outputId": "c6f59229-7628-4893-ef2e-c02159babd22", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
} | |
}, | |
"source": [ | |
"data_feed_script.shape" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(7985, 44)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 200 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "cSTRsalIeC_g", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"emb={}\n", | |
"with open(os.path.join('/content/drive/My Drive/glove.6B.50d.txt')) as f :\n", | |
" for l in f:\n", | |
" v=l.split()\n", | |
" w=v[0]\n", | |
" emb[w]=v[1:]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9ozVg1fqeG0f", | |
"colab_type": "code", | |
"outputId": "d200ccc9-c134-447f-c472-71edd5cb164c", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 72 | |
} | |
}, | |
"source": [ | |
"t={}\n", | |
"emb_dim=50\n", | |
"num_words =vocab_len+1 \n", | |
"embedding_matrix = np.random.rand(num_words,emb_dim)*2-1\n", | |
"for word,i in word2idx.items():\n", | |
" embedding_vector = emb.get(word)\n", | |
" if embedding_vector is not None:\n", | |
" embedding_matrix[i] = np.asarray(embedding_vector)\n", | |
" else:\n", | |
" t[word]=i\n", | |
" embedding_matrix[i] = np.asarray(emb.get(\"unk\"))\n", | |
"print(t)\n", | |
"print(len(t))" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"{\"they're\": 382, 'selfie': 449, \"that's\": 546, \"it's\": 627, 'selfies': 688, 'processds': 836, \"couldn't\": 852, \"don't\": 874, \"weren't\": 891, 'teddybear': 961, 'laudry': 1024, 'vegins': 1159, \"didn't\": 1181, \"wouldn't\": 1207, 'neaten': 1242, \"won't\": 1250, 'coffe': 1295, \"bedroom's\": 1385, \"get's\": 1392, 'quiddith': 1464, 'decines': 1550, \"pantry's\": 1577, 'begans': 1615, 'bathrom': 1628, 'refridgerator': 1636, \"who's\": 1652, \"person's\": 1658, \"window's\": 1679, \"camera's\": 1688, 'thens': 1723, \"what's\": 1728, 'theu': 1751, 'begings': 1789, 'doornobo': 1790, 'vacum': 1839, \"house's\": 1851, \"fly's\": 1874, \"phone's\": 1915, 'frig': 1941, 'visable': 1977, \"can't\": 1984, \"aren't\": 2003, 'bathrooom': 2009, 'enterway': 2025, \"night's\": 2080, \"doesn't\": 2109, \"home's\": 2123, 'sittiing': 2154, \"they've\": 2161, \"door's\": 2172, \"anyone's\": 2286, \"clock's\": 2294, 'smiliing': 2304}\n", | |
"53\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LetIutfOfuKp", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"classes=156" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aAoKRWzalo99", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def network(text,act,sen):\n", | |
" with tf.variable_scope('Action'):\n", | |
" act_out=tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50,return_sequences=False))(act)\n", | |
" print(act_out.shape)\n", | |
"\n", | |
" with tf.variable_scope('Text'):\n", | |
" text_out=tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50,return_sequences=False))(text) \n", | |
" print(text_out.shape) \n", | |
"\n", | |
" with tf.variable_scope('Sentence'):\n", | |
" st=tf.keras.layers.Dense(250,activation='relu')(sen)\n", | |
" st=tf.keras.layers.Dense(100,activation='relu')(st)\n", | |
"\n", | |
" fused_stacked_layer=tf.stack([act_out,text_out,st],axis=1)\n", | |
" print(fused_stacked_layer.shape)\n", | |
"\n", | |
" fused_concat_layer=tf.keras.layers.Concatenate()([act_out,text_out,st])\n", | |
" print(fused_concat_layer.shape)\n", | |
"\n", | |
" fused_atten_dense=tf.keras.layers.Dense(128,activation='relu')(fused_concat_layer)\n", | |
" fused_atten_out=tf.keras.layers.Dense(3,activation='softmax')(fused_atten_dense)\n", | |
" print(fused_atten_out.shape)\n", | |
"\n", | |
" fused_out_fin = tf.multiply(tf.expand_dims(fused_atten_out,2),fused_stacked_layer)\n", | |
" print(fused_out_fin.shape)\n", | |
"\n", | |
" fused_out_fin = tf.reshape(fused_out_fin,[tf.shape(fused_out_fin)[0],fused_out_fin.shape[1]*fused_out_fin.shape[2]])\n", | |
" print(fused_out_fin.shape)\n", | |
"\n", | |
" fused_fc=tf.keras.layers.Dense(170,activation='relu')(fused_out_fin)\n", | |
" fused_fc_logits=tf.keras.layers.Dense(classes)(fused_fc)\n", | |
" fused_fc_sft=tf.nn.softmax(fused_fc_logits)\n", | |
" return fused_fc_logits,fused_fc_sft,fused_atten_out\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "sZLUBILmVv9k", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"sentence_encoding[:1]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "lEOEnIwAqi3T", | |
"colab_type": "code", | |
"outputId": "ce4b18f8-25bb-4c45-8cd7-a964cdbea7f8", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
} | |
}, | |
"source": [ | |
"ACTION=tf.placeholder(tf.float32,[None,20,1280])\n", | |
"TEXT=tf.placeholder(tf.int32,[None,thresh+1])\n", | |
"SENTENCE=tf.placeholder(tf.float32,[None,512])\n", | |
"Y=tf.placeholder(tf.int32,[None,1])\n", | |
"\n", | |
"word_embeddings=tf.Variable(embedding_matrix,name=\"emb\",trainable=False,dtype=tf.float32)\n", | |
"txt= tf.nn.embedding_lookup(word_embeddings,TEXT)\n", | |
"\n", | |
"logit,prob,atten=network(txt,ACTION,SENTENCE)\n", | |
"one_hot=tf.one_hot(indices=Y,depth=classes,axis=-1)\n", | |
"loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=one_hot,logits=logit))\n", | |
"\n", | |
"with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):\n", | |
" optimizer=tf.train.AdamOptimizer(0.001).minimize(loss)\n", | |
"\n", | |
"with tf.Session() as sess:\n", | |
" sess.run(tf.global_variables_initializer())\n", | |
" for ep in range(1000):\n", | |
" l,_=sess.run([loss,optimizer],feed_dict={ACTION:np.random.rand(5,20,1280),TEXT:data_feed_script[:5],Y:np.random.randint(152,size=(5,1)),SENTENCE:sentence_encoding[:5]})\n", | |
" print(ep,\":\",l)\n", | |
"\n", | |
" " | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"(?, 100)\n", | |
"(?, 100)\n", | |
"(?, 3, 100)\n", | |
"(?, 300)\n", | |
"(?, 3)\n", | |
"(?, 3, 100)\n", | |
"(?, 300)\n", | |
"0 : 5.084094\n", | |
"1 : 5.140191\n", | |
"2 : 5.0464463\n", | |
"3 : 5.066551\n", | |
"4 : 5.040895\n", | |
"5 : 5.055976\n", | |
"6 : 5.041914\n", | |
"7 : 5.057515\n", | |
"8 : 5.028638\n", | |
"9 : 5.0676007\n", | |
"10 : 5.0415964\n", | |
"11 : 5.093951\n", | |
"12 : 5.041438\n", | |
"13 : 5.1062355\n", | |
"14 : 5.0619154\n", | |
"15 : 5.048433\n", | |
"16 : 5.058339\n", | |
"17 : 5.0459924\n", | |
"18 : 5.0554686\n", | |
"19 : 5.047388\n", | |
"20 : 5.0548134\n", | |
"21 : 5.050895\n", | |
"22 : 5.0838923\n", | |
"23 : 5.020631\n", | |
"24 : 5.0613947\n", | |
"25 : 5.0646677\n", | |
"26 : 4.9761286\n", | |
"27 : 5.015681\n", | |
"28 : 5.0151606\n", | |
"29 : 5.091149\n", | |
"30 : 5.072516\n", | |
"31 : 5.1013336\n", | |
"32 : 5.1015444\n", | |
"33 : 5.0422206\n", | |
"34 : 5.0966477\n", | |
"35 : 5.0671477\n", | |
"36 : 5.086529\n", | |
"37 : 5.0568547\n", | |
"38 : 5.0729733\n", | |
"39 : 5.0722356\n", | |
"40 : 5.0251718\n", | |
"41 : 5.037053\n", | |
"42 : 5.062892\n", | |
"43 : 5.025443\n", | |
"44 : 5.062688\n", | |
"45 : 5.091768\n", | |
"46 : 5.0460234\n", | |
"47 : 5.07825\n", | |
"48 : 5.0746784\n", | |
"49 : 5.0552917\n", | |
"50 : 5.092469\n", | |
"51 : 5.049386\n", | |
"52 : 5.0131245\n", | |
"53 : 5.025819\n", | |
"54 : 4.996691\n", | |
"55 : 5.03922\n", | |
"56 : 5.0430665\n", | |
"57 : 5.079389\n", | |
"58 : 5.034032\n", | |
"59 : 5.0283127\n", | |
"60 : 5.070874\n", | |
"61 : 5.0709095\n", | |
"62 : 4.982889\n", | |
"63 : 5.076527\n", | |
"64 : 5.012154\n", | |
"65 : 5.055999\n", | |
"66 : 4.989286\n", | |
"67 : 5.081506\n", | |
"68 : 5.0998077\n", | |
"69 : 5.040742\n", | |
"70 : 5.102376\n", | |
"71 : 5.005675\n", | |
"72 : 5.089303\n", | |
"73 : 5.009126\n", | |
"74 : 5.1388564\n", | |
"75 : 5.0614815\n", | |
"76 : 5.082844\n", | |
"77 : 5.0669847\n", | |
"78 : 5.1167483\n", | |
"79 : 5.062215\n", | |
"80 : 4.97816\n", | |
"81 : 5.0972013\n", | |
"82 : 5.0274596\n", | |
"83 : 5.0594816\n", | |
"84 : 5.052736\n", | |
"85 : 5.1070566\n", | |
"86 : 4.976112\n", | |
"87 : 5.1168394\n", | |
"88 : 5.0060587\n", | |
"89 : 5.0351644\n", | |
"90 : 5.060041\n", | |
"91 : 5.0270715\n", | |
"92 : 4.975452\n", | |
"93 : 5.053933\n", | |
"94 : 5.0228996\n", | |
"95 : 5.1335196\n", | |
"96 : 5.136248\n", | |
"97 : 4.960604\n", | |
"98 : 5.0995083\n", | |
"99 : 5.035681\n", | |
"100 : 5.076349\n", | |
"101 : 5.099015\n", | |
"102 : 5.0169845\n", | |
"103 : 5.084094\n", | |
"104 : 5.0602336\n", | |
"105 : 5.076144\n", | |
"106 : 5.097354\n", | |
"107 : 5.093378\n", | |
"108 : 5.062036\n", | |
"109 : 5.042508\n", | |
"110 : 5.0566034\n", | |
"111 : 5.0681005\n", | |
"112 : 5.0489616\n", | |
"113 : 5.025224\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "error", | |
"ename": "KeyboardInterrupt", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-211-9a1bd4b85c43>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mglobal_variables_initializer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mACTION\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1280\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mTEXT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mdata_feed_script\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mY\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m152\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mSENTENCE\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0msentence_encoding\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mep\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\":\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 954\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 955\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 956\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 957\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 958\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1178\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1179\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m-> 1180\u001b[0;31m feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[1;32m 1181\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1357\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1358\u001b[0m return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[0;32m-> 1359\u001b[0;31m run_metadata)\n\u001b[0m\u001b[1;32m 1360\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1361\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1363\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1364\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1365\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1366\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1367\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1348\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1349\u001b[0m return self._call_tf_sessionrun(options, feed_dict, fetch_list,\n\u001b[0;32m-> 1350\u001b[0;31m target_list, run_metadata)\n\u001b[0m\u001b[1;32m 1351\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1352\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_call_tf_sessionrun\u001b[0;34m(self, options, feed_dict, fetch_list, target_list, run_metadata)\u001b[0m\n\u001b[1;32m 1441\u001b[0m return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,\n\u001b[1;32m 1442\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1443\u001b[0;31m run_metadata)\n\u001b[0m\u001b[1;32m 1444\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1445\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_tf_sessionprun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "AZ2TtTXOq8lQ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment