Skip to content

Instantly share code, notes, and snippets.

@hdon
Last active March 15, 2018 07:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hdon/cdfff5e96ffea2f2ff71fc57b406ee11 to your computer and use it in GitHub Desktop.
Save hdon/cdfff5e96ffea2f2ff71fc57b406ee11 to your computer and use it in GitHub Desktop.
garbage technique for training word embedding+predictor
# notable flaws:
# totally unorthodox method for training embedding (see lines 78, 81 and 82)
# extremely inefficient demo evaluation every 100,000 training steps (see lines 125-140)
# surprises:
# was actually extremely effective on small data sets
from __future__ import print_function
import tensorflow as tf
import numpy as np
import re, random, sys
from datetime import datetime
from unidecode import unidecode
path = "datasets/pride_prejudice.txt"
text = unidecode(open(path, 'r', encoding='utf-8').read().lower())
print('corpus length:', len(text))
token_re = re.compile(r'(\w+|[\'])+|[.,!"]') # TODO more punctuation and more exotic chars
# corpus_tokens should contain a list of tokens, which for the above regexp
# should be words and some punctuation marks
corpus_tokens = list(map(lambda match: match.string[match.start():match.end()],
token_re.finditer(open(path).read().lower())))
# Obtain a set of all unique tokens in the corpus
token_set = set(corpus_tokens)
token_to_ordinal = dict((token, i) for i, token in enumerate(token_set))
ordinal_to_token = dict((v,k) for k,v in token_to_ordinal.items())
# Write mapping to disk as TSV for tensorboard projector
with open('tokens.tsv', 'w') as f:
#f.write('token\n')
for token, ordinal in token_to_ordinal.items():
f.write(token+'\n')
print("total number of unique tokens", len(token_set))
num_embedding_dims = 32
num_hidden = 16
batch_size = 1024
sequence_len = 16
features = tf.placeholder(tf.int32, [batch_size, sequence_len], name='features')
labels = tf.placeholder(tf.int32, [batch_size], name='labels')
embedding = tf.Variable(tf.random_normal([len(token_set), num_embedding_dims], stddev=1.0, dtype=tf.float32), name='embedding')
xe = tf.nn.embedding_lookup(embedding, features, name='x_embedded')
ye = tf.nn.embedding_lookup(embedding, labels, name='y_embedded')
xr = tf.reshape(xe, [batch_size, sequence_len * num_embedding_dims])
w0 = tf.Variable(tf.random_normal([sequence_len * num_embedding_dims, num_hidden], stddev=0.1, dtype=tf.float32), name='weights')
b0 = tf.Variable(tf.random_normal([num_hidden], stddev=0.1, dtype=tf.float32), name='bias')
w1 = tf.Variable(tf.random_normal([num_hidden, num_hidden], stddev=0.1, dtype=tf.float32), name='weights')
b1 = tf.Variable(tf.random_normal([num_hidden], stddev=0.1, dtype=tf.float32), name='bias')
w2 = tf.Variable(tf.random_normal([num_hidden, num_embedding_dims], stddev=0.1, dtype=tf.float32), name='weights')
b2 = tf.Variable(tf.random_normal([num_embedding_dims], stddev=0.1, dtype=tf.float32), name='bias')
l1 = tf.tanh(tf.matmul(xr, w0) + b0)
l2 = tf.tanh(tf.matmul(l1, w1) + b1)
y = tf.tanh(tf.matmul(l2, w2) + b2)
# decode embedding back to ordinal value
decoded_y = tf.argmin(tf.reduce_sum(
(tf.reshape(y, [batch_size, 1, num_embedding_dims]) -
tf.reshape(embedding, [1, len(token_set), num_embedding_dims])) ** 2
, axis=2
)
, axis=1
, output_type=tf.int32
, name='decoded_y'
)
decoded_x = tf.argmin(tf.reduce_sum(
(tf.reshape(xe, [batch_size, 1, sequence_len, num_embedding_dims]) -
tf.reshape(embedding, [1, len(token_set), 1, num_embedding_dims])) ** 2
, axis=2)
, axis=1
, output_type=tf.int32
, name='decoded_x'
)
foo = (ye - y) ** 2
embedding_mean, embedding_variance = tf.nn.moments(embedding, axes=[0])
main_cost = tf.reduce_mean(foo)
cost = main_cost \
+ tf.reduce_mean((embedding_mean - 0.0) ** 2) \
+ tf.reduce_mean((embedding_variance - 1.0) ** 2)
optimizer = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost)
#optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.000001).minimize(cost)
features_data = np.zeros([batch_size, sequence_len], int)
labels_data = np.zeros([batch_size], int)
_logDir = None
def getLogDir():
global _logDir
if _logDir is None:
_logDir = 'log/' + datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
return _logDir
tf.summary.scalar('cost', cost)
summary_op = tf.summary.merge_all()
with tf.Session() as sess:
# writer for summaries
writer = tf.summary.FileWriter(getLogDir(), graph=sess.graph)
# checkpoint writer for all trainable model variables
saver = tf.train.Saver()
# initialize tf variables
sess.run(tf.global_variables_initializer())
saver.restore(sess, 'checkpoints/latest.ckpt-181000')
print('loaded checkpoint')
keep_training = True
steps = 0
while keep_training:
# generate a sequence to train on
for iBatchSample in range(batch_size):
seq_start = random.randint(0, len(corpus_tokens) - sequence_len - 1)
#seq_start = random.randint(0, (len(corpus_tokens) // (sequence_len + 1))-1) * (sequence_len + 1)
#seq_start = iBatchSample
features_data[iBatchSample] = list(token_to_ordinal[token] for token in corpus_tokens[seq_start:seq_start+sequence_len])
labels_data[iBatchSample] = token_to_ordinal[corpus_tokens[seq_start+sequence_len]]
ign, cost_run, main_cost_run, summary = sess.run([optimizer, cost, main_cost, summary_op], feed_dict={
features: features_data
, labels: labels_data
})
writer.add_summary(summary, steps)
if steps % 100000 == 99999:
seq_start = random.randint(0, batch_size - sequence_len - 1)
seed_tokens = corpus_tokens[seq_start:seq_start+sequence_len]
print('seed:', ' '.join(seed_tokens))
print('sentence:')
for i in range(50):
features_data[0] = list(token_to_ordinal[token] for token in seed_tokens)
(decoded_y_run,) = sess.run(
[decoded_y], feed_dict={
features: features_data
})
sys.stdout.write(ordinal_to_token[decoded_y_run[0]] + ' ')
for i in range(1, len(seed_tokens)):
seed_tokens[i-1] = seed_tokens[i]
seed_tokens[-1] = ordinal_to_token[decoded_y_run[0]]
print()
if steps % 1000 == 0:
#print('saving')
saver.save(sess, 'checkpoints/latest.ckpt', steps)
#print('cost=', cost_run)
#print('main_cost=', main_cost_run)
steps += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment