Skip to content

Instantly share code, notes, and snippets.

@schlamar
Created September 26, 2013 13:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save schlamar/6713843 to your computer and use it in GitHub Desktop.
Save schlamar/6713843 to your computer and use it in GitHub Desktop.
SQLAlchemy ordered collection.
from sqlalchemy import create_engine, event, Column, Integer
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import func, select
Base = declarative_base()
class Job(Base):
__tablename__ = 'job'
id = Column(Integer, primary_key=True)
position = Column(Integer, nullable=False)
engine = create_engine('sqlite:///:memory:')
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
@event.listens_for(Job, 'before_insert')
def job_before_insert(mapper, connection, target):
old = None
new = target.position
if new is None:
s = select([func.max(Job.position)])
new = connection.scalar(s)
if new is None:
new = 1
else:
new += 1
target.position = new
fix_ordering(old, new, connection)
@event.listens_for(Job, 'before_delete')
def job_before_delete(mapper, connection, target):
old = target.position
new = None
fix_ordering(old, new, connection)
@event.listens_for(Job, 'before_update')
def job_before_update(mapper, connection, target):
s = select([Job.position]).where(Job.id == target.id)
old = connection.scalar(s)
new = target.position
fix_ordering(old, new, connection)
def fix_ordering(old, new, conn):
if old is None:
# added a job
value = Job.position + 1
where = Job.position >= new
elif new is None:
# deleted a job
value = Job.position - 1
where = Job.position > old
elif old > new:
value = Job.position + 1
where = (Job.position >= new) & (Job.position < old)
elif new > old:
value = Job.position - 1
where = (Job.position > old) & (Job.position <= new)
else:
return
conn.execute(Job.__table__.update().where(where).values(position=value))
def assert_order(session, jobs):
for j, j_db in zip(jobs, session.query(Job).order_by(Job.position)):
# print j.id, j.position, j_db.id, j_db.position
assert j.position is not None
assert j.id == j_db.id
assert j.position == j_db.position
def test():
session = Session()
jobs = list()
for _ in xrange(5):
j = Job()
jobs.append(j)
session.add(j)
session.commit()
assert_order(session, jobs)
# move 4 -> 2
j = jobs[3]
j.position = 2
jobs.remove(j)
jobs.insert(1, j)
session.commit()
assert_order(session, jobs)
# move 1 -> 3
j = jobs[0]
j.position = 3
jobs.remove(j)
jobs.insert(2, j)
session.commit()
assert_order(session, jobs)
# add 4
j = Job(position=4)
jobs.insert(3, j)
session.add(j)
session.commit()
assert_order(session, jobs)
# delete 3
j = jobs[2]
jobs.remove(j)
session.delete(j)
session.commit()
assert_order(session, jobs)
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment