Skip to content

Instantly share code, notes, and snippets.

@Orbifold
Created January 27, 2018 12:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Orbifold/e5d023b9a280b6b0e86c5c88ce38c8d8 to your computer and use it in GitHub Desktop.
Save Orbifold/e5d023b9a280b6b0e86c5c88ce38c8d8 to your computer and use it in GitHub Desktop.
Keras translation network.
# The [Anki repository](http://www.manythings.org/anki/) has a lot of sentence-pairs to learn a language and they are ideal to train a translation network.
# To judge the quality of a translation it helps to understand a bit both languages so in my case
# the [Dutch-English](http://www.manythings.org/anki/nld-eng.zip),
# [French-English](http://www.manythings.org/anki/fra-eng.zip)
# and [German-English](http://www.manythings.org/anki/deu-eng.zip) were perfect.
import string
import re
from pickle import dump
from unicodedata import normalize
from numpy import array
import numpy as np
from typing import *
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
import os
from keras.utils import to_categorical
from pickle import load
from numpy import array
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from keras.utils.vis_utils import plot_model
from keras.models import Sequential
from keras.layers import LSTM
from keras.layers import Dense
from keras.layers import Embedding
from keras.layers import RepeatVector
from keras.layers import TimeDistributed
from keras.callbacks import ModelCheckpoint
from numpy.random import shuffle
class Transformation():
def __init__(self, filepath):
if os.path.exists(filepath):
self.filepath = filepath
self.load_text(filepath)
self.get_clean_lines()
self.split_lines()
else:
raise Exception("The give file does not exist.")
def load_text(self, filepath):
'''Returns the text contained in the given file.'''
with open(filepath, mode='rt', encoding='utf-8') as f:
self.raw_text = f.read()
def _clean_lines(self, lines):
cleaned = list()
# prepare regex for char filtering
re_punc = re.compile('[%s]' % re.escape(string.punctuation))
re_print = re.compile('[^%s]' % re.escape(string.printable))
for line in lines:
# normalize unicode characters
line = normalize('NFD', line).encode('ascii', 'ignore')
line = line.decode('UTF-8')
pair = line.split('\t')
duo = list()
for item in pair:
# tokenize on white space
words = item.split()
words = [word.lower() for word in words]
words = [re_punc.sub('', w) for w in words]
words = [re_print.sub('', w) for w in words]
words = [word for word in words if word.isalpha()]
duo.append(' '.join(words))
line = '\t'.join(duo)
cleaned.append(line)
return array(cleaned)
def get_clean_lines(self):
'''Returns cleaned lines from the given document.'''
lines = self.raw_text.strip().split('\n')
shuffle(lines)
self.clean_lines = self._clean_lines(lines)
def split_lines(self):
l1 = list()
l2 = list()
for line in self.clean_lines:
parts = line.split('\t')
l1.append(parts[0])
l2.append(parts[1])
self.lang1 = LanguageData(l1)
self.lang2 = LanguageData(l2)
class LanguageData():
def __init__(self, lines):
self.hot_lines = None
self.clean_lines = array(lines)
self.create_tokenizer()
self.vocab_size = len(self.tokenizer.word_index) + 1
self.max_word_count = max(len(line.split()) for line in lines)
self.encode_sequences()
self.split()
def split(self):
test_amount = int(round(len(self.encoded_lines) / 10))
p = len(self.encoded_lines) - test_amount
if self.hot_lines is None:
self.train_data = self.encoded_lines[:p]
self.test_data = self.encoded_lines[p:]
else:
self.train_data = self.hot_lines[:p]
self.test_data = self.hot_lines[p:]
def create_tokenizer(self):
self.tokenizer = Tokenizer()
self.tokenizer.fit_on_texts(self.clean_lines)
def encode_sequences(self):
# integer encode sequences
a = self.tokenizer.texts_to_sequences(self.clean_lines)
# pad sequences with 0's
self.encoded_lines = pad_sequences(a, maxlen=self.max_word_count, padding='post')
def make_hot(self):
ylist = list()
for a in self.encoded_lines:
hot = to_categorical(a, num_classes=self.vocab_size)
ylist.append(hot)
y = array(ylist)
self.hot_lines = y.reshape(self.encoded_lines.shape[0], self.encoded_lines.shape[1], self.vocab_size)
self.split()
class Learning():
def __init__(self, source: LanguageData, target: LanguageData, n_unit = 64):
self.source = source
self.target = target
self.create_model(self.source.vocab_size, self.target.vocab_size, self.source.max_word_count, self.target.max_word_count, n_unit)
def create_model(self, src_vocab, tar_vocab, src_timesteps, tar_timesteps, n_units):
model = Sequential()
model.add(Embedding(src_vocab, n_units, input_length=src_timesteps, mask_zero=True))
model.add(LSTM(n_units))
model.add(RepeatVector(tar_timesteps))
model.add(LSTM(n_units, return_sequences=True))
model.add(TimeDistributed(Dense(tar_vocab, activation='softmax')))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
self.model = model
model.summary()
def learn(self):
print("Starting to train")
self.model.fit(self.source.train_data, self.target.train_data, epochs=10, batch_size=64, verbose=0)
e = self.model.evaluate(self.source.test_data, self.target.test_data, verbose=0)
print("%s: %.0f%%" % (self.model.metrics_names[1], e[1] * 100))
return e[1]
transfo = Transformation("./data.txt")
transfo.lang2.make_hot()
learner = Learning(transfo.lang1, transfo.lang2)
learner.learn()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment