-
-
Save bszilagyi/139c8abfc754fd40ed76 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
import csv | |
import sys | |
START_DATE = 'start_date' | |
END_DATE = 'end_date' | |
SPELL_ID = 'spell_id' | |
IMPUTED_END_DATE = 'imputed_end_date' | |
TOLERANCE = 31 # days | |
PRIMARY_KEYS = ['frame_id', 'person_id'] | |
def mold_date_end(any_date): | |
if isinstance(any_date, basestring): | |
try: | |
date = datetime.strptime(any_date, '%Y-%m-%d') | |
except: | |
date = datetime.strptime('9999-12-31', '%Y-%m-%d') | |
return date | |
else: | |
return any_date | |
def mold_date_start(any_date): | |
if isinstance(any_date, basestring): | |
try: | |
date = datetime.strptime(any_date, '%Y-%m-%d') | |
except: | |
date = datetime.strptime('1111-12-31', '%Y-%m-%d') | |
return date | |
else: | |
return any_date | |
def dates_within_tolerance(date1, date2, tolerance=TOLERANCE): | |
difference = mold_date_start(date2)-mold_date_end(date1) | |
return difference.days<=tolerance | |
def same_primary_key(row1, row2, primary_keys=PRIMARY_KEYS): | |
for key in primary_keys: | |
if key not in row1 or key not in row2 or not (row1[key]==row2[key]): | |
return False | |
return True | |
def same_spell(row1, row2, tolerance=TOLERANCE, primary_keys=PRIMARY_KEYS): | |
if not same_primary_key(row1, row2, primary_keys=primary_keys): | |
return False | |
if not dates_within_tolerance(row1[END_DATE], row2[START_DATE], tolerance=tolerance): | |
return False | |
return True | |
def tag_spells(input_rows, tolerance=TOLERANCE, primary_keys=PRIMARY_KEYS): | |
previous_row = None | |
spell_id = 1 | |
for current_row in input_rows: | |
if previous_row: | |
if not same_spell(previous_row, current_row, tolerance=tolerance, primary_keys=primary_keys): | |
spell_id = spell_id + 1 | |
else: | |
previous_row[IMPUTED_END_DATE] = current_row[START_DATE] | |
yield previous_row | |
previous_row = current_row.copy() | |
previous_row[IMPUTED_END_DATE] = '' | |
previous_row[SPELL_ID] = spell_id | |
yield previous_row | |
def merge_spells(input_rows, tolerance=TOLERANCE, primary_keys=PRIMARY_KEYS): | |
previous_row = None | |
for current_row in input_rows: | |
if previous_row: | |
if same_spell(previous_row, current_row, tolerance=tolerance, primary_keys=primary_keys): | |
previous_row[END_DATE] = current_row[END_DATE] | |
else: | |
yield previous_row | |
previous_row = current_row | |
else: | |
previous_row = current_row | |
# flush last row | |
yield previous_row | |
def main(): | |
try: | |
mode = sys.argv[1] | |
except IndexError: | |
mode ='' | |
reader = csv.DictReader(sys.stdin) | |
if mode in ['m', 'merge']: | |
new_fields = [] | |
action = merge_spells | |
PRIMARY_KEYS += ['pos5'] | |
else: | |
new_fields = [SPELL_ID, IMPUTED_END_DATE] | |
action = tag_spells | |
writer = csv.DictWriter(sys.stdout, fieldnames=list(reader.fieldnames) + new_fields) | |
writer.writeheader() | |
writer.writerows(action(reader)) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
frame_id | person_id | start_date | end_date | spell_id | |
---|---|---|---|---|---|
1 | 1 | 1992-01-01 | 1994-10-31 | 1 | |
1 | 1 | 1994-11-01 | 2000-10-31 | 1 | |
1 | 2 | 1992-01-01 | 1994-10-31 | 2 | |
2 | 3 | 1994-11-01 | 3 | ||
2 | 3 | 2000-12-31 | 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
frame_id | person_id | start_date | end_date | |
---|---|---|---|---|
1 | 1 | 1992-01-01 | 1994-10-31 | |
1 | 1 | 1994-11-01 | 2000-10-31 | |
1 | 2 | 1992-01-01 | 1994-10-31 | |
2 | 3 | 1994-11-01 | ||
2 | 3 | 2000-12-31 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import merge_spells as module | |
import unittest | |
class TestDateMold(unittest.TestCase): | |
def test_string_equals_date(self): | |
pass | |
class TestTolerance(unittest.TestCase): | |
def test_same_is_within(self): | |
date = '2009-01-03' | |
self.failUnless(module.dates_within_tolerance(date, date)) | |
def test_tolerance_is_within(self): | |
date1 = '1992-06-21' | |
date2 = '1992-06-23' | |
self.failUnless(module.dates_within_tolerance(date1, date2)) | |
def test_100_years_is_outside(self): | |
date1 = '1892-06-23' | |
date2 = '1992-06-23' | |
self.failIf(module.dates_within_tolerance(date1, date2)) | |
def test_missing_start_is_inside(self): | |
date1 = '1992-06-23' | |
date2 = '' | |
self.failUnless(module.dates_within_tolerance(date1, date2)) | |
def test_missing_end_is_inside(self): | |
date1 = '' | |
date2 = '1992-06-23' | |
self.failUnless(module.dates_within_tolerance(date1, date2)) | |
def test_null_is_within(self): | |
pass | |
def test_earlier_is_within(self): | |
date1 = '1992-06-21' | |
date2 = '1992-05-21' | |
self.failUnless(module.dates_within_tolerance(date1, date2)) | |
class TestPrimaryKeys(unittest.TestCase): | |
def test_single_same(self): | |
row1 = dict(id=1, data='a') | |
row2 = dict(id=1, data='b') | |
self.failUnless(module.same_primary_key(row1, row2, primary_keys=['id'])) | |
def test_single_different(self): | |
row1 = dict(id=1, data='a') | |
row2 = dict(id=2, data='b') | |
self.failIf(module.same_primary_key(row1, row2, primary_keys=['id'])) | |
def test_multiple_same(self): | |
row1 = dict(state=1, county='a', data='a') | |
row2 = dict(state=1, county='a', data='b') | |
self.failUnless(module.same_primary_key(row1, row2, primary_keys=['state', 'county'])) | |
def test_multiple_different(self): | |
row1 = dict(state=1, county='a', data='a') | |
row2 = dict(state=1, county='b', data='b') | |
self.failIf(module.same_primary_key(row1, row2, primary_keys=['state', 'county'])) | |
class TestImpuation(unittest.TestCase): | |
def tag_spells(self, rows): | |
return list(module.tag_spells(rows, primary_keys=['id'])) | |
def test_output_has_new_field(self): | |
row = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
reader = iter([row]) | |
output = self.tag_spells(reader) | |
self.failUnless(module.IMPUTED_END_DATE in output[0]) | |
def test_interpolated_end_is_beginning(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31') | |
reader = iter([row1, row2]) | |
output = self.tag_spells(reader) | |
self.assertEqual(output[0][module.IMPUTED_END_DATE], row2[module.START_DATE]) | |
def test_interpolated_end_of_empty(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='') | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31') | |
reader = iter([row1, row2]) | |
output = self.tag_spells(reader) | |
self.assertEqual(output[0][module.IMPUTED_END_DATE], row2[module.START_DATE]) | |
class TestTagSpells(unittest.TestCase): | |
def tag_spells(self, rows): | |
output = [] | |
for row in module.tag_spells(rows, primary_keys=['id']): | |
# drop new field to keep using old tests | |
del row[module.IMPUTED_END_DATE] | |
output.append(row) | |
return output | |
def test_single_row_unchanged(self): | |
row = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
reader = iter([row]) | |
output = self.tag_spells(reader) | |
row['spell_id'] = 1 | |
self.assertEqual(row, output[0]) | |
def test_two_rows_merged(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31') | |
reader = iter([row1, row2]) | |
output = self.tag_spells(reader) | |
desired1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31', spell_id=1) | |
desired2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31', spell_id=1) | |
self.assertListEqual([desired1, desired2], output) | |
def test_three_rows_one_spell(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31') | |
row3 = dict(id=1, start_date='1993-12-31', end_date='1994-12-31') | |
reader = iter([row1, row2, row3]) | |
output = self.tag_spells(reader) | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31', spell_id=1) | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31', spell_id=1) | |
row3 = dict(id=1, start_date='1993-12-31', end_date='1994-12-31', spell_id=1) | |
self.assertListEqual([row1, row2, row3], output) | |
def test_three_rows_two_spells(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31') | |
row3 = dict(id=1, start_date='1994-12-31', end_date='1995-12-31') | |
reader = iter([row1, row2, row3]) | |
output = self.tag_spells(reader) | |
desired = [dict(id=1, start_date='1992-01-01', end_date='1992-12-31', spell_id=1), | |
dict(id=1, start_date='1993-01-03', end_date='1993-12-31', spell_id=1), | |
dict(id=1, start_date='1994-12-31', end_date='1995-12-31', spell_id=2) | |
] | |
self.assertListEqual(desired, output) | |
def test_open_spell_closed(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='') | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31') | |
reader = iter([row1, row2]) | |
output = self.tag_spells(reader) | |
row1 = dict(id=1, start_date='1992-01-01', end_date='', spell_id=1) | |
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31', spell_id=1) | |
self.assertListEqual([row1, row2], output) | |
def test_gap_causes_new_spell(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
row2 = dict(id=1, start_date='1994-01-03', end_date='1994-12-31') | |
reader = iter([row1, row2]) | |
output = self.tag_spells(reader) | |
row1['spell_id'] = 1 | |
row2['spell_id'] = 2 | |
desired = [row1, row2] | |
self.assertListEqual(desired, output) | |
def test_new_entity_new_spell(self): | |
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31') | |
row2 = dict(id=2, start_date='1992-12-31', end_date='1994-12-31') | |
reader = iter([row1, row2]) | |
output = self.tag_spells(reader) | |
row1['spell_id'] = 1 | |
row2['spell_id'] = 2 | |
desired = [row1, row2] | |
self.assertListEqual(desired, output) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment