Skip to content

Instantly share code, notes, and snippets.

@kanjieater
Created November 27, 2020 22:44
Show Gist options
  • Save kanjieater/7ae2e2c930a66604c4e011fe032367a3 to your computer and use it in GitHub Desktop.
Save kanjieater/7ae2e2c930a66604c4e011fe032367a3 to your computer and use it in GitHub Desktop.
ankiconnect but without crashing on client disconnects
# Copyright 2016-2020 Alex Yatskov
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import json
import select
import socket
from . import web, util
log = None
logPath = 'C:/Users/AppData/Roaming/Anki2/addons21/2055492159/log.log'
if logPath is not None:
log = open(logPath, 'w')
def logEvent(name, data):
if log is not None:
log.write('[{}]\n'.format(name))
log.write('[{}]\n'.format(data))
# json.dump(data, log, indent=4, sort_keys=True)
log.write('\n\n')
log.flush()
#
# WebRequest
#
class WebRequest:
def __init__(self, headers, body):
self.headers = headers
self.body = body
#
# WebClient
#
class WebClient:
def __init__(self, sock, handler):
self.sock = sock
self.handler = handler
self.readBuff = bytes()
self.writeBuff = bytes()
def advance(self, recvSize=1024):
if self.sock is None:
return False
rlist, wlist = select.select([self.sock], [self.sock], [], 0)[:2]
logEvent('rlist', rlist)
self.sock.settimeout(1.0)
if rlist:
while True:
try:
logEvent('getmsg', 'notclosed')
msg = self.sock.recv(recvSize)
except: #ConnectionResetError or socket.timeout
logEvent('close', 'close')
self.close()
return False
if not msg:
logEvent('not msg', msg)
self.close()
return False
logEvent('msg', msg)
self.readBuff += msg
req, length = self.parseRequest(self.readBuff)
if req is not None:
logEvent('build up', 'buildup')
self.readBuff = self.readBuff[length:]
self.writeBuff += self.handler(req)
break
else:
logEvent('huh', 'huh')
if wlist and self.writeBuff:
try:
length = self.sock.send(self.writeBuff)
self.writeBuff = self.writeBuff[length:]
if not self.writeBuff:
self.close()
return False
except:
self.close()
return False
return True
def close(self):
if self.sock is not None:
self.sock.close()
self.sock = None
self.readBuff = bytes()
self.writeBuff = bytes()
def parseRequest(self, data):
parts = data.split('\r\n\r\n'.encode('utf-8'), 1)
if len(parts) == 1:
return None, 0
headers = {}
for line in parts[0].split('\r\n'.encode('utf-8')):
pair = line.split(': '.encode('utf-8'))
headers[pair[0].lower()] = pair[1] if len(pair) > 1 else None
headerLength = len(parts[0]) + 4
bodyLength = int(headers.get('content-length'.encode('utf-8'), 0))
totalLength = headerLength + bodyLength
if totalLength > len(data):
return None, 0
body = data[headerLength : totalLength]
return WebRequest(headers, body), totalLength
#
# WebServer
#
class WebServer:
def __init__(self, handler):
self.handler = handler
self.clients = []
self.sock = None
def advance(self):
if self.sock is not None:
self.acceptClients()
self.advanceClients()
def acceptClients(self):
rlist = select.select([self.sock], [], [], 0)[0]
if not rlist:
return
clientSock = self.sock.accept()[0]
if clientSock is not None:
clientSock.setblocking(False)
# clientSock.settimeout(5.0)
self.clients.append(WebClient(clientSock, self.handlerWrapper))
def advanceClients(self):
self.clients = list(filter(lambda c: c.advance(), self.clients))
def listen(self):
self.close()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.setblocking(False)
self.sock.bind((util.setting('webBindAddress'), util.setting('webBindPort')))
self.sock.listen(util.setting('webBacklog'))
def handlerWrapper(self, req):
if len(req.body) == 0:
body = 'AnkiConnect v.{}'.format(util.setting('apiVersion')).encode('utf-8')
else:
try:
params = json.loads(req.body.decode('utf-8'))
body = json.dumps(self.handler(params)).encode('utf-8')
except ValueError:
body = json.dumps(None).encode('utf-8')
# handle multiple cors origins by checking the 'origin'-header against the allowed origin list from the config
webCorsOriginList = util.setting('webCorsOriginList')
# keep support for deprecated 'webCorsOrigin' field, as long it is not removed
webCorsOrigin = util.setting('webCorsOrigin')
if webCorsOrigin:
webCorsOriginList.append(webCorsOrigin)
corsOrigin = 'http://localhost'
allowAllCors = '*' in webCorsOriginList # allow CORS for all domains
if len(webCorsOriginList) == 1 and not allowAllCors:
corsOrigin = webCorsOriginList[0]
elif b'origin' in req.headers:
originStr = req.headers[b'origin'].decode()
if originStr in webCorsOriginList or allowAllCors:
corsOrigin = originStr
headers = [
['HTTP/1.1 200 OK', None],
['Content-Type', 'text/json'],
['Access-Control-Allow-Origin', corsOrigin],
['Content-Length', str(len(body))]
]
resp = bytes()
for key, value in headers:
if value is None:
resp += '{}\r\n'.format(key).encode('utf-8')
else:
resp += '{}: {}\r\n'.format(key, value).encode('utf-8')
resp += '\r\n'.encode('utf-8')
resp += body
return resp
def close(self):
if self.sock is not None:
self.sock.close()
self.sock = None
for client in self.clients:
client.close()
self.clients = []
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment