Created
November 17, 2011 12:20
-
-
Save idan/1373019 to your computer and use it in GitHub Desktop.
auth.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
requests.auth | |
~~~~~~~~~~~~~ | |
This module contains the authentication handlers for Requests. | |
""" | |
import time | |
import hashlib | |
from random import getrandbits | |
import urllib | |
import hmac | |
import binascii | |
from base64 import b64encode | |
from urlparse import urlparse | |
from .utils import randombytes, parse_dict_header | |
class AuthBase(object): | |
"""Base class that all auth implementations derive from""" | |
def __call__(self, r): | |
raise NotImplementedError('Auth hooks must be callable.') | |
class HTTPBasicAuth(AuthBase): | |
"""Attaches HTTP Basic Authentication to the given Request object.""" | |
def __init__(self, username, password): | |
self.username = str(username) | |
self.password = str(password) | |
def __call__(self, r): | |
auth_s = b64encode('%s:%s' % (self.username, self.password)) | |
r.headers['Authorization'] = ('Basic %s' % auth_s) | |
return r | |
class HTTPDigestAuth(AuthBase): | |
"""Attaches HTTP Digest Authentication to the given Request object.""" | |
def __init__(self, username, password): | |
self.username = username | |
self.password = password | |
def handle_401(self, r): | |
"""Takes the given response and tries digest-auth, if needed.""" | |
s_auth = r.headers.get('www-authenticate', '') | |
if 'digest' in s_auth.lower(): | |
last_nonce = '' | |
nonce_count = 0 | |
chal = parse_dict_header(s_auth.replace('Digest ', '')) | |
realm = chal['realm'] | |
nonce = chal['nonce'] | |
qop = chal.get('qop') | |
algorithm = chal.get('algorithm', 'MD5') | |
opaque = chal.get('opaque', None) | |
algorithm = algorithm.upper() | |
# lambdas assume digest modules are imported at the top level | |
if algorithm == 'MD5': | |
H = lambda x: hashlib.md5(x).hexdigest() | |
elif algorithm == 'SHA': | |
H = lambda x: hashlib.sha1(x).hexdigest() | |
# XXX MD5-sess | |
KD = lambda s, d: H("%s:%s" % (s, d)) | |
if H is None: | |
return None | |
# XXX not implemented yet | |
entdig = None | |
p_parsed = urlparse(r.request.url) | |
path = p_parsed.path + p_parsed.query | |
A1 = "%s:%s:%s" % (self.username, realm, self.password) | |
A2 = "%s:%s" % (r.request.method, path) | |
if qop == 'auth': | |
if nonce == last_nonce: | |
nonce_count += 1 | |
else: | |
nonce_count = 1 | |
last_nonce = nonce | |
ncvalue = '%08x' % nonce_count | |
cnonce = (hashlib.sha1("%s:%s:%s:%s" % ( | |
nonce_count, nonce, time.ctime(), randombytes(8))) | |
.hexdigest()[:16] | |
) | |
noncebit = "%s:%s:%s:%s:%s" % (nonce, ncvalue, cnonce, qop, H(A2)) | |
respdig = KD(H(A1), noncebit) | |
elif qop is None: | |
respdig = KD(H(A1), "%s:%s" % (nonce, H(A2))) | |
else: | |
# XXX handle auth-int. | |
return None | |
# XXX should the partial digests be encoded too? | |
base = 'username="%s", realm="%s", nonce="%s", uri="%s", ' \ | |
'response="%s"' % (self.username, realm, nonce, path, respdig) | |
if opaque: | |
base += ', opaque="%s"' % opaque | |
if entdig: | |
base += ', digest="%s"' % entdig | |
base += ', algorithm="%s"' % algorithm | |
if qop: | |
base += ', qop=auth, nc=%s, cnonce="%s"' % (ncvalue, cnonce) | |
r.request.headers['Authorization'] = 'Digest %s' % (base) | |
r.request.send(anyway=True) | |
_r = r.request.response | |
_r.history.append(r) | |
return _r | |
return r | |
def __call__(self, r): | |
r.hooks['response'] = self.handle_401 | |
return r | |
def escape(s): | |
"""Escape a URL including any /. | |
Adheres to conventions laid out in section 3.6 of the OAuth 1.0 spec: | |
http://tools.ietf.org/html/rfc5849#section-3.6 | |
""" | |
return urllib.quote(s.encode('utf-8'), safe='~') | |
def utf8_str(s): | |
"""Convert unicode to utf-8.""" | |
if isinstance(s, unicode): | |
return s.encode("utf-8") | |
else: | |
return str(s) | |
def generate_timestamp(): | |
"""Get seconds since epoch (UTC).""" | |
return str(int(time.time())) | |
def generate_nonce(): | |
"""Generate pseudorandom 64-bit value.""" | |
return str(getrandbits(64)) + generate_timestamp() | |
class OAuth(object): | |
def __init__(self, consumer_key, consumer_secret, access_token, access_token_secret): | |
self.consumer_key = consumer_key | |
self.consumer_secret = consumer_secret | |
self.access_token = access_token | |
self.access_token_secret = access_token_secret | |
def __call__(self, r): | |
"""Generate a signed request.""" | |
params = {} | |
params['oauth_consumer_key'] = self.consumer_key | |
params['oauth_nonce'] = generate_nonce() | |
params['oauth_signature_method'] = 'HMAC-SHA1' | |
params['oauth_token'] = self.access_token | |
params['oauth_timestamp'] = generate_timestamp() | |
params['oauth_version'] = '1.0' | |
# now generate the signature | |
params['oauth_signature'] = self.sign(r, params) | |
r.headers['Authorization'] = 'OAuth realm="", ' + ', '.join( | |
['{0}="{1}"'.format(k, params[k]) for k in sorted(params.keys())]) | |
return r | |
def sign(self, r, params): | |
chunks = { | |
'method': self.get_normalized_http_method(r), | |
'url': self.get_normalized_http_url(r), | |
'rparams': self.get_normalized_parameters(r, params) | |
} | |
raw = "{method}&{url}&{rparams}".format(**chunks) | |
key = '{0}&{1}'.format(escape(self.consumer_secret), escape(self.access_token_secret)) | |
signature = hmac.new(key, raw, hashlib.sha1) | |
return escape(binascii.b2a_base64(signature.digest())[:-1]) | |
def get_normalized_http_method(self, r): | |
"""Uppercases the http method.""" | |
return r.method.upper() | |
def get_normalized_http_url(self, r): | |
"""Parses the URL and rebuilds it to be scheme://host/path.""" | |
parts = urlparse(r.url) | |
scheme, netloc, path = parts[:3] | |
# Exclude default port numbers. | |
if scheme == 'http' and netloc[-3:] == ':80': | |
netloc = netloc[:-3] | |
elif scheme == 'https' and netloc[-4:] == ':443': | |
netloc = netloc[:-4] | |
return escape('{0}://{1}{2}'.format(scheme, netloc, path)) | |
def get_normalized_parameters(self, r, params): | |
"""Return a string that contains the parameters that must be signed.""" | |
try: | |
# Exclude the signature if it exists. | |
del params['oauth_signature'] | |
except: | |
pass | |
# Escape key values before sorting. | |
key_values = [] | |
for k, v in params.items(): | |
key_values.append((escape(utf8_str(k)), escape(utf8_str(v)))) | |
# Sort lexicographically, first after key, then after value. | |
key_values.sort() | |
# Combine key value pairs into a string. | |
return escape('&'.join(['{0}={1}'.format(escape(k), escape(v)) for k, v in key_values])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment