Skip to content

Instantly share code, notes, and snippets.

@cemoody
Last active October 30, 2016 21:21
Show Gist options
  • Save cemoody/d8707f1d953e2c00d692 to your computer and use it in GitHub Desktop.
Save cemoody/d8707f1d953e2c00d692 to your computer and use it in GitHub Desktop.
lstm2style_id
from aa import config
import pandas as pd
import numpy as np
import os.path
import pickle
import chainer
import chainer.links as L
import chainer.functions as F
from chainer import Variable, serializers
from chainer.optimizers import Adam
from chainer.optimizer import GradientClipping
from lda2vec import preprocess, Corpus
query1 = """
SELECT s.shipment_id, s.request_notes
FROM shipment s
WHERE LENGTH(s.request_notes) > 5
AND s.created_ts >= '2015-%02d-01'
AND s.created_ts < '2015-%02d-01'
"""
query2 = """
SELECT si.shipment_id, si.style_id
FROM shipmentitem si
JOIN shipment s USING(shipment_id)
WHERE LENGTH(s.request_notes) > 5
AND s.created_ts > '2015-01-01'
"""
def get_q1(month):
cnf = config.get_config()
eng1 = cnf.get_redshift_db_engine()
q = query1 % (month, month+1)
df = pd.read_sql_query(q, eng1)
return df
if not os.path.exists('data1.df'):
print "SQL1 start"
from multiprocessing import Pool
pool = Pool()
dfs = pool.map(get_q1, range(1, 10))
df1 = pd.concat(dfs)
df1.to_pickle('data1.df')
del pool
print "SQL1 done"
if not os.path.exists('data2.df'):
print "SQL2 start"
cnf = config.get_config()
eng = cnf.get_redshift_db_engine()
df2 = pd.read_sql_query(query2, eng)
df2.to_pickle('data2.df')
print "SQL2 done"
if not os.path.exists('data.df'):
df1 = pd.read_pickle('data1.df')
df2 = pd.read_pickle('data2.df')
df = df1.merge(df2, on='shipment_id')
df.to_pickle('data.df')
print "SQL done"
if not os.path.exists('pruned.npz'):
print "Restoring SQL"
df = pd.read_pickle('data.df')
style_ids = df['style_id'].values
texts = [unicode(t) for t in df['request_notes'].values]
print "Preprocessing"
# Preprocess data
max_length = 100 # Limit of 10k words per document
tokens, vocab = preprocess.tokenize(texts, max_length, tag=False,
parse=False, entity=False)
corpus = Corpus()
# Make a ranked list of rare vs frequent words
corpus.update_word_count(tokens)
corpus.finalize()
# The tokenization uses spaCy indices, and so may have gaps
# between indices for words that aren't present in our dataset.
# This builds a new compact index
compact = corpus.to_compact(tokens)
# Remove extremely rare words
pruned = corpus.filter_count(compact, min_count=5)
# Words tend to have power law frequency, so selectively
# downsample the most prevalent words
# clean = corpus.subsample_frequent(pruned)
# Now flatten a 2D array of document per row and word position
# per column to a 1D array of words. This will also remove skips
# and OoV words
pickle.dump(corpus, open('corpus.pkl', 'w'))
np.savez('pruned', pruned=pruned, style_ids=style_ids)
print "Preprocessing done"
corpus = pickle.load(open('corpus.pkl', 'r'))
data = np.load('pruned.npz')
pruned, style_ids = data['pruned'], data['style_ids']
n_styles = style_ids.max() + 1
n_vocab = pruned.max() + 1
n_topics = 32
n_units = 256
epochs = 10
batchsize = 1024
class RNNLM(chainer.Chain):
"""Recurrent neural net languabe model for penn tree bank corpus.
This is an example of deep LSTM network for infinite length input.
"""
def __init__(self, n_vocab, n_units, n_styles, train=True):
super(RNNLM, self).__init__(
embed=L.EmbedID(n_vocab, n_units),
l1=L.LSTM(n_units, n_units),
l2=L.LSTM(n_units, n_units),
linear=L.Linear(n_units, n_styles))
self.train = train
def reset_state(self, state=None):
self.l1.reset_state()
self.l2.reset_state()
def __call__(self, x):
# Update only the internal state, no outputs
h0 = self.embed(x)
h1 = self.l1(F.dropout(h0, train=self.train))
self.l2(F.dropout(h1, train=self.train))
def predict(self, t):
# Get last cell state and predict style id
prediction = self.linear(self.l2.c)
# First argument is unnormalized log probability
return F.softmax_cross_entropy(prediction, t)
model = RNNLM(n_vocab, n_units, n_styles)
if os.path.exists('model.hdf5'):
print "Reloading model"
serializers.load_hdf5('model.hdf5', model)
else:
for param in model.params():
data = param.data
data[:] = np.random.uniform(-0.1, 0.1, data.shape)
# Skip this line if you don't have a GPU
model.to_gpu()
xp = model.xp # xp is either numpy or cupy
optimizer = Adam()
optimizer.setup(model)
optimizer.add_hook(GradientClipping(5.0))
def chunks(n, *args):
for i in xrange(0, len(args[0]), n):
yield [a[i:i+n] for a in args]
training = True
if training:
for epoch in range(epochs):
for chunk, style_id in chunks(batchsize, pruned, style_ids):
model.reset_state()
losses = None
style_id = Variable(xp.asarray(style_id))
for col in range(pruned.shape[1]):
word = Variable(xp.asarray(pruned[:, col]))
model(word)
# Now we predict the style
losses = model.predict(style_id)
optimizer.zero_grads()
losses.backward()
optimizer.update()
perp = np.exp(losses / (batchsize * pruned.shape[1]))
msg = "Current log_perplexity is %1.3e, perp is %1.3e"
print msg % (perp, np.exp(perp))
if os.path.exists('model.hdf5'):
print "Saving model"
serializers.save_hdf5('model.hdf5', model)
else:
while True: # infinite loop for input
msg = "\n\nWhat kind of item would you like from Stitch Fix?"
text = raw_input(msg)
texts = [unicode(text)]
tokens, vocab = preprocess.tokenize(texts, 100, tag=False,
parse=False, entity=False)
compact = corpus.to_compact(tokens)
pruned = corpus.filter_count(compact, min_count=5)
model.reset_state()
for col in range(pruned.shape[1]):
word = Variable(xp.asarray(pruned[:, col]))
model(word)
prediction = F.softmax(model.linear(model.l2.c))
prediction.to_cpu()
probs = prediction.data
sids = np.argsort(probs)[::-1][:10]
url = "https://fashionthing.stitchfix.com/styles/%i\n"
print "Predicted top style IDs: "
print "\n".join([url % sid for sid in sids])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment