Skip to content

Instantly share code, notes, and snippets.

@bszilagyi
Forked from perediszilard/merge_spells.py
Last active December 4, 2015 14:56
Show Gist options
  • Save bszilagyi/139c8abfc754fd40ed76 to your computer and use it in GitHub Desktop.
Save bszilagyi/139c8abfc754fd40ed76 to your computer and use it in GitHub Desktop.
from datetime import datetime
import csv
import sys
START_DATE = 'start_date'
END_DATE = 'end_date'
SPELL_ID = 'spell_id'
IMPUTED_END_DATE = 'imputed_end_date'
TOLERANCE = 31 # days
PRIMARY_KEYS = ['frame_id', 'person_id']
def mold_date_end(any_date):
if isinstance(any_date, basestring):
try:
date = datetime.strptime(any_date, '%Y-%m-%d')
except:
date = datetime.strptime('9999-12-31', '%Y-%m-%d')
return date
else:
return any_date
def mold_date_start(any_date):
if isinstance(any_date, basestring):
try:
date = datetime.strptime(any_date, '%Y-%m-%d')
except:
date = datetime.strptime('1111-12-31', '%Y-%m-%d')
return date
else:
return any_date
def dates_within_tolerance(date1, date2, tolerance=TOLERANCE):
difference = mold_date_start(date2)-mold_date_end(date1)
return difference.days<=tolerance
def same_primary_key(row1, row2, primary_keys=PRIMARY_KEYS):
for key in primary_keys:
if key not in row1 or key not in row2 or not (row1[key]==row2[key]):
return False
return True
def same_spell(row1, row2, tolerance=TOLERANCE, primary_keys=PRIMARY_KEYS):
if not same_primary_key(row1, row2, primary_keys=primary_keys):
return False
if not dates_within_tolerance(row1[END_DATE], row2[START_DATE], tolerance=tolerance):
return False
return True
def tag_spells(input_rows, tolerance=TOLERANCE, primary_keys=PRIMARY_KEYS):
previous_row = None
spell_id = 1
for current_row in input_rows:
if previous_row:
if not same_spell(previous_row, current_row, tolerance=tolerance, primary_keys=primary_keys):
spell_id = spell_id + 1
else:
previous_row[IMPUTED_END_DATE] = current_row[START_DATE]
yield previous_row
previous_row = current_row.copy()
previous_row[IMPUTED_END_DATE] = ''
previous_row[SPELL_ID] = spell_id
yield previous_row
def merge_spells(input_rows, tolerance=TOLERANCE, primary_keys=PRIMARY_KEYS):
previous_row = None
for current_row in input_rows:
if previous_row:
if same_spell(previous_row, current_row, tolerance=tolerance, primary_keys=primary_keys):
previous_row[END_DATE] = current_row[END_DATE]
else:
yield previous_row
previous_row = current_row
else:
previous_row = current_row
# flush last row
yield previous_row
def main():
try:
mode = sys.argv[1]
except IndexError:
mode =''
reader = csv.DictReader(sys.stdin)
if mode in ['m', 'merge']:
new_fields = []
action = merge_spells
PRIMARY_KEYS += ['pos5']
else:
new_fields = [SPELL_ID, IMPUTED_END_DATE]
action = tag_spells
writer = csv.DictWriter(sys.stdout, fieldnames=list(reader.fieldnames) + new_fields)
writer.writeheader()
writer.writerows(action(reader))
if __name__ == '__main__':
main()
frame_id person_id start_date end_date spell_id
1 1 1992-01-01 1994-10-31 1
1 1 1994-11-01 2000-10-31 1
1 2 1992-01-01 1994-10-31 2
2 3 1994-11-01 3
2 3 2000-12-31 3
frame_id person_id start_date end_date
1 1 1992-01-01 1994-10-31
1 1 1994-11-01 2000-10-31
1 2 1992-01-01 1994-10-31
2 3 1994-11-01
2 3 2000-12-31
import merge_spells as module
import unittest
class TestDateMold(unittest.TestCase):
def test_string_equals_date(self):
pass
class TestTolerance(unittest.TestCase):
def test_same_is_within(self):
date = '2009-01-03'
self.failUnless(module.dates_within_tolerance(date, date))
def test_tolerance_is_within(self):
date1 = '1992-06-21'
date2 = '1992-06-23'
self.failUnless(module.dates_within_tolerance(date1, date2))
def test_100_years_is_outside(self):
date1 = '1892-06-23'
date2 = '1992-06-23'
self.failIf(module.dates_within_tolerance(date1, date2))
def test_missing_start_is_inside(self):
date1 = '1992-06-23'
date2 = ''
self.failUnless(module.dates_within_tolerance(date1, date2))
def test_missing_end_is_inside(self):
date1 = ''
date2 = '1992-06-23'
self.failUnless(module.dates_within_tolerance(date1, date2))
def test_null_is_within(self):
pass
def test_earlier_is_within(self):
date1 = '1992-06-21'
date2 = '1992-05-21'
self.failUnless(module.dates_within_tolerance(date1, date2))
class TestPrimaryKeys(unittest.TestCase):
def test_single_same(self):
row1 = dict(id=1, data='a')
row2 = dict(id=1, data='b')
self.failUnless(module.same_primary_key(row1, row2, primary_keys=['id']))
def test_single_different(self):
row1 = dict(id=1, data='a')
row2 = dict(id=2, data='b')
self.failIf(module.same_primary_key(row1, row2, primary_keys=['id']))
def test_multiple_same(self):
row1 = dict(state=1, county='a', data='a')
row2 = dict(state=1, county='a', data='b')
self.failUnless(module.same_primary_key(row1, row2, primary_keys=['state', 'county']))
def test_multiple_different(self):
row1 = dict(state=1, county='a', data='a')
row2 = dict(state=1, county='b', data='b')
self.failIf(module.same_primary_key(row1, row2, primary_keys=['state', 'county']))
class TestImpuation(unittest.TestCase):
def tag_spells(self, rows):
return list(module.tag_spells(rows, primary_keys=['id']))
def test_output_has_new_field(self):
row = dict(id=1, start_date='1992-01-01', end_date='1992-12-31')
reader = iter([row])
output = self.tag_spells(reader)
self.failUnless(module.IMPUTED_END_DATE in output[0])
def test_interpolated_end_is_beginning(self):
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31')
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31')
reader = iter([row1, row2])
output = self.tag_spells(reader)
self.assertEqual(output[0][module.IMPUTED_END_DATE], row2[module.START_DATE])
def test_interpolated_end_of_empty(self):
row1 = dict(id=1, start_date='1992-01-01', end_date='')
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31')
reader = iter([row1, row2])
output = self.tag_spells(reader)
self.assertEqual(output[0][module.IMPUTED_END_DATE], row2[module.START_DATE])
class TestTagSpells(unittest.TestCase):
def tag_spells(self, rows):
output = []
for row in module.tag_spells(rows, primary_keys=['id']):
# drop new field to keep using old tests
del row[module.IMPUTED_END_DATE]
output.append(row)
return output
def test_single_row_unchanged(self):
row = dict(id=1, start_date='1992-01-01', end_date='1992-12-31')
reader = iter([row])
output = self.tag_spells(reader)
row['spell_id'] = 1
self.assertEqual(row, output[0])
def test_two_rows_merged(self):
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31')
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31')
reader = iter([row1, row2])
output = self.tag_spells(reader)
desired1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31', spell_id=1)
desired2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31', spell_id=1)
self.assertListEqual([desired1, desired2], output)
def test_three_rows_one_spell(self):
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31')
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31')
row3 = dict(id=1, start_date='1993-12-31', end_date='1994-12-31')
reader = iter([row1, row2, row3])
output = self.tag_spells(reader)
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31', spell_id=1)
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31', spell_id=1)
row3 = dict(id=1, start_date='1993-12-31', end_date='1994-12-31', spell_id=1)
self.assertListEqual([row1, row2, row3], output)
def test_three_rows_two_spells(self):
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31')
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31')
row3 = dict(id=1, start_date='1994-12-31', end_date='1995-12-31')
reader = iter([row1, row2, row3])
output = self.tag_spells(reader)
desired = [dict(id=1, start_date='1992-01-01', end_date='1992-12-31', spell_id=1),
dict(id=1, start_date='1993-01-03', end_date='1993-12-31', spell_id=1),
dict(id=1, start_date='1994-12-31', end_date='1995-12-31', spell_id=2)
]
self.assertListEqual(desired, output)
def test_open_spell_closed(self):
row1 = dict(id=1, start_date='1992-01-01', end_date='')
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31')
reader = iter([row1, row2])
output = self.tag_spells(reader)
row1 = dict(id=1, start_date='1992-01-01', end_date='', spell_id=1)
row2 = dict(id=1, start_date='1993-01-03', end_date='1993-12-31', spell_id=1)
self.assertListEqual([row1, row2], output)
def test_gap_causes_new_spell(self):
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31')
row2 = dict(id=1, start_date='1994-01-03', end_date='1994-12-31')
reader = iter([row1, row2])
output = self.tag_spells(reader)
row1['spell_id'] = 1
row2['spell_id'] = 2
desired = [row1, row2]
self.assertListEqual(desired, output)
def test_new_entity_new_spell(self):
row1 = dict(id=1, start_date='1992-01-01', end_date='1992-12-31')
row2 = dict(id=2, start_date='1992-12-31', end_date='1994-12-31')
reader = iter([row1, row2])
output = self.tag_spells(reader)
row1['spell_id'] = 1
row2['spell_id'] = 2
desired = [row1, row2]
self.assertListEqual(desired, output)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment