Skip to content

Instantly share code, notes, and snippets.

@ottokart
Created May 1, 2019 18:34
Show Gist options
  • Save ottokart/8f61c406324969cc326d0bb128e3fc4a to your computer and use it in GitHub Desktop.
Save ottokart/8f61c406324969cc326d0bb128e3fc4a to your computer and use it in GitHub Desktop.
Simple script to serve a punctuation API on https://github.com/ottokart/punctuator2 (similar to http://bark.phon.ioc.ee/punctuator). need to change the configuration at the beginning of the file to point to the model file. Also, the script should be in the same dir as the other punctuator scripts. The Europarl model that I use in the demo (http:…
# coding: utf-8
from __future__ import division
from nltk.tokenize import word_tokenize
import models
import data
import theano
import tornado.ioloop
import tornado.web
import re
import theano.tensor as T
import numpy as np
from tornado.options import define, options
define('port', default=8080, help='run on the given port', type=int)
### CONFIGURATION ###
MAX_TOTAL_LENGTH = 10000
LANGUAGES = {
'EN': {
'model_path': './Model_669040_h256_lr0.02.pcl',
'untokenizer': lambda text: text.replace(" '", "'").replace(" n't", "n't").replace("can not", "cannot"),
},
}
### END OF CONFIGURATION ###
default_tokenizer = word_tokenize
default_untokenizer = lambda text: text
class Punctuator(object):
def __init__(self, conf):
self.model_path = conf['model_path']
self.tokenizer = conf.get('tokenizer', default_tokenizer)
self.untokenizer = conf.get('untokenizer', default_untokenizer)
self.lowercase = conf.get('lowercase', True)
x = T.imatrix('x')
net, _ = models.load(self.model_path, 1, x)
self.predict=theano.function(inputs=[x], outputs=net.y)
self.word_vocabulary=net.x_vocabulary
self.punctuation_vocabulary=net.y_vocabulary
self.reverse_punctuation_vocabulary = {v:k for k,v in self.punctuation_vocabulary.items()}
self.human_readable_punctuation_vocabulary = [p[0] for p in self.punctuation_vocabulary if p != data.SPACE]
numbers = re.compile(r'\d')
is_number = lambda x: len(numbers.sub('', x)) / len(x) < 0.6
def convert_punctuation_to_readable(punct_token):
if punct_token == data.SPACE:
return ' '
elif punct_token.startswith('-'):
return ' ' + punct_token[0] + ' '
else:
return punct_token[0] + ' '
def punctuate(words, word_vocabulary, predict, reverse_punctuation_vocabulary, lowercase, writer=None):
if len(words) == 0:
return
if words[-1] != data.END:
words += [data.END]
i = 0
while True:
subsequence = words[i:i+data.MAX_SEQUENCE_LEN]
if len(subsequence) == 0:
break
converted_subsequence = [word_vocabulary.get(data.NUM
if is_number(w)
else (w.lower() if lowercase else w),
word_vocabulary[data.UNK]) for w in subsequence]
y = predict(np.array([converted_subsequence], dtype=np.int32).T)
if writer:
writer.write(subsequence[0].title())
last_eos_idx = 0
punctuations = []
for y_t in y:
p_i = np.argmax(y_t.flatten())
punctuation = reverse_punctuation_vocabulary[p_i]
punctuations.append(punctuation)
if punctuation in data.EOS_TOKENS:
last_eos_idx = len(punctuations) # we intentionally want the index of next element
if subsequence[-1] == data.END:
step = len(subsequence) - 1
elif last_eos_idx != 0:
step = last_eos_idx
else:
step = len(subsequence) - 1
for j in range(step):
current_punctuation = punctuations[j]
yield current_punctuation
if writer:
writer.write(convert_punctuation_to_readable(current_punctuation))
if j < step - 1:
if current_punctuation in data.EOS_TOKENS:
writer.write(subsequence[1+j].title())
else:
writer.write(subsequence[1+j])
if writer:
writer.flush()
if subsequence[-1] == data.END:
break
i += step
if i > MAX_TOTAL_LENGTH:
break
### HANDLERS ###
class BaseHandler(tornado.web.RequestHandler):
def initialize(self, punctuators):
self.punctuators = punctuators
def get_punctuator(self, language):
if not language:
language = 'EN'
if language not in self.punctuators:
raise tornado.web.HTTPError(404)
return punctuators[language]
class MainHandler(BaseHandler):
def post(self, language):
punctuator = self.get_punctuator(language)
text = self.get_argument('text', '')
words = [w for w in punctuator.untokenizer(' '.join(punctuator.tokenizer(text))).split()
if w not in punctuator.punctuation_vocabulary and w not in punctuator.human_readable_punctuation_vocabulary]
list(punctuate(words, punctuator.word_vocabulary, punctuator.predict, punctuator.reverse_punctuation_vocabulary, punctuator.lowercase, writer=self))
if __name__ == '__main__':
tornado.options.parse_command_line()
punctuators = {}
for language, conf in LANGUAGES.iteritems():
print('Initializing %s...' % language)
punctuators[language] = Punctuator(conf)
init_params = dict(punctuators=punctuators)
print 'Serving...'
application = tornado.web.Application([
(r'/?(?P<language>[a-zA-Z]+)?/?', MainHandler, init_params),
])
application.listen(options.port)
tornado.ioloop.IOLoop.current().start()
@sadi304
Copy link

sadi304 commented Jun 1, 2021

Hello, we are having a strange issue replicating this code on our linux server with 8 cpus and 32 gb of ram. Each request is taking like 800ms (single request) and cpu jumps to 800%. But your deployed example taking 200ms on average. Strange thing is running this on a mac system with 12 cpu and 16gb ram it takes 200ms.

May we know your deployed configuration. Or python version? anything that can point us to a solution of this issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment