Created
January 27, 2018 12:37
-
-
Save Orbifold/e5d023b9a280b6b0e86c5c88ce38c8d8 to your computer and use it in GitHub Desktop.
Keras translation network.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# The [Anki repository](http://www.manythings.org/anki/) has a lot of sentence-pairs to learn a language and they are ideal to train a translation network. | |
# To judge the quality of a translation it helps to understand a bit both languages so in my case | |
# the [Dutch-English](http://www.manythings.org/anki/nld-eng.zip), | |
# [French-English](http://www.manythings.org/anki/fra-eng.zip) | |
# and [German-English](http://www.manythings.org/anki/deu-eng.zip) were perfect. | |
import string | |
import re | |
from pickle import dump | |
from unicodedata import normalize | |
from numpy import array | |
import numpy as np | |
from typing import * | |
from keras.preprocessing.text import Tokenizer | |
from keras.preprocessing.sequence import pad_sequences | |
import os | |
from keras.utils import to_categorical | |
from pickle import load | |
from numpy import array | |
from keras.preprocessing.text import Tokenizer | |
from keras.preprocessing.sequence import pad_sequences | |
from keras.utils import to_categorical | |
from keras.utils.vis_utils import plot_model | |
from keras.models import Sequential | |
from keras.layers import LSTM | |
from keras.layers import Dense | |
from keras.layers import Embedding | |
from keras.layers import RepeatVector | |
from keras.layers import TimeDistributed | |
from keras.callbacks import ModelCheckpoint | |
from numpy.random import shuffle | |
class Transformation(): | |
def __init__(self, filepath): | |
if os.path.exists(filepath): | |
self.filepath = filepath | |
self.load_text(filepath) | |
self.get_clean_lines() | |
self.split_lines() | |
else: | |
raise Exception("The give file does not exist.") | |
def load_text(self, filepath): | |
'''Returns the text contained in the given file.''' | |
with open(filepath, mode='rt', encoding='utf-8') as f: | |
self.raw_text = f.read() | |
def _clean_lines(self, lines): | |
cleaned = list() | |
# prepare regex for char filtering | |
re_punc = re.compile('[%s]' % re.escape(string.punctuation)) | |
re_print = re.compile('[^%s]' % re.escape(string.printable)) | |
for line in lines: | |
# normalize unicode characters | |
line = normalize('NFD', line).encode('ascii', 'ignore') | |
line = line.decode('UTF-8') | |
pair = line.split('\t') | |
duo = list() | |
for item in pair: | |
# tokenize on white space | |
words = item.split() | |
words = [word.lower() for word in words] | |
words = [re_punc.sub('', w) for w in words] | |
words = [re_print.sub('', w) for w in words] | |
words = [word for word in words if word.isalpha()] | |
duo.append(' '.join(words)) | |
line = '\t'.join(duo) | |
cleaned.append(line) | |
return array(cleaned) | |
def get_clean_lines(self): | |
'''Returns cleaned lines from the given document.''' | |
lines = self.raw_text.strip().split('\n') | |
shuffle(lines) | |
self.clean_lines = self._clean_lines(lines) | |
def split_lines(self): | |
l1 = list() | |
l2 = list() | |
for line in self.clean_lines: | |
parts = line.split('\t') | |
l1.append(parts[0]) | |
l2.append(parts[1]) | |
self.lang1 = LanguageData(l1) | |
self.lang2 = LanguageData(l2) | |
class LanguageData(): | |
def __init__(self, lines): | |
self.hot_lines = None | |
self.clean_lines = array(lines) | |
self.create_tokenizer() | |
self.vocab_size = len(self.tokenizer.word_index) + 1 | |
self.max_word_count = max(len(line.split()) for line in lines) | |
self.encode_sequences() | |
self.split() | |
def split(self): | |
test_amount = int(round(len(self.encoded_lines) / 10)) | |
p = len(self.encoded_lines) - test_amount | |
if self.hot_lines is None: | |
self.train_data = self.encoded_lines[:p] | |
self.test_data = self.encoded_lines[p:] | |
else: | |
self.train_data = self.hot_lines[:p] | |
self.test_data = self.hot_lines[p:] | |
def create_tokenizer(self): | |
self.tokenizer = Tokenizer() | |
self.tokenizer.fit_on_texts(self.clean_lines) | |
def encode_sequences(self): | |
# integer encode sequences | |
a = self.tokenizer.texts_to_sequences(self.clean_lines) | |
# pad sequences with 0's | |
self.encoded_lines = pad_sequences(a, maxlen=self.max_word_count, padding='post') | |
def make_hot(self): | |
ylist = list() | |
for a in self.encoded_lines: | |
hot = to_categorical(a, num_classes=self.vocab_size) | |
ylist.append(hot) | |
y = array(ylist) | |
self.hot_lines = y.reshape(self.encoded_lines.shape[0], self.encoded_lines.shape[1], self.vocab_size) | |
self.split() | |
class Learning(): | |
def __init__(self, source: LanguageData, target: LanguageData, n_unit = 64): | |
self.source = source | |
self.target = target | |
self.create_model(self.source.vocab_size, self.target.vocab_size, self.source.max_word_count, self.target.max_word_count, n_unit) | |
def create_model(self, src_vocab, tar_vocab, src_timesteps, tar_timesteps, n_units): | |
model = Sequential() | |
model.add(Embedding(src_vocab, n_units, input_length=src_timesteps, mask_zero=True)) | |
model.add(LSTM(n_units)) | |
model.add(RepeatVector(tar_timesteps)) | |
model.add(LSTM(n_units, return_sequences=True)) | |
model.add(TimeDistributed(Dense(tar_vocab, activation='softmax'))) | |
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc']) | |
self.model = model | |
model.summary() | |
def learn(self): | |
print("Starting to train") | |
self.model.fit(self.source.train_data, self.target.train_data, epochs=10, batch_size=64, verbose=0) | |
e = self.model.evaluate(self.source.test_data, self.target.test_data, verbose=0) | |
print("%s: %.0f%%" % (self.model.metrics_names[1], e[1] * 100)) | |
return e[1] | |
transfo = Transformation("./data.txt") | |
transfo.lang2.make_hot() | |
learner = Learning(transfo.lang1, transfo.lang2) | |
learner.learn() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment