Created
August 31, 2017 13:38
-
-
Save mkroutikov/29c7b1eb3453bafb1034e24c5cab892e to your computer and use it in GitHub Desktop.
ilabs.client
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import socket | |
import json | |
from urllib import request | |
import time | |
from izutil.progress import NULL_PROGRESS_METER | |
class ILabsEngine: | |
URL_UPLOAD = 'https://api.innodatalabs.com/v1/documents/input/' | |
URL_PREDICT = 'https://api.innodatalabs.com/v1/reference/{domain}/{name}' | |
URL_FEEDBACK = 'https://api.innodatalabs.com/v1/documents/training/{domain}/{batch_id}.xml' | |
def __init__(self, domain, user_key, timeout=None): | |
self.domain = domain | |
self._user_key = user_key | |
self._timeout = timeout | |
if self._timeout is None: | |
self._timeout = socket._GLOBAL_DEFAULT_TIMEOUT | |
def request(self, method, url, data=None, content_type=None): | |
print('>', method, url) | |
headers = { | |
'User-Key' : self._user_key, | |
'User-Agent' : 'RED workbench', | |
'cache-control': 'no-cache' | |
} | |
if content_type is not None: | |
headers['Content-Type'] = content_type | |
req = request.Request(url, | |
method=method, | |
data=data, | |
headers= headers | |
) | |
content = request.urlopen(req, timeout=self._timeout).read() | |
return content | |
def post(self, url, data, content_type): | |
return self.request('POST', url, data, content_type=content_type) | |
def get(self, url): | |
return self.request('GET', url) | |
def __call__(self, binary_data, progress=NULL_PROGRESS_METER): | |
with progress(title='Predicting using domain {domain}'.format(domain=self.domain)): | |
try: | |
task_cancel_url = None | |
out = self.post(self.URL_UPLOAD, | |
data=binary_data, | |
content_type='application/octet-stream') | |
out = json.loads(out.decode()) | |
bytes_accepted = int(out['bytes_accepted']) | |
input_filename = out['input_filename'] | |
progress.worked(subtitle='uploaded, accepted size=%s' % bytes_accepted) | |
predict_reference_url = self.URL_PREDICT.format(name=input_filename, domain=self.domain) | |
out = self.get(predict_reference_url) | |
out = json.loads(out.decode()) | |
task_id = out['task_id'] | |
task_cancel_url = out['task_cancel_url'].replace('http:', 'https:') | |
document_output_url = out['document_output_url'].replace('http:', 'https:') | |
task_status_url = out['task_status_url'].replace('http:', 'https:') | |
output_filename = out['output_filename'] | |
#version = out['version'] | |
progress.worked(subtitle='job submitted, taks id: %s' % task_id) | |
count = 1 | |
for _ in range(100): | |
for count_idx in reversed(range(count)): | |
time.sleep(1.0) | |
progress.worked(subtitle='retrying in: %s' % (count_idx+1)) | |
response = self.get(task_status_url) | |
out = json.loads(response.decode()) | |
assert out is not None, response | |
progress.worked(subtitle='progress: %s/%s' % (out['progress'], out['steps'])) | |
if out['completed']: | |
break | |
count = min(count*2, 60) | |
else: | |
raise RuntimeError('timeout') | |
err = out.get('error') | |
if err is not None: | |
raise RuntimeError('Prediction server returned error: ' + err) | |
progress.worked(subtitle='finished, fetching result') | |
predictions = self.get(document_output_url) | |
task_cancel_url = None | |
finally: | |
if task_cancel_url is not None: | |
self.get(task_cancel_url) | |
return predictions | |
def send_feedback(self, binary_data, batch_id, progress=NULL_PROGRESS_METER): | |
with progress(title='Sending feedback for domain {domain}, batch id {batch_id}'.format( | |
domain=self.domain, batch_id=batch_id)): | |
url = self.URL_FEEDBACK.format(domain=self.domain, batch_id=batch_id) | |
out = self.post(url, | |
data=binary_data, | |
content_type='application/octet-stream') | |
out = json.loads(out.decode()) | |
bytes_accepted = int(out['bytes_accepted']) | |
print('Uploaded feedback for domain', self.domain, 'as', url, 'bytes_accepted', bytes_accepted) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment