Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save GINK03/6a015233b81a6413bc90d974665fb4d3 to your computer and use it in GitHub Desktop.
Save GINK03/6a015233b81a6413bc90d974665fb4d3 to your computer and use it in GitHub Desktop.
counter = 0
def build_model(mode=None, maxlen=None, output_dim=None):
print('Build model...')
def _scheduler(epoch):
global counter
counter += 1
rate = learning_rates[counter]
#model.lr.set_value(rate)
print(model.optimizer.lr)
#print(counter, rate)
return model.optimizer.lr
change_lr = LRS(_scheduler)
model = Sequential()
model.add(LSTM(128*15, return_sequences=False, input_shape=(maxlen, 512)))
model.add(Dropout(0.5))
model.add(Dense(output_dim))
model.add(Activation('softmax'))
if mode=="rms":
optimizer = RMSprop(lr=0.01)
if mode=="adam":
optimizer = Adam()
model.compile(loss='categorical_crossentropy', optimizer=optimizer)
return model, change_lr
def dynast(preds, temperature=1.0):
preds = np.asarray(preds).astype('float64')
preds = np.log(preds) / temperature
exp_preds = np.exp(preds)
preds = exp_preds / np.sum(exp_preds)
return np.argmax(preds)
def train(mode):
datasets = []
for ni, name in enumerate(glob.glob('./dataset/*')[:20000]):
if ni%1000 == 0:
print("now loading pkl iter %d"%ni)
try:
with open(name, 'rb') as f:
data = pickle.loads(f.read())
except pickle.UnpicklingError as e:
continue
datasets.append(data)
#datasets = datasets*100
random.shuffle(datasets)
max_texts = len(datasets)
char_index = pickle.loads(open('./char_index.pkl', 'rb').read())
index_char = {index:char for char, index in char_index.items()}
term_vec = pickle.loads(open('./term_vec.pkl', 'rb').read())
X = np.zeros((max_texts, 10, 512), dtype=np.float64)
Y = np.zeros((max_texts, len(char_index)), dtype=np.float64)
emojis_vec = []
for i, dataset in enumerate(datasets):
if i%1000 == 0:
print("now mapping to numpy iter %d"%i)
x, emoji, emoji_each_vec, y = dataset
emoji_vec = np.zeros((256), dtype=np.float64)
for e in set(list(emoji)):
try:
emoji_vec += np.array(term_vec[e])
except KeyError as e:
print("error %s"%e)
emojis_vec.append( [emoji, emoji_vec] )
Y[i, char_index[y]] = 1.
for t, vec in enumerate(x):
X[i, t,:256] = vec*emoji_vec
X[i, t,256:] = vec + emoji_vec
open('emojis_vec.pkl', 'wb').write(pickle.dumps(emojis_vec) )
model, scheduler = build_model(mode=mode, maxlen=10, output_dim=len(char_index))
for iteration in range(1, 10000):
print()
print('-' * 50)
print('Iteration', iteration)
model.fit(X, Y, batch_size=64, nb_epoch=1)#, callbacks=[scheduler])
MODEL_NAME = "./models.%s/%09d.model"%(mode, iteration)
model.save(MODEL_NAME)
if iteration%1==0:
for diversity in [1.0]:
print()
print('----- diversity:', diversity)
sent = ["*"]*10
emoji_data = emojis_vec[random.randint(0, len(emojis_vec) - 1)]
emoji, _emoji_vec = emoji_data
print('----- Generating with seed: "' + emoji +" " + "".join(sent) + '"')
sys.stdout.write("".join(sent))
for i in range(200):
x = np.zeros((1, 10, 512))
try:
for t, char in enumerate(sent):
x[0, t, :256] = term_vec[char]*_emoji_vec
x[0, t, 256:] = term_vec[char] + _emoji_vec
preds = model.predict(x, verbose=0)[0]
next_index = dynast(preds, diversity)
next_char = index_char[next_index]
sent.append(next_char)
sent = sent[1:]
sys.stdout.write(next_char)
sys.stdout.flush()
except KeyError as e:
break
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment