Skip to content

Instantly share code, notes, and snippets.

@fizyk
Last active December 12, 2015 10:09
Show Gist options
  • Save fizyk/4757230 to your computer and use it in GitHub Desktop.
Save fizyk/4757230 to your computer and use it in GitHub Desktop.
SqlAlchemy NestedSet Mixin + tests
# -*- coding: utf-8 -*-
'''
Tested for SQLAlchemy >=0.7.x and 0.8
'''
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Unicode
from sqlalchemy import ForeignKey
from sqlalchemy import Sequence
from sqlalchemy.sql import select, case, not_, and_
from sqlalchemy.orm.session import object_session
from sqlalchemy.orm.attributes import get_history
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import relationship
from sqlalchemy.orm import backref
class NestedSetMixin(object):
'''
Defines NestedSetMixin for nested set models
'''
_left = Column('lft', Integer, nullable=True)
'''
Defines left reach of given nested object (updates automatically)
'''
_right = Column('rgt', Integer, nullable=True)
'''
Defines right reach of given nested object (updates automatically)
'''
_depth = Column('depth', Integer, nullable=True, default=0)
'''
Depth of given nested object (updates automatically)
'''
@declared_attr
def parent_id(cls):
return Column(Integer, ForeignKey(cls.id))
@declared_attr
def parent(cls):
return relationship(cls, remote_side=cls.id, primaryjoin=lambda: (cls.parent_id == cls.id), backref=backref('children', cascade='delete'))
__mapper_args__ = {
'batch': False # allows listener to fire for each instance before going to the next. Does not work on delete though!
}
@property
def is_leaf(self):
'''
Determines if is leaf, or not based on the left and right reaches
'''
# If the difference is one, we've got a leaf!
if self._right - self._left == 1:
return True
return False
@classmethod
def _listener_insert_nested_set(cls, mapper, connection, nested):
'''
Listener fired before insert is beeing performed. Keeps all ltg, rtg and depth in order
.. note::
See `sqlalchemy.orm.events.MapperEvents.before_insert <http://docs.sqlalchemy.org/en/latest/orm/events.html#sqlalchemy.orm.events.MapperEvents.before_insert>`_ for more informations
'''
# let's us get the mapped table first
nested_table = cls.__get_table(mapper, nested)
# this part check whether currently saved nested object has parent or not.
if not nested.parent_id or not nested.parent:
# _depth is 0, and the right most sibling is just most right element
nested._depth = 0
right_most_sibling = connection.scalar(
select([nested_table.c.rgt]).where(nested_table.c.depth == 0).order_by(nested_table.c.rgt.desc()).limit(1)
)
if not right_most_sibling:
# that means we have first element in table
nested._left = 1
nested._right = 2
else:
nested._left = right_most_sibling + 1
nested._right = right_most_sibling + 2
else:
# Code below serves to get same mixin behaviour no matter if new parent
# relation might be passed either by object or it's id. In case of id, we
# might not get related object in same session (unless it's already in
# memory).
if nested.parent:
parent_id = nested.parent.id
parent_depth = nested.parent._depth
else:
parent_id = nested.parent_id
parent_depth = connection.scalar(
select([nested_table.c.depth]).where(nested_table.c.id == parent_id)
)
# setting nested object's depth
nested._depth = parent_depth + 1
# we get the value how far right the current branch reaches
parent_right = connection.scalar(
select([nested_table.c.rgt]).where(nested_table.c.id == parent_id)
)
'''
we move all left and right reaches that are further to accommodate new node
It also extends new node's parent by moving only it's right reach
All nodes that needs to be moved to make room for our branch, will have moved both left, and right reaches.
If we insert something in a node, than only right reach is moved.
The +2 is because the new node will always have diff 2.
'''
connection.execute(
nested_table.update(nested_table.c.rgt >= parent_right).values(
lft=case(
[(nested_table.c.lft > parent_right, nested_table.c.lft + 2)],
else_=nested_table.c.lft
),
rgt=case(
[(nested_table.c.rgt >= parent_right, nested_table.c.rgt + 2)],
else_=nested_table.c.rgt
)
)
)
nested._left = parent_right
nested._right = parent_right + 1
@classmethod
def _listener_update_nested_set(cls, mapper, connection, nested):
'''
Listener fired before UPDATE is beeing performed. Keeps all ltg, rtg and depth in order.
Checks first if object has been modified **object_session(nested).is_modified(nested, include_collections=False)**,
and if parent or parent_id attributes has been changed **get_history(obj, attribute).has_changes()**.
Other attributes changes are not concern of ours.
.. note::
See `sqlalchemy.orm.events.MapperEvents.before_update <http://docs.sqlalchemy.org/en/latest/orm/events.html#sqlalchemy.orm.events.MapperEvents.before_update>`_ for more informations
'''
# let's check if object is really dirty and in need of update
# get_history reads changes for given attribute. `See <http://docs.sqlalchemy.org/en/latest/orm/session.html?highlight=attributes%20get_history#sqlalchemy.orm.attributes.History>`_
if object_session(nested).is_modified(nested, include_collections=False)\
and (get_history(nested, 'parent').has_changes() or get_history(nested, 'parent_id').has_changes()):
nested_table = cls.__get_table(mapper, nested)
diff = (nested._right - nested._left) + 1
# we check how many children does updated nested object have
children = connection.execute(
select([nested_table.c.id]).where(nested_table.c.lft > nested._left).where(nested_table.c.rgt < nested._right)
).fetchall()
# updating is basically removing a node and inserting it in a new place.
cls._listener_delete_nested_set(mapper, connection, nested)
# if it's a leaf, then whole operation is this simple
if not children:
cls._listener_insert_nested_set(mapper, connection, nested)
# otherwise it's not...
else:
'''
Variables:
:var list children: - list containing children ids, helpful in keeping the data correct and updating
only what needs to be updated in given step
:var int diff: - how far is current left reach from right reach
:var int direction: - how much and in what direction will the reach change in node and it's children
:var int depth_change: - how much will the depth change and in what direction
'''
children = [c[0] for c in children]
# calculate the difference between left and right side
diff = (nested._right - nested._left) + 1
direction = 0
depth_change = 0
if not nested.parent and not nested.parent_id:
depth_change = -nested._depth
right_most_sibling = connection.scalar(
select([nested_table.c.rgt]).where(and_(nested_table.c.depth == 0,
not_(nested_table.c.id.in_(children)),
not_(nested_table.c.id == nested.id))).
order_by(nested_table.c.rgt.desc()).limit(1)
)
if not right_most_sibling:
# first element in a table
direction = 1 - nested._left
else:
direction = right_most_sibling - nested._left + 1
else:
if nested.parent:
parent_id = nested.parent.id
else:
parent_id = nested.parent_id
parents_right_reach, parents_depth = connection.execute(
select([nested_table.c.rgt, nested_table.c.depth]).where(nested_table.c.id == parent_id)
).fetchone()
depth_change = (parents_depth + 1) - nested._depth
# This is to get same behaviour no matter if we entered new relation by id,
# or object, as not every time object might be in memory and assigned to
# relation with all it's data when assigning id to parent_id field.
# we read new parent's right reach
# here we update the overall structure to accommodate that moved
connection.execute(
nested_table.update(
and_(nested_table.c.rgt >= parents_right_reach,
not_(nested_table.c.id.in_(children)),
not_(nested_table.c.id == nested.id)
)
).values(
lft=case(
[(nested_table.c.lft > parents_right_reach, nested_table.c.lft + diff)],
else_=nested_table.c.lft
),
rgt=case(
[(nested_table.c.rgt >= parents_right_reach, nested_table.c.rgt + diff)],
else_=nested_table.c.rgt
)
)
)
# we calculate direction as to where and how much we should move reaches of moved nested and it's children
direction = parents_right_reach - nested._left
# here we update nested object's children
connection.execute(
nested_table.update(nested_table.c.id.in_(children)).values(
lft=(nested_table.c.lft + direction),
rgt=(nested_table.c.rgt + direction),
depth=(nested_table.c.depth + depth_change)
)
)
# and nested object itself. Normal SQLAlchemy process will take care with inserting this into the database
nested._left += direction
nested._right += direction
nested._depth += depth_change
@classmethod
def _listener_delete_nested_set(cls, mapper, connection, nested):
'''
Listener fired before DELETE is being performed. Keeps all ltg, rtg and depth in order.
Is performed only for a tip of deleted branch
.. note::
See `sqlalchemy.orm.events.MapperEvents.before_delete <http://docs.sqlalchemy.org/en/latest/orm/events.html#sqlalchemy.orm.events.MapperEvents.before_delete>`_ for more informations
'''
# Proceed with algorithm only for the tip of deleted branch. Otherwise we
# get incorrect behaviour becouse of delete action not using batch=False
if not nested.parent in object_session(nested).deleted:
nested_table = cls.__get_table(mapper, nested)
# we calculate the distance we'll recover after deleting this
diff = (nested._right - nested._left) + 1
# updating tree and reaches to compensate for the deletion
connection.execute(
nested_table.update(nested_table.c.rgt > nested._right).values(
lft=case(
[(nested_table.c.lft > nested._right, nested_table.c.lft - diff)],
else_=nested_table.c.lft
),
rgt=case(
[(nested_table.c.rgt > nested._right, nested_table.c.rgt - diff)],
else_=nested_table.c.rgt
)
)
)
@classmethod
def __get_table(cls, mapper, nested):
'''
Returns proper Table object to be used within listeners.
:param sqlalchemy.orm.mapper.Mapper mapper: a mapper object
:param NestedSetMixin nested: an instance of NestedSetMixin table
.. note::
Unfortunately mapper.mapped_table returns Join expression for polymorphic tables, which doesn't provide access for columns. that's why we need to check for these field access, and eventually get tanble from object.
'''
nested_table = mapper.mapped_table
if not hasattr(nested_table.c, 'rgt'):
nested_table = nested.__table__
return nested_table
# -*- coding: utf-8 -*-
'''
Created on 13-07-2012
@author: sliwinski
'''
import unittest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column
from sqlalchemy import Unicode
from sqlalchemy import Integer
from sqlalchemy.event import listen
from nestedset import NestedSetMixin
class Base(declarative_base()):
__abstract__ = True
__table_args__ = ({'mysql_engine': 'InnoDB'},)
class Category(Base, NestedSetMixin):
__tablename__ = 'category'
id = Column(Integer, primary_key=True)
category = Column(Unicode(255))
listen(Category, 'before_insert', NestedSetMixin._listener_insert_nested_set)
listen(Category, 'before_update', NestedSetMixin._listener_update_nested_set)
listen(Category, 'before_delete', NestedSetMixin._listener_delete_nested_set)
class Test(unittest.TestCase):
'''Testing missing conflicted_clusters key'''
def setUp(self, echo=False):
'''
setUp test method @see unittest.TestCase.setUp
Args:
:param Base (sqlalchemy.ext.declarative.DeclarativeMeta): DeclarativeMeta object
:param connection_type: connection type, either sqlite, or mysql
:param echo (bool): whether to echo sqlalchemy queries, or not
'''
connection = 'sqlite://'
engine = create_engine(connection, echo=echo)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
self.session = sessionmaker(bind=engine)()
self.engine = engine
def tearDown(self):
'''
This tears tests down
'''
# Lets drop tables after test, so we don't have any problem if model changes
Base.metadata.drop_all(self.engine)
self.session.close()
del self.engine
def test_nested_add1(self):
'''NestedSetMixin::FirstRootInsert'''
c = Category(category=u'first')
self.session.add(c)
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'first').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 1, 'left value should be 1, is {0}'.format(cat._left))
self.assertEqual(cat._right, 2, 'right value should be 2, is {0}'.format(cat._right))
self.assertTrue(cat.is_leaf, 'This category should be considered a leaf')
def test_nested_add2(self):
'''NestedSetMixin::SecondRootInsert'''
c0 = Category(category=u'first')
c1 = Category(category=u'second')
self.session.add_all([c0, c1])
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'second').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 3, 'left value should be 3, is {0}'.format(cat._left))
self.assertEqual(cat._right, 4, 'left value should be 4, is {0}'.format(cat._right))
def test_nested_add3(self):
'''NestedSetMixin::SecondRootInsert'''
c0 = Category(category=u'first')
c1 = Category(category=u'second')
c2 = Category(category=u'first.first')
c2.parent = c0
self.session.add_all([c0, c1, c2])
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'first.first').one()
self.assertEqual(cat._depth, 1, 'Wrong depth, should be 1, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 2, 'left value should be 2, is {0}'.format(cat._left))
self.assertEqual(cat._right, 3, 'right value should be 3, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'second').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 5, 'left value should be 5, is {0}'.format(cat._left))
self.assertEqual(cat._right, 6, 'right value should be 6, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'first').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 1, 'left value should be 1, is {0}'.format(cat._left))
self.assertEqual(cat._right, 4, 'right value should be 4, is {0}'.format(cat._right))
def test_nested_update(self):
'''NestedSetMixin::Moove leaf-category'''
c0 = Category(category=u'first')
c1 = Category(category=u'second')
c2 = Category(category=u'first.first')
c2.parent = c0
self.session.add_all([c0, c1, c2])
self.session.commit()
cat1 = self.session.query(Category).filter(Category.category == u'second').one()
cat2 = self.session.query(Category).filter(Category.category == u'first.first').one()
cat2.parent = cat1
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'first.first').one()
self.assertEqual(cat._depth, 1, 'Wrong depth, should be 1, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 4, 'left value should be 4, is {0}'.format(cat._left))
self.assertEqual(cat._right, 5, 'right value should be 5, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'second').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 3, 'left value should be 3, is {0}'.format(cat._left))
self.assertEqual(cat._right, 6, 'right value should be 6, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'first').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 1, 'left value should be 1, is {0}'.format(cat._left))
self.assertEqual(cat._right, 2, 'right value should be 2, is {0}'.format(cat._right))
def test_nested_update_node(self):
'''NestedSetMixin::Move node-category'''
c0 = Category(category=u'c0')
c1 = Category(category=u'c1')
c2 = Category(category=u'c2')
c3 = Category(category=u'c3')
c4 = Category(category=u'c4')
c2.parent = c0
c3.parent = c2
self.session.add_all([c0, c1, c2, c3, c4])
self.session.commit()
cat1 = self.session.query(Category).filter(Category.category == u'c4').one()
cat2 = self.session.query(Category).filter(Category.category == u'c2').one()
cat2.parent = cat1
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'c2').one()
self.assertEqual(cat._depth, 1, 'Wrong depth, should be 1, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 6, 'left value should be 6, is {0}'.format(cat._left))
self.assertEqual(cat._right, 9, 'right value should be 9, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'c3').one()
self.assertEqual(cat._depth, 2, 'Wrong depth, should be 2, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 7, 'left value should be 7, is {0}'.format(cat._left))
self.assertEqual(cat._right, 8, 'right value should be 8, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'c1').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 3, 'left value should be 3, is {0}'.format(cat._left))
self.assertEqual(cat._right, 4, 'right value should be 4, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'c0').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 1, 'left value should be 1, is {0}'.format(cat._left))
self.assertEqual(cat._right, 2, 'right value should be 2, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'c4').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 5, 'left value should be 1, is {0}'.format(cat._left))
self.assertEqual(cat._right, 10, 'right value should be 2, is {0}'.format(cat._right))
def test_nested_update_node2(self):
'''NestedSetMixin::Move node-category to root'''
c0 = Category(category=u'c0')
c1 = Category(category=u'c1')
c2 = Category(category=u'c2')
c3 = Category(category=u'c3')
c2.parent = c0
c3.parent = c2
self.session.add_all([c0, c1, c2, c3])
self.session.commit()
cat2 = self.session.query(Category).filter(Category.category == u'c2').one()
cat2.parent = None
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'c2').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 5, 'left value should be 5, is {0}'.format(cat._left))
self.assertEqual(cat._right, 8, 'right value should be 8, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'c3').one()
self.assertEqual(cat._depth, 1, 'Wrong depth, should be 1, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 6, 'left value should be 6, is {0}'.format(cat._left))
self.assertEqual(cat._right, 7, 'right value should be 7, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'c1').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 3, 'left value should be 3, is {0}'.format(cat._left))
self.assertEqual(cat._right, 4, 'right value should be 4, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'c0').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 1, 'left value should be 1, is {0}'.format(cat._left))
self.assertEqual(cat._right, 2, 'right value should be 2, is {0}'.format(cat._right))
def test_nested_delete(self):
'''NestedSetMixin::Delete children leaf'''
c0 = Category(category=u'first')
c1 = Category(category=u'second')
c2 = Category(category=u'first.first')
c3 = Category(category=u'first.second')
c2.parent = c0
c3.parent = c0
self.session.add_all([c0, c1, c2, c3])
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'first.second').one()
self.assertEqual(cat._depth, 1, 'Wrong depth, should be 1, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 4, 'left value should be 4, is {0}'.format(cat._left))
self.assertEqual(cat._right, 5, 'right value should be 5, is {0}'.format(cat._right))
self.session.delete(cat)
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'first.first').one()
self.assertEqual(cat._depth, 1, 'Wrong depth, should be 1, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 2, 'left value should be 2, is {0}'.format(cat._left))
self.assertEqual(cat._right, 3, 'right value should be 3, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'second').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 5, 'left value should be 5, is {0}'.format(cat._left))
self.assertEqual(cat._right, 6, 'right value should be 6, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'first').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 1, 'left value should be 1, is {0}'.format(cat._left))
self.assertEqual(cat._right, 4, 'right value should be 4, is {0}'.format(cat._right))
def test_nested_delete_node(self):
'''NestedSetMixin::Delete node'''
c0 = Category(category=u'first')
c1 = Category(category=u'second')
c2 = Category(category=u'first.first')
c2.parent = c0
c3 = Category(category=u'first.first.buu')
c3.parent = c2
self.session.add_all([c0, c1, c2, c3])
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'first.first').one()
self.assertEqual(cat._depth, 1, 'Wrong depth, should be 1, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 2, 'left value should be 2, is {0}'.format(cat._left))
self.assertEqual(cat._right, 5, 'right value should be 5, is {0}'.format(cat._right))
self.session.delete(cat)
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'second').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 3, 'left value should be 3, is {0}'.format(cat._left))
self.assertEqual(cat._right, 4, 'right value should be 4, is {0}'.format(cat._right))
cat = self.session.query(Category).filter(Category.category == u'first').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 1, 'left value should be 1, is {0}'.format(cat._left))
self.assertEqual(cat._right, 2, 'right value should be 2, is {0}'.format(cat._right))
def test_complicated_nodes(self):
'''NestedSetMixin::Complicated nodes'''
c1 = Category(category=u'1')
c2 = Category(category=u'2')
c3 = Category(category=u'3')
c4 = Category(category=u'21')
c5 = Category(category=u'4')
c6 = Category(category=u'5')
c7 = Category(category=u'6')
c8 = Category(category=u'11')
c9 = Category(category=u'12')
c10 = Category(category=u'211')
c11 = Category(category=u'212')
c8.parent = c1
c9.parent = c1
c10.parent = c4
c11.parent = c4
self.session.add_all([c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11])
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'1').one()
self.assertEqual(cat._depth, 0, 'Wrong depth, should be 0, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 1, 'left value should be 1, is {0}'.format(cat._left))
self.assertEqual(cat._right, 6, 'right value should be 2, is {0}'.format(cat._right))
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'212').one()
self.assertEqual(cat._depth, 1, 'Wrong depth, should be 1, is {0}'.format(cat._depth))
self.assertEqual(cat._left, 14, 'left value should be 14, is {0}'.format(cat._left))
self.assertEqual(cat._right, 15, 'right value should be 15, is {0}'.format(cat._right))
moving_node = self.session.query(Category).filter(Category.category == u'6').one()
self.assertEqual(moving_node._depth, 0, 'Wrong depth, should be 0, is {0}'.format(moving_node._depth))
self.assertEqual(moving_node._left, 21, 'left value should be 21, is {0}'.format(moving_node._left))
self.assertEqual(moving_node._right, 22, 'right value should be 22, is {0}'.format(moving_node._right))
parent_node = self.session.query(Category).filter(Category.category == u'21').one()
self.assertEqual(parent_node._depth, 0, 'Wrong depth, should be 0, is {0}'.format(parent_node._depth))
self.assertEqual(parent_node._left, 11, 'left value should be 11, is {0}'.format(parent_node._left))
self.assertEqual(parent_node._right, 16, 'right value should be 16, is {0}'.format(parent_node._right))
moving_node.parent = parent_node
self.session.commit()
# new values should be here!
moving_node = self.session.query(Category).filter(Category.category == u'6').one()
self.assertEqual(moving_node._depth, 1, 'Wrong depth, should be 1, is {0}'.format(moving_node._depth))
self.assertEqual(moving_node._left, 16, 'left value should be 16, is {0}'.format(moving_node._left))
self.assertEqual(moving_node._right, 17, 'right value should be 17, is {0}'.format(moving_node._right))
parent_node = self.session.query(Category).filter(Category.category == u'21').one()
self.assertEqual(parent_node._depth, 0, 'Wrong depth, should be 0, is {0}'.format(parent_node._depth))
self.assertEqual(parent_node._left, 11, 'left value should be 11, is {0}'.format(parent_node._left))
self.assertEqual(parent_node._right, 18, 'right value should be 18, is {0}'.format(parent_node._right))
def test_is_leaf(self):
'''NestedSetMixin::is_leaf'''
c = Category(category=u'first')
self.session.add(c)
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'first').one()
self.assertTrue(cat.is_leaf, 'This category should be considered a leaf')
def test_is_leaf(self):
'''NestedSetMixin::is_leaf two elements'''
c0 = Category(category=u'first')
c1 = Category(category=u'second')
c1.parent = c0
self.session.add_all([c0, c1])
self.session.commit()
cat = self.session.query(Category).filter(Category.category == u'first').one()
self.assertFalse(cat.is_leaf, 'This category should not be considered a leaf')
cat = self.session.query(Category).filter(Category.category == u'second').one()
self.assertTrue(cat.is_leaf, 'This category should be considered a leaf')
if __name__ == "__main__":
# import sys;sys.argv = ['', 'Test.testName']
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment