Skip to content

Instantly share code, notes, and snippets.

@bvanvugt
Created September 12, 2015 06:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bvanvugt/32fbc679bfff94253da9 to your computer and use it in GitHub Desktop.
Save bvanvugt/32fbc679bfff94253da9 to your computer and use it in GitHub Desktop.
import sqlite3
class CohortDatabase(object):
def __init__(self):
self.db = sqlite3.connect(':memory:')
# Create Tables
self.db.cursor().execute("""
CREATE TABLE data
(cohort text, period text, count int, sum real,
UNIQUE (cohort, period) ON CONFLICT FAIL)
""")
def insert(self, cohort, period, value):
# Insert
self.db.cursor().execute("""
INSERT OR IGNORE INTO data (cohort, period, count, sum)
VALUES (?, ?, 0, 0.0)
""", (cohort, period))
# Update
self.db.cursor().execute("""
UPDATE data
SET count = count + 1, sum = sum + ?
WHERE cohort = ? AND period = ?
""", (value, cohort, period))
def select(self, cohort, period):
cursor = self.db.cursor()
cursor.execute("""
SELECT sum
FROM data
WHERE cohort = ? AND period = ?
""", (cohort, period))
rows = cursor.fetchall()
if not rows:
return 0.0
if len(rows) > 1:
raise Exception('SELECT returned more than one result, it returned %s' % len(rows))
return rows[0][0]
def raw(self):
cursor = self.db.cursor()
cursor.execute("""
SELECT cohort, period, count, sum
FROM data
""")
return [
{'cohort': row[0], 'period': row[1], 'count': row[2], 'sum': row[3]}
for row in cursor.fetchall()
]
def dump(self):
rows = self.raw()
cols = ['cohort', 'period', 'count', 'sum']
print ','.join(cols)
for row in rows:
print ','.join([str(row[col]) for col in cols])
def print_table(self):
cursor = self.db.cursor()
cursor.execute("""SELECT DISTINCT cohort FROM data ORDER BY cohort ASC""")
cohorts = [row[0] for row in cursor.fetchall()]
cursor.execute("""SELECT DISTINCT period FROM data ORDER BY cohort ASC""")
periods = [row[0] for row in cursor.fetchall()]
print ','.join([','] + periods)
for cohort in cohorts:
row = [cohort]
for period in periods:
row.append(str(self.select(cohort, period)))
print ','.join(row)
@classmethod
def test(cls):
db = CohortDatabase()
expected = []
assert db.raw() == expected
db.insert('c_one', 'm_one', 4)
expected = [
{'cohort': 'c_one', 'period': 'm_one', 'count': 1, 'sum': 4}
]
assert db.raw() == expected
db.insert('c_one', 'm_two', 5.7)
expected = [
{'cohort': 'c_one', 'period': 'm_one', 'count': 1, 'sum': 4},
{'cohort': 'c_one', 'period': 'm_two', 'count': 1, 'sum': 5.7}
]
assert db.raw() == expected
db.insert('c_one', 'm_one', 78.999)
expected = [
{'cohort': 'c_one', 'period': 'm_one', 'count': 2, 'sum': 82.999},
{'cohort': 'c_one', 'period': 'm_two', 'count': 1, 'sum': 5.7}
]
assert db.raw() == expected
x = CohortDatabase()
x.insert('cohort_1', 'month_1', 4)
x.insert('cohort_1', 'month_21', 44)
x.insert('cohort_34', 'month_21', 444)
x.print_table()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment