Skip to content

Instantly share code, notes, and snippets.

@todofixthis
Last active August 22, 2020 01:07
Show Gist options
  • Save todofixthis/79a2f213989a3584211e49bfba582b40 to your computer and use it in GitHub Desktop.
Save todofixthis/79a2f213989a3584211e49bfba582b40 to your computer and use it in GitHub Desktop.
MongoDB transparent escaping/unescaping of illegal keys
# coding=utf-8
from __future__ import absolute_import, print_function, unicode_literals
from pprint import pprint
from bson import ObjectId, SON
from pymongo import MongoClient
from pymongo.collection import Collection
from key_escaper import DeterministicKeyEscaper
def main():
client = MongoClient()
collection = client['test_db']['test_collection']
# Example of a document with naughty keys.
# :see: https://docs.mongodb.com/manual/reference/limits/#Restrictions-on-Field-Names
document = {
'$foo': 'bar',
'$baz': 'luhrmann'
}
document_id = store(collection, document)
retrieved = retrieve(collection, document_id)
print('Asserting that retrieved document matches what we stored.')
retrieved.pop('_id')
assert retrieved.to_dict() == document
print('Match!')
def store(collection, document):
# type: (Collection, dict) -> ObjectId
"""Stores a document to the specified collection."""
print('Original document:')
pprint(document)
print('')
# Before inserting the document, we must first run it through our
# DeterministicKeyEscaper.
manipulator = DeterministicKeyEscaper()
# Note that the method to invoke here is `transform_incoming`.
# From MongoDB's perspective, this document is coming in.
escaped = manipulator.transform_incoming(document, collection.name)
print('Escaped document:')
pprint(escaped.to_dict())
print('')
# Now we can insert the document.
result = collection.insert_one(escaped)
return result.inserted_id
def retrieve(collection, document_id):
# type: (Collection, ObjectId) -> SON
"""Retrieves a document from the specified collection."""
raw = collection.find_one({'_id': document_id})
print('Stored document:')
pprint(raw)
print('')
# Run the document through our DeterministicKeyEscaper to restore the original
# keys.
manipulator = DeterministicKeyEscaper()
# Note that the method to invoke here is `transform_outgoing`.
# From MongoDB's perspective, this document is going out.
unescaped = manipulator.transform_outgoing(raw, collection.name)
print('Unescaped document:')
pprint(unescaped.to_dict())
print('')
return unescaped
if __name__ == '__main__':
main()
# coding=utf-8
from __future__ import absolute_import, division, print_function, \
unicode_literals
from abc import ABCMeta, abstractmethod as abstract_method
from codecs import decode
from hashlib import md5
from sys import getdefaultencoding as get_default_encoding
from typing import Any, Dict, Iterable, List, Mapping, Text, Tuple, Union
from bson import InvalidDocument, SON
from pymongo.son_manipulator import SONManipulator
from six import binary_type, iteritems, string_types, text_type, with_metaclass
__all__ = [
'BaseKeyEscaper',
'DeterministicKeyEscaper',
'NonDeterministicKeyEscaper',
]
class BaseKeyEscaper(with_metaclass(ABCMeta, SONManipulator)):
"""
Escapes illegal keys, ensuring that the original values can be
recovered later.
Note that the escaped keys will be virtually impossible to query
for, but that's infinitely preferable to MongoDB refusing to
store the document in the first place.
"""
magic_prefix = '__escaped__'
"""Used to identify escaped keys."""
def __init__(self):
super(BaseKeyEscaper, self).__init__()
##
# These attributes are only used when escaping keys.
##
self.current_path = None # type: List[Text]
"""
Keeps track of where we are in the document so that we can
populate the ``__escapedKeys`` dict correctly.
"""
self.escaped_keys = None # type: Dict[Text, Union[Dict, Text]]
"""
Keeps track of any keys that we've escaped so that a
KeyEscaper can later unescape them.
"""
@abstract_method
def escape_key(self, key):
# type: (Text) -> Text
"""
Escapes a single key.
"""
raise NotImplementedError(
'Not implemented in {cls}.'.format(cls=type(self).__name__),
)
def will_copy(self):
# type: () -> bool
"""
Does this manipulator create a copy of the SON?
"""
#
# ``transform_incoming`` does create a copy, but
# ``transform_outgoing`` does not. Well, we have to pick one!
#
# We'll go with ``False`` because it is not safe to assume that
# this manipulator will create a copy of the SON.
#
return False
def transform_incoming(self, son, collection):
# type: (Union[Mapping, SON], Text) -> SON
"""
Transforms a document before it is stored to the database.
"""
self.current_path = [] # type: List[Text]
self.escaped_keys = {} # type: Dict[Text, Union[Dict, Text]]
transformed = self._escape(son)
transformed[self.magic_prefix] = self.escaped_keys
return transformed
def transform_outgoing(self, son, collection):
# type: (Union[Mapping, SON], Text) -> SON
"""
Transforms a document after it is retrieved from the database.
Note that this method will directly modify the document!
"""
escaped_keys = son.pop(self.magic_prefix, None)
if not isinstance(escaped_keys, Mapping):
# Document is corrupted or was not escaped when it was stored.
return son
return self._unescape(son, escaped_keys) if escaped_keys else son
def _escape(self, son):
# type: (Union[Mapping, SON]) -> SON
"""
Recursively crawls the document, transforming keys as it goes.
"""
copy = SON()
for (key, value) in iteritems(son): # type: Tuple[Text, Any]
# Python 2 compatibility: Binary strings are allowed, so long as
# they can be converted to unicode strings.
if isinstance(key, binary_type):
encoding = get_default_encoding()
if encoding == 'ascii':
encoding = 'utf-8'
try:
key = decode(key, encoding)
except UnicodeDecodeError:
pass
if not isinstance(key, text_type):
raise InvalidDocument(
'documents must have only string keys, '
'key was {path}[{actual!r}]'.format(
actual = key,
path = '.'.join(self.current_path),
),
)
if (
key.startswith('$')
or key.startswith(self.magic_prefix)
or ('.' in key)
or ('\x00' in key)
):
key = self._escape_key(key)
self.current_path.append(key)
copy[key] = self._escape_value(value)
self.current_path.pop()
return copy
def _escape_key(self, key):
# type: (Text) -> Text
"""
Transforms an illegal key into something that MongoDB will
approve of.
"""
new_key = self.escape_key(key)
# Insert the escaped key into the correct location in
# ``self.escaped_keys`` so that it can be unescaped later.
crawler = self.escaped_keys
for x in self.current_path:
crawler.setdefault(x, [None, {}])
crawler = crawler[x][1]
crawler[new_key] = [key, {}]
return new_key
def _escape_value(self, value):
"""
Recursively escapes nested values inside mappings and iterables.
"""
# Escape nested mappings.
if isinstance(value, Mapping):
return self._escape(value)
# Scan iterables for nested mappings.
elif isinstance(value, Iterable) and not isinstance(value, string_types):
copy = []
for i, item in enumerate(value):
self.current_path.append(text_type(i))
copy.append(self._escape_value(item))
self.current_path.pop()
return copy
# Any other value is safe to return unescaped.
else:
return value
def _unescape(self, son, escaped_keys):
"""
Recursively unescapes a value.
"""
if isinstance(son, Mapping):
copy = SON()
for key, value in iteritems(son):
if key in escaped_keys:
# - ``r_key`` is the unescaped key value.
# - ``r_children`` contains information needed to unescape
# nested objects (if any).
r_key, r_children = escaped_keys[key]
if r_key is None:
# The key did not need to be escaped; it's just a
# placeholder so that we can find a nested object that was
# escaped.
r_key = key
if r_children:
# Descend into the nested value and continue escaping.
copy[r_key] = self._unescape(son[key], r_children)
else:
# The nested value did not need to be escaped.
copy[r_key] = son[key]
else:
copy[key] = value
elif isinstance(son, Iterable) and not isinstance(son, string_types):
copy = []
for i, value in enumerate(son):
key = text_type(i)
if key in escaped_keys:
# Lists don't have keys that need escaping; we're only
# interested in whether the value is a nested mapping.
_, r_children = escaped_keys[key]
if r_children:
# Descend into the nested value and continue escaping.
copy.append(self._unescape(value, r_children))
else:
# The nested value did not need to be escaped.
copy.append(value)
else:
copy.append(value)
else:
copy = son
return copy
class NonDeterministicKeyEscaper(BaseKeyEscaper):
"""
A KeyEscaper that uses an internal counter to generate escaped keys.
This method is a bit faster and tends to yield smaller escaped keys
than DeterministicKeyEscaper, but the result is more difficult to
query.
"""
def __init__(self):
super(NonDeterministicKeyEscaper, self).__init__()
self.escaped_key_count = None # type: int
"""Used to ensure each escaped key is unique."""
def transform_incoming(self, son, collection):
self.escaped_key_count = 0 # type: int
return \
super(NonDeterministicKeyEscaper, self) \
.transform_incoming(son, collection)
def escape_key(self, key):
escaped = self.magic_prefix + text_type(self.escaped_key_count)
self.escaped_key_count += 1
return escaped
class DeterministicKeyEscaper(BaseKeyEscaper):
"""
A KeyEscaper that uses hashes to escape unsafe keys.
This method is a little slower and tends to yield larger escaped keys
than NonDeterministicKeyEscaper, but you can execute queries
against the escaped keys more easily.
"""
def escape_key(self, key):
# Note: In Python 3, hashlib requires a byte string.
return self.magic_prefix + md5(key.encode('utf-8')).hexdigest()
# coding=utf-8
from __future__ import absolute_import, unicode_literals
from abc import ABCMeta, abstractproperty as abstract_property
from inspect import isabstract as is_abstract
from unittest import TestCase
from bson import SON
from pymongo import MongoClient
from six import with_metaclass
from key_escaper import (
BaseKeyEscaper,
DeterministicKeyEscaper,
NonDeterministicKeyEscaper,
)
__all__ = [
'DeterministicKeyEscaper',
'NonDeterministicKeyEscaper',
]
class TestCaseMeta(ABCMeta):
# noinspection PyShadowingBuiltins
def __init__(cls, name, bases=None, dict=None):
super(TestCaseMeta, cls).__init__(name, bases, dict)
# :see: https://nose.readthedocs.io/en/latest/finding_tests.html
cls.__test__ = not is_abstract(cls)
class BaseKeyEscaperTestCase(with_metaclass(TestCaseMeta, TestCase)):
"""
Defines base functionality and templates for KeyEscaper test cases.
"""
@abstract_property
def get_manipulator(self):
# type: () -> BaseKeyEscaper
raise NotImplementedError(
'Not implemented in {cls}.'.format(cls=type(self).__name__),
)
def setUp(self):
super(BaseKeyEscaperTestCase, self).setUp()
client = MongoClient()
self.collection = client['test_db']['test_collection']
# Purge any existing documents from the collection.
self.collection.drop()
self.manipulator = self.get_manipulator() # type: BaseKeyEscaper
def assertKeysEscaped(self, document):
"""
Asserts that the KeyEscaper correctly escapes/unescapes keys in the
document.
"""
escaped =\
self.manipulator.transform_incoming(document, self.collection.name)
document_id = self.collection.insert_one(escaped).inserted_id
# Load the stored document from the database, omitting the '_id'
# field, since we can't predict that value for the comparison.
stored = self.collection.find_one(
filter = {'_id': document_id},
projection = {'_id': False},
)
unescaped =\
self.manipulator.transform_outgoing(stored, self.collection.name)
if isinstance(unescaped, SON):
unescaped = unescaped.to_dict()
self.assertEqual(unescaped, document)
def test_illegal_key_names_dollar(self):
"""
The stored document includes keys that starts with '$'.
This is a MongoDB no-no, according to
https://docs.mongodb.com/manual/reference/limits/#Restrictions-on-Field-Names
"""
self.assertKeysEscaped({
'$topLevel': {
'severity': 'innocent enough',
'explanation': 'this is a common enough use case',
},
'nested': {
'$tricky': 'can we handle nested values?',
'$deep': {
'$reference': 'we need to go deeper',
},
},
'$iñtërnâtiônàlizætiøn': 'non-ascii characters supported, too',
'perfectly$legal':
"keys may contain '$' so long as it's not the first character",
'string': '$values may start with "$", no problem',
'list': ['$list', '$items', '$are', '$also', '$exempt'],
})
def test_illegal_key_names_dot(self):
"""
The stored document includes keys that include '.' characters.
This is a MongoDB no-no, according to
https://docs.mongodb.com/manual/reference/limits/#Restrictions-on-Field-Names
"""
self.assertKeysEscaped({
'top.level': {
'severity': 'innocent enough',
'explanation': 'this is a common enough use case',
},
'nested': {
'.tricky': 'can we handle nested values?',
'.deep': {
'reference.': 'we need to go deeper',
},
},
'.iñtërnâtiônàlizætiøn': 'non-ascii characters supported, too',
'string': 'values.may.contain "." no.problem',
'list': ['.list', 'items.', '.are.', '.also', 'exempt.'],
})
def test_illegal_key_names_null(self):
"""
The stored document includes keys that include null bytes.
This is a MongoDB no-no, according to
https://docs.mongodb.com/manual/reference/limits/#Restrictions-on-Field-Names
"""
self.assertKeysEscaped({
# These all should evaluate to the same code point, but
# just to make absolutely sure....
'\U00000000top\x00level\u0000': {
'severity': 'suspect',
'explanation':
'not sure why you would ever really need to do this',
},
'nested': {
'\x00tricky\u0000': 'can we handle nested values?',
'\x00deep': {
'\U00000000reference': 'we need to go deeper',
},
},
'\x00iñtërnâtiônàlizætiøn': 'non-ascii characters supported, too',
'string': 'values\x00may\x00contain\x00nulls\x00no\x00problem',
'list': [
'\x00list',
'items\x00',
'\x00are\x00',
'also\x00',
'\x00exempt',
],
})
def test_illegal_key_names_magic(self):
"""
The stored document includes key names that coincide with
escaped keys.
"""
self.assertKeysEscaped({
# This is the attribute where the self.manipulator stores the
# escaped keys.
self.manipulator.magic_prefix: {
'severity': 'strange',
'explanation': 'i guess it could happen',
},
# This is an example of an escaped key.
self.manipulator.magic_prefix + '1':
'somebody has to think of these things',
# This is nonsense, but props for creative thinking.
self.manipulator.magic_prefix + 'wonka':
'there is no life i know to compare with pure imagination',
'nested': {
self.manipulator.magic_prefix: 'can we handle nested values?',
self.manipulator.magic_prefix + '0': 'same story, different day',
self.manipulator.magic_prefix + 'deep': {
self.manipulator.magic_prefix: 'we need to go deeper',
},
},
self.manipulator.magic_prefix + 'iñtërnâtiônàlizætiøn':
'non-ascii characters supported, too',
# Values may use the magic prefix without consequence.
'string': self.manipulator.magic_prefix,
'list': [
self.manipulator.magic_prefix,
self.manipulator.magic_prefix + '0',
self.manipulator.magic_prefix + 'foobar',
],
})
def test_illegal_key_names_combo(self):
"""The stored document has all kinds of illegal keys."""
self.assertKeysEscaped({
self.manipulator.magic_prefix + '$very.very.\x00illegal\x00': {
'severity': 'major',
'explanation': 'did you even read the instructions?',
},
'nested': {
'$dollars': 'starts with $',
'has.dot': 'contains a .',
'has\x00null': 'contains a null',
'$iñtërnâtiônàlizætiøn': 'contains non-ascii',
'$level.down': {
'..': 'low-budget ascii bear',
},
self.manipulator.magic_prefix: 'overslept',
},
})
def test_safe_byte_strings(self):
"""
Byte strings are allowed, so long as they can be converted into
unicode strings.
"""
document_id = self._store_document({
b'$ascii_escaped': 'escaped, safe; contains ascii only',
b'ascii_unescaped': 'unescaped, safe; contains ascii only',
'$iñtërnâtiônàlizætiøn_escaped'.encode(get_default_encoding()):
def test_safe_byte_strings(self):
"""
Byte strings are allowed, so long as they can be converted into
unicode strings.
"""
document_id = self._store_document({
b'$ascii_escaped': 'escaped, safe; contains ascii only',
b'ascii_unescaped': 'unescaped, safe; contains ascii only',
'$iñtërnâtiônàlizætiøn_escaped'.encode('utf-8'):
'escaped, safe; non-ascii, but can be decoded w/ default encoding',
'iñtërnâtiônàlizætiøn_unescaped'.encode('utf-8'):
'unescaped, safe; non-ascii, but can be decoded w/ default encoding',
})
retrieved = self._retrieve_document({'_id': document_id})
self.assertDictEqual(
retrieved,
{
# Note that keys are automatically converted to unicode strings
# before storage.
'$ascii_escaped': 'escaped, safe; contains ascii only',
'ascii_unescaped': 'unescaped, safe; contains ascii only',
'$iñtërnâtiônàlizætiøn_escaped':
'escaped, safe; non-ascii, but can be decoded w/ default encoding',
'iñtërnâtiônàlizætiøn_unescaped':
'unescaped, safe; non-ascii, but can be decoded w/ default encoding',
},
)
def test_unsafe_byte_strings(self):
"""
Any byte string that can't be converted into a unicode string is
invalid.
"""
# Ensure that we pick the wrong encoding, regardless of system
# configuration.
wrong_encoding = \
'latin-1' if get_default_encoding() == 'utf-16' else 'utf-16'
with self.assertRaises(InvalidDocument):
self.manipulator.transform_incoming(
{'$iñtërnâtiônàlizætiøn'.encode(wrong_encoding): 'wrong encoding!'},
self.collection.name,
)
# An exception will be raised even if the key doesn't need to be
# escaped.
with self.assertRaises(InvalidDocument):
self.manipulator.transform_incoming(
{'iñtërnâtiônàlizætiøn'.encode(wrong_encoding): 'wrong encoding!'},
self.collection.name,
)
class DeterministicKeyEscaperTestCase(BaseKeyEscaperTestCase):
def get_manipulator(self):
return DeterministicKeyEscaper()
def test_query_by_escaped_key(self):
"""
It is possible (with a little work) to find a document using
an escaped key.
"""
document = {
'data': {
'responseValues': {
'$firstName': 'Marcus',
'$lastName': 'Brody',
},
},
}
self._store_document(document)
# If we escape the entire search key, we won't find anything,
# because the entire thing will be escaped.
self.assertIsNone(
self.collection.find_one({
self.manipulator.escape_key('data.responseValues.$lastName'): 'Brody',
})
)
# Instead, we need to escape just the final part of the filter key.
self.assertDictEqual(
self._retrieve_document({
'data.responseValues.' + self.manipulator.escape_key('$lastName'):
'Brody',
}),
document,
)
class NonDeterministicKeyEscaperTestCase(BaseKeyEscaperTestCase):
def get_manipulator(self):
return NonDeterministicKeyEscaper()
def test_query_by_escaped_key(self):
"""
As its name suggests, NonDeterministicKeyEscaper uses (effectively)
unpredictable replacement names for escaped keys.
"""
document = {
'data': {
'responseValues': {
'$firstName': 'Marcus',
'$lastName': 'Brody',
},
},
}
self._store_document(document)
#
# It is theoretically possible to guess the correct escaped key,
# but outside of contrived examples in unit tests, it's very
# unlikely that this approach will ever be practical.
#
# If you want to be able to query against escaped keys, you're
# better off using DeterministicKeyEscaper.
#
self.assertIsNone(
self.collection.find_one({
'data.responseValues.' + self.manipulator.escape_key('$lastName'):
'Brody',
})
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment