Skip to content

Instantly share code, notes, and snippets.

@rsk97
Last active December 10, 2019 16:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rsk97/e2f2f35f069ad2046506c01dc8e3fe9e to your computer and use it in GitHub Desktop.
Save rsk97/e2f2f35f069ad2046506c01dc8e3fe9e to your computer and use it in GitHub Desktop.
CGM-Text-embedding.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "CGM-Text-embedding.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/rsk97/e2f2f35f069ad2046506c01dc8e3fe9e/cgm-text-embedding.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "QbbEqZaxliDr",
"colab_type": "code",
"colab": {}
},
"source": [
"import tensorflow as tf\n",
"# import cv2 as cv\n",
"import os\n",
"import numpy as np\n",
"import glob\n",
"# # from random import shuffle\n",
"from google.colab import drive\n",
"import matplotlib.pyplot as plt\n",
"import csv\n",
"import re\n",
"import json\n",
"from nltk.corpus import stopwords \n",
"from nltk.tokenize import word_tokenize \n",
"from keras.preprocessing.text import Tokenizer\n",
"from keras.preprocessing.sequence import pad_sequences\n",
"from keras.layers import Embedding,Input\n",
"from keras.models import Model\n",
"from keras.optimizers import Adam\n",
"from keras.layers import Bidirectional,LSTM,GlobalMaxPool1D,Dense\n",
"from tensorflow.contrib import rnn\n",
"from sklearn.utils import shuffle\n",
"import pandas as pd\n",
"import tensorflow_hub as hub\n",
"from sklearn.model_selection import train_test_split\n",
"import statistics as st\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "rFAHbwFHmXKB",
"colab_type": "code",
"outputId": "0fd82942-b9d0-4ec0-9846-5d50594966b2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 126
}
},
"source": [
"drive.mount(\"/content/drive\")"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n",
"\n",
"Enter your authorization code:\n",
"··········\n",
"Mounted at /content/drive\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "eoeNCC1SHory",
"colab_type": "code",
"colab": {}
},
"source": [
"df=pd.read_csv(\"/content/drive/My Drive/Charades_v1_train.csv\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8NJluwraknsj",
"colab_type": "code",
"outputId": "400a2d96-77eb-4a19-9ee7-a5d01df734d6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 189
}
},
"source": [
"df.head()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>subject</th>\n",
" <th>scene</th>\n",
" <th>quality</th>\n",
" <th>relevance</th>\n",
" <th>verified</th>\n",
" <th>script</th>\n",
" <th>objects</th>\n",
" <th>descriptions</th>\n",
" <th>actions</th>\n",
" <th>length</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>46GP8</td>\n",
" <td>HR43</td>\n",
" <td>Kitchen</td>\n",
" <td>6.0</td>\n",
" <td>7.0</td>\n",
" <td>Yes</td>\n",
" <td>A person cooking on a stove while watching som...</td>\n",
" <td>food;stove;window</td>\n",
" <td>A person cooks food on a stove before looking ...</td>\n",
" <td>c092 11.90 21.20;c147 0.00 12.60</td>\n",
" <td>24.83</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>N11GT</td>\n",
" <td>0KZ7</td>\n",
" <td>Stairs</td>\n",
" <td>6.0</td>\n",
" <td>7.0</td>\n",
" <td>Yes</td>\n",
" <td>One person opens up a folded blanket, then sne...</td>\n",
" <td>blanket;broom;floor</td>\n",
" <td>Person at the bottom of the staircase shakes a...</td>\n",
" <td>c098 8.60 14.20;c075 0.00 11.70;c127 0.00 15.2...</td>\n",
" <td>18.33</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0IH69</td>\n",
" <td>6RE8</td>\n",
" <td>Bedroom</td>\n",
" <td>6.0</td>\n",
" <td>5.0</td>\n",
" <td>Yes</td>\n",
" <td>A person is seen leaving a cabinet. They then ...</td>\n",
" <td>book;box;cabinet;shelf</td>\n",
" <td>A person is standing in a bedroom. They walk o...</td>\n",
" <td>NaN</td>\n",
" <td>30.25</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>KRF68</td>\n",
" <td>YA10</td>\n",
" <td>Laundry room</td>\n",
" <td>6.0</td>\n",
" <td>7.0</td>\n",
" <td>Yes</td>\n",
" <td>A person runs into their laundry room. They gr...</td>\n",
" <td>clothes;door;phone</td>\n",
" <td>A person runs in and shuts door. The person gr...</td>\n",
" <td>c018 22.60 27.80;c141 4.10 9.60;c148 10.30 25....</td>\n",
" <td>30.33</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>MJO7C</td>\n",
" <td>6RE8</td>\n",
" <td>Kitchen</td>\n",
" <td>6.0</td>\n",
" <td>6.0</td>\n",
" <td>Yes</td>\n",
" <td>A person runs into their pantry holding a bott...</td>\n",
" <td>cup;phone</td>\n",
" <td>A person runs in place while holding a bottle ...</td>\n",
" <td>c015 0.00 32.00;c107 0.00 32.00</td>\n",
" <td>31.38</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id subject ... actions length\n",
"0 46GP8 HR43 ... c092 11.90 21.20;c147 0.00 12.60 24.83\n",
"1 N11GT 0KZ7 ... c098 8.60 14.20;c075 0.00 11.70;c127 0.00 15.2... 18.33\n",
"2 0IH69 6RE8 ... NaN 30.25\n",
"3 KRF68 YA10 ... c018 22.60 27.80;c141 4.10 9.60;c148 10.30 25.... 30.33\n",
"4 MJO7C 6RE8 ... c015 0.00 32.00;c107 0.00 32.00 31.38\n",
"\n",
"[5 rows x 11 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 180
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "eqhdcgDdNnKs",
"colab_type": "code",
"colab": {}
},
"source": [
"scripts=df['script'].values.tolist()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Z76gQx6Nt-nf",
"colab_type": "code",
"colab": {}
},
"source": [
"def embed_useT(module):\n",
" with tf.Graph().as_default():\n",
" sentences = tf.placeholder(tf.string)\n",
" embed = hub.Module(module)\n",
" embeddings = embed(sentences)\n",
" session = tf.train.MonitoredSession()\n",
" return lambda x: session.run(embeddings, {sentences: x})"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Uat9ZQOJvm1d",
"colab_type": "code",
"outputId": "58429a25-090b-4628-b7a8-c17614541cb0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 88
}
},
"source": [
"embed_fn = embed_useT('/content/drive/My Drive/USE/')"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "4pK3Nf6kuA7P",
"colab_type": "code",
"colab": {}
},
"source": [
"sentence_encoding=embed_fn(scripts)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Xc5jJei8uA_m",
"colab_type": "code",
"outputId": "5505896f-edc3-4f3b-f04e-e20e7ec4ece3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"sentence_encoding.shape"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(7985, 512)"
]
},
"metadata": {
"tags": []
},
"execution_count": 186
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "NaeUTaoEuA5M",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0RW4yJWtIGDZ",
"colab_type": "code",
"colab": {}
},
"source": [
"vocab=set()\n",
"for i in df['script']:\n",
" v=i.split()\n",
" vocab.update(v)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "FZoNFkcHI1aZ",
"colab_type": "code",
"colab": {}
},
"source": [
"vocab_len=len(vocab)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2Pd9xgNwNhbE",
"colab_type": "code",
"colab": {}
},
"source": [
"tok = Tokenizer(num_words=30000,lower=True)\n",
"tok.fit_on_texts(vocab)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "HoqwC_FpNnnU",
"colab_type": "code",
"colab": {}
},
"source": [
"seq_script=tok.texts_to_sequences(scripts)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "pLJrwhVZNnHt",
"colab_type": "code",
"colab": {}
},
"source": [
"max_seq_len=0\n",
"a=[]\n",
"for i in seq_script:\n",
" a.append(len(i))\n",
" if len(i)>max_seq_len:\n",
" max_seq_len=len(i)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "HoKgukEofC2U",
"colab_type": "code",
"colab": {}
},
"source": [
"word2idx = tok.word_index"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "wKiU2f6LRc-D",
"colab_type": "code",
"outputId": "84f53589-9e51-4a4e-83ea-27627502880e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"max_seq_len"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"84"
]
},
"metadata": {
"tags": []
},
"execution_count": 194
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "OH8c_pZfRxlt",
"colab_type": "code",
"colab": {}
},
"source": [
"mn=st.mean(a)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0oq8LOdpShyX",
"colab_type": "code",
"colab": {}
},
"source": [
"std=st.stdev(a)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "u5kk4pkeUFru",
"colab_type": "code",
"outputId": "cc83b34a-a94c-4279-f3e6-b9f1e5a153b6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"thresh=int((mn+(3*std)))\n",
"thresh"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"43"
]
},
"metadata": {
"tags": []
},
"execution_count": 197
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "wm-aY8TRcOc3",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "EkmBYGsVbz3Q",
"colab_type": "code",
"colab": {}
},
"source": [
"for ind,val in enumerate(seq_script):\n",
" seq_script[ind]=seq_script[ind][:thresh]"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xXh_FkbnUR0G",
"colab_type": "code",
"colab": {}
},
"source": [
"data_feed_script= pad_sequences(seq_script, maxlen=thresh+1)\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "NCjN07iOVy9N",
"colab_type": "code",
"outputId": "c6f59229-7628-4893-ef2e-c02159babd22",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"data_feed_script.shape"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(7985, 44)"
]
},
"metadata": {
"tags": []
},
"execution_count": 200
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cSTRsalIeC_g",
"colab_type": "code",
"colab": {}
},
"source": [
"emb={}\n",
"with open(os.path.join('/content/drive/My Drive/glove.6B.50d.txt')) as f :\n",
" for l in f:\n",
" v=l.split()\n",
" w=v[0]\n",
" emb[w]=v[1:]"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9ozVg1fqeG0f",
"colab_type": "code",
"outputId": "d200ccc9-c134-447f-c472-71edd5cb164c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 72
}
},
"source": [
"t={}\n",
"emb_dim=50\n",
"num_words =vocab_len+1 \n",
"embedding_matrix = np.random.rand(num_words,emb_dim)*2-1\n",
"for word,i in word2idx.items():\n",
" embedding_vector = emb.get(word)\n",
" if embedding_vector is not None:\n",
" embedding_matrix[i] = np.asarray(embedding_vector)\n",
" else:\n",
" t[word]=i\n",
" embedding_matrix[i] = np.asarray(emb.get(\"unk\"))\n",
"print(t)\n",
"print(len(t))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"{\"they're\": 382, 'selfie': 449, \"that's\": 546, \"it's\": 627, 'selfies': 688, 'processds': 836, \"couldn't\": 852, \"don't\": 874, \"weren't\": 891, 'teddybear': 961, 'laudry': 1024, 'vegins': 1159, \"didn't\": 1181, \"wouldn't\": 1207, 'neaten': 1242, \"won't\": 1250, 'coffe': 1295, \"bedroom's\": 1385, \"get's\": 1392, 'quiddith': 1464, 'decines': 1550, \"pantry's\": 1577, 'begans': 1615, 'bathrom': 1628, 'refridgerator': 1636, \"who's\": 1652, \"person's\": 1658, \"window's\": 1679, \"camera's\": 1688, 'thens': 1723, \"what's\": 1728, 'theu': 1751, 'begings': 1789, 'doornobo': 1790, 'vacum': 1839, \"house's\": 1851, \"fly's\": 1874, \"phone's\": 1915, 'frig': 1941, 'visable': 1977, \"can't\": 1984, \"aren't\": 2003, 'bathrooom': 2009, 'enterway': 2025, \"night's\": 2080, \"doesn't\": 2109, \"home's\": 2123, 'sittiing': 2154, \"they've\": 2161, \"door's\": 2172, \"anyone's\": 2286, \"clock's\": 2294, 'smiliing': 2304}\n",
"53\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LetIutfOfuKp",
"colab_type": "code",
"colab": {}
},
"source": [
"classes=156"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "aAoKRWzalo99",
"colab_type": "code",
"colab": {}
},
"source": [
"def network(text,act,sen):\n",
" with tf.variable_scope('Action'):\n",
" act_out=tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50,return_sequences=False))(act)\n",
" print(act_out.shape)\n",
"\n",
" with tf.variable_scope('Text'):\n",
" text_out=tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50,return_sequences=False))(text) \n",
" print(text_out.shape) \n",
"\n",
" with tf.variable_scope('Sentence'):\n",
" st=tf.keras.layers.Dense(250,activation='relu')(sen)\n",
" st=tf.keras.layers.Dense(100,activation='relu')(st)\n",
"\n",
" fused_stacked_layer=tf.stack([act_out,text_out,st],axis=1)\n",
" print(fused_stacked_layer.shape)\n",
"\n",
" fused_concat_layer=tf.keras.layers.Concatenate()([act_out,text_out,st])\n",
" print(fused_concat_layer.shape)\n",
"\n",
" fused_atten_dense=tf.keras.layers.Dense(128,activation='relu')(fused_concat_layer)\n",
" fused_atten_out=tf.keras.layers.Dense(3,activation='softmax')(fused_atten_dense)\n",
" print(fused_atten_out.shape)\n",
"\n",
" fused_out_fin = tf.multiply(tf.expand_dims(fused_atten_out,2),fused_stacked_layer)\n",
" print(fused_out_fin.shape)\n",
"\n",
" fused_out_fin = tf.reshape(fused_out_fin,[tf.shape(fused_out_fin)[0],fused_out_fin.shape[1]*fused_out_fin.shape[2]])\n",
" print(fused_out_fin.shape)\n",
"\n",
" fused_fc=tf.keras.layers.Dense(170,activation='relu')(fused_out_fin)\n",
" fused_fc_logits=tf.keras.layers.Dense(classes)(fused_fc)\n",
" fused_fc_sft=tf.nn.softmax(fused_fc_logits)\n",
" return fused_fc_logits,fused_fc_sft,fused_atten_out\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sZLUBILmVv9k",
"colab_type": "code",
"colab": {}
},
"source": [
"sentence_encoding[:1]"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "lEOEnIwAqi3T",
"colab_type": "code",
"outputId": "ce4b18f8-25bb-4c45-8cd7-a964cdbea7f8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"ACTION=tf.placeholder(tf.float32,[None,20,1280])\n",
"TEXT=tf.placeholder(tf.int32,[None,thresh+1])\n",
"SENTENCE=tf.placeholder(tf.float32,[None,512])\n",
"Y=tf.placeholder(tf.int32,[None,1])\n",
"\n",
"word_embeddings=tf.Variable(embedding_matrix,name=\"emb\",trainable=False,dtype=tf.float32)\n",
"txt= tf.nn.embedding_lookup(word_embeddings,TEXT)\n",
"\n",
"logit,prob,atten=network(txt,ACTION,SENTENCE)\n",
"one_hot=tf.one_hot(indices=Y,depth=classes,axis=-1)\n",
"loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=one_hot,logits=logit))\n",
"\n",
"with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):\n",
" optimizer=tf.train.AdamOptimizer(0.001).minimize(loss)\n",
"\n",
"with tf.Session() as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
" for ep in range(1000):\n",
" l,_=sess.run([loss,optimizer],feed_dict={ACTION:np.random.rand(5,20,1280),TEXT:data_feed_script[:5],Y:np.random.randint(152,size=(5,1)),SENTENCE:sentence_encoding[:5]})\n",
" print(ep,\":\",l)\n",
"\n",
" "
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"(?, 100)\n",
"(?, 100)\n",
"(?, 3, 100)\n",
"(?, 300)\n",
"(?, 3)\n",
"(?, 3, 100)\n",
"(?, 300)\n",
"0 : 5.084094\n",
"1 : 5.140191\n",
"2 : 5.0464463\n",
"3 : 5.066551\n",
"4 : 5.040895\n",
"5 : 5.055976\n",
"6 : 5.041914\n",
"7 : 5.057515\n",
"8 : 5.028638\n",
"9 : 5.0676007\n",
"10 : 5.0415964\n",
"11 : 5.093951\n",
"12 : 5.041438\n",
"13 : 5.1062355\n",
"14 : 5.0619154\n",
"15 : 5.048433\n",
"16 : 5.058339\n",
"17 : 5.0459924\n",
"18 : 5.0554686\n",
"19 : 5.047388\n",
"20 : 5.0548134\n",
"21 : 5.050895\n",
"22 : 5.0838923\n",
"23 : 5.020631\n",
"24 : 5.0613947\n",
"25 : 5.0646677\n",
"26 : 4.9761286\n",
"27 : 5.015681\n",
"28 : 5.0151606\n",
"29 : 5.091149\n",
"30 : 5.072516\n",
"31 : 5.1013336\n",
"32 : 5.1015444\n",
"33 : 5.0422206\n",
"34 : 5.0966477\n",
"35 : 5.0671477\n",
"36 : 5.086529\n",
"37 : 5.0568547\n",
"38 : 5.0729733\n",
"39 : 5.0722356\n",
"40 : 5.0251718\n",
"41 : 5.037053\n",
"42 : 5.062892\n",
"43 : 5.025443\n",
"44 : 5.062688\n",
"45 : 5.091768\n",
"46 : 5.0460234\n",
"47 : 5.07825\n",
"48 : 5.0746784\n",
"49 : 5.0552917\n",
"50 : 5.092469\n",
"51 : 5.049386\n",
"52 : 5.0131245\n",
"53 : 5.025819\n",
"54 : 4.996691\n",
"55 : 5.03922\n",
"56 : 5.0430665\n",
"57 : 5.079389\n",
"58 : 5.034032\n",
"59 : 5.0283127\n",
"60 : 5.070874\n",
"61 : 5.0709095\n",
"62 : 4.982889\n",
"63 : 5.076527\n",
"64 : 5.012154\n",
"65 : 5.055999\n",
"66 : 4.989286\n",
"67 : 5.081506\n",
"68 : 5.0998077\n",
"69 : 5.040742\n",
"70 : 5.102376\n",
"71 : 5.005675\n",
"72 : 5.089303\n",
"73 : 5.009126\n",
"74 : 5.1388564\n",
"75 : 5.0614815\n",
"76 : 5.082844\n",
"77 : 5.0669847\n",
"78 : 5.1167483\n",
"79 : 5.062215\n",
"80 : 4.97816\n",
"81 : 5.0972013\n",
"82 : 5.0274596\n",
"83 : 5.0594816\n",
"84 : 5.052736\n",
"85 : 5.1070566\n",
"86 : 4.976112\n",
"87 : 5.1168394\n",
"88 : 5.0060587\n",
"89 : 5.0351644\n",
"90 : 5.060041\n",
"91 : 5.0270715\n",
"92 : 4.975452\n",
"93 : 5.053933\n",
"94 : 5.0228996\n",
"95 : 5.1335196\n",
"96 : 5.136248\n",
"97 : 4.960604\n",
"98 : 5.0995083\n",
"99 : 5.035681\n",
"100 : 5.076349\n",
"101 : 5.099015\n",
"102 : 5.0169845\n",
"103 : 5.084094\n",
"104 : 5.0602336\n",
"105 : 5.076144\n",
"106 : 5.097354\n",
"107 : 5.093378\n",
"108 : 5.062036\n",
"109 : 5.042508\n",
"110 : 5.0566034\n",
"111 : 5.0681005\n",
"112 : 5.0489616\n",
"113 : 5.025224\n"
],
"name": "stdout"
},
{
"output_type": "error",
"ename": "KeyboardInterrupt",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-211-9a1bd4b85c43>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mglobal_variables_initializer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mACTION\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1280\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mTEXT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mdata_feed_script\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mY\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m152\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mSENTENCE\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0msentence_encoding\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mep\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\":\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 954\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 955\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 956\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 957\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 958\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1178\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1179\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m-> 1180\u001b[0;31m feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[1;32m 1181\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1357\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1358\u001b[0m return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[0;32m-> 1359\u001b[0;31m run_metadata)\n\u001b[0m\u001b[1;32m 1360\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1361\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1363\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1364\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1365\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1366\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1367\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1348\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1349\u001b[0m return self._call_tf_sessionrun(options, feed_dict, fetch_list,\n\u001b[0;32m-> 1350\u001b[0;31m target_list, run_metadata)\n\u001b[0m\u001b[1;32m 1351\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1352\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_call_tf_sessionrun\u001b[0;34m(self, options, feed_dict, fetch_list, target_list, run_metadata)\u001b[0m\n\u001b[1;32m 1441\u001b[0m return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,\n\u001b[1;32m 1442\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1443\u001b[0;31m run_metadata)\n\u001b[0m\u001b[1;32m 1444\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1445\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_tf_sessionprun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "AZ2TtTXOq8lQ",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment