Last active
July 3, 2019 18:44
-
-
Save jaketf/97a98069bde125549176e15f77b0c995 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Module for sessionizing on sampling rate. | |
""" | |
import argparse | |
import csv | |
import os | |
MILLIS_IN_A_SECOND = 1000 | |
def load_csv_as_list_of_dict(filename): | |
""" | |
Convenience function for loading csv data. | |
Args: | |
filename (str): Path to csv file to load. | |
Returns: list[dict] | |
""" | |
with open(filename) as input_file: | |
reader = csv.DictReader(input_file) | |
return [row for row in reader] | |
def sessionize_on_sampling_rate(data, sampling_rate_hz=20, | |
timestamp_col_name='timestamp', | |
min_samples_per_session=2): | |
""" | |
This function will extract sessions of data that were | |
collected at the expected sampling rate. | |
Args: | |
data (list[dict]): Should have 'timestamp' key | |
with a value of ms since Unix Epoch. | |
sampling_rate_hz (int): Sampling rate in Hz. | |
timestamp_col_name (str): Columnn name containing timestamp | |
as milliseconds since Unix Epoch. | |
min_samples_per_session (int): the minimum number of samples | |
for a valid session. | |
Returns: | |
list[list[dict]]: A list of sessions made up of the adjacent data points | |
with the expected sampling rate. | |
""" | |
expected_ms_between_samples = MILLIS_IN_A_SECOND / sampling_rate_hz | |
timestamps = [row[timestamp_col_name] for row in data] | |
session_start_ends = [] | |
start = 0 | |
end = 0 | |
while end < len(timestamps): | |
if timestamps[end] - timestamps[end - 1] == expected_ms_between_samples: | |
end += 1 | |
else: | |
if end - start - 1 > min_samples_per_session: | |
# End is currently pointing to a bad sample. | |
session_start_ends.append((start, end - 1)) | |
start = end | |
return [data[start:end] for (start, end) in session_start_ends] | |
def main(): | |
""" | |
Main method for CLI invocation | |
""" | |
parser = argparse.ArgumentParser('Sessionize on sampling rate') | |
parser.add_argument('--input_csv_filename', dest='input_csv_filename', | |
required=True) | |
parser.add_argument('--timestamp_field', dest='timestamp_field', | |
required=True) | |
parser.add_argument('--output_dir', dest='output_dir', required=False, | |
default='sessions') | |
args = parser.parse_args() | |
if not os.path.exists(args.output_dir): | |
os.makedirs(args.output_dir) | |
data = load_csv_as_list_of_dict(args.input_csv_filename) | |
sessions = sessionize_on_sampling_rate(data, timestamp_col_name=args.timestamp_field) | |
for i, session in enumerate(sessions): | |
output_file_path = os.path.join(os.curdir, args.ouput_dir, | |
'session-{i}-of-{total}.csv'.format( | |
i=i, total=len(sessions))) | |
with open(output_file_path, 'w') as output_file: | |
writer = csv.DictWriter(output_file, fieldnames=session[0].keys()) | |
for row in session: | |
writer.writerow(row) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment