Skip to content

Instantly share code, notes, and snippets.

@LittleKey
Last active May 21, 2020 10:30
Show Gist options
  • Save LittleKey/016c46be3aa3bd51d3b2178aa68a7e39 to your computer and use it in GitHub Desktop.
Save LittleKey/016c46be3aa3bd51d3b2178aa68a7e39 to your computer and use it in GitHub Desktop.
通过微信发送命令到远程服务器. 基于werkzeug + gunicorn(gevent) 实现, 造了很多轮子
# coding: utf8
from gevent.monkey import patch_all; patch_all() # noqa
import re
import os
import sys
import time
import json
import struct
import string
import random
import base64
import hashlib
import logging
import datetime
import requests
import binascii
import traceback
from functools import wraps
from contextlib import ExitStack
from urllib.parse import urlparse
from threading import local, get_ident
import untangle
from Crypto.Cipher import AES
from Crypto.PublicKey import RSA
from Crypto.Random import atfork
from blinker import signal
from werkzeug.wrappers import Request, Response
from werkzeug.routing import Map, Rule
from werkzeug.exceptions import HTTPException, NotFound
from werkzeug.wsgi import (
get_current_url,
get_content_length,
)
import gunicorn.util
from gunicorn.app.wsgiapp import WSGIApplication
###
# Logger
###
logger = logging.getLogger('remote.control.app')
logger.setLevel(logging.INFO)
for handler in logger.handlers[:]:
logger.removeHandler(handler)
class LogFormatter(logging.Formatter):
def format(self, record):
msg = obj_to_str(record.msg)
log_ctx = get_log_ctx()
if log_ctx:
metas = []
for k, v in log_ctx.items():
if k == '-':
for item in v:
metas.append(obj_to_str(item))
else:
metas.append(f'{obj_to_str(k)} => {obj_to_str(v)}')
if metas:
meta = ' 'join(metas)
msg = f'[{meta}] ## {msg}'
record.msg = msg
return super(LogFormatter, self).format(record)
formatter = LogFormatter(
'%(asctime)s %(levelname)-6s '
'%(name)s[%(process)d] %(message)s'
)
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setFormatter(formatter)
logger.addHandler(stdout_handler)
def get_log_ctx():
return ctx.get('logging_ctx', {})
def add_log_ctx(*args, **kwargs):
# log meta will ordered after python version more than 3.6
logging_ctx = ctx.get('logging_ctx', {})
if args:
logging_ctx.setdefault('-', []).extend(args)
if kwargs:
logging_ctx.setdefault(name, {}).update(kwargs)
###
# Utils
###
def dict_to_obj(d):
class Obj(object):
pass
o = Obj()
o.__dict__.update(d)
return o
def obj_to_str(o):
if isinstance(o, (int, str, float, bool)):
return str(o)
return repr(o)
def get_request_json(request):
request.max_content_length = 1024 * 1024 * 4 # 4MB
content_length = get_content_length(request.environ)
if not content_length:
return
content_type = request.headers.get('content-type')
if not content_type or content_type.startswith('application/json'):
data = bytes()
while len(data) < content_length:
data += request.get_data()
return json.loads(request.data.decode('utf8'))
def get_current_method(environ):
return environ.get('REQUEST_METHOD', None)
def get_current_client_host(environ):
return environ.get('HTTP_X_REAL_IP', None) or \
environ.get('HTTP_X_FORWARDED_FOR', None)
def get_request_info_msg():
# TODO make logging meta more flexable
status_code = ctx.response.status_code \
if isinstance(ctx.response, Response) else 500 # default 500
return (
f'reqid => {ctx.request_id} '
f'host => {ctx.request_host} '
f'{ctx.request_method} {ctx.request_path} '
f'{status_code} - {ctx.request_cost_time}ms'
)
def get_random_string(length):
if length <= 0:
return ''
rule = string.ascii_letters + string.digits
rand_list = random.sample(rule, length)
return ''.join(rand_list)
def encrypt_secret_config(secret_config):
secret_bytes = json.dumps(secret_config).encode()
with open(config['public_key_file'], 'r') as f:
pubkey = RSA.importKey(f.read())
encrypt_bytes = pubkey.encrypt(secret_bytes, 'M')
return base64.b64decode(encrypt_bytes)
###
# Wechat Utils
###
def verify_signature(signature, *data):
sha1 = hashlib.sha1()
for s in sorted(data):
sha1.update(s.encode('utf8'))
return signature == sha1.hexdigest()
def decrypt_data(encoding_msg, msg_signature, timestamp, token, nonce,
encoding_aes_key, appid):
if not verify_signature(msg_signature, timestamp,
token, nonce, encoding_msg):
raise ValueError('Signature Error')
cipher = base64.b64decode(encoding_msg.encode('utf8'))
aes_key = base64.b64decode((encoding_aes_key + '=').encode())
decryptor = AES.new(aes_key, AES.MODE_CBC, aes_key[:16])
decryption_bytes = decryptor.decrypt(cipher)
pkcs7_padding_len = int(binascii.hexlify(decryption_bytes[-1:]), 16)
decryption_bytes = decryption_bytes[:-pkcs7_padding_len] # cut padding
msg_len = int(binascii.hexlify(decryption_bytes[16:20]), 16)
appid_bytes_len = len(appid.encode())
msg_data = decryption_bytes[20:-appid_bytes_len]
if msg_len != len(msg_data):
raise ValueError('Message Length Error.')
msg_appid = decryption_bytes[-appid_bytes_len:].decode('utf8')
if msg_appid != appid:
raise ValueError('Invalid AppId')
return untangle.parse(msg_data.decode('utf8'))
def encrypt_data(data, encoding_aes_key, token, appid):
timestamp = str(int(time.time()))
# noqa msg_encrypt=Base64_Encode(AES_Encrypt [random(16B)+ msg_len(4B) + msg + $AppId])
nonce = get_random_string(16)
prefix = nonce.encode() + struct.pack('>I', len(data))
msg_data = f'{data}{appid}'
msg_bytes = prefix + msg_data.encode()
# pkcs7 encode
padding_len = 32 - len(msg_bytes) % 32
padding_char = '%02x' % padding_len
msg_bytes += binascii.unhexlify(padding_char * padding_len)
# aes cbc encrypt
aes_key = base64.b64decode((encoding_aes_key + '=').encode())
encryptor = AES.new(aes_key, AES.MODE_CBC, aes_key[:16])
cipher = encryptor.encrypt(msg_bytes)
encoding_msg = base64.b64encode(cipher).decode('utf8')
# signature
sha1 = hashlib.sha1()
for s in sorted([token, timestamp, nonce, encoding_msg]):
sha1.update(s.encode('utf8'))
return f'''
<xml>
<Encrypt><![CDATA[{encoding_msg}]]></Encrypt>
<MsgSignature><![CDATA[{sha1.hexdigest()}]]></MsgSignature>
<TimeStamp>{timestamp}</TimeStamp>
<Nonce><![CDATA[{nonce}]]></Nonce>
</xml>
'''
def wechat_deco(func):
@wraps(func)
def wrapper(request):
appid = config['wechat']['appid']
token = config['wechat']['token']
encoding_aes_key = config['wechat']['encoding_aes_key']
openid = request.args.get('openid', '')
encrypt_type = request.args.get('encrypt_type', '')
msg_signature = request.args.get('msg_signature', '')
signature = request.args.get('signature', '')
nonce = request.args.get('nonce', '')
timestamp = request.args.get('timestamp', '')
public_account_id = ''
if signature and \
not verify_signature(signature, token, timestamp, nonce):
raise IllegalRequest('Signature Error')
if request.data:
xml_data = untangle.parse(request.data.decode('utf8'))
public_account_id = xml_data.xml.ToUserName.cdata
request.data = xml_data.xml
if encrypt_type == 'aes':
try:
xml_msg = decrypt_data(
request.data.Encrypt.cdata, msg_signature, timestamp,
token, nonce, encoding_aes_key, appid)
request.data = xml_msg.xml
except Exception as e:
logger.exception(f'Decrypt Wechat data error. {e}')
raise IllegalRequest('Decrypt data error.')
try:
response = func(request)
except Exception as e:
logger.exception(f'Handle wechat data error: {e}')
response = {'content': e}
if ctx.response_type_wrap == 'xml':
create_time = int(time.time())
msg_type = response.get('msg_type', 'text')
content = response.get('content', '')
response = f'''
<xml>
<ToUserName><![CDATA[{openid}]]></ToUserName>
<FromUserName><![CDATA[{public_account_id}]]></FromUserName>
<CreateTime>{create_time}</CreateTime>
<MsgType><![CDATA[{msg_type}]]></MsgType>
<Content><![CDATA[{content}]]></Content>
</xml>
'''
if encrypt_type == 'aes':
try:
response = \
encrypt_data(response, encoding_aes_key, token, appid)
except Exception as e:
logger.exception(f'Encrypt Wechat data error. {e}')
raise IllegalRequest('Encrypt data error.')
return response
return wrapper
###
# Context
###
# set a threading per request in gunicorn
# local belong current threading
class Context(type):
_locals = {}
def __new__(mcls, name, basees, attrs, **kwargs):
if name != 'Context':
raise TypeError('Create ContextDelete class error')
attrs['_locals'] = mcls._locals
attrs['__dict__'] = {}
return super().__new__(mcls, name, basees, attrs)
@property
def _local(cls):
return cls._locals.setdefault(get_ident(), local())
def get(cls, name, default=None):
return getattr(cls._local, name, default)
def set(cls, name, value):
setattr(cls._local, name, value)
def clear(cls):
cls._local.__dict__.clear()
def __getattr__(cls, name):
return cls.get(name)
def __setattr__(cls, name, value):
cls.set(name, value)
ctx = Context('Context', (object,), {})
def gen_request_id():
salt = random.randint(10, 99)
return hashlib.sha1(
f'request_id:{ctx.request_host}.'
f'{ctx.request_url}.'
f'{ctx.request_method}.'
f'{ctx.request_timestamp}.{salt}'.encode()).hexdigest()
def context_deco(func):
"""context middleware. switch context per request
"""
@wraps(func)
def wrapper(app, environ, start_response):
try:
now = datetime.datetime.now()
ctx.request_timestamp = int(now.timestamp() * 1000)
ctx.request_host = get_current_client_host(environ)
url = get_current_url(environ)
if url:
url = urlparse(url)
ctx.request_url = url.geturl()
ctx.request_protocol = url.scheme
ctx.request_path = url.path
ctx.request_query = url.query
ctx.request_method = get_current_method(environ)
# `gen_request_id` must after url,method,timestamp set
ctx.request_id = gen_request_id()
iterable_resp = func(app, environ, start_response)
except HTTPException as http_exc:
logger.exception(get_request_info_msg() + f'\n{http_exc}')
return http_exc(environ, start_response) # return iterable resp
except Exception as exc:
fmt_exc = traceback.format_exc()
logger.exception(get_request_info_msg() + f'\n{fmt_exc}')
raise exc # raise unknown error, no custom response return
else:
logger.info(get_request_info_msg())
return iterable_resp
finally:
# clear local after response return
ctx.clear()
return wrapper
###
# Exception
###
class IllegalRequest(HTTPException):
code = 400
description = "IllegalRequest"
class Unauthorized(HTTPException):
code = 401
descriptionn = "Unauthorized"
###
# WSGI Application
###
def set_response_type_wrap(response_type):
ctx.response_type_wrap = response_type
def dispatcher_deco(dispatcher):
@wraps(dispatcher)
def wrapper(app, request):
ctx.request = request
try:
response = dispatcher(app, request)
ctx.response = response
if ctx.response_type_wrap == 'json':
response.data = json.dumps({
"code": response.status_code,
"data": json.loads(response.data.decode('utf8')),
})
return response
except HTTPException as exc:
response = exc.get_response(request.environ)
ctx.response = response
if ctx.response_type_wrap == 'json':
response.data = json.dumps({
"code": exc.code,
"description": exc.description,
})
exc.response = response
raise exc
finally:
# NOTE 如果出现未捕获异常, 则获取不到response
response = ctx.response
if response:
# common headers
response.headers['Request-Id'] = ctx.request_id
if ctx.response_type_wrap == 'json':
response.headers['Content-Type'] = \
'application/json; charset=utf-8'
request_timestamp = ctx.request_timestamp
if request_timestamp:
ctx.request_cost_time = \
int(time.time() * 1000 - request_timestamp)
return wrapper
class RemoteControlApp(WSGIApplication):
def init(self, parser, opts, args):
self.url_map = Map()
self.func_map = {}
env = dict_to_obj(config['env'])
self.cfg.set('default_proc_name', env.app_name)
self.cfg.set('worker_class', env.worker_class)
self.cfg.set('worker_connections', env.worker_connections)
self.cfg.set('loglevel', 'info')
self.cfg.set('graceful_timeout', env.graceful_timeout)
self.cfg.set('timeout', env.timeout)
self.cfg.set('bind', env.bind)
self.cfg.set('workers', env.num_workers)
self.cfg.set('errorlog', '-')
self.app_uri = env.app_uri
args = [self.app_uri]
super(RemoteControlApp, self).init(parser, opts, args)
def load_wsgiapp(self):
self.chdir()
return gunicorn.util.import_app(self.app_uri)
def run(self):
super(RemoteControlApp, self).run()
@dispatcher_deco
def dispatch_request(self, request):
adapter = self.url_map.bind_to_environ(request.environ)
endpoint, values = adapter.match()
handler = self.func_map.get(endpoint, None)
if not handler: # handler not found
raise NotFound(f"Unknown url '{ctx.request_path}'")
resp = handler(request, **values)
return Response(resp)
@context_deco
def wsgi_app(self, environ, start_response):
request = Request(environ)
response = self.dispatch_request(request)
return response(environ, start_response)
def __call__(self, environ, start_response):
return self.wsgi_app(environ, start_response)
def route(self, path, methods, **kwargs):
def decorator(func):
@wraps(func)
def wrapper(request, **values):
ctx.response_type_wrap = kwargs.get('resp_wrap')
handler = func # NOTE 这样写是为了避免有些愚蠢的lint警告
if kwargs.get('wechat', False):
handler = wechat_deco(handler)
return handler(request, **values)
endpoint = f'{path}'
if self.func_map.get(path, None):
raise ValueError(f'Endpoint({endpoint}) already registered.')
rule = Rule(path, endpoint=endpoint)
self.url_map.add(rule)
self.func_map[endpoint] = wrapper
return wrapper
return decorator
###
# init app
###
def load_config(config_path):
config = {}
with open(config_path) as f:
config = json.load(f)
wechat_secret_file = config.get('wechat_secret_file')
private_key_file = config.get('private_key_file')
passphrase = config.get('private_key_pass')
if all((wechat_secret_file, private_key_file, passphrase)):
with ExitStack() as stack:
fd = stack.enter_context(open(wechat_secret_file, 'rb'))
fk = stack.enter_context(open(private_key_file))
encrypt_bytes = base64.b64decode(fd.read())
# NOTE after process fork. must be call `atfork` hook!!
# for make crypt work normal.
atfork()
prikey = RSA.importKey(fk.read(), passphrase)
data = prikey.decrypt(encrypt_bytes).decode()
config['wechat'] = json.loads(data)
return config
config = load_config('config.json')
app = RemoteControlApp()
###
# Register Routers: map urls and handlers
###
@app.route('/control/<command>', methods=['GET', 'POST'], resp_wrap='json')
def sserver_control_daemon_command(request, command):
from subprocess import getoutput
work_dir = config.get('shadowsocks_work_dir')
if request.method == 'GET':
if command != 'status':
raise IllegalRequest(f"Illegal command '{command}'")
log_file = os.path.join(work_dir, 'shadowsocks.log')
return json.dumps(getoutput(f'tail -n 50 {log_file}'))
if command not in ['start', 'restart', 'stop']:
raise IllegalRequest(f"Illegal command '{command}'")
json_body = get_request_json(request) or {}
if json_body.get("token", None) not in config['wechat']['openids']:
raise Unauthorized('Invalid token')
config_file = os.path.join(work_dir, 'config.json')
pid_file = os.path.join(work_dir, 'pid.file')
log_file = os.path.join(work_dir, 'shadowsocks.log')
return json.dumps(getoutput(
f"ssserver -c {config_file} -d {command} "
f"--pid-file {pid_file} --log-file {log_file}"
))
@app.route('/wechat/msg', methods=['GET', 'POST'],
wechat=True, resp_wrap='xml')
def wechat_message_recevied(request):
if request.method == 'GET':
set_response_type_wrap('text')
return request.args.get('echostr', 'failed')
try:
msg = request.data.Content.cdata
except AttributeError as e:
logger.exception(e)
raise IllegalRequest('No content post.')
if msg.startswith('#ssserver'):
strings = [s.strip() for s in re.split(r'[ :_]', msg)
if not s.isspace()]
command = strings[-1]
bind = config['env']['bind']
resp = requests.post(f'http://{bind}/control/{command}',
json={'token': request.args.get('openid')})
if resp.status_code == 200:
content = resp.json().get('data')
else:
content = 'something wrong'
else:
content = 'Everything\'s Alright'
return {'msg_type': 'text', 'content': content}
if __name__ == '__main__':
app.run()
@LittleKey
Copy link
Author

LittleKey commented Jun 22, 2018

app的config.json文件

{
  "shadowsocks_work_dir": "/etc/shadowsocks",
  "env": {
    "app_uri": "app:app",
    "app_name": "remote.control",
    "worker_class": "gevent",
    "worker_connections": 1000,
    "graceful_timeout": 3,
    "timeout": 30,
    "bind": "0.0.0.0:5000",
    "num_workers": 1
  },
  "wechat_secret_file": "wechat_secret.encrypt",
  "private_key_file": "private.key",
  "private_key_pass": "******",
  "public_key_file": "public.key"
}

@LittleKey
Copy link
Author

TODO

  • make logging meta data more flexable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment