Last active
October 30, 2016 21:21
-
-
Save cemoody/d8707f1d953e2c00d692 to your computer and use it in GitHub Desktop.
lstm2style_id
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from aa import config | |
import pandas as pd | |
import numpy as np | |
import os.path | |
import pickle | |
import chainer | |
import chainer.links as L | |
import chainer.functions as F | |
from chainer import Variable, serializers | |
from chainer.optimizers import Adam | |
from chainer.optimizer import GradientClipping | |
from lda2vec import preprocess, Corpus | |
query1 = """ | |
SELECT s.shipment_id, s.request_notes | |
FROM shipment s | |
WHERE LENGTH(s.request_notes) > 5 | |
AND s.created_ts >= '2015-%02d-01' | |
AND s.created_ts < '2015-%02d-01' | |
""" | |
query2 = """ | |
SELECT si.shipment_id, si.style_id | |
FROM shipmentitem si | |
JOIN shipment s USING(shipment_id) | |
WHERE LENGTH(s.request_notes) > 5 | |
AND s.created_ts > '2015-01-01' | |
""" | |
def get_q1(month): | |
cnf = config.get_config() | |
eng1 = cnf.get_redshift_db_engine() | |
q = query1 % (month, month+1) | |
df = pd.read_sql_query(q, eng1) | |
return df | |
if not os.path.exists('data1.df'): | |
print "SQL1 start" | |
from multiprocessing import Pool | |
pool = Pool() | |
dfs = pool.map(get_q1, range(1, 10)) | |
df1 = pd.concat(dfs) | |
df1.to_pickle('data1.df') | |
del pool | |
print "SQL1 done" | |
if not os.path.exists('data2.df'): | |
print "SQL2 start" | |
cnf = config.get_config() | |
eng = cnf.get_redshift_db_engine() | |
df2 = pd.read_sql_query(query2, eng) | |
df2.to_pickle('data2.df') | |
print "SQL2 done" | |
if not os.path.exists('data.df'): | |
df1 = pd.read_pickle('data1.df') | |
df2 = pd.read_pickle('data2.df') | |
df = df1.merge(df2, on='shipment_id') | |
df.to_pickle('data.df') | |
print "SQL done" | |
if not os.path.exists('pruned.npz'): | |
print "Restoring SQL" | |
df = pd.read_pickle('data.df') | |
style_ids = df['style_id'].values | |
texts = [unicode(t) for t in df['request_notes'].values] | |
print "Preprocessing" | |
# Preprocess data | |
max_length = 100 # Limit of 10k words per document | |
tokens, vocab = preprocess.tokenize(texts, max_length, tag=False, | |
parse=False, entity=False) | |
corpus = Corpus() | |
# Make a ranked list of rare vs frequent words | |
corpus.update_word_count(tokens) | |
corpus.finalize() | |
# The tokenization uses spaCy indices, and so may have gaps | |
# between indices for words that aren't present in our dataset. | |
# This builds a new compact index | |
compact = corpus.to_compact(tokens) | |
# Remove extremely rare words | |
pruned = corpus.filter_count(compact, min_count=5) | |
# Words tend to have power law frequency, so selectively | |
# downsample the most prevalent words | |
# clean = corpus.subsample_frequent(pruned) | |
# Now flatten a 2D array of document per row and word position | |
# per column to a 1D array of words. This will also remove skips | |
# and OoV words | |
pickle.dump(corpus, open('corpus.pkl', 'w')) | |
np.savez('pruned', pruned=pruned, style_ids=style_ids) | |
print "Preprocessing done" | |
corpus = pickle.load(open('corpus.pkl', 'r')) | |
data = np.load('pruned.npz') | |
pruned, style_ids = data['pruned'], data['style_ids'] | |
n_styles = style_ids.max() + 1 | |
n_vocab = pruned.max() + 1 | |
n_topics = 32 | |
n_units = 256 | |
epochs = 10 | |
batchsize = 1024 | |
class RNNLM(chainer.Chain): | |
"""Recurrent neural net languabe model for penn tree bank corpus. | |
This is an example of deep LSTM network for infinite length input. | |
""" | |
def __init__(self, n_vocab, n_units, n_styles, train=True): | |
super(RNNLM, self).__init__( | |
embed=L.EmbedID(n_vocab, n_units), | |
l1=L.LSTM(n_units, n_units), | |
l2=L.LSTM(n_units, n_units), | |
linear=L.Linear(n_units, n_styles)) | |
self.train = train | |
def reset_state(self, state=None): | |
self.l1.reset_state() | |
self.l2.reset_state() | |
def __call__(self, x): | |
# Update only the internal state, no outputs | |
h0 = self.embed(x) | |
h1 = self.l1(F.dropout(h0, train=self.train)) | |
self.l2(F.dropout(h1, train=self.train)) | |
def predict(self, t): | |
# Get last cell state and predict style id | |
prediction = self.linear(self.l2.c) | |
# First argument is unnormalized log probability | |
return F.softmax_cross_entropy(prediction, t) | |
model = RNNLM(n_vocab, n_units, n_styles) | |
if os.path.exists('model.hdf5'): | |
print "Reloading model" | |
serializers.load_hdf5('model.hdf5', model) | |
else: | |
for param in model.params(): | |
data = param.data | |
data[:] = np.random.uniform(-0.1, 0.1, data.shape) | |
# Skip this line if you don't have a GPU | |
model.to_gpu() | |
xp = model.xp # xp is either numpy or cupy | |
optimizer = Adam() | |
optimizer.setup(model) | |
optimizer.add_hook(GradientClipping(5.0)) | |
def chunks(n, *args): | |
for i in xrange(0, len(args[0]), n): | |
yield [a[i:i+n] for a in args] | |
training = True | |
if training: | |
for epoch in range(epochs): | |
for chunk, style_id in chunks(batchsize, pruned, style_ids): | |
model.reset_state() | |
losses = None | |
style_id = Variable(xp.asarray(style_id)) | |
for col in range(pruned.shape[1]): | |
word = Variable(xp.asarray(pruned[:, col])) | |
model(word) | |
# Now we predict the style | |
losses = model.predict(style_id) | |
optimizer.zero_grads() | |
losses.backward() | |
optimizer.update() | |
perp = np.exp(losses / (batchsize * pruned.shape[1])) | |
msg = "Current log_perplexity is %1.3e, perp is %1.3e" | |
print msg % (perp, np.exp(perp)) | |
if os.path.exists('model.hdf5'): | |
print "Saving model" | |
serializers.save_hdf5('model.hdf5', model) | |
else: | |
while True: # infinite loop for input | |
msg = "\n\nWhat kind of item would you like from Stitch Fix?" | |
text = raw_input(msg) | |
texts = [unicode(text)] | |
tokens, vocab = preprocess.tokenize(texts, 100, tag=False, | |
parse=False, entity=False) | |
compact = corpus.to_compact(tokens) | |
pruned = corpus.filter_count(compact, min_count=5) | |
model.reset_state() | |
for col in range(pruned.shape[1]): | |
word = Variable(xp.asarray(pruned[:, col])) | |
model(word) | |
prediction = F.softmax(model.linear(model.l2.c)) | |
prediction.to_cpu() | |
probs = prediction.data | |
sids = np.argsort(probs)[::-1][:10] | |
url = "https://fashionthing.stitchfix.com/styles/%i\n" | |
print "Predicted top style IDs: " | |
print "\n".join([url % sid for sid in sids]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment