Created
December 22, 2009 07:14
-
-
Save simonw/261572 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# Copy of http://dpaste.com/136418/ | |
# Based on http://github.com/simonw/django-signed | |
# All cookie related code is untested. | |
import base64 | |
import hmac | |
import struct | |
import time | |
from django.conf import settings | |
from django.core.exceptions import SuspiciousOperation | |
from django.utils.hashcompat import sha_constructor | |
from django.utils import simplejson | |
from django.utils.functional import wraps | |
class BadSignature(ValueError, SuspiciousOperation): | |
# Extends ValueError, which makes it more convenient to catch and has | |
# basically the correct semantics. | |
pass | |
class SignedCookies(object): | |
def __init__(self, signer, cookies, extra_key=''): | |
self.signer = signer | |
self.cookies = request | |
self.extra_key = extra_key | |
def __getitem__(self, name): | |
cookie = self.cookies[name] | |
try: | |
return self.signer.unsign(cookie, extra_key=self.extra_key+name) | |
except BadSignature: | |
raise KeyError() | |
class Signer(object): | |
def __init__(self, key=None, compress=False, serializer=simplejson, hash_constructor=sha_constructor, separator='.'): | |
self.key = key or settings.SECRET_KEY | |
self.compress = compress | |
self.serializer = serializer | |
self.hash_constructor = hash_constructor | |
self.separator = separator | |
def split(self, signed_value): | |
return signed_value.rsplit(self.separator, 1) | |
def join(self, value, signature): | |
return value + self.separator + signature | |
def encode(self, s): | |
return base64.urlsafe_b64encode(s).rstrip('=') | |
def decode(self, s): | |
pad = len(s) % 4 | |
if pad: | |
s += '=' * (4 - pad) | |
return base64.urlsafe_b64decode(s) | |
def signature(self, value, extra_key=''): | |
return self.encode(hmac.new(self.key + extra_key, value, self.hash_constructor).digest()) | |
def sign(self, value, extra_key=''): | |
if not isinstance(value, str): | |
raise TypeError('sign() needs bytestring, got %s: %s' % (type(value), repr(value))) | |
return self.join(value, self.signature(value, extra_key=extra_key)) | |
def unsign(self, signed_value, extra_key=''): | |
if not isinstance(signed_value, str): | |
raise TypeError('unsign() needs bytestring, got %s: %s' % (type(signed_value), repr(signed_value))) | |
try: | |
value, signature = self.split(signed_value) | |
except (ValueError, TypeError): | |
raise BadSignature() | |
if self.signature(value, extra_key=extra_key) == signature: | |
return value | |
raise BadSignature() | |
def dumps(self, obj, extra_key=''): | |
s = self.serializer.dumps(obj) | |
is_compressed = False # Flag for if it's been compressed or not | |
if self.compress: | |
import zlib # Avoid zlib dependency unless compress is being used | |
compressed = zlib.compress(s) | |
if len(compressed) < (len(s) - 1): | |
s = compressed | |
is_compressed = True | |
s = self.encode(s) | |
if is_compressed: | |
s = self.separator + s | |
return self.sign(s, extra_key=extra_key) | |
def loads(self, s, extra_key=''): | |
value = self.unsign(s, extra_key=extra_key) | |
if value and value[0] == self.separator: | |
value = value[1:] | |
is_compressed = True | |
else: | |
is_compressed = False | |
value = self.decode(value) | |
if is_compressed: | |
import zlib | |
value = zlib.decompress(value) | |
return self.serializer.loads(value) | |
def get_cookie(self, request, name, extra_key=''): | |
cookie = request.COOKIES[name] | |
return self.unsign(cookie, extra_key=extra_key + name) | |
def set_cookie(self, response, name, value, extra_key='', **kwargs): | |
cookie = self.sign(value, extra_key=extra_key + name) | |
response.set_cookie(name, cookie, **kwargs) | |
def sign_cookies(self, extra_key=''): | |
def decorator(view_func): | |
@wraps(view_func) | |
def decorated(request, *args, **kwargs): | |
request.COOKIES = SignedCookies(self, request.COOKIES, extra_key=extra_key) | |
response = view_func(request, *args, **kwargs) | |
for name in response.cookies.keys(): | |
response.cookies[name] = self.sign(response.cookies[name], extra_key=extra_key + name) | |
return response | |
return decorated | |
return decorator | |
class SignatureExpired(BadSignature): pass | |
class TimestampedSigner(Signer): | |
def __init__(self, ttl=60, **kwargs): | |
self.ttl = ttl | |
super(TimestampedSigner, self).__init__(**kwargs) | |
def sign(self, value, extra_key=''): | |
timestamp = self.encode(struct.pack('I', int(time.time()))) | |
value = self.join(value, timestamp) | |
return super(TimestampedSigner, self).sign(value, extra_key=extra_key) | |
def unsign(self, s, extra_key=''): | |
value = super(TimestampedSigner, self).unsign(s, extra_key=extra_key) | |
value, timestamp = self.split(value) | |
timestamp = struct.unpack('I', self.decode(timestamp))[0] | |
if self.ttl < int(time.time()) - timestamp: | |
raise SignatureExpired() | |
return value | |
def set_cookie(self, response, name, value, extra_key='', max_age=None, **kwargs): | |
max_age = min(max_age, self.ttl) | |
super(TimestampedSigner, self).set_cookie(response, name, value, extra_key=extra_key, max_age=max_age, **kwargs) | |
# shortcut functions | |
_default_signer = Signer() | |
def sign(value, extra_key=''): | |
return _default_signer.sign(value, extra_key=extra_key) | |
def unsign(signed_value, extra_key=''): | |
return _default_signer.unsign(signed_value, extra_key=extra_key) | |
def dumps(value, extra_key=''): | |
return _default_signer.dumps(value, extra_key=extra_key) | |
def loads(signed_value, extra_key=''): | |
return _default_signer.loads(signed_value, extra_key=extra_key) | |
# tests | |
if __name__ == "__main__": | |
import unittest, pickle, hashlib | |
class Test(unittest.TestCase): | |
sign_tests = ["a", "123", "\00\01\02", "foo\t.bar", "\n\n", "1.302.55", '!"$%&/()=?`"', "x"*1000] | |
dumps_tests = [1, 2.3, False, True, None, {'foo': 'bar'}, {'foo': 'x'*1000}, ['foo', 'bar', 'baz'], u"fööbär"] | |
def _test_signer(self, signer): | |
for method in (('sign', 'unsign'), ('dumps', 'loads')): | |
sign, unsign = method | |
objects = getattr(self, "%s_tests" % sign) | |
sign, unsign = getattr(signer, sign), getattr(signer, unsign) | |
for s in objects: | |
signed = sign(s) | |
self.assertEqual(s, unsign(signed)) | |
extra_signed = sign(s, extra_key='foo') | |
self.assertEqual(s, unsign(extra_signed, extra_key='foo')) | |
self.assertRaises(BadSignature, unsign, extra_signed) | |
self.assertRaises(BadSignature, unsign, signed, extra_key='foo') | |
for bad_signed in (repr(s), signed[:-3], signed[1:], signed.replace('.', '').replace('#', '')): | |
self.assertRaises(BadSignature, unsign, bad_signed) | |
def test(self): | |
options = dict(key='SECRET', compress=True, serializer=pickle, hash_constructor=hashlib.md5, separator='#') | |
kwargs_list = [{}] | |
for key, value in options.items(): | |
new_options = [] | |
for kwargs in kwargs_list: | |
args = kwargs.copy() | |
args[key] = value | |
new_options.append(args) | |
kwargs_list += new_options | |
for cls in (Signer, TimestampedSigner): | |
for kwargs in kwargs_list: | |
signer = cls(**kwargs) | |
self._test_signer(signer) | |
t_signer = TimestampedSigner(ttl=1) | |
signed = t_signer.sign("foo") | |
time.sleep(2) | |
self.assertRaises(SignatureExpired, t_signer.unsign, signed) | |
unittest.main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment