Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active November 4, 2016 01:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/ce0ac72c50832b0a27e67bd97ba36080 to your computer and use it in GitHub Desktop.
Save wassname/ce0ac72c50832b0a27e67bd97ba36080 to your computer and use it in GitHub Desktop.
Loader for cmudict dataset (CMU Pronouncing Dictionary) for keras
"""
Load cmudict/CMU Pronouncing Dictionary as a dataset for keras
author: wassname
url : https://gist.github.com/wassname/ce0ac72c50832b0a27e67bd97ba36080
"""
from keras.utils.data_utils import get_file
from sklearn.model_selection import train_test_split
import numpy as np
import itertools
import re
# based on https://github.com/fchollet/keras/blob/master/examples/addition_rnn.py
class CharacterTable(object):
'''
Given a set of characters:
+ Encode them to a one hot integer representation
+ Decode the one hot integer representation to their character output
+ Decode a vector of probabilities to their character output
'''
def __init__(self, chars='', maxlen=None, null_char=' ', left_pad=False):
self.chars = sorted(set([null_char] + list(chars)))
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
self.maxlen = maxlen
self.left_pad = left_pad
self.null_char = null_char
def fit(self, Cs, null_char=' '):
"""Determine chars and maxlen by fitting to data"""
self.chars = sorted(set(itertools.chain([null_char], *Cs)))
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
self.maxlen = max(len(c) for c in Cs)
self.null_char = null_char
def encode(self, Cs, maxlen=None):
"""Pass in an array of arrays to convert to integers"""
maxlen = maxlen if maxlen else self.maxlen
n = len(Cs)
X = np.zeros((n, maxlen, len(self.chars)), dtype=np.bool)
for j, C in enumerate(Cs):
if self.left_pad:
C = [self.null_char] * (maxlen - len(C)) + list(C)
else:
C = list(C) + [self.null_char] * (maxlen - len(C))
for i, c in enumerate(C):
X[j, i, self.char_indices[c]] = True
return X
def decode(self, Xs, calc_argmax=True):
if calc_argmax:
Xs = Xs.argmax(axis=-1)
return np.array(list([self.indices_char[x] for x in X] for X in Xs))
def get_data(origin='https://raw.githubusercontent.com/cmusphinx/cmudict/master/cmudict.dict', test_size=0.33, verbose=False, maxlen_x=None, maxlen_y=None, blacklist='().0123456789', max_phonemes=np.inf, max_chars=np.inf, seed=42):
"""
Process CMU pronounciation dictionary as one-hot encoded grapheme and phoneme data
# Output
(X_train, y_train): One-hot encoded graphemes and phonemes
(X_test, y_test): Test data
(xtable, ytable): Charecter en/decoding tables
# Arguments
seed: random seed for data split and shuffle
test_size: fraction of data to set aside for testing
verbose: print messages about data processing
maxlen_x: crop and pad grapheme sequences to this length
maxlen_y: crop and pad phoneme sequences to this length
max_phonemes: restrict data to this <=max_phonemes
max_chars: restrict data to this <=max_charectors
blacklist: remove words with these charectors e.g. HOUSE(2) for the second varient of house
"""
cmudict_path = get_file("cmudict-py", origin=origin, untar=False)
# load data
X, y= [], []
for line in open(cmudict_path,'r').readlines():
word, pron = line.strip().split(' ',1)
X.append(list(word))
y.append(pron.split(' '))
X = np.array(X)
y = np.array(y)
if verbose: print('loaded {} entries from cmu_dict'.format(len(X)))
# split data
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=test_size, random_state=seed)
# filter out duplicate entries like 'HOUSE(2)'
p=re.compile('[%s]'%(blacklist))
X_train, y_train = zip(*[(x,y) for x,y in zip(X_train,y_train) if not bool(p.findall(''.join(x)))])
X_test, y_test = zip(*[(x,y) for x,y in zip(X_test,y_test) if not bool(p.findall(''.join(x)))])
if verbose: print('removed duplicate entries leaving {}'.format(len(X_train)+len(X_test)))
# filter out complex entries
X_train, y_train = zip(*[(x,y) for x,y in zip(X_train,y_train) if len(y)<=max_phonemes and len(x)<=max_chars])
X_test, y_test = zip(*[(x,y) for x,y in zip(X_test,y_test) if len(y)<=max_phonemes and len(x)<=max_chars])
if verbose: print('restricted to less than {} phonemes leaving {} entries'.format(max_phonemes, len(X_train)+len(X_test)))
# encode x and y and pad them
xtable = CharacterTable()
xtable.fit(X_test+X_train)
if maxlen_x: xtable.maxlen = maxlen_x
X_train = xtable.encode(X_train)
X_test = xtable.encode(X_test)
ytable = CharacterTable()
ytable.fit(y_test+y_train)
if maxlen_y: ytable.maxlen = maxlen_y
y_train = ytable.encode(y_train)
y_test = ytable.encode(y_test)
if verbose:
print('X_train shape:', X_train.shape)
print('X_test shape:', X_test.shape)
print('y_train shape:', y_train.shape)
print('y_test shape:', y_test.shape)
return (X_train, y_train),(X_test,y_test),(xtable,ytable)
# usage
(X_train, y_train),(X_test,y_test),(xtable,ytable) = get_data(
verbose=1,
max_phonemes=5,
max_chars=5
)
# loaded 135009 entries from cmu_dict
# removed duplicate entries leaving 125814
# restricted to less than 5 phonemes leaving 22906 entries
# X_train shape: (15411, 5, 29)
# X_test shape: (7495, 5, 29)
# y_train shape: (15411, 5, 70)
# y_test shape: (7495, 5, 70)
# decoding back to text
[''.join(i) for i in xtable.decode(X_train[:2])]
# ['kjos ', 'wilt ']
[' '.join(i) for i in ytable.decode(y_train[:2])]
# ['K Y AO1 S ', 'W IH1 L T ']
##############################################################################################
# fitting a model that predicts the pronounciation given the spelling (grapheme to phoneme) #
##############################################################################################
# crop to batch size to prevent errors
train_crop = len(X_train)-(len(X_train)%batch_size)
test_crop = len(X_test)-(len(X_test)%batch_size)
print(train_crop,test_crop)
X_train=X_train[:train_crop]
y_train=y_train[:train_crop]
X_test=X_test[:test_crop]
y_test=y_test[:test_crop]
from keras.models import Sequential
from keras.layers import Activation, TimeDistributed, Bidirectional, LSTM, GRU
RNN=LSTM
nb_chars = len(ytable.chars)
nb_phons = len(xtable.chars)
maxlen_x = xtable.maxlen
maxlen_y = ytable.maxlen
batch_size = 100
hidden_nodes = len(ytable.chars)
model = Sequential()
# Encode
model.add(Bidirectional(RNN(hidden_nodes, return_sequences=True), batch_input_shape=(batch_size, maxlen_x, nb_phons)))
# Decode
model.add(RNN(hidden_nodes, return_sequences=True, consume_less='mem'))
# classifier
model.add(TimeDistributed(Dense(nb_chars)))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.summary()
history = model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=20, verbose=1, validation_split=0.1)
# output
"""
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
bidirectional_2 (Bidirectional) (300, 5, 140) 56000 bidirectional_input_2[0][0]
____________________________________________________________________________________________________
lstm_5 (LSTM) (300, 5, 70) 59080 bidirectional_2[0][0]
____________________________________________________________________________________________________
lstm_6 (LSTM) (300, 5, 70) 39480 lstm_5[0][0]
____________________________________________________________________________________________________
timedistributed_2 (TimeDistribute(300, 5, 70) 4970 lstm_6[0][0]
____________________________________________________________________________________________________
activation_2 (Activation) (300, 5, 70) 0 timedistributed_2[0][0]
====================================================================================================
Total params: 159530
____________________________________________________________________________________________________
Train on 15300 samples, validate on 7200 samples
Epoch 1/20
15300/15300 [==============================] - 18s - loss: 3.0481 - acc: 0.0715 - val_loss: 2.7764 - val_acc: 0.0834
Epoch 2/20
15300/15300 [==============================] - 17s - loss: 2.5749 - acc: 0.1278 - val_loss: 2.3741 - val_acc: 0.1785
Epoch 3/20
15300/15300 [==============================] - 17s - loss: 2.2065 - acc: 0.2157 - val_loss: 2.0415 - val_acc: 0.2531
Epoch 4/20
15300/15300 [==============================] - 20s - loss: 1.8324 - acc: 0.3243 - val_loss: 1.6104 - val_acc: 0.4021
Epoch 5/20
15300/15300 [==============================] - 23s - loss: 1.3781 - acc: 0.4658 - val_loss: 1.1892 - val_acc: 0.5059
Epoch 6/20
15300/15300 [==============================] - 20s - loss: 1.0463 - acc: 0.5321 - val_loss: 0.9577 - val_acc: 0.5432
Epoch 7/20
15300/15300 [==============================] - 18s - loss: 0.8722 - acc: 0.5641 - val_loss: 0.8393 - val_acc: 0.5686
Epoch 8/20
15300/15300 [==============================] - 20s - loss: 0.7767 - acc: 0.5846 - val_loss: 0.7719 - val_acc: 0.5855
Epoch 9/20
15300/15300 [==============================] - 23s - loss: 0.7104 - acc: 0.5997 - val_loss: 0.7117 - val_acc: 0.5970
Epoch 10/20
15300/15300 [==============================] - 17s - loss: 0.6569 - acc: 0.6122 - val_loss: 0.6665 - val_acc: 0.6089
Epoch 11/20
15300/15300 [==============================] - 23s - loss: 0.6159 - acc: 0.6216 - val_loss: 0.6349 - val_acc: 0.6155
Epoch 12/20
15300/15300 [==============================] - 23s - loss: 0.5802 - acc: 0.6299 - val_loss: 0.6033 - val_acc: 0.6235
Epoch 13/20
15300/15300 [==============================] - 23s - loss: 0.5495 - acc: 0.6377 - val_loss: 0.5771 - val_acc: 0.6293
Epoch 14/20
15300/15300 [==============================] - 22s - loss: 0.5232 - acc: 0.6441 - val_loss: 0.5567 - val_acc: 0.6351
Epoch 15/20
15300/15300 [==============================] - 24s - loss: 0.5000 - acc: 0.6492 - val_loss: 0.5423 - val_acc: 0.6382
Epoch 16/20
15300/15300 [==============================] - 21s - loss: 0.4796 - acc: 0.6554 - val_loss: 0.5232 - val_acc: 0.6423
Epoch 17/20
15300/15300 [==============================] - 21s - loss: 0.4612 - acc: 0.6597 - val_loss: 0.5082 - val_acc: 0.6471
Epoch 18/20
15300/15300 [==============================] - 20s - loss: 0.4448 - acc: 0.6639 - val_loss: 0.4972 - val_acc: 0.6497
Epoch 19/20
15300/15300 [==============================] - 21s - loss: 0.4298 - acc: 0.6682 - val_loss: 0.4865 - val_acc: 0.6524
Epoch 20/20
15300/15300 [==============================] - 22s - loss: 0.4173 - acc: 0.6719 - val_loss: 0.4747 - val_acc: 0.6563
"""
# Tests
def test_dataset_cmudict():
(X_train, y_train), (X_test, y_test), (xtable, ytable) = get_data()
# lengths
assert len(X_train) > 0
assert len(X_test) > 0
assert len(X_train) == len(y_train)
assert len(X_test) == len(y_test)
# unique and can be decoded
dx_train = [' '.join(xx) for xx in xtable.decode(X_train)]
assert len(dx_train) == len(set(dx_train)), 'X_train should be unique'
dy_train = [' '.join(yy) for yy in ytable.decode(y_train)]
assert len(dy_train) == len(set(dy_train)), 'y_train should be unique'
# should be one-hot
assert len(X_train.shape) == 3, 'should be one-hot'
assert len(y_train.shape) == 3, 'should be one-hot'
assert y_test.reshape((-1, y_test.shape[-1])).sum(-1).all(), 'should be one-hot'
assert X_test.reshape((-1, X_test.shape[-1])).sum(-1).all(), 'should be one-hot'
assert y_train.reshape((-1, y_train.shape[-1])).sum(-1).all(), 'should be one-hot'
assert X_train.reshape((-1, X_train.shape[-1])).sum(-1).all(), 'should be one-hot'
dx_train = [' '.join(xx) for xx in xtable.decode(X_train)]
dx_test = [' '.join(xx) for xx in xtable.decode(X_test)]
x = dx_train + dx_test
assert len(x) == len(set(x)), 'should be no overlap between test and train'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment