Skip to content

Instantly share code, notes, and snippets.

@jRimbault
Last active March 10, 2021 15:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jRimbault/17f1626c0a24656639da669d8efc322d to your computer and use it in GitHub Desktop.
Save jRimbault/17f1626c0a24656639da669d8efc322d to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""Algorithm
1. read rows sorted by date :
- avg O(n log n)
- best O(n)
- worst O(n log n)
2. partition by name : O(n)
put all rows in buckets keyed by user id/name,
each bucket contains the _sorted_ rows for that user
3. find non compliants (streaming into 4) :
- avg O(n log n)
- best O(n)
- worst O(n log n)
if the latest record for user A is non-compliant,
then search since when A has been non-compliant,
return the recorded date at which A became non-compliant
if A has always been non-compliant, then return the latest record date
4. write non-compliants : O(n)
See the `meh.py` file for a variation with roughly the same characteriscs.
Overall not happy to have to load everything in memory.
"""
import argparse
import csv
import io
import sys
import tracemalloc
from collections import defaultdict
from operator import itemgetter
CSV_OPTIONS = {"dialect": "excel", "delimiter": ";", "quoting": csv.QUOTE_NONNUMERIC}
BUFFER_SIZE = io.DEFAULT_BUFFER_SIZE * 10
def main(args):
write_uncompliants(
args.output_csv,
find_non_compliants(
partition((user["name"], user) for user in read_all_users(args.input_csv))
),
)
def write_uncompliants(path, uncompliants):
with open(path, "w", newline="", buffering=BUFFER_SIZE) as fd:
writer = csv.writer(fd, **CSV_OPTIONS)
trace("writing rows")
writer.writerows(uncompliants)
trace("rows written")
def read_all_users(path):
def map_row(row):
# the "extract_date" field should be in ISO-8601 format
# in a real case it could be anther format
# we'd just have to parse it
return {
"name": str(row["name"]),
"compliant": bool(row["compliant"]),
"extract_date": str(row["extract_date"]),
}
with open(path, buffering=BUFFER_SIZE) as fd:
reader = csv.DictReader(fd, **CSV_OPTIONS)
users = map(map_row, reader)
trace("sorting users")
return sorted(users, key=itemgetter("extract_date"), reverse=True)
def find_non_compliants(users_bag):
def non_compliant_since(records):
for i, record in enumerate(records):
if record["compliant"]:
break
# if has always been uncompliant use latest record
# else use the first uncompliant record
i = 0 if i == len(records) - 1 else i - 1
return records[i]["extract_date"]
trace("finding non compliant users")
yield ("name", "date")
yield from (
(username, non_compliant_since(records))
for username, records in users_bag.items()
# if latest record was not compliant
if not records[0]["compliant"]
)
def partition(key_value_iter):
bag = defaultdict(list)
for key, value in key_value_iter:
bag[key].append(value)
return bag
POWERS = ["", "K", "M", "G", "T"]
def trace(source="", peak=False):
def fmt(size):
power = 2 ** 10
n = 0
while size > power:
size /= power
n += 1
return f"{size:.02f} {POWERS[n]}B"
peak_flag = peak
current, peak = tracemalloc.get_traced_memory()
if current == 0:
return
msg = (
f"Current({fmt(current)}) Peak({fmt(peak)}) {source}"
if peak_flag
else f"Current({fmt(current)}) {source}"
)
print(msg, file=sys.stderr)
def parse_args(argv):
parser = argparse.ArgumentParser()
parser.add_argument("input_csv")
parser.add_argument("output_csv")
parser.add_argument("--tracemalloc", action="store_true")
return parser.parse_args(argv)
if __name__ == "__main__":
args = parse_args(sys.argv[1:])
if args.tracemalloc:
tracemalloc.start()
trace("start")
main(args)
trace(peak=True)
tracemalloc.stop()
-- TABLE
CREATE TABLE USER (
name varchar,
compliant integer,
extract_date integer -- simulate dates for the purpose of the example
);
-- VALUES
INSERT INTO USER
(name, compliant, extract_date)
VALUES
("foo", True, 0),
("bar", True, 0),
("qaz", False, 0),
("foo", True, 1),
("qaz", False, 1),
("foo", False, 2),
("bar", True, 2),
("qaz", False, 2),
("foo", False, 3),
("bar", True, 3),
("qaz", False, 3);
-- QUERY
-- incomplete
SELECT name, compliant, extract_date
FROM (
SELECT
USER.name as name,
USER.compliant as compliant,
USER.extract_date as extract_date
FROM USER
ORDER BY USER.extract_date DESC
)
GROUP BY name

Abstract presentation

| User              |
|-------------------|
| name string       |
| compliant boolean |
| extract_date date |

Find which users are non-compliant at the latest extract_date. For each non-compliant user find the date at which they became non-compliant, if they have always been non-compliant use the most recent date.

Without adding neither columns nor tables.

#!/usr/bin/env python3
"""Algorithm
1. read rows and partition by name, keeping sorted :
- avg O(n log n)
put all rows in buckets keyed by user id/name,
each bucket contains the _sorted by date_ rows for that user
2. find and write non compliants :
- avg O(n log n)
- best O(n)
- worst O(n log n)
if the latest record for user A is non-compliant,
then search since when A has been non-compliant,
return the recorded date at which A became non-compliant
if A has always been non-compliant, then return the latest record date
Overall not happy to have to load everything in memory.
"""
import argparse
import csv
import io
import sys
import tracemalloc
from collections import defaultdict
from operator import itemgetter
CSV_OPTIONS = {"dialect": "excel", "delimiter": ";", "quoting": csv.QUOTE_NONNUMERIC}
BUFFER_SIZE = io.DEFAULT_BUFFER_SIZE * 10
def main(args):
users = read_users_partitioned(args.input_csv)
find_and_write_non_compliants(args.output_csv, users)
def find_and_write_non_compliants(path, users):
write_uncompliants(path, find_non_compliants(users))
def read_users_partitioned(path):
def keyed_row(row):
return (row["name"], row)
return ordered_partition(
(keyed_row(row) for row in read_users(path)),
key=itemgetter("extract_date"),
reverse=True,
)
def write_uncompliants(path, uncompliants):
with open(path, "w", newline="", buffering=BUFFER_SIZE) as fd:
writer = csv.writer(fd, **CSV_OPTIONS)
trace("writing rows")
writer.writerows(uncompliants)
trace("rows written")
def read_users(path):
def map_row(row):
# the "extract_date" field should be in ISO-8601 format
# in a real case it could be anther format
# we'd just have to parse it
return {
"name": str(row[0]),
"compliant": bool(row[1]),
"extract_date": str(row[2]),
}
with open(path, buffering=BUFFER_SIZE) as fd:
reader = csv.reader(fd, **CSV_OPTIONS)
next(reader) # skips headers
trace("reading rows")
yield from (map_row(row) for row in reader)
def find_non_compliants(users_bag):
def non_compliant_since(records):
for i, record in enumerate(records):
if record["compliant"]:
break
# if has always been uncompliant use latest record
# else use the first uncompliant record
i = 0 if i == len(records) - 1 else i - 1
return records[i]["extract_date"]
trace("finding non compliant users")
yield ("name", "date") # headers
yield from (
(username, non_compliant_since(records)) # values
for username, records in users_bag.items()
# if latest record was not compliant
if not records[0]["compliant"]
)
def ordered_partition(key_value_iter, key=None, reverse=None):
return partition(key_value_iter, sortedlist(key, reverse))
def partition(key_value_iter, ty=list):
bag = defaultdict(ty)
for key, value in key_value_iter:
bag[key].append(value)
return bag
POWERS = ["", "K", "M", "G", "T"]
def trace(source="", peak=False):
def fmt(size):
power = 2 ** 10
n = 0
while size > power:
size /= power
n += 1
return f"{size:.02f} {POWERS[n]}B"
peak_flag = peak
current, peak = tracemalloc.get_traced_memory()
if current == 0:
return
msg = (
f"Current({fmt(current)}) Peak({fmt(peak)}) {source}"
if peak_flag
else f"Current({fmt(current)}) {source}"
)
print(msg, file=sys.stderr)
def sortedlist(key=None, reverse=None):
class SortedList:
def __init__(self, key, reverse):
self.inner = []
self.key = key
self.reverse = reverse
def append(self, value):
self.inner.append(value)
# timsort is quite efficient with already sorted lists
self.inner.sort(key=self.key, reverse=self.reverse)
def __iter__(self):
return iter(self.inner)
def __getitem__(self, index):
return self.inner[index]
def __len__(self):
return len(self.inner)
def _id(x):
return x
if key is None:
key = _id
def inner():
return SortedList(key, reverse)
return inner
def parse_args(argv):
parser = argparse.ArgumentParser()
parser.add_argument("input_csv")
parser.add_argument("output_csv")
parser.add_argument("--tracemalloc", action="store_true")
return parser.parse_args(argv)
if __name__ == "__main__":
args = parse_args(sys.argv[1:])
if args.tracemalloc:
tracemalloc.start()
trace("start")
main(args)
trace(peak=True)
tracemalloc.stop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment