Created
September 24, 2015 06:04
-
-
Save chawyehsu/7dc6986fc44ca6ba715c to your computer and use it in GitHub Desktop.
GitHub pages Webhook
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
#-*- coding:utf-8 -*- | |
# start a python service and watch the nginx request dog | |
from http.server import HTTPServer, CGIHTTPRequestHandler | |
from threading import Thread, RLock | |
from hashlib import sha1 | |
import ipaddress | |
import subprocess | |
import requests | |
#import pprint | |
import hmac | |
import logging | |
import json | |
import sys | |
import os.path | |
HOOK_SECRET_KEY = b'' | |
_PWD = os.path.abspath(os.path.dirname(__file__)) | |
# Exec the update shell | |
def execute_cmd(args, cwd=None, timeout=30): | |
if isinstance(args, str): args = [args] | |
try: | |
with subprocess.Popen(args, stdout=subprocess.PIPE, cwd=cwd) as proc: | |
try: | |
output, unused_err = proc.communicate(timeout=timeout) | |
except: | |
proc.kill() | |
raise | |
retcode = proc.poll() | |
if retcode: | |
raise subprocess.CalledProcessError(retcode, proc.args, | |
output=output) | |
return output.decode('utf-8', 'ignore') if output else '' | |
except Exception as ex: | |
logging.error('EXECUTE_CMD_ERROR: %s', ' '.join(str(x) for x in args)) | |
raise ex | |
class HttpHandler(CGIHTTPRequestHandler): | |
_lock = RLock() | |
_counter = 0 | |
_building = False | |
def _validate_signature(self, data): | |
sha_name, signature = self.headers['X-Hub-Signature'].split('=') | |
if sha_name != 'sha1': | |
return False | |
# HMAC requires its key to be bytes, but data is strings. | |
mac = hmac.new(HOOK_SECRET_KEY, msg=data, digestmod=sha1) | |
return hmac.compare_digest(mac.hexdigest(), signature) | |
def _handle_payload(self, payload): | |
if 'ref' in payload and payload['ref'] == "refs/head/gh-pages": | |
return True | |
else: | |
return False | |
# Simple handler that pretty-prints the payload. | |
# print('JSON payload') | |
# pprint.pprint(json_payload) | |
def _build_site(self): | |
with HttpHandler._lock: | |
if HttpHandler._counter == 0 or HttpHandler._building: | |
return | |
HttpHandler._counter = 0 | |
HttpHandler._building = True | |
logging.info("====== Site update shell start ======") | |
try: | |
resp = execute_cmd(os.path.join(_PWD, 'build.sh'), | |
cwd=_PWD, | |
timeout=600) | |
logging.info(resp) | |
logging.info("====== Site update shell end ======") | |
finally: | |
HttpHandler._building = False | |
self._build_site() | |
def do_POST(self): | |
# Get the real remote_addr | |
request_ip = self.headers['X-Real-IP'] | |
# Print the IP address of the requester | |
logging.info('POST request from IP: %s', request_ip) | |
# Get the hook address blocks from Github API | |
hook_blocks = requests.get( | |
'https://api.github.com/meta').json()['hooks'] | |
# Check if the POST request is from Github | |
for block in hook_blocks: | |
logging.info('request_ip and block_ip: %s, %s', request_ip, block) | |
if ipaddress.ip_address(request_ip) in ipaddress.ip_network(block): | |
break # the remote_addr is within the network range of github | |
else: | |
self.send_error(403) | |
return | |
data_length = int(self.headers['Content-Length']) | |
post_data = self.rfile.read(data_length) | |
# Validate Signature | |
if not self._validate_signature(post_data): | |
self.send_response(401) | |
return | |
payload = json.loads(post_data.decode('utf-8')) | |
if not self._handle_payload(payload): | |
self.send_response(501) | |
self.end_headers() | |
self.wfile.write(b'Event mismatch, so do nothing.') | |
self.wfile.flush() | |
else: | |
self.send_response(200) | |
self.end_headers() | |
self.wfile.write(b'Ok, site updated.') | |
self.wfile.flush() | |
with HttpHandler._lock: | |
HttpHandler._counter += 1 | |
Thread(target=self._build_site).start() | |
if __name__ == '__main__': | |
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', | |
level=logging.INFO) | |
port = int(sys.argv[1]) if len(sys.argv) > 1 else 4001 | |
logging.info('starting the server at 127.0.0.1:%s', port) | |
httpd = HTTPServer(('127.0.0.1', port), HttpHandler) | |
httpd.serve_forever() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment