Last active
November 4, 2016 01:57
-
-
Save wassname/ce0ac72c50832b0a27e67bd97ba36080 to your computer and use it in GitHub Desktop.
Loader for cmudict dataset (CMU Pronouncing Dictionary) for keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Load cmudict/CMU Pronouncing Dictionary as a dataset for keras | |
author: wassname | |
url : https://gist.github.com/wassname/ce0ac72c50832b0a27e67bd97ba36080 | |
""" | |
from keras.utils.data_utils import get_file | |
from sklearn.model_selection import train_test_split | |
import numpy as np | |
import itertools | |
import re | |
# based on https://github.com/fchollet/keras/blob/master/examples/addition_rnn.py | |
class CharacterTable(object): | |
''' | |
Given a set of characters: | |
+ Encode them to a one hot integer representation | |
+ Decode the one hot integer representation to their character output | |
+ Decode a vector of probabilities to their character output | |
''' | |
def __init__(self, chars='', maxlen=None, null_char=' ', left_pad=False): | |
self.chars = sorted(set([null_char] + list(chars))) | |
self.char_indices = dict((c, i) for i, c in enumerate(self.chars)) | |
self.indices_char = dict((i, c) for i, c in enumerate(self.chars)) | |
self.maxlen = maxlen | |
self.left_pad = left_pad | |
self.null_char = null_char | |
def fit(self, Cs, null_char=' '): | |
"""Determine chars and maxlen by fitting to data""" | |
self.chars = sorted(set(itertools.chain([null_char], *Cs))) | |
self.char_indices = dict((c, i) for i, c in enumerate(self.chars)) | |
self.indices_char = dict((i, c) for i, c in enumerate(self.chars)) | |
self.maxlen = max(len(c) for c in Cs) | |
self.null_char = null_char | |
def encode(self, Cs, maxlen=None): | |
"""Pass in an array of arrays to convert to integers""" | |
maxlen = maxlen if maxlen else self.maxlen | |
n = len(Cs) | |
X = np.zeros((n, maxlen, len(self.chars)), dtype=np.bool) | |
for j, C in enumerate(Cs): | |
if self.left_pad: | |
C = [self.null_char] * (maxlen - len(C)) + list(C) | |
else: | |
C = list(C) + [self.null_char] * (maxlen - len(C)) | |
for i, c in enumerate(C): | |
X[j, i, self.char_indices[c]] = True | |
return X | |
def decode(self, Xs, calc_argmax=True): | |
if calc_argmax: | |
Xs = Xs.argmax(axis=-1) | |
return np.array(list([self.indices_char[x] for x in X] for X in Xs)) | |
def get_data(origin='https://raw.githubusercontent.com/cmusphinx/cmudict/master/cmudict.dict', test_size=0.33, verbose=False, maxlen_x=None, maxlen_y=None, blacklist='().0123456789', max_phonemes=np.inf, max_chars=np.inf, seed=42): | |
""" | |
Process CMU pronounciation dictionary as one-hot encoded grapheme and phoneme data | |
# Output | |
(X_train, y_train): One-hot encoded graphemes and phonemes | |
(X_test, y_test): Test data | |
(xtable, ytable): Charecter en/decoding tables | |
# Arguments | |
seed: random seed for data split and shuffle | |
test_size: fraction of data to set aside for testing | |
verbose: print messages about data processing | |
maxlen_x: crop and pad grapheme sequences to this length | |
maxlen_y: crop and pad phoneme sequences to this length | |
max_phonemes: restrict data to this <=max_phonemes | |
max_chars: restrict data to this <=max_charectors | |
blacklist: remove words with these charectors e.g. HOUSE(2) for the second varient of house | |
""" | |
cmudict_path = get_file("cmudict-py", origin=origin, untar=False) | |
# load data | |
X, y= [], [] | |
for line in open(cmudict_path,'r').readlines(): | |
word, pron = line.strip().split(' ',1) | |
X.append(list(word)) | |
y.append(pron.split(' ')) | |
X = np.array(X) | |
y = np.array(y) | |
if verbose: print('loaded {} entries from cmu_dict'.format(len(X))) | |
# split data | |
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=test_size, random_state=seed) | |
# filter out duplicate entries like 'HOUSE(2)' | |
p=re.compile('[%s]'%(blacklist)) | |
X_train, y_train = zip(*[(x,y) for x,y in zip(X_train,y_train) if not bool(p.findall(''.join(x)))]) | |
X_test, y_test = zip(*[(x,y) for x,y in zip(X_test,y_test) if not bool(p.findall(''.join(x)))]) | |
if verbose: print('removed duplicate entries leaving {}'.format(len(X_train)+len(X_test))) | |
# filter out complex entries | |
X_train, y_train = zip(*[(x,y) for x,y in zip(X_train,y_train) if len(y)<=max_phonemes and len(x)<=max_chars]) | |
X_test, y_test = zip(*[(x,y) for x,y in zip(X_test,y_test) if len(y)<=max_phonemes and len(x)<=max_chars]) | |
if verbose: print('restricted to less than {} phonemes leaving {} entries'.format(max_phonemes, len(X_train)+len(X_test))) | |
# encode x and y and pad them | |
xtable = CharacterTable() | |
xtable.fit(X_test+X_train) | |
if maxlen_x: xtable.maxlen = maxlen_x | |
X_train = xtable.encode(X_train) | |
X_test = xtable.encode(X_test) | |
ytable = CharacterTable() | |
ytable.fit(y_test+y_train) | |
if maxlen_y: ytable.maxlen = maxlen_y | |
y_train = ytable.encode(y_train) | |
y_test = ytable.encode(y_test) | |
if verbose: | |
print('X_train shape:', X_train.shape) | |
print('X_test shape:', X_test.shape) | |
print('y_train shape:', y_train.shape) | |
print('y_test shape:', y_test.shape) | |
return (X_train, y_train),(X_test,y_test),(xtable,ytable) | |
# usage | |
(X_train, y_train),(X_test,y_test),(xtable,ytable) = get_data( | |
verbose=1, | |
max_phonemes=5, | |
max_chars=5 | |
) | |
# loaded 135009 entries from cmu_dict | |
# removed duplicate entries leaving 125814 | |
# restricted to less than 5 phonemes leaving 22906 entries | |
# X_train shape: (15411, 5, 29) | |
# X_test shape: (7495, 5, 29) | |
# y_train shape: (15411, 5, 70) | |
# y_test shape: (7495, 5, 70) | |
# decoding back to text | |
[''.join(i) for i in xtable.decode(X_train[:2])] | |
# ['kjos ', 'wilt '] | |
[' '.join(i) for i in ytable.decode(y_train[:2])] | |
# ['K Y AO1 S ', 'W IH1 L T '] | |
############################################################################################## | |
# fitting a model that predicts the pronounciation given the spelling (grapheme to phoneme) # | |
############################################################################################## | |
# crop to batch size to prevent errors | |
train_crop = len(X_train)-(len(X_train)%batch_size) | |
test_crop = len(X_test)-(len(X_test)%batch_size) | |
print(train_crop,test_crop) | |
X_train=X_train[:train_crop] | |
y_train=y_train[:train_crop] | |
X_test=X_test[:test_crop] | |
y_test=y_test[:test_crop] | |
from keras.models import Sequential | |
from keras.layers import Activation, TimeDistributed, Bidirectional, LSTM, GRU | |
RNN=LSTM | |
nb_chars = len(ytable.chars) | |
nb_phons = len(xtable.chars) | |
maxlen_x = xtable.maxlen | |
maxlen_y = ytable.maxlen | |
batch_size = 100 | |
hidden_nodes = len(ytable.chars) | |
model = Sequential() | |
# Encode | |
model.add(Bidirectional(RNN(hidden_nodes, return_sequences=True), batch_input_shape=(batch_size, maxlen_x, nb_phons))) | |
# Decode | |
model.add(RNN(hidden_nodes, return_sequences=True, consume_less='mem')) | |
# classifier | |
model.add(TimeDistributed(Dense(nb_chars))) | |
model.add(Activation('softmax')) | |
model.compile(loss='categorical_crossentropy', | |
optimizer='adam', | |
metrics=['accuracy']) | |
model.summary() | |
history = model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=20, verbose=1, validation_split=0.1) | |
# output | |
""" | |
____________________________________________________________________________________________________ | |
Layer (type) Output Shape Param # Connected to | |
==================================================================================================== | |
bidirectional_2 (Bidirectional) (300, 5, 140) 56000 bidirectional_input_2[0][0] | |
____________________________________________________________________________________________________ | |
lstm_5 (LSTM) (300, 5, 70) 59080 bidirectional_2[0][0] | |
____________________________________________________________________________________________________ | |
lstm_6 (LSTM) (300, 5, 70) 39480 lstm_5[0][0] | |
____________________________________________________________________________________________________ | |
timedistributed_2 (TimeDistribute(300, 5, 70) 4970 lstm_6[0][0] | |
____________________________________________________________________________________________________ | |
activation_2 (Activation) (300, 5, 70) 0 timedistributed_2[0][0] | |
==================================================================================================== | |
Total params: 159530 | |
____________________________________________________________________________________________________ | |
Train on 15300 samples, validate on 7200 samples | |
Epoch 1/20 | |
15300/15300 [==============================] - 18s - loss: 3.0481 - acc: 0.0715 - val_loss: 2.7764 - val_acc: 0.0834 | |
Epoch 2/20 | |
15300/15300 [==============================] - 17s - loss: 2.5749 - acc: 0.1278 - val_loss: 2.3741 - val_acc: 0.1785 | |
Epoch 3/20 | |
15300/15300 [==============================] - 17s - loss: 2.2065 - acc: 0.2157 - val_loss: 2.0415 - val_acc: 0.2531 | |
Epoch 4/20 | |
15300/15300 [==============================] - 20s - loss: 1.8324 - acc: 0.3243 - val_loss: 1.6104 - val_acc: 0.4021 | |
Epoch 5/20 | |
15300/15300 [==============================] - 23s - loss: 1.3781 - acc: 0.4658 - val_loss: 1.1892 - val_acc: 0.5059 | |
Epoch 6/20 | |
15300/15300 [==============================] - 20s - loss: 1.0463 - acc: 0.5321 - val_loss: 0.9577 - val_acc: 0.5432 | |
Epoch 7/20 | |
15300/15300 [==============================] - 18s - loss: 0.8722 - acc: 0.5641 - val_loss: 0.8393 - val_acc: 0.5686 | |
Epoch 8/20 | |
15300/15300 [==============================] - 20s - loss: 0.7767 - acc: 0.5846 - val_loss: 0.7719 - val_acc: 0.5855 | |
Epoch 9/20 | |
15300/15300 [==============================] - 23s - loss: 0.7104 - acc: 0.5997 - val_loss: 0.7117 - val_acc: 0.5970 | |
Epoch 10/20 | |
15300/15300 [==============================] - 17s - loss: 0.6569 - acc: 0.6122 - val_loss: 0.6665 - val_acc: 0.6089 | |
Epoch 11/20 | |
15300/15300 [==============================] - 23s - loss: 0.6159 - acc: 0.6216 - val_loss: 0.6349 - val_acc: 0.6155 | |
Epoch 12/20 | |
15300/15300 [==============================] - 23s - loss: 0.5802 - acc: 0.6299 - val_loss: 0.6033 - val_acc: 0.6235 | |
Epoch 13/20 | |
15300/15300 [==============================] - 23s - loss: 0.5495 - acc: 0.6377 - val_loss: 0.5771 - val_acc: 0.6293 | |
Epoch 14/20 | |
15300/15300 [==============================] - 22s - loss: 0.5232 - acc: 0.6441 - val_loss: 0.5567 - val_acc: 0.6351 | |
Epoch 15/20 | |
15300/15300 [==============================] - 24s - loss: 0.5000 - acc: 0.6492 - val_loss: 0.5423 - val_acc: 0.6382 | |
Epoch 16/20 | |
15300/15300 [==============================] - 21s - loss: 0.4796 - acc: 0.6554 - val_loss: 0.5232 - val_acc: 0.6423 | |
Epoch 17/20 | |
15300/15300 [==============================] - 21s - loss: 0.4612 - acc: 0.6597 - val_loss: 0.5082 - val_acc: 0.6471 | |
Epoch 18/20 | |
15300/15300 [==============================] - 20s - loss: 0.4448 - acc: 0.6639 - val_loss: 0.4972 - val_acc: 0.6497 | |
Epoch 19/20 | |
15300/15300 [==============================] - 21s - loss: 0.4298 - acc: 0.6682 - val_loss: 0.4865 - val_acc: 0.6524 | |
Epoch 20/20 | |
15300/15300 [==============================] - 22s - loss: 0.4173 - acc: 0.6719 - val_loss: 0.4747 - val_acc: 0.6563 | |
""" | |
# Tests | |
def test_dataset_cmudict(): | |
(X_train, y_train), (X_test, y_test), (xtable, ytable) = get_data() | |
# lengths | |
assert len(X_train) > 0 | |
assert len(X_test) > 0 | |
assert len(X_train) == len(y_train) | |
assert len(X_test) == len(y_test) | |
# unique and can be decoded | |
dx_train = [' '.join(xx) for xx in xtable.decode(X_train)] | |
assert len(dx_train) == len(set(dx_train)), 'X_train should be unique' | |
dy_train = [' '.join(yy) for yy in ytable.decode(y_train)] | |
assert len(dy_train) == len(set(dy_train)), 'y_train should be unique' | |
# should be one-hot | |
assert len(X_train.shape) == 3, 'should be one-hot' | |
assert len(y_train.shape) == 3, 'should be one-hot' | |
assert y_test.reshape((-1, y_test.shape[-1])).sum(-1).all(), 'should be one-hot' | |
assert X_test.reshape((-1, X_test.shape[-1])).sum(-1).all(), 'should be one-hot' | |
assert y_train.reshape((-1, y_train.shape[-1])).sum(-1).all(), 'should be one-hot' | |
assert X_train.reshape((-1, X_train.shape[-1])).sum(-1).all(), 'should be one-hot' | |
dx_train = [' '.join(xx) for xx in xtable.decode(X_train)] | |
dx_test = [' '.join(xx) for xx in xtable.decode(X_test)] | |
x = dx_train + dx_test | |
assert len(x) == len(set(x)), 'should be no overlap between test and train' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment