-
-
Save e3krisztian/2de3e85e2b226afb1b6c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
import csv | |
import sys | |
START_DATE = 'start_date' | |
END_DATE = 'end_date' | |
TOLERANCE = 31 # days | |
PRIMARY_KEYS = ['frame_id', 'person_id'] | |
def mold_date(any_date): | |
if isinstance(any_date, basestring): | |
try: | |
date = datetime.strptime(any_date, '%Y-%m-%d') | |
except: | |
date = datetime.strptime('9999-12-31', '%Y-%m-%d') | |
return date | |
else: | |
return any_date | |
def dates_within_tolerance(date1, date2, tolerance=TOLERANCE): | |
difference = mold_date(date2)-mold_date(date1) | |
return difference.days<=tolerance | |
def same_primary_key(row1, row2, primary_keys=PRIMARY_KEYS): | |
for key in primary_keys: | |
if key not in row1 or key not in row2 or not (row1[key]==row2[key]): | |
return False | |
return True | |
def same_spell(row1, row2, tolerance=TOLERANCE, primary_keys=PRIMARY_KEYS): | |
if not same_primary_key(row1, row2, primary_keys=primary_keys): | |
return False | |
if not dates_within_tolerance(row1[END_DATE], row2[START_DATE], tolerance=tolerance): | |
return False | |
return True | |
def tag_spells(input_rows, tolerance=TOLERANCE, primary_keys=PRIMARY_KEYS): | |
previous_row = None | |
spell_id = 1 | |
for current_row in input_rows: | |
if previous_row: | |
if same_spell(previous_row, current_row, tolerance=tolerance, primary_keys=primary_keys): | |
previous_row[END_DATE] = current_row[END_DATE] | |
else: | |
spell_id = spell_id + 1 | |
previous_row = current_row.copy() | |
else: | |
previous_row = current_row.copy() | |
current_row['spell_id'] = spell_id | |
yield current_row | |
def main(): | |
reader = csv.DictReader(sys.stdin) | |
writer = csv.DictWriter(sys.stdout, fieldnames=list(reader.fieldnames) + ['spell_id']) | |
writer.writeheader() | |
writer.writerows(tag_spells(reader)) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
frame_id | person_id | start_date | end_date | spell_id | |
---|---|---|---|---|---|
1 | 1 | 1992-01-01 | 1994-10-31 | 1 | |
1 | 1 | 1994-11-01 | 2000-10-31 | 1 | |
1 | 2 | 1992-01-01 | 1994-10-31 | 2 | |
2 | 3 | 1994-11-01 | 3 | ||
2 | 3 | 2000-12-31 | 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
frame_id | person_id | start_date | end_date | |
---|---|---|---|---|
1 | 1 | 1992-01-01 | 1994-10-31 | |
1 | 1 | 1994-11-01 | 2000-10-31 | |
1 | 2 | 1992-01-01 | 1994-10-31 | |
2 | 3 | 1994-11-01 | ||
2 | 3 | 2000-12-31 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import merge_spells as module | |
import unittest | |
class TestDateMold(unittest.TestCase): | |
def test_string_equals_date(self): | |
pass | |
class TestTolerance(unittest.TestCase): | |
def test_same_is_within(self): | |
date = '2009-01-03' | |
self.failUnless(module.dates_within_tolerance(date, date)) | |
def test_tolerance_is_within(self): | |
date1 = '1992-06-21' | |
date2 = '1992-06-23' | |
self.failUnless(module.dates_within_tolerance(date1, date2)) | |
def test_100_years_is_outside(self): | |
date1 = '' | |
date2 = '1992-06-23' | |
self.failUnless(module.dates_within_tolerance(date1, date2)) | |
def test_null_is_within(self): | |
pass | |
def test_earlier_is_within(self): | |
date1 = '1992-06-21' | |
date2 = '1992-05-21' | |
self.failUnless(module.dates_within_tolerance(date1, date2)) | |
class TestPrimaryKeys(unittest.TestCase): | |
def test_single_same(self): | |
row1 = dict(id=1, data='a') | |
row2 = dict(id=1, data='b') | |
self.failUnless(module.same_primary_key(row1, row2, primary_keys=['id'])) | |
def test_single_different(self): | |
row1 = dict(id=1, data='a') | |
row2 = dict(id=2, data='b') | |
self.failIf(module.same_primary_key(row1, row2, primary_keys=['id'])) | |
def test_multiple_same(self): | |
row1 = dict(state=1, county='a', data='a') | |
row2 = dict(state=1, county='a', data='b') | |
self.failUnless(module.same_primary_key(row1, row2, primary_keys=['state', 'county'])) | |
def test_multiple_different(self): | |
row1 = dict(state=1, county='a', data='a') | |
row2 = dict(state=1, county='b', data='b') | |
self.failIf(module.same_primary_key(row1, row2, primary_keys=['state', 'county'])) | |
class TestTagSpells(unittest.TestCase): | |
def tag_spells(self, rows): | |
return list(module.tag_spells(rows, primary_keys=['id'])) | |
def test_single_row_unchanged(self): | |
row = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
reader = iter([row]) | |
output = self.tag_spells(reader) | |
row['spell_id'] = 1 | |
self.assertEqual(row, output[0]) | |
def test_two_rows_merged(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31') | |
reader = iter([row1, row2]) | |
output = self.tag_spells(reader) | |
desired1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31', spell_id=1) | |
desired2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31', spell_id=1) | |
self.assertListEqual([desired1, desired2], output) | |
def test_three_rows_one_spell(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31') | |
row3 = dict(id=1, start_date='1993-12-31', end_date='1994-12-31') | |
reader = iter([row1, row2, row3]) | |
output = self.tag_spells(reader) | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31', spell_id=1) | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31', spell_id=1) | |
row3 = dict(id=1, start_date='1993-12-31', end_date='1994-12-31', spell_id=1) | |
self.assertListEqual([row1, row2, row3], output) | |
def test_three_rows_two_spells(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31') | |
row3 = dict(id=1, start_date='1994-12-31', end_date='1995-12-31') | |
reader = iter([row1, row2, row3]) | |
output = self.tag_spells(reader) | |
desired = [dict(id=1, start_date='1992-01-01', end_date='1992-12-31', spell_id=1), | |
dict(id=1, start_date='1993-01-03', end_date='1993-12-31', spell_id=1), | |
dict(id=1, start_date='1994-12-31', end_date='1995-12-31', spell_id=2) | |
] | |
self.assertListEqual(desired, output) | |
def test_open_spell_closed(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='') | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31') | |
reader = iter([row1, row2]) | |
output = self.tag_spells(reader) | |
row1 = dict(id=1, start_date='1992-01-01', end_date='', spell_id=1) | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31', spell_id=1) | |
self.assertListEqual([row1, row2], output) | |
def test_gap_causes_new_spell(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
row2 = dict(id=1, start_date='1994-01-03', end_date='1994-12-31') | |
reader = iter([row1, row2]) | |
output = self.tag_spells(reader) | |
desired = [row1, row2] | |
self.assertListEqual(desired, output) | |
def test_new_entity_new_spell(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
row2 = dict(id=2, start_date='1992-12-31', end_date='1994-12-31') | |
reader = iter([row1, row2]) | |
output = self.tag_spells(reader) | |
desired = [row1, row2] | |
self.assertListEqual(desired, output) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment