Skip to content

Instantly share code, notes, and snippets.

@unrealwill
Created September 29, 2016 09:29
Show Gist options
  • Save unrealwill/4bd198c1a4b129837395b11d555252e2 to your computer and use it in GitHub Desktop.
Save unrealwill/4bd198c1a4b129837395b11d555252e2 to your computer and use it in GitHub Desktop.
#LICENSE MIT#
from keras.models import Model
from keras.layers import Input, Dense, Merge, Recurrent
from keras.layers.recurrent import SimpleRNN, GRU,LSTM
from keras.layers.embeddings import Embedding
from keras.layers.wrappers import TimeDistributed
import random
import numpy as np
maxStringLength=20
spreadlength = 4 * maxStringLength
def SpreadWithZeros(x, length):
out = np.zeros( (x.shape[0],length), dtype='uint16')
mask = np.zeros((x.shape[0], length), dtype='uint16')
for i in range( x.shape[0] ):
co = 0
ind = 0
while ind < x.shape[1] and co < length:
out[i,co] = x[i,ind]
ind = ind + 1
if( ind < x.shape[1]):
co += random.randint(1,3)
for j in range(co+1):
mask[i,j]=1
return out,mask
def BuildSpreadModel(Nclass,loss):
text = Input(batch_shape=(None, spreadlength),dtype='uint16')
emb = Embedding(Nclass, 5, input_length=spreadlength)(text)
#SimpleRNN shouldn't be even needed to solve this toy problem
enc = SimpleRNN( 100, return_sequences=True,activation='tanh' )(emb)
dectext = TimeDistributed(Dense(Nclass+1, activation="softmax"))(enc)
spreadModel = Model(text, dectext)
#ctc_cost_precise
#ctc_cost_for_train
spreadModel.compile(loss=loss, optimizer='adam', sample_weight_mode='temporal')
return spreadModel
def TrainSpreadModel(model,nbepoch,batch_size):
for i in range(nbepoch):
X = np.random.randint(1, 30, (batch_size, maxStringLength))
xcat, mask = SpreadWithZeros(X, spreadlength)
res = model.train_on_batch(x=xcat, y=X, sm_mask=mask, return_sm=True)
print(res[0])
#resultseqs = CTC.best_path_decode_batch(res[1], np.ones((X.shape[0], X.shape[1])))
# for j in range(len(resultseqs)):
# print(resultseqs[j])
#lengths = np.array([len(item) for item in resultseqs])
#print("length " + str(np.mean(lengths)))
modelTrain = BuildSpreadModel(33,"ctc_cost_for_train")
TrainSpreadModel(modelTrain,10,50)
modelPrecise = BuildSpreadModel(33,"ctc_cost_precise")
TrainSpreadModel(modelPrecise,10,50)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment