Instantly share code, notes, and snippets.

Embed
What would you like to do?
Random python scripts
Random collection of python (mostly 2.7) scripts
"""
Argument parser template
"""
import argparse
parser = argparse.ArgumentParser(description='Your application description')
# simple argument (mandatory)
parser.add_argument('a', help='some description')
# cast positional argument to int
parser.add_argument('b', type=int, help='some description')
# option (optional)
parser.add_argument('-r', help='some description')
# set silent=True if this option available
parser.add_argument('-s', '--silent', action='store_true', default=False, help='some description')
# parse arguments/options to an object args
args = parser.parse_args()
# call the arguments/options
print(args.a)
print(args.b)
print(args.r)
print(args.s)
print(args.silent)
# Parse BibTex entries from input file and render them in IEEEtran.cls format
# http://www.michaelshell.org/tex/ieeetran/
# Usage: python bibtexconverter.py [bibtex file]
#
# BibTex example (input):
# @article{lecun2015deep,
# title={Deep learning},
# author={LeCun, Yann and Bengio, Yoshua and Hinton, Geoffrey},
# journal={Nature},
# volume={521},
# number={7553},
# pages={436--444},
# year={2015},
# publisher={Nature Publishing Group}
# }
#
# IEEETran example (output):
# \bibitem{lecun2015deep} Y.~LeCun and Y.~Bengio and G.~Hinton, \emph{Deep learning}.\hskip 1em plus 0.5em minus 0.4em\relax Nature, Nature Publishing Group, 2015.
import re
import sys
from pprint import pprint
def ieee(refs):
print '\n'
for ref in refs:
print _ieee(ref) + '\n'
def _ieee(dic):
return """\\bibitem{{{}}} {}, \\emph{{{}}}.\\hskip 1em plus 0.5em minus 0.4em\\relax {}, {}.""".format(
dic['refcode'],
_ieee_author(dic['author']),
dic['title'],
_ieee_publisher(dic),
dic['year']
)
def _ieee_publisher(dic):
publisher = []
keys = ['journal', 'booktitle', 'publisher', 'organization']
for key in keys:
if key in dic:
publisher.append(dic[key])
return ', '.join(publisher)
def _ieee_author(text):
formatted = []
authors = text.split(' and ')
for a in authors:
names = a.split(', ')
if len(names) >= 2:
last, first = names[0], names[1]
formatted.append(first[0].upper() + '.~' + last)
else:
formatted.append(names[0])
return ' and '.join(formatted)
if __name__ == '__main__':
if len(sys.argv) < 2:
print 'Usage: python bibtexconverter.py [bibtex file]'
exit()
filename = sys.argv[1]
# collect BibTex entries from input file
# separated by blank line
entries = []
with open(filename) as f:
entry = []
for line in f:
line = line.strip()
if len(line) > 0:
# save line
entry.append(line)
elif len(entry) > 0:
# blank line
entries.append(entry)
entry = []
# last entry
if len(entry) > 0:
entries.append(entry)
# parse BibTex entries
references = []
for entry in entries:
dic = {}
dic['refcode'] = re.search(r'@(article|inproceedings|thesis){([\w\d]*),', entry[0], re.M | re.I).group(2)
for i in range(1, (len(entry) - 1)):
key, value = entry[i].split('=')
value = re.search(r'{([^{}]*)}', value, re.M | re.I).group(1)
dic[key] = value
references.append(dic)
# render entries in IEEEtran.cls format
# http://www.michaelshell.org/tex/ieeetran/
ieee(references)
'''
Convert ENAMEX Named-Entity annotated file to Stanford NLP format (token-based)
@Author yohanes.gultom@gmail
ENAMEX example (2 sentences):
Sementara itu Pengamat Pasar Modal <ENAMEX TYPE="PERSON">Dandossi Matram</ENAMEX> mengatakan, sulit bagi sebuah <ENAMEX TYPE="ORGANIZATION">kantor akuntan publik</ENAMEX> (<ENAMEX TYPE="ORGANIZATION">KAP</ENAMEX>) untuk dapat menyelesaikan audit perusahaan sebesar <ENAMEX TYPE="ORGANIZATION">Telkom</ENAMEX> dalam waktu 3 bulan. 1
<ENAMEX TYPE="ORGANIZATION">Telkom</ENAMEX> akan melakukan RUPS pada 30 Juli 2004 yang selain melaporkan kinerja 2003 juga akan meminta persetujuan untuk pemecahan nilai nominal saham atau stock split 1:2. 2
'''
import sys
import re
START_PATTERN = re.compile(r'^(.*?)<ENAMEX$', re.I)
END_SINGLE_PATTERN = re.compile(r'^TYPE="(.*?)">(.*?)</ENAMEX>(.*?)$', re.I)
TYPE_PATTERN = re.compile(r'^TYPE="(.*?)">(.*?)$', re.I)
END_MULTI_PATTERN = re.compile(r'^(.*?)</ENAMEX>(.*?)$', re.I)
EOS_PATTERN = re.compile(r'^([^<>]*)\.?\t(\d+)$', re.I)
NON_ENTITY_TYPE = 'O'
def check_and_process_eos(token):
match = re.match(EOS_PATTERN, token)
if match:
out.write(match.group(1) + '\t' + cur_type + '\n')
out.write('.' + '\t' + cur_type + '\n')
out.write('\n')
return True
return False
infile = sys.argv[1]
outfile = sys.argv[2]
cur_type = NON_ENTITY_TYPE
with open(infile, 'rb') as f, open(outfile, 'w') as out:
for line in f:
for token in line.strip().split(' '):
token = token.strip()
if not token:
continue
match = re.match(START_PATTERN, token)
if match:
if match.group(1):
out.write(match.group(1) + '\t' + NON_ENTITY_TYPE + '\n')
continue
match = re.match(END_SINGLE_PATTERN, token)
if match:
out.write(match.group(2) + '\t' + match.group(1) + '\n')
cur_type = NON_ENTITY_TYPE
if not check_and_process_eos(match.group(3)):
out.write(match.group(3) + '\t' + cur_type + '\n')
continue
match = re.match(TYPE_PATTERN, token)
if match:
cur_type = match.group(1)
out.write(match.group(2) + '\t' + cur_type + '\n')
continue
match = re.match(END_MULTI_PATTERN, token)
if match:
out.write(match.group(1) + '\t' + cur_type + '\n')
cur_type = NON_ENTITY_TYPE
if not check_and_process_eos(match.group(2)):
out.write(match.group(2) + '\t' + cur_type + '\n')
continue
if check_and_process_eos(token):
continue
out.write(token + '\t' + cur_type + '\n')
#!/usr/bin/env python3
"""
Simple example on compiling & deploying simple smartcontract, and calling its methods
Setup:
pip3 install web3==4.7.2 py-solc==3.2.0
python3 -m solc.install v0.4.24
export PATH="$PATH:$HOME/.py-solc/solc-v0.4.24/bin"
@author yohanes.gultom@gmail.com
"""
from web3 import Web3, HTTPProvider, middleware
from solc import compile_source
import random
def compile_contract(contract_source_file, contractName=None):
"""
Reads file, compiles, returns contract name and interface
"""
with open(contract_source_file, "r") as f:
contract_source_code = f.read()
compiled_sol = compile_source(contract_source_code) # Compiled source code
if not contractName:
contractName = list(compiled_sol.keys())[0]
contract_interface = compiled_sol[contractName]
else:
contract_interface = compiled_sol['<stdin>:' + contractName]
return contractName, contract_interface
def deploy_contract(acct, contract_interface, contract_args=None):
"""
deploys contract using self-signed tx, waits for receipt, returns address
"""
contract = w3.eth.contract(abi=contract_interface['abi'], bytecode=contract_interface['bin'])
constructed = contract.constructor() if not contract_args else contract.constructor(*contract_args)
tx = constructed.buildTransaction({
'from': acct.address,
'nonce': w3.eth.getTransactionCount(acct.address),
})
print ("Signing and sending raw tx ...")
signed = acct.signTransaction(tx)
tx_hash = w3.eth.sendRawTransaction(signed.rawTransaction)
print ("tx_hash = {} waiting for receipt ...".format(tx_hash.hex()))
tx_receipt = w3.eth.waitForTransactionReceipt(tx_hash, timeout=120)
contractAddress = tx_receipt["contractAddress"]
print ("Receipt accepted. gasUsed={gasUsed} contractAddress={contractAddress}".format(**tx_receipt))
return contractAddress
def exec_contract(acct, nonce, func):
"""
call contract transactional function func
"""
construct_txn = func.buildTransaction({'from': acct.address, 'nonce': nonce})
signed = acct.signTransaction(construct_txn)
tx_hash = w3.eth.sendRawTransaction(signed.rawTransaction)
return tx_hash.hex()
if __name__ == '__main__':
"""
// contract.sol:
pragma solidity ^0.4.21;
contract simplestorage {
uint public storedData;
event Updated(address by, uint _old, uint _new);
function set(uint x) {
uint old = storedData;
storedData = x;
emit Updated(msg.sender, old, x);
}
function get() constant returns (uint retVal) {
return storedData;
}
}
"""
# config
RPC_ADDRESS = 'http://localhost:8545'
CONTRACT_SOL = 'contract.sol'
CONTRACT_NAME = 'simplestorage'
PRIVATE_KEY="youraddressprivatekey"
# instantiate web3 object
w3 = Web3(HTTPProvider(RPC_ADDRESS, request_kwargs={'timeout': 120}))
# use additional middleware for PoA (eg. Rinkedby)
# w3.middleware_stack.inject(middleware.geth_poa_middleware, layer=0)
acct = w3.eth.account.privateKeyToAccount(PRIVATE_KEY)
# compile contract to get abi
print('Compiling contract..')
contract_name, contract_interface = compile_contract(CONTRACT_SOL, CONTRACT_NAME)
# deploy contract
print('Deploying contract..')
contract_address = deploy_contract(acct, contract_interface)
# create contract object
contract = w3.eth.contract(address=contract_address, abi=contract_interface['abi'])
# call non-transactional method
val = contract.functions.get().call()
print('Invoke get()={}'.format(val))
assert val == 0
# call transactional method
nonce = w3.eth.getTransactionCount(acct.address)
from_block_number = w3.eth.blockNumber
new_val = random.randint(1, 100)
contract_func = contract.functions.set(new_val)
print('Invoke set()={}'.format(new_val))
tx_hash = exec_contract(acct, nonce, contract_func)
print('tx_hash={} waiting for receipt..'.format(tx_hash))
tx_receipt = w3.eth.waitForTransactionReceipt(tx_hash, timeout=120)
print("Receipt accepted. gasUsed={gasUsed} blockNumber={blockNumber}". format(**tx_receipt))
# catch event
contract_filter = contract.events.Updated.createFilter(fromBlock=from_block_number)
entries = None
print('Waiting for event..')
while not entries: entries = contract_filter.get_all_entries()
# _new == new_val
args = entries[0].args
print(args)
assert args._old == 0
assert args._new == new_val
assert args.by == acct.address
# call non-transactional method
val = contract.functions.get().call()
print('Invoke get()={}'.format(val))
assert val == new_val
import os
import sys
# get directory (of current file)
dir_path = os.path.dirname(os.path.realpath(__file__))
# get base filename (without extension) (of current file)
basename = os.path.basename(os.path.realpath(__file__))
# get relative path from arg
mypath = sys.argv[1]
# iterate dirs and files
for f in os.listdir(mypath):
path = os.path.join(mypath, f)
# print if file
if os.path.isfile(path):
print os.path.join(dir_path, path)
# iterate and rename files
dir = mypath
for f in os.listdir(dir):
basename, ext = os.path.splitext(f)
if ext == '.jpg':
new_name = basename.split('_')[0].lower() + ext
os.rename(os.path.join(dir, f), os.path.join(dir, new_name))
"""
Finding fingerprint and calculating simple fuzzy similarity
@author yohanes.gultom@gmail.com
Prerequisites on Ubuntu:
* Python 2.7 and pip
* FFMPEG `sudo apt install ffmpeg`
* AcoustID fingerprinter `sudo apt install acoustid-fingerprinter`
* PyAcoustID `pip install pyacoustid`
* FuzzyWuzzy `pip install fuzzywuzzy[speedup]`
"""
import acoustid
import sys
import os
import chromaprint
import numpy as np
import matplotlib.pyplot as plt
from fuzzywuzzy import fuzz
DIR_DATABASE = 'music/full'
DIR_SAMPLES = 'music/partial'
def get_fingerprint(filepath):
"""
Get fingerprint (list of signed integer), version, duration
"""
duration, fp_encoded = acoustid.fingerprint_file(filepath)
fp, version = chromaprint.decode_fingerprint(fp_encoded)
return fp, version, duration
def build_fingerprint_database(dirpath, file_ext='.mp3'):
"""
Build database from directory of audio files
"""
database = {}
print('Processing {}..'.format(dirpath))
for f in os.listdir(dirpath):
path = os.path.join(dirpath, f)
name, ext = os.path.splitext(f)
if os.path.isfile(path) and ext == file_ext:
print('Getting fingerprint from database item: {}..'.format(f))
database[f], version, duration = get_fingerprint(path)
return database
def plot_fingerprints(db):
"""
Visualize fingerprints in database
"""
fig = plt.figure()
numrows = len(db)
plot_id = 1
for name, fp in db.iteritems():
# single column grid
a = fig.add_subplot(numrows, 1, plot_id)
imgplot = plt.imshow(get_fingerprint_bitmap(fp))
a.set_title(name)
plot_id += 1
plt.show()
def get_fingerprint_bitmap(fp):
"""
Plot list of uint32 as (32, len(list)) bitmap
"""
bitmap = np.transpose(np.array([[b == '1' for b in list('{:32b}'.format(i & 0xffffffff))] for i in fp]))
return bitmap
if __name__ == '__main__':
# load database and samples
database = build_fingerprint_database(DIR_DATABASE)
samples = build_fingerprint_database(DIR_SAMPLES)
print('\n')
# find best match of each samples in database
for sample, sample_fp in samples.iteritems():
print('Similarity score of "{}":'.format(sample))
best_match = None
for name, fp in database.iteritems():
similarity = fuzz.ratio(sample_fp, fp)
if not best_match or best_match['score'] < similarity:
best_match = {
'score': similarity,
'name': name
}
print('{} {}%'.format(name, similarity))
print('Best match: {name} ({score}%)\n'.format(**best_match))
# plot database
plot_fingerprints(database)
# Train a ProbabilisticProjectiveDependencyParser using CoNLL-U treebank from Universal Dependencies https://github.com/UniversalDependencies
# In this script we are using Indonesian treebank https://github.com/UniversalDependencies/UD_Indonesian
from pprint import pprint
from nltk.parse import (
DependencyGraph,
ProbabilisticProjectiveDependencyParser
)
# open treebank file
with open('id-ud-train.conllu', 'r') as f:
# parse dependency graphs from file
graphs = [DependencyGraph(entry, top_relation_label='root') for entry in f.read().decode('utf-8').split('\n\n') if entry]
# train ProbabilisticProjectiveDependencyParser
ppdp = ProbabilisticProjectiveDependencyParser()
print('Training Probabilistic Projective Dependency Parser...')
ppdp.train(graphs)
# try to parse a sentence
# and print tree ordered by probability (most probable first)
sent = ['Melingge', 'adalah', 'gampong', 'di', 'kecamatan', 'Pulo', 'Aceh', '.']
print('Parsing \'' + " ".join(sent) + '\'...')
print('Parse:')
for tree in ppdp.parse(sent):
pprint(tree)
'''
Simple script to test sending email using SMTP server
'''
import smtplib
from email.MIMEMultipart import MIMEMultipart
from email.MIMEText import MIMEText
# smtp config
SMTP_SERVER = 'smtp.gmail.com'
SMTP_PORT = 587
SMTP_USER = 'user@gmail.com'
SMTP_PASS = 'password'
# email content
to = "yohanes.gultom@gmail.com"
subject = "Just a test mail"
body = "This is just a test message from a new server. Kindly ignore it and proceed with what you are doing. Thank you!"
if __name__ == '__main__':
msg = MIMEMultipart()
msg['From'] = SMTP_USER
msg['To'] = to
msg['Subject'] = subject
msg.attach(MIMEText(body, 'plain'))
server = smtplib.SMTP(SMTP_SERVER, SMTP_PORT)
server.starttls()
server.login(SMTP_USER, SMTP_PASS)
server.sendmail(SMTP_USER, to, msg.as_string())
server.quit()
'''
Convert Named-Entity tagged file (Open NLP format) to Stanford NLP format (token-based)
@Author yohanes.gultom@gmail
Tagged file example (2 sentences):
"Internal DPD Sulsel mudah-mudahan dalam waktu dekat ada keputusan. Sudah ada keputusan kita serahkan ke DPP dan Rabu ini kita akan rapat harian soal itu," kata <PERSON>Sudding</PERSON> kepada Tribunnews.com, <TIME>Senin (30/1/2012)</TIME>.
Menurut <PERSON>Sudding</PERSON>, DPP Hanura pada prinsipnya memberikan kesempatan dan ruang sama bagi pengurus DPD dan DPC Hanura Sulsel untuk menyampaikan aspirasinya.
"Dan diberikan kesempatan melakukan verfikasi akar msalah yang terjadi di DPD Hanura Sulsel," kata dia.
'''
import sys
import re
SINGLE_PATTERN = re.compile(r'^([^<>]*)<(\w+)>([^<]*)</(\w+)>([^<>]*)$', re.I)
START_PATTERN = re.compile(r'^([^<>]*)<(\w+)>([^<]*)$', re.I)
END_PATTERN = re.compile(r'^([^<>]*)</(\w+)>([^<]*)$', re.I)
EOS_PATTERN = re.compile(r'^([^<>]*)\.$', re.I)
NON_ENTITY_TYPE = 'O'
infile = sys.argv[1]
outfile = sys.argv[2]
cur_type = NON_ENTITY_TYPE
with open(infile, 'rb') as f, open(outfile, 'w') as out:
for line in f:
for token in line.strip().split(' '):
token = token.strip()
if not token:
continue
match = re.match(SINGLE_PATTERN, token)
if match:
if match.group(1):
out.write(match.group(1) + '\t' + NON_ENTITY_TYPE + '\n')
out.write(match.group(3) + '\t' + match.group(2) + '\n')
if match.group(2) != match.group(4):
raise ValueError('Invalid tag pair: {} and {}'.format(match.group(2), match.group(4)))
if match.group(5):
out.write(match.group(5) + '\t' + NON_ENTITY_TYPE + '\n')
continue
match = re.match(START_PATTERN, token)
if match:
if match.group(1):
out.write(match.group(1) + '\t' + NON_ENTITY_TYPE + '\n')
cur_type = match.group(2)
out.write(match.group(3) + '\t' + cur_type + '\n')
continue
match = re.match(END_PATTERN, token)
if match:
out.write(match.group(1) + '\t' + cur_type + '\n')
if match.group(2) != cur_type:
raise ValueError('Invalid tag pair: {} and {}'.format(cur_type, match.group(2)))
cur_type = NON_ENTITY_TYPE
if match.group(3):
out.write(match.group(3) + '\t' + NON_ENTITY_TYPE + '\n')
continue
match = re.match(EOS_PATTERN, token)
if match:
out.write(match.group(1) + '\t' + cur_type + '\n')
out.write('.' + '\t' + cur_type + '\n')
out.write('\n')
continue
out.write(token + '\t' + cur_type + '\n')
# VIP currency notification script
# Usage: python vip2.py <gmail_username> <gmail_password> <to_email>
# Author: yohanes.gultom@gmail.com
from bs4 import BeautifulSoup
from bs4.element import Tag
from re import sub
from decimal import Decimal
import urllib2
import backoff
import smtplib
import sys
url = 'https://www.vip.co.id'
# rules to send email
rules = [
{'currency': 'SGD', 'op': '>=', 'type': 'buy', 'value': 9400}
]
smtp_config = {
'username': sys.argv[1],
'password': sys.argv[2],
'server': 'smtp.gmail.com',
'port': 465,
'from': 'VIP Bot',
'to': sys.argv[3]
}
message_tpl = '''From: {0}\r\nTo: {1}\r\nSubject: {2} to IDR today\r\nMIME-Version: 1.0\r\nContent-Type: text/html\r\n\r\n
<h1>{2} to IDR</h1>
<ul>
<li>Buy: IDR {3}</li>
<li>Sell: IDR {4}</li>
</ul>
<p>Source: {5}</p>
'''
@backoff.on_exception(backoff.expo, urllib2.URLError, max_tries=3)
def fetch_content(url):
return urllib2.urlopen(url)
def parse_currency(s):
return Decimal(sub(r'[^\d.]', '', str(s)))
# retrieve and parse rates
print('Fetching content from {}..'.format(url))
rates = {}
response = fetch_content(url)
html = response.read()
soup = BeautifulSoup(html, 'html.parser')
rate_table = soup.select('#rate-table tr')
for rate in rate_table[1:]:
values = []
for content in rate.contents:
if isinstance(content, Tag):
if 'title' in content:
values.append(content['title'])
else:
values.append(content.contents[0])
first = parse_currency(values[1])
second = parse_currency(values[2])
rates[str(values[0])] = {
'buy': min(first, second),
'sell': max(first, second)
}
# check rules
print('Checking rules..')
server_ssl = smtplib.SMTP_SSL(smtp_config['server'], smtp_config['port'])
server_ssl.ehlo()
server_ssl.login(smtp_config['username'], smtp_config['password'])
for rule in rules:
if rule['currency'] in rates:
rate = rates[rule['currency']]
rule_expr = '{} {} {}'.format(rate[rule['type']], rule['op'], rule['value'])
if eval(rule_expr, {'__builtins__': None}):
print('Found matching rule: {}'.format(rule))
message = message_tpl.format(
smtp_config['from'],
smtp_config['to'],
rule['currency'],
rate['buy'],
rate['sell'],
url
)
print('Sending email..')
server_ssl.sendmail(smtp_config['from'], smtp_config['to'], message)
server_ssl.close()
print('Done!')
# VIP currency notification script
# Require Python >= 3.5.2
# Usage: python vip3.py <gmail_username> <gmail_password> <to_email>
# Author: yohanes.gultom@gmail.com
from bs4 import BeautifulSoup
from bs4.element import Tag
from re import sub
from decimal import Decimal
from urllib.request import Request, urlopen
import urllib.error
import backoff
import smtplib
import sys
url = 'https://www.vip.co.id'
# rules to send email
rules = [
{'currency': 'SGD', 'op': '>=', 'type': 'buy', 'value': 9400}
]
smtp_config = {
'username': sys.argv[1],
'password': sys.argv[2],
'server': 'smtp.gmail.com',
'port': 465,
'from': 'VIP Bot',
'to': sys.argv[3]
}
message_tpl = '''From: {0}\r\nTo: {1}\r\nSubject: {2} to IDR today\r\nMIME-Version: 1.0\r\nContent-Type: text/html\r\n\r\n
<h1>{2} to IDR</h1>
<ul>
<li>Buy: IDR {3}</li>
<li>Sell: IDR {4}</li>
</ul>
<p>Source: {5}</p>
'''
@backoff.on_exception(backoff.expo, urllib.error.URLError, max_tries=3)
def fetch_content(url):
req = Request(url, headers={'User-Agent': 'Mozilla/5.0'})
return urlopen(req).read()
def parse_currency(s):
return Decimal(sub(r'[^\d.]', '', str(s)))
# retrieve and parse rates
print('Fetching content from {}..'.format(url))
rates = {}
html = fetch_content(url)
soup = BeautifulSoup(html, 'html.parser')
rate_table = soup.select('#rate-table tr')
for rate in rate_table[1:]:
values = []
for content in rate.contents:
if isinstance(content, Tag):
if 'title' in content:
values.append(content['title'])
else:
values.append(content.contents[0])
first = parse_currency(values[1])
second = parse_currency(values[2])
rates[str(values[0])] = {
'buy': min(first, second),
'sell': max(first, second)
}
# check rules
print('Checking rules..')
server_ssl = smtplib.SMTP_SSL(smtp_config['server'], smtp_config['port'])
server_ssl.ehlo()
server_ssl.login(smtp_config['username'], smtp_config['password'])
for rule in rules:
if rule['currency'] in rates:
rate = rates[rule['currency']]
rule_expr = '{} {} {}'.format(rate[rule['type']], rule['op'], rule['value'])
if eval(rule_expr, {'__builtins__': None}):
print('Found matching rule: {}'.format(rule))
message = message_tpl.format(
smtp_config['from'],
smtp_config['to'],
rule['currency'],
rate['buy'],
rate['sell'],
url
)
print('Sending email..')
server_ssl.sendmail(smtp_config['from'], smtp_config['to'], message)
server_ssl.close()
print('Done!')
#!/usr/bin/python
"""
Simple Voting HTTP server with MySQL database
Setup in Ubuntu:
$ sudo apt-get install python-pip python-dev libmysqlclient-dev
$ pip install MySQL-python
"""
import MySQLdb
import cgi
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
# Server and database combination
PORT_NUMBER = 8080
DB_HOST = 'localhost'
DB_USER = 'root'
DB_PASS = 'root'
DB_NAME = 'vote'
class VoteHandler(BaseHTTPRequestHandler):
"""
HTTP request handler for simple voting
"""
def do_GET(self):
self.send_response(200)
self.send_header('Content-type', 'text/html')
self.end_headers()
self.wfile.write(get_vote_form_html())
return
def do_POST(self):
form = cgi.FieldStorage(
fp=self.rfile,
headers=self.headers,
environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': self.headers['Content-Type']}
)
self.send_response(200)
self.send_header('Content-type', 'text/html')
self.end_headers()
try:
candidate = form.getvalue('candidate')
state = form.getvalue('state')
if inc_vote(candidate, state) == 1:
html = get_vote_form_html('Thanks for your vote!', message_color='green')
else:
html = get_vote_form_html('Vote error. Invalid candidate and/or state', message_color='red')
except Exception as e:
print(e)
html = get_vote_form_html('Server error. Please contact support', message_color='red')
self.wfile.write(html)
return
def get_vote_form_html(message_html=None, message_color='green'):
"""
Generate HTML with form for voting
"""
candidate_html = get_radio_group_html('candidate', get_distinct_vote('candidate'))
state_html = get_radio_group_html('state', get_distinct_vote('state'))
html_form = """
<html>
<head><title>Voting App</title></head>
<body>
"""
if message_html:
html_form += """
<p style="color:{}">{}</p>
""".format(message_color, message_html)
html_form += """
<form action="" method="POST">
<table>
<tr><td>Candidates:</td><td>{}</td></tr>
<tr><td>States:</td><td>{}</td></tr>
<tr><td><input type="submit" value="Submit"/></td></tr>
</table>
</form>
""".format(candidate_html, state_html)
html_form += """
</body>
</html>
"""
return html_form
def get_distinct_vote(col):
"""
Get distinct vote column
"""
if not db:
raise Exception('Connection not opened')
else:
cursor = db.cursor()
cursor.execute('SELECT DISTINCT {} FROM vote'.format(col))
results = cursor.fetchall()
return sorted([row[0] for row in results])
def inc_vote(candidate, state):
"""
Increase vote for certain candidate and state by 1
"""
if not db:
raise Exception('Connection not opened')
else:
try:
cursor = db.cursor()
# use parameterized query to prevent sql injection
affected_rows = cursor.execute("UPDATE vote SET total_votes = total_votes + 1 WHERE candidate = %s AND state = %s", [candidate, state])
db.commit()
return affected_rows
except Exception as e:
db.rollback()
raise Exception('Database update failed')
return 0
def get_radio_group_html(group_name, values):
html = []
for val in values:
if not html:
default = 'checked'
else:
default = ''
html.append('<input type="radio" name="{0}" value="{1}" {2}/> {1}'.format(group_name, val, default))
return ' '.join(html)
if __name__ == '__main__':
try:
# connect to database
db = MySQLdb.connect(DB_HOST, DB_USER, DB_PASS, DB_NAME)
print 'Connected to database {}@{}'.format(DB_NAME, DB_HOST)
# start HTTP server
server = HTTPServer(('', PORT_NUMBER), VoteHandler)
print 'Server is started and accessible on http://localhost:{}'.format(PORT_NUMBER)
print 'Press CTRL+C to shutdown..'
server.serve_forever()
except KeyboardInterrupt:
print 'Shutting down the web server'
# shutdown server
server.socket.close()
# close db connection
db.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment