Skip to content

Instantly share code, notes, and snippets.

@brianmhunt
Last active December 3, 2015 19:01
Show Gist options
  • Save brianmhunt/9974822 to your computer and use it in GitHub Desktop.
Save brianmhunt/9974822 to your computer and use it in GitHub Desktop.
tags for ndb on Google App Engine / Python
"""
From blog post: http://brianmhunt.github.io/articles/ndb-tags/
License: MIT <http://brianmhunt.mit-license.org/>
"""
from google.appengine.ext import ndb
MAX_TAGS_FOR_TAGGABLE = 1000
POPULAR_PAGE_SIZE = 30
class Tag(ndb.Model):
"""Keep track of data related to a tag added with the TagMixin class.
"""
tag = ndb.StringProperty(required=True, indexed=True,
validator=lambda p, v: v.lower())
count = ndb.IntegerProperty(default=0, indexed=True)
linked = ndb.KeyProperty(repeated=True)
created = ndb.DateTimeProperty(auto_now_add=True)
modified = ndb.DateTimeProperty(auto_now=True)
@staticmethod
def tag_to_keyname(tag):
return "tag__{}".format(tag.lower())
@staticmethod
def tag_to_key(tag):
return ndb.Key("Tag", Tag.tag_to_keyname(tag))
@classmethod
def get_linked_by_tag(self, tag, limit=MAX_TAGS_FOR_TAGGABLE):
"""Return the set of keys for this name"""
try:
return Tag.tag_to_key(tag).get().linked
except AttributeError:
return []
@classmethod
def get_or_create_async(cls, tag):
return Tag.get_or_insert_async(Tag.tag_to_keyname(tag), tag=tag)
@classmethod
def get_popular_query(cls, page_size=POPULAR_PAGE_SIZE):
return Tag.query().order(-Tag.count)
def unlink_async(self, key):
self.linked.remove(key)
self.count -= 1
return self.put_async()
def link_async(self, key):
self.linked.append(key)
self.count += 1
return self.put_async()
class TagMixin(object):
"""A mixin that adds taggability to a class.
Adds a 'tags' property.
"""
tags = ndb.StringProperty(repeated=True, indexed=True)
def _post_get_hook(self, future):
"""Set the _tm_tags so we can compare for changes in pre_put
"""
self._tm_tags = future.get_result().tags
def _post_put_hook(self, future):
"""Modify the associated Tag instances to reflect any updates
"""
old_tagset = set(getattr(self, '_tm_tags', []))
new_tagset = set(self.tags)
# These are tags that have changed
added_tags = new_tagset - old_tagset
deleted_tags = old_tagset - new_tagset
# Get the key for this post
self_key = future.get_result()
@ndb.transactional_tasklet
def update_changed(tag):
tag_instance = yield Tag.get_or_create_async(tag)
if tag in added_tags:
yield tag_instance.link_async(self_key)
else:
yield tag_instance.unlink_async(self_key)
ndb.Future.wait_all([
update_changed(tag) for tag in added_tags | deleted_tags
])
# Update for any successive puts on this model.
self._tm_tags = self.tags
"""
License: MIT <http://brianmhunt.mit-license.org/>
"""
import logging
import unittest
from google.appengine.ext import ndb, testbed
from models import Tag, TagMixin
class TestTagModel(TagMixin, ndb.Model):
"""This is a test class for trying out tags
"""
name = ndb.StringProperty()
class TestTags(unittest.TestCase):
def setUp(self):
tb = testbed.Testbed()
tb.activate()
tb.setup_env()
tb.init_datastore_v3_stub()
tb.init_memcache_stub()
self.tb = tb
def tearDown(self):
self.tb.deactivate()
def test_init(self):
ttm = TestTagModel(name="X")
ttm.put()
assert ttm is not None
self.assertEqual(ttm.tags, [])
def test_add_by_arg(self):
ttm = TestTagModel(name="X", tags=['a'])
ttm.put()
self.assertEqual(ttm.tags, ['a'])
def test_add_assign(self):
ttm = TestTagModel(name="X")
ttm.put()
ttm.tags = ['b']
self.assertEqual(ttm.tags, ['b'])
def test_del(self):
ttm = TestTagModel(name="X", tags=['b', 'c'])
ttm.put()
ttm.tags = ['c'] # delete 'b'
ttm.put()
b_tags = Tag.get_or_create_async('b').get_result()
c_tags = Tag.get_or_create_async('c').get_result()
self.assertEqual(len(b_tags.linked), 0)
self.assertEqual(len(c_tags.linked), 1)
self.assertEqual(len(b_tags.linked), b_tags.count)
self.assertEqual(len(c_tags.linked), c_tags.count)
def test_tag_count_none(self):
self.assertEqual(len(Tag.get_linked_by_tag('x')), 0)
def test_tag_count(self):
ttms = [
TestTagModel(name='X1', tags=['d']),
TestTagModel(name='X2', tags=['d']),
TestTagModel(name='X3', tags=['d']),
]
for t in ttms:
t.put()
self.assertEqual(len(Tag.get_linked_by_tag('d')), 3)
def test_tag_count_del(self):
ttms = [
TestTagModel(name='X1', tags=['d']),
TestTagModel(name='X2', tags=['d']),
TestTagModel(name='X3', tags=['d']),
]
for t in ttms:
t.put()
ttms[0].tags = []
ttms[0].put()
self.assertEqual(len(Tag.get_linked_by_tag('d')), 2)
def test_get_linked_by_tag(self):
TestTagModel(name='X1', tags=['a']).put()
TestTagModel(name='X2', tags=['b']).put()
TestTagModel(name='X3', tags=['b', 'c']).put()
self.assertEqual(len(Tag.get_linked_by_tag("a")), 1)
self.assertEqual(len(Tag.get_linked_by_tag("b")), 2)
self.assertEqual(len(Tag.get_linked_by_tag("zzz")), 0)
def test_popular_query(self):
TestTagModel(name='X1', tags=['a', 'b']).put()
TestTagModel(name='X2', tags=['b']).put()
res = Tag.get_popular_query().fetch(5)
self.assertEqual(len(res), 2)
self.assertEqual(res[0].tag, 'b')
self.assertEqual(res[1].tag, 'a')
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment