Last active
April 14, 2022 03:37
-
-
Save halflearned/59d2f03b4ad5cee1237137914c14ec82 to your computer and use it in GitHub Desktop.
Testing reconciliation with real sessions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
import pickle | |
from itertools import chain | |
from datetime import datetime, timezone | |
from sleep_x_ml_reconciliation_service import SleepXMLReconciliation | |
from sleep_x_ml_reconciliation_service.constants.reconciliation import ( | |
UNINTENDED_USER_FLAG, | |
LARGE_SLEEP_GAP_FLAG, | |
PUT_GAP_TOLERANCE_SECONDS | |
) | |
from sleep_x_ml_reconciliation_service.exceptions import ( | |
IncorrectArchitectureException, | |
IncorrectPredictionTimeException, | |
) | |
from sleep_x_ml_reconciliation_service.validation import ( | |
validate_sleep_response, | |
validate_put_response, | |
validate_sleep_put_overlap | |
) | |
def find_chunk(sequence, point): | |
""" | |
Auxiliary function that retrieves the chunk in sequence satistifying | |
chunk['start'] <= point <= chunk['end'] | |
""" | |
found = [chunk for chunk in sequence if chunk["start"] <= point <= chunk["end"]] | |
if found: | |
# Sometimes there are multiple chunks that overlap. | |
# In this case, take one with the latest start. | |
return max(found, key = lambda x: x['start']) | |
else: | |
return [] | |
def from_timestamp(stamp): | |
"""Reverts timestamp back into integer, i.e. '2001-09-09T01:46:40.000Z' into 1000000000000""" | |
return int( | |
1000 | |
* datetime.strptime(stamp, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc).timestamp() | |
) | |
def test_holistic(mocker): | |
""" | |
Heuristic real data from sessions stored in the alpha account. | |
This simple test does not test for correctness, only checks the the code runs. | |
""" | |
# Instantiating with whatever default values of currently have. | |
sleepxml = SleepXMLReconciliation() | |
with open("/Users/hadadv/tmp/device-alpha", "rb") as f: | |
example_data = pickle.load(f) | |
for i, session_id in enumerate(example_data['common_session_ids'][155:]): | |
print(i, session_id) | |
put_response = example_data["put"][session_id] | |
sleep_response = example_data["sleep"][session_id] | |
mocker.patch( | |
"sleep_x_ml_reconciliation_service.tables.primary_user_tracking" | |
".PrimaryUserTrackingDDB.get_session_data", | |
return_value=put_response, | |
) | |
mocker.patch( | |
"sleep_x_ml_reconciliation_service.tables.sleep_stage" | |
".SleepStageDDB.get_session_data", | |
return_value=sleep_response, | |
) | |
# Get predictions from session. | |
sleep_predictions = list(chain(*[x["payload"] for x in sleep_response])) | |
put_predictions = put_response[0]["state_details_list"] | |
try: | |
reconciled = sleepxml.get_reconciled_data(session_id) | |
# If any exception needs to be raised, | |
# ensure they are raised appropriately. | |
except IncorrectArchitectureException: | |
with pytest.raises(IncorrectArchitectureException): | |
validate_sleep_response(sleep_response) | |
validate_put_response(put_response) | |
continue | |
except IncorrectPredictionTimeException: | |
with pytest.raises(IncorrectPredictionTimeExpression): | |
validate_sleep_put_overlap(sleep_predictions, put_predictions) | |
continue | |
# Any other exception would raise an error here. | |
# The rest of the code heuristically checks for correctness. | |
# If the entire session is marked as UU, check that reconciled is empty. | |
if all(chunk["state"] == "UU" for chunk in put_predictions): | |
assert sleep_predictions == [] | |
continue | |
# Otherwise, get start and end of reconciled session. | |
recon_start = from_timestamp(reconciled[0]["start"]) | |
recon_end = from_timestamp(reconciled[-1]["end"]) | |
# Check that reconciled data only exists for the intervals where | |
# sleep and PUT predictions are both available. | |
start_available = max(sleep_predictions[0]["start"], put_predictions[0]["start"]) | |
end_available = min(sleep_predictions[-1]["end"], put_predictions[-1]["end"]) | |
assert recon_start >= start_available | |
assert recon_end <= end_available | |
# Check that reconciled data does not start on an interval marked as UU. | |
put_at_recon_start = find_chunk(put_predictions, recon_start)["state"] | |
assert put_at_recon_start == "IU" | |
# Check that reconciled data does on end on an interval marked as UU. | |
put_at_recon_end = find_chunk(put_predictions, recon_end)["state"] | |
if put_at_recon_end != "IU": | |
pass | |
# try: | |
# assert put_at_recon_end == "IU" | |
# except: | |
# print("hey") | |
# import pdb; pdb.set_trace() | |
# Check correctness of reconciled sleep predictions. | |
for recon in reconciled: | |
# Pick a point in that reconciled chunk. | |
recon_midpoint = (from_timestamp(recon["end"]) + from_timestamp(recon["start"])) / 2 | |
# Find the corresponding PUT chunk at that time. | |
put = find_chunk(put_predictions, recon_midpoint) | |
# Find the corresponding sleep at chunk that time. | |
sleep = find_chunk(sleep_predictions, recon_midpoint) | |
# If PUT state was UU, recon should be UNINTENDED_USER_FLAG. | |
if put and put["state"] == "UU": | |
try: | |
assert recon["stage"] == UNINTENDED_USER_FLAG | |
except: | |
import pdb; pdb.set_trace() | |
# If PUT state is IU and there was sleep data, recon should be equal to sleep. | |
if put and sleep and put["state"] == "IU": | |
assert recon["stage"] == sleep["stage"] | |
# If PUT state is IU and there was no sleep data, recon depends on gap. | |
if put and not sleep: | |
# If gap is large, recon should be marked as such. | |
if recon["stage"] == LARGE_SLEEP_GAP_FLAG: | |
# Assert there was really such a gap | |
max_gap_secs = 0 | |
for prev, cur in zip(sleep_predictions, sleep_predictions[1:]): | |
max_gap_secs = max(max_gap_secs, (cur["start"] - prev["end"]) / 1000) | |
assert max_gap_secs > PUT_GAP_TOLERANCE_SECONDS | |
# If gap is small, recon stage should be equal to value of next chunk. | |
else: | |
# Find the value of the next chunk. | |
next_sleep = [ | |
chunk for chunk in sleep_predictions if chunk["start"] >= recon_midpoint | |
][0] | |
assert recon["stage"] == next_sleep["stage"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment