Skip to content

Instantly share code, notes, and snippets.

@monk1337
Last active April 6, 2018 19:40
Show Gist options
  • Save monk1337/b42c3d35598e38beb36063f174794273 to your computer and use it in GitHub Desktop.
Save monk1337/b42c3d35598e38beb36063f174794273 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import numpy as np
from tensorflow.contrib import rnn
word_embedding_dim=250
hidden_dim=270
vocab_ = {'\xa0': 60, 'S': 26, 'W': 30, 'É': 62, 'Á': 61, 'ò': 75, 'ê': 71, 'õ': 77, 'ñ': 74, 'J': 17, 'o': 48, ',': 3, "'": 2, 'g': 40, 'Q': 24, 'ż': 87, 'B': 9, 'ç': 68, 'O': 22, 'N': 21, 'D': 11, 'd': 37, 'x': 57, 'q': 50, 'L': 19, 'z': 59, 'U': 28, 'F': 13, 'w': 56, 't': 53, 'h': 41, 'j': 43, '1': 6, 'r': 51, 'e': 38, 'K': 18, 'k': 44, 'ú': 80, 'a': 34, 'ü': 81, 'é': 70, 'I': 16, 'Y': 32, 'ì': 72, 'ó': 76, 'A': 8, 'c': 36, 'E': 12, 'i': 42, 'G': 14, 'à': 64, 'y': 58, 'V': 29, 'C': 10, 'X': 31, 'ä': 67, '0': 0, 'b': 35, 's': 52, '/': 5, 'n': 47, 'p': 49, 'ö': 78, 'ą': 82, ' ': 1, 'Ż': 86, 'l': 45, 'á': 65, 'ù': 79, ':': 7, 'u': 54, 'Z': 33, 'è': 69, 'Ś': 85, 'm': 46, '-': 4, 'ł': 83, 'T': 27, 'P': 23, 'ń': 84, 'R': 25, 'í': 73, 'ã': 66, 'ß': 63, 'v': 55, 'M': 20, 'H': 15, 'f': 39}
data_pad=[[18, 41, 48, 54, 51, 58, 0, 0],[18, 41, 48, 54, 51, 58, 0, 0], [21, 34, 41, 34, 52, 0, 0, 0], [11, 34, 41, 38, 51, 0, 0, 0], [14, 38, 51, 40, 38, 52, 0, 0], [21, 34, 59, 34, 51, 42, 0, 0], [20, 34, 34, 45, 48, 54, 39, 0], [14, 38, 51, 40, 38, 52, 0, 0], [21, 34, 42, 39, 38, 41, 0, 0], [14, 54, 42, 51, 40, 54, 42, 52], [9, 34, 35, 34, 0, 0, 0, 0], [26, 34, 35, 35, 34, 40, 41, 0], [8, 53, 53, 42, 34, 0, 0, 0], [27, 34, 41, 34, 47, 0, 0, 0], [15, 34, 37, 37, 34, 37, 0, 0], [8, 52, 56, 34, 37, 0, 0, 0], [21, 34, 43, 43, 34, 51, 0, 0], [11, 34, 40, 41, 38, 51, 0, 0], [20, 34, 45, 48, 48, 39, 0, 0], [16, 52, 34, 0, 0, 0, 0, 0], [8, 52, 40, 41, 34, 51, 0, 0], [21, 34, 37, 38, 51, 0, 0, 0], [14, 34, 35, 38, 51, 0, 0, 0], [8, 35, 35, 48, 54, 37, 0, 0], [20, 34, 34, 45, 48, 54, 39, 0], [33, 48, 40, 35, 58, 0, 0, 0], [26, 51, 48, 54, 51, 0, 0, 0], [9, 34, 41, 34, 51, 0, 0, 0], [20, 54, 52, 53, 34, 39, 34, 0], [15, 34, 47, 34, 47, 42, 34, 0], [11, 34, 41, 38, 51, 0, 0, 0], [27, 54, 46, 34, 0, 0, 0, 0], [21, 34, 41, 34, 52, 0, 0, 0], [26, 34, 45, 42, 35, 34, 0, 0], [26, 41, 34, 46, 48, 48, 47, 0], [15, 34, 47, 37, 34, 45, 0, 0], [9, 34, 35, 34, 0, 0, 0, 0], [8, 46, 34, 51, 42, 0, 0, 0], [9, 34, 41, 34, 51, 0, 0, 0], [8, 53, 42, 58, 38, 41, 0, 0], [26, 34, 42, 37, 0, 0, 0, 0], [18, 41, 48, 54, 51, 42, 0, 0], [27, 34, 41, 34, 47, 0, 0, 0], [9, 34, 35, 34, 0, 0, 0, 0], [20, 54, 52, 53, 34, 39, 34, 0], [14, 54, 42, 51, 40, 54, 42, 52], [26, 45, 38, 42, 46, 34, 47, 0], [26, 38, 42, 39, 0, 0, 0, 0], [11, 34, 40, 41, 38, 51, 0, 0], [9, 34, 41, 34, 51, 0, 0, 0], [14, 34, 35, 38, 51, 0, 0, 0], [15, 34, 51, 35, 0, 0, 0, 0], [26, 38, 42, 39, 0, 0, 0, 0], [8, 52, 44, 38, 51, 0, 0, 0], [21, 34, 37, 38, 51, 0, 0, 0], [8, 47, 53, 34, 51, 0, 0, 0], [8, 56, 34, 37, 0, 0, 0, 0], [26, 51, 48, 54, 51, 0, 0, 0], [26, 41, 34, 37, 42, 37, 0, 0], [15, 34, 43, 43, 34, 51, 0, 0], [15, 34, 47, 34, 47, 42, 34, 0], [18, 34, 45, 35, 0, 0, 0, 0], [26, 41, 34, 37, 42, 37, 0, 0], [9, 34, 59, 59, 42, 0, 0, 0], [20, 54, 52, 53, 34, 39, 34, 0], [20, 34, 52, 42, 41, 0, 0, 0], [14, 41, 34, 47, 38, 46, 0, 0], [15, 34, 37, 37, 34, 37, 0, 0], [16, 52, 34, 0, 0, 0, 0, 0], [8, 47, 53, 48, 54, 47, 0, 0], [26, 34, 51, 51, 34, 39, 0, 0], [26, 45, 38, 42, 46, 34, 47, 0], [11, 34, 40, 41, 38, 51, 0, 0], [21, 34, 43, 43, 34, 51, 0, 0], [20, 34, 45, 48, 54, 39, 0, 0], [21, 34, 41, 34, 52, 0, 0, 0], [21, 34, 52, 38, 51, 0, 0, 0], [26, 34, 45, 42, 35, 34, 0, 0], [26, 41, 34, 46, 48, 47, 0, 0], [20, 34, 45, 48, 54, 39, 0, 0], [18, 34, 45, 35, 0, 0, 0, 0], [11, 34, 41, 38, 51, 0, 0, 0], [20, 34, 34, 45, 48, 54, 39, 0], [30, 34, 52, 38, 46, 0, 0, 0], [18, 34, 47, 34, 34, 47, 0, 0], [21, 34, 42, 39, 38, 41, 0, 0], [9, 48, 54, 53, 51, 48, 52, 0], [20, 48, 40, 41, 34, 37, 34, 46], [20, 34, 52, 42, 41, 0, 0, 0], [26, 45, 38, 42, 46, 34, 47, 0], [8, 52, 56, 34, 37, 0, 0, 0], [10, 41, 34, 46, 0, 0, 0, 0], [8, 52, 52, 34, 39, 0, 0, 0], [24, 54, 51, 34, 42, 52, 41, 42], [26, 41, 34, 45, 41, 48, 54, 35], [26, 34, 35, 35, 34, 40, 0, 0], [20, 42, 39, 52, 54, 37, 0, 0], [14, 34, 35, 38, 51, 0, 0, 0], [26, 41, 34, 46, 46, 34, 52, 0], [27, 34, 47, 47, 48, 54, 52, 0], [26, 45, 38, 42, 46, 34, 47, 0], [9, 34, 59, 59, 42, 0, 0, 0], [24, 54, 51, 34, 42, 52, 41, 42], [25, 34, 41, 34, 45, 0, 0, 0], [10, 41, 34, 46, 0, 0, 0, 0], [14, 41, 34, 47, 38, 46, 0, 0], [14, 41, 34, 47, 38, 46, 0, 0], [21, 34, 52, 38, 51, 0, 0, 0], [9, 34, 35, 34, 0, 0, 0, 0], [26, 41, 34, 46, 48, 47, 0, 0], [8, 45, 46, 34, 52, 42, 0, 0], [9, 34, 52, 34, 51, 34, 0, 0], [24, 54, 51, 34, 42, 52, 41, 42], [9, 34, 53, 34, 0, 0, 0, 0], [30, 34, 52, 38, 46, 0, 0, 0], [26, 41, 34, 46, 48, 54, 47, 0], [11, 38, 38, 35, 0, 0, 0, 0], [27, 48, 54, 46, 34, 0, 0, 0], [8, 52, 39, 48, 54, 51, 0, 0], [11, 38, 38, 35, 0, 0, 0, 0], [15, 34, 37, 34, 37, 0, 0, 0], [21, 34, 42, 39, 38, 41, 0, 0], [27, 48, 54, 46, 34, 0, 0, 0], [9, 34, 59, 59, 42, 0, 0, 0], [26, 41, 34, 46, 48, 54, 47, 0], [21, 34, 41, 34, 52, 0, 0, 0], [15, 34, 37, 37, 34, 37, 0, 0], [8, 51, 42, 34, 47, 0, 0, 0], [18, 48, 54, 51, 42, 0, 0, 0], [11, 38, 38, 35, 0, 0, 0, 0], [27, 48, 46, 34, 0, 0, 0, 0], [15, 34, 45, 34, 35, 42, 0, 0], [21, 34, 59, 34, 51, 42, 0, 0], [26, 34, 45, 42, 35, 34, 0, 0], [13, 34, 44, 41, 48, 54, 51, 58], [15, 34, 37, 34, 37, 0, 0, 0], [9, 34, 35, 34, 0, 0, 0, 0], [20, 34, 47, 52, 48, 54, 51, 0], [26, 34, 58, 38, 40, 41, 0, 0], [8, 47, 53, 34, 51, 0, 0, 0], [11, 38, 38, 35, 0, 0, 0, 0], [20, 48, 51, 36, 48, 52, 0, 0], [26, 41, 34, 45, 41, 48, 54, 35], [26, 34, 51, 51, 34, 39, 0, 0], [8, 46, 34, 51, 42, 0, 0, 0], [30, 34, 52, 38, 46, 0, 0, 0], [14, 34, 47, 42, 46, 0, 0, 0], [27, 54, 46, 34, 0, 0, 0, 0], [13, 34, 44, 41, 48, 54, 51, 58], [15, 34, 37, 34, 37, 0, 0, 0], [15, 34, 44, 42, 46, 42, 0, 0], [21, 34, 37, 38, 51, 0, 0, 0], [26, 34, 42, 37, 0, 0, 0, 0], [14, 34, 47, 42, 46, 0, 0, 0], [11, 34, 41, 38, 51, 0, 0, 0], [14, 34, 47, 38, 46, 0, 0, 0], [27, 54, 46, 34, 0, 0, 0, 0], [9, 48, 54, 53, 51, 48, 52, 0], [8, 52, 56, 34, 37, 0, 0, 0], [26, 34, 51, 44, 42, 52, 0, 0], [11, 34, 41, 38, 51, 0, 0, 0], [27, 48, 46, 34, 0, 0, 0, 0], [9, 48, 54, 53, 51, 48, 52, 0], [18, 34, 47, 34, 34, 47, 0, 0], [8, 47, 53, 34, 51, 0, 0, 0], [14, 38, 51, 40, 38, 52, 0, 0], [18, 48, 54, 51, 42, 0, 0, 0], [20, 34, 51, 48, 54, 47, 0, 0], [30, 34, 52, 38, 46, 0, 0, 0], [11, 34, 40, 41, 38, 51, 0, 0], [21, 34, 42, 39, 38, 41, 0, 0], [9, 42, 52, 41, 34, 51, 34, 0], [9, 34, 0, 0, 0, 0, 0, 0], [10, 41, 34, 46, 0, 0, 0, 0], [18, 34, 45, 35, 0, 0, 0, 0], [9, 34, 59, 59, 42, 0, 0, 0], [9, 42, 53, 34, 51, 0, 0, 0], [15, 34, 37, 34, 37, 0, 0, 0], [20, 48, 40, 41, 34, 37, 34, 46], [26, 45, 38, 42, 46, 34, 47, 0], [26, 41, 34, 46, 48, 54, 47, 0], [8, 47, 53, 34, 51, 0, 0, 0], [8, 53, 42, 58, 38, 41, 0, 0], [18, 48, 54, 51, 58, 0, 0, 0], [21, 34, 41, 34, 52, 0, 0, 0], [18, 48, 54, 51, 42, 0, 0, 0], [20, 34, 51, 48, 54, 47, 0, 0], [21, 34, 52, 52, 34, 51, 0, 0], [26, 34, 58, 38, 40, 41, 0, 0], [15, 34, 42, 44, 0, 0, 0, 0], [14, 41, 34, 47, 38, 46, 0, 0], [26, 34, 58, 38, 40, 41, 0, 0], [26, 34, 45, 42, 35, 0, 0, 0], [10, 41, 34, 46, 0, 0, 0, 0], [9, 34, 53, 34, 0, 0, 0, 0], [27, 48, 54, 46, 34, 0, 0, 0], [8, 47, 53, 48, 54, 47, 0, 0], [8, 47, 53, 34, 51, 0, 0, 0], [9, 34, 53, 34, 0, 0, 0, 0], [9, 48, 53, 51, 48, 52, 0, 0]]
labels_x = [1,1, 13, 5, 6, 7, 14, 10, 14, 8, 11, 8, 11, 7, 3, 9, 14, 11, 11, 4, 9, 2, 4, 9, 12, 3, 14, 7, 13, 14, 1, 6, 13, 14, 7, 9, 0, 12, 4, 8, 12, 6, 1, 6, 7, 11, 14, 8, 4, 0, 5, 7, 12, 2, 5, 3, 9, 14, 1, 10, 12, 12, 14, 2, 2, 12, 13, 0, 2, 11, 2, 5, 6, 1, 3, 6, 6, 10, 14, 14, 2, 5, 0, 5, 6, 4, 3, 12, 9, 0, 11, 6, 8, 10, 6, 10, 2, 2, 5, 4, 3, 11, 3, 1, 0, 1, 4, 11, 10, 9, 13, 5, 4, 13, 12, 5, 12, 11, 9, 5, 11, 10, 8, 0, 9, 3, 6, 3, 5, 12, 8, 11, 6, 5, 6, 5, 1, 10, 5, 6, 2, 9, 7, 10, 3, 2, 6, 2, 8, 8, 3, 14, 3, 5, 8, 1, 12, 13, 3, 3, 14, 7, 7, 14, 1, 4, 5, 3, 10, 6, 0, 3, 0, 10, 3, 14, 8, 3, 6, 9, 3, 12, 12, 0, 10, 6, 0, 2, 2, 5, 7, 5, 4, 5, 6, 0, 13, 0, 11, 10]
batch_size=5
epoch=500
iteration=len(data_pad)//batch_size
indx_tola={i:j for i,j in enumerate(labels_x)}
input_x= tf.placeholder(tf.int32,shape=[None,None])
output_y= tf.placeholder(tf.int32,shape=[None,])
word_embedding =tf.get_variable('embedding',shape=[len(vocab_),word_embedding_dim],dtype=tf.float32,initializer=tf.random_uniform_initializer(-0.01,0.01))
lookup_embedding = tf.nn.embedding_lookup(word_embedding,input_x)
sequence_len=tf.count_nonzero(input_x,axis=-1)
with tf.variable_scope('encoder') as scope:
cell=rnn.LSTMCell(hidden_dim)
model=tf.nn.bidirectional_dynamic_rnn(cell,cell,inputs=lookup_embedding,sequence_length=sequence_len,dtype=tf.float32)
model_output,(fs,fc)=model
transpose1=tf.transpose(model_output[0],[1,0,2])
concat=tf.concat((fs.c,fc.c),axis=-1)
weights=tf.get_variable('weight',shape=[2*hidden_dim,len(labels_x)],dtype=tf.float32,initializer=tf.random_uniform_initializer(-0.01,0.01))
bias=tf.get_variable('bias',shape=[len(labels_x)],dtype=tf.float32,initializer=tf.random_uniform_initializer(-0.01,0.01))
final_output=tf.matmul(concat,weights) + bias
#normalizatio
normal_a=tf.nn.softmax(final_output)
pred=tf.argmax(normal_a,axis=-1)
#cross entropy
ce=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=final_output,labels=output_y)
loss=tf.reduce_mean(ce)
#evaluate
evalu=tf.reduce_mean(tf.cast(tf.equal(tf.cast(pred,tf.int32),output_y),tf.float32))
#train
train=tf.train.AdamOptimizer().minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(epoch):
for j in range(iteration):
data_batch = data_pad[j * batch_size:(j + 1) * batch_size]
labels_batch = labels_x[j * batch_size:(j + 1) * batch_size]
a_new, b_new, c_new, d_new, e_new, f_new, g_new_, h_new = sess.run(
[model, final_output, normal_a, pred, ce, loss, evalu, train],feed_dict={input_x:data_batch,output_y:labels_batch})
print("epoch {} loss {} , iteration {} accuracy {}".format(i,f_new,j,g_new_))
#interaction
while True:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
input11 = [int(i.replace(',', '')) for i in input().split()]
if input11 == 'q':
break
else:
a, b, c = sess.run([model, normal_a, pred], feed_dict={input_x: [input11]})
tolist = np.array(b[0]).tolist()
print(sorted([(indx_tola[i], j) for i, j in enumerate(tolist)], key=lambda x: x[1], reverse=True)[:3])
print(tolist[c[0]])
#output testing
epoch 499 loss 0.9233143925666809 , iteration 38 accuracy 0.4000000059604645
epoch 499 loss 0.9760797619819641 , iteration 39 accuracy 0.20000000298023224
18, 41, 48, 54, 51, 58, 0, 0
[(3, 0.005048779770731926), (6, 0.005048572551459074), (3, 0.005048364866524935)]
0.005048779770731926
18, 41, 48, 54, 51, 58, 0, 0
[(12, 0.00504953321069479), (7, 0.005048563238233328), (10, 0.00504817022010684)]
0.00504953321069479
18, 41, 48, 54, 51, 58, 0, 0
[(9, 0.005051372107118368), (14, 0.005050850100815296), (7, 0.005049719940871)]
0.005051372107118368
18, 41, 48, 54, 51, 58, 0, 0
[(0, 0.005058418493717909), (10, 0.005055117420852184), (12, 0.0050542913377285)]
0.005058418493717909
18, 41, 48, 54, 51, 58, 0, 0
[(9, 0.005047139711678028), (14, 0.00504682632163167), (3, 0.005046062637120485)]
0.005047139711678028
18, 41, 48, 54, 51, 58, 0, 0
[(6, 0.005050025414675474), (8, 0.005049354396760464), (6, 0.005048769526183605)]
0.005050025414675474
18, 41, 48, 54, 51, 58, 0, 0
[(8, 0.005048401188105345), (8, 0.005047514569014311), (13, 0.005047360435128212)]
0.005048401188105345
18, 41, 48, 54, 51, 58, 0, 0
[(1, 0.005045496858656406), (14, 0.005045333877205849), (3, 0.005045332480221987)]
0.005045496858656406
21, 34, 42, 39, 38, 41, 0, 0
[(8, 0.005050063133239746), (14, 0.005049287807196379), (9, 0.005048781633377075)]
0.005050063133239746
21, 34, 42, 39, 38, 41, 0, 0
[(6, 0.005050109699368477), (3, 0.005049820989370346), (14, 0.005048571154475212)]
0.005050109699368477
21, 34, 42, 39, 38, 41, 0, 0
[(1, 0.005054341163486242), (12, 0.005052998661994934), (6, 0.005051418207585812)]
0.005054341163486242
21, 34, 42, 39, 38, 41, 0, 0
[(9, 0.005053166300058365), (13, 0.00505290599539876), (6, 0.005051507614552975)]
0.005053166300058365
27, 34, 41, 34, 47, 0, 0, 0
[(3, 0.005050933454185724), (4, 0.0050498913042247295), (13, 0.005047991871833801)]
0.005050933454185724
21, 34, 43, 43, 34, 51, 0, 0
[(3, 0.005046783946454525), (10, 0.005045922938734293), (12, 0.005045799538493156)]
0.005046783946454525
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment