Last active
May 21, 2020 10:30
-
-
Save LittleKey/016c46be3aa3bd51d3b2178aa68a7e39 to your computer and use it in GitHub Desktop.
通过微信发送命令到远程服务器. 基于werkzeug + gunicorn(gevent) 实现, 造了很多轮子
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf8 | |
from gevent.monkey import patch_all; patch_all() # noqa | |
import re | |
import os | |
import sys | |
import time | |
import json | |
import struct | |
import string | |
import random | |
import base64 | |
import hashlib | |
import logging | |
import datetime | |
import requests | |
import binascii | |
import traceback | |
from functools import wraps | |
from contextlib import ExitStack | |
from urllib.parse import urlparse | |
from threading import local, get_ident | |
import untangle | |
from Crypto.Cipher import AES | |
from Crypto.PublicKey import RSA | |
from Crypto.Random import atfork | |
from blinker import signal | |
from werkzeug.wrappers import Request, Response | |
from werkzeug.routing import Map, Rule | |
from werkzeug.exceptions import HTTPException, NotFound | |
from werkzeug.wsgi import ( | |
get_current_url, | |
get_content_length, | |
) | |
import gunicorn.util | |
from gunicorn.app.wsgiapp import WSGIApplication | |
### | |
# Logger | |
### | |
logger = logging.getLogger('remote.control.app') | |
logger.setLevel(logging.INFO) | |
for handler in logger.handlers[:]: | |
logger.removeHandler(handler) | |
class LogFormatter(logging.Formatter): | |
def format(self, record): | |
msg = obj_to_str(record.msg) | |
log_ctx = get_log_ctx() | |
if log_ctx: | |
metas = [] | |
for k, v in log_ctx.items(): | |
if k == '-': | |
for item in v: | |
metas.append(obj_to_str(item)) | |
else: | |
metas.append(f'{obj_to_str(k)} => {obj_to_str(v)}') | |
if metas: | |
meta = ' 'join(metas) | |
msg = f'[{meta}] ## {msg}' | |
record.msg = msg | |
return super(LogFormatter, self).format(record) | |
formatter = LogFormatter( | |
'%(asctime)s %(levelname)-6s ' | |
'%(name)s[%(process)d] %(message)s' | |
) | |
stdout_handler = logging.StreamHandler(sys.stdout) | |
stdout_handler.setFormatter(formatter) | |
logger.addHandler(stdout_handler) | |
def get_log_ctx(): | |
return ctx.get('logging_ctx', {}) | |
def add_log_ctx(*args, **kwargs): | |
# log meta will ordered after python version more than 3.6 | |
logging_ctx = ctx.get('logging_ctx', {}) | |
if args: | |
logging_ctx.setdefault('-', []).extend(args) | |
if kwargs: | |
logging_ctx.setdefault(name, {}).update(kwargs) | |
### | |
# Utils | |
### | |
def dict_to_obj(d): | |
class Obj(object): | |
pass | |
o = Obj() | |
o.__dict__.update(d) | |
return o | |
def obj_to_str(o): | |
if isinstance(o, (int, str, float, bool)): | |
return str(o) | |
return repr(o) | |
def get_request_json(request): | |
request.max_content_length = 1024 * 1024 * 4 # 4MB | |
content_length = get_content_length(request.environ) | |
if not content_length: | |
return | |
content_type = request.headers.get('content-type') | |
if not content_type or content_type.startswith('application/json'): | |
data = bytes() | |
while len(data) < content_length: | |
data += request.get_data() | |
return json.loads(request.data.decode('utf8')) | |
def get_current_method(environ): | |
return environ.get('REQUEST_METHOD', None) | |
def get_current_client_host(environ): | |
return environ.get('HTTP_X_REAL_IP', None) or \ | |
environ.get('HTTP_X_FORWARDED_FOR', None) | |
def get_request_info_msg(): | |
# TODO make logging meta more flexable | |
status_code = ctx.response.status_code \ | |
if isinstance(ctx.response, Response) else 500 # default 500 | |
return ( | |
f'reqid => {ctx.request_id} ' | |
f'host => {ctx.request_host} ' | |
f'{ctx.request_method} {ctx.request_path} ' | |
f'{status_code} - {ctx.request_cost_time}ms' | |
) | |
def get_random_string(length): | |
if length <= 0: | |
return '' | |
rule = string.ascii_letters + string.digits | |
rand_list = random.sample(rule, length) | |
return ''.join(rand_list) | |
def encrypt_secret_config(secret_config): | |
secret_bytes = json.dumps(secret_config).encode() | |
with open(config['public_key_file'], 'r') as f: | |
pubkey = RSA.importKey(f.read()) | |
encrypt_bytes = pubkey.encrypt(secret_bytes, 'M') | |
return base64.b64decode(encrypt_bytes) | |
### | |
# Wechat Utils | |
### | |
def verify_signature(signature, *data): | |
sha1 = hashlib.sha1() | |
for s in sorted(data): | |
sha1.update(s.encode('utf8')) | |
return signature == sha1.hexdigest() | |
def decrypt_data(encoding_msg, msg_signature, timestamp, token, nonce, | |
encoding_aes_key, appid): | |
if not verify_signature(msg_signature, timestamp, | |
token, nonce, encoding_msg): | |
raise ValueError('Signature Error') | |
cipher = base64.b64decode(encoding_msg.encode('utf8')) | |
aes_key = base64.b64decode((encoding_aes_key + '=').encode()) | |
decryptor = AES.new(aes_key, AES.MODE_CBC, aes_key[:16]) | |
decryption_bytes = decryptor.decrypt(cipher) | |
pkcs7_padding_len = int(binascii.hexlify(decryption_bytes[-1:]), 16) | |
decryption_bytes = decryption_bytes[:-pkcs7_padding_len] # cut padding | |
msg_len = int(binascii.hexlify(decryption_bytes[16:20]), 16) | |
appid_bytes_len = len(appid.encode()) | |
msg_data = decryption_bytes[20:-appid_bytes_len] | |
if msg_len != len(msg_data): | |
raise ValueError('Message Length Error.') | |
msg_appid = decryption_bytes[-appid_bytes_len:].decode('utf8') | |
if msg_appid != appid: | |
raise ValueError('Invalid AppId') | |
return untangle.parse(msg_data.decode('utf8')) | |
def encrypt_data(data, encoding_aes_key, token, appid): | |
timestamp = str(int(time.time())) | |
# noqa msg_encrypt=Base64_Encode(AES_Encrypt [random(16B)+ msg_len(4B) + msg + $AppId]) | |
nonce = get_random_string(16) | |
prefix = nonce.encode() + struct.pack('>I', len(data)) | |
msg_data = f'{data}{appid}' | |
msg_bytes = prefix + msg_data.encode() | |
# pkcs7 encode | |
padding_len = 32 - len(msg_bytes) % 32 | |
padding_char = '%02x' % padding_len | |
msg_bytes += binascii.unhexlify(padding_char * padding_len) | |
# aes cbc encrypt | |
aes_key = base64.b64decode((encoding_aes_key + '=').encode()) | |
encryptor = AES.new(aes_key, AES.MODE_CBC, aes_key[:16]) | |
cipher = encryptor.encrypt(msg_bytes) | |
encoding_msg = base64.b64encode(cipher).decode('utf8') | |
# signature | |
sha1 = hashlib.sha1() | |
for s in sorted([token, timestamp, nonce, encoding_msg]): | |
sha1.update(s.encode('utf8')) | |
return f''' | |
<xml> | |
<Encrypt><![CDATA[{encoding_msg}]]></Encrypt> | |
<MsgSignature><![CDATA[{sha1.hexdigest()}]]></MsgSignature> | |
<TimeStamp>{timestamp}</TimeStamp> | |
<Nonce><![CDATA[{nonce}]]></Nonce> | |
</xml> | |
''' | |
def wechat_deco(func): | |
@wraps(func) | |
def wrapper(request): | |
appid = config['wechat']['appid'] | |
token = config['wechat']['token'] | |
encoding_aes_key = config['wechat']['encoding_aes_key'] | |
openid = request.args.get('openid', '') | |
encrypt_type = request.args.get('encrypt_type', '') | |
msg_signature = request.args.get('msg_signature', '') | |
signature = request.args.get('signature', '') | |
nonce = request.args.get('nonce', '') | |
timestamp = request.args.get('timestamp', '') | |
public_account_id = '' | |
if signature and \ | |
not verify_signature(signature, token, timestamp, nonce): | |
raise IllegalRequest('Signature Error') | |
if request.data: | |
xml_data = untangle.parse(request.data.decode('utf8')) | |
public_account_id = xml_data.xml.ToUserName.cdata | |
request.data = xml_data.xml | |
if encrypt_type == 'aes': | |
try: | |
xml_msg = decrypt_data( | |
request.data.Encrypt.cdata, msg_signature, timestamp, | |
token, nonce, encoding_aes_key, appid) | |
request.data = xml_msg.xml | |
except Exception as e: | |
logger.exception(f'Decrypt Wechat data error. {e}') | |
raise IllegalRequest('Decrypt data error.') | |
try: | |
response = func(request) | |
except Exception as e: | |
logger.exception(f'Handle wechat data error: {e}') | |
response = {'content': e} | |
if ctx.response_type_wrap == 'xml': | |
create_time = int(time.time()) | |
msg_type = response.get('msg_type', 'text') | |
content = response.get('content', '') | |
response = f''' | |
<xml> | |
<ToUserName><![CDATA[{openid}]]></ToUserName> | |
<FromUserName><![CDATA[{public_account_id}]]></FromUserName> | |
<CreateTime>{create_time}</CreateTime> | |
<MsgType><![CDATA[{msg_type}]]></MsgType> | |
<Content><![CDATA[{content}]]></Content> | |
</xml> | |
''' | |
if encrypt_type == 'aes': | |
try: | |
response = \ | |
encrypt_data(response, encoding_aes_key, token, appid) | |
except Exception as e: | |
logger.exception(f'Encrypt Wechat data error. {e}') | |
raise IllegalRequest('Encrypt data error.') | |
return response | |
return wrapper | |
### | |
# Context | |
### | |
# set a threading per request in gunicorn | |
# local belong current threading | |
class Context(type): | |
_locals = {} | |
def __new__(mcls, name, basees, attrs, **kwargs): | |
if name != 'Context': | |
raise TypeError('Create ContextDelete class error') | |
attrs['_locals'] = mcls._locals | |
attrs['__dict__'] = {} | |
return super().__new__(mcls, name, basees, attrs) | |
@property | |
def _local(cls): | |
return cls._locals.setdefault(get_ident(), local()) | |
def get(cls, name, default=None): | |
return getattr(cls._local, name, default) | |
def set(cls, name, value): | |
setattr(cls._local, name, value) | |
def clear(cls): | |
cls._local.__dict__.clear() | |
def __getattr__(cls, name): | |
return cls.get(name) | |
def __setattr__(cls, name, value): | |
cls.set(name, value) | |
ctx = Context('Context', (object,), {}) | |
def gen_request_id(): | |
salt = random.randint(10, 99) | |
return hashlib.sha1( | |
f'request_id:{ctx.request_host}.' | |
f'{ctx.request_url}.' | |
f'{ctx.request_method}.' | |
f'{ctx.request_timestamp}.{salt}'.encode()).hexdigest() | |
def context_deco(func): | |
"""context middleware. switch context per request | |
""" | |
@wraps(func) | |
def wrapper(app, environ, start_response): | |
try: | |
now = datetime.datetime.now() | |
ctx.request_timestamp = int(now.timestamp() * 1000) | |
ctx.request_host = get_current_client_host(environ) | |
url = get_current_url(environ) | |
if url: | |
url = urlparse(url) | |
ctx.request_url = url.geturl() | |
ctx.request_protocol = url.scheme | |
ctx.request_path = url.path | |
ctx.request_query = url.query | |
ctx.request_method = get_current_method(environ) | |
# `gen_request_id` must after url,method,timestamp set | |
ctx.request_id = gen_request_id() | |
iterable_resp = func(app, environ, start_response) | |
except HTTPException as http_exc: | |
logger.exception(get_request_info_msg() + f'\n{http_exc}') | |
return http_exc(environ, start_response) # return iterable resp | |
except Exception as exc: | |
fmt_exc = traceback.format_exc() | |
logger.exception(get_request_info_msg() + f'\n{fmt_exc}') | |
raise exc # raise unknown error, no custom response return | |
else: | |
logger.info(get_request_info_msg()) | |
return iterable_resp | |
finally: | |
# clear local after response return | |
ctx.clear() | |
return wrapper | |
### | |
# Exception | |
### | |
class IllegalRequest(HTTPException): | |
code = 400 | |
description = "IllegalRequest" | |
class Unauthorized(HTTPException): | |
code = 401 | |
descriptionn = "Unauthorized" | |
### | |
# WSGI Application | |
### | |
def set_response_type_wrap(response_type): | |
ctx.response_type_wrap = response_type | |
def dispatcher_deco(dispatcher): | |
@wraps(dispatcher) | |
def wrapper(app, request): | |
ctx.request = request | |
try: | |
response = dispatcher(app, request) | |
ctx.response = response | |
if ctx.response_type_wrap == 'json': | |
response.data = json.dumps({ | |
"code": response.status_code, | |
"data": json.loads(response.data.decode('utf8')), | |
}) | |
return response | |
except HTTPException as exc: | |
response = exc.get_response(request.environ) | |
ctx.response = response | |
if ctx.response_type_wrap == 'json': | |
response.data = json.dumps({ | |
"code": exc.code, | |
"description": exc.description, | |
}) | |
exc.response = response | |
raise exc | |
finally: | |
# NOTE 如果出现未捕获异常, 则获取不到response | |
response = ctx.response | |
if response: | |
# common headers | |
response.headers['Request-Id'] = ctx.request_id | |
if ctx.response_type_wrap == 'json': | |
response.headers['Content-Type'] = \ | |
'application/json; charset=utf-8' | |
request_timestamp = ctx.request_timestamp | |
if request_timestamp: | |
ctx.request_cost_time = \ | |
int(time.time() * 1000 - request_timestamp) | |
return wrapper | |
class RemoteControlApp(WSGIApplication): | |
def init(self, parser, opts, args): | |
self.url_map = Map() | |
self.func_map = {} | |
env = dict_to_obj(config['env']) | |
self.cfg.set('default_proc_name', env.app_name) | |
self.cfg.set('worker_class', env.worker_class) | |
self.cfg.set('worker_connections', env.worker_connections) | |
self.cfg.set('loglevel', 'info') | |
self.cfg.set('graceful_timeout', env.graceful_timeout) | |
self.cfg.set('timeout', env.timeout) | |
self.cfg.set('bind', env.bind) | |
self.cfg.set('workers', env.num_workers) | |
self.cfg.set('errorlog', '-') | |
self.app_uri = env.app_uri | |
args = [self.app_uri] | |
super(RemoteControlApp, self).init(parser, opts, args) | |
def load_wsgiapp(self): | |
self.chdir() | |
return gunicorn.util.import_app(self.app_uri) | |
def run(self): | |
super(RemoteControlApp, self).run() | |
@dispatcher_deco | |
def dispatch_request(self, request): | |
adapter = self.url_map.bind_to_environ(request.environ) | |
endpoint, values = adapter.match() | |
handler = self.func_map.get(endpoint, None) | |
if not handler: # handler not found | |
raise NotFound(f"Unknown url '{ctx.request_path}'") | |
resp = handler(request, **values) | |
return Response(resp) | |
@context_deco | |
def wsgi_app(self, environ, start_response): | |
request = Request(environ) | |
response = self.dispatch_request(request) | |
return response(environ, start_response) | |
def __call__(self, environ, start_response): | |
return self.wsgi_app(environ, start_response) | |
def route(self, path, methods, **kwargs): | |
def decorator(func): | |
@wraps(func) | |
def wrapper(request, **values): | |
ctx.response_type_wrap = kwargs.get('resp_wrap') | |
handler = func # NOTE 这样写是为了避免有些愚蠢的lint警告 | |
if kwargs.get('wechat', False): | |
handler = wechat_deco(handler) | |
return handler(request, **values) | |
endpoint = f'{path}' | |
if self.func_map.get(path, None): | |
raise ValueError(f'Endpoint({endpoint}) already registered.') | |
rule = Rule(path, endpoint=endpoint) | |
self.url_map.add(rule) | |
self.func_map[endpoint] = wrapper | |
return wrapper | |
return decorator | |
### | |
# init app | |
### | |
def load_config(config_path): | |
config = {} | |
with open(config_path) as f: | |
config = json.load(f) | |
wechat_secret_file = config.get('wechat_secret_file') | |
private_key_file = config.get('private_key_file') | |
passphrase = config.get('private_key_pass') | |
if all((wechat_secret_file, private_key_file, passphrase)): | |
with ExitStack() as stack: | |
fd = stack.enter_context(open(wechat_secret_file, 'rb')) | |
fk = stack.enter_context(open(private_key_file)) | |
encrypt_bytes = base64.b64decode(fd.read()) | |
# NOTE after process fork. must be call `atfork` hook!! | |
# for make crypt work normal. | |
atfork() | |
prikey = RSA.importKey(fk.read(), passphrase) | |
data = prikey.decrypt(encrypt_bytes).decode() | |
config['wechat'] = json.loads(data) | |
return config | |
config = load_config('config.json') | |
app = RemoteControlApp() | |
### | |
# Register Routers: map urls and handlers | |
### | |
@app.route('/control/<command>', methods=['GET', 'POST'], resp_wrap='json') | |
def sserver_control_daemon_command(request, command): | |
from subprocess import getoutput | |
work_dir = config.get('shadowsocks_work_dir') | |
if request.method == 'GET': | |
if command != 'status': | |
raise IllegalRequest(f"Illegal command '{command}'") | |
log_file = os.path.join(work_dir, 'shadowsocks.log') | |
return json.dumps(getoutput(f'tail -n 50 {log_file}')) | |
if command not in ['start', 'restart', 'stop']: | |
raise IllegalRequest(f"Illegal command '{command}'") | |
json_body = get_request_json(request) or {} | |
if json_body.get("token", None) not in config['wechat']['openids']: | |
raise Unauthorized('Invalid token') | |
config_file = os.path.join(work_dir, 'config.json') | |
pid_file = os.path.join(work_dir, 'pid.file') | |
log_file = os.path.join(work_dir, 'shadowsocks.log') | |
return json.dumps(getoutput( | |
f"ssserver -c {config_file} -d {command} " | |
f"--pid-file {pid_file} --log-file {log_file}" | |
)) | |
@app.route('/wechat/msg', methods=['GET', 'POST'], | |
wechat=True, resp_wrap='xml') | |
def wechat_message_recevied(request): | |
if request.method == 'GET': | |
set_response_type_wrap('text') | |
return request.args.get('echostr', 'failed') | |
try: | |
msg = request.data.Content.cdata | |
except AttributeError as e: | |
logger.exception(e) | |
raise IllegalRequest('No content post.') | |
if msg.startswith('#ssserver'): | |
strings = [s.strip() for s in re.split(r'[ :_]', msg) | |
if not s.isspace()] | |
command = strings[-1] | |
bind = config['env']['bind'] | |
resp = requests.post(f'http://{bind}/control/{command}', | |
json={'token': request.args.get('openid')}) | |
if resp.status_code == 200: | |
content = resp.json().get('data') | |
else: | |
content = 'something wrong' | |
else: | |
content = 'Everything\'s Alright' | |
return {'msg_type': 'text', 'content': content} | |
if __name__ == '__main__': | |
app.run() |
TODO
- make logging meta data more flexable
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
app的
config.json
文件