Last active
May 18, 2021 20:09
-
-
Save rjrajivjha/91f2a8e256f56323dae16087a410e121 to your computer and use it in GitHub Desktop.
Predict Next 10 Trains Departing from Park Street Station of MBTA , Using MBTA API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytz | |
import requests | |
import unittest | |
from collections import defaultdict | |
from datetime import datetime | |
from typing import Optional, Union | |
# from unittest.mock import patch # Mock current time function and API result to test complete functionality | |
NUMBER_OF_RESULTS = 10 | |
STATION_NAME = 'place-pktrm' | |
BASE_URL = 'https://api-v3.mbta.com/' | |
predict_by_schedule_api = BASE_URL + f'predictions?page[limit]=20&filter[stop]={STATION_NAME}&sort=departure_time' | |
stop_detail_api = BASE_URL + 'stops/' | |
def call_api(api: str) -> dict: | |
try: | |
response = requests.get(api) | |
if response.status_code != 200: | |
print('Seems train is magically Gone. Next One is coming') | |
raise Exception | |
data = response.json() | |
except Exception as msg: | |
raise Exception(msg) | |
return data | |
def get_minute_from_timedelta(difference): | |
seconds = difference.seconds | |
return (seconds // 60) % 60 | |
def get_time_utc(given_time: Union[Optional[datetime], str]) -> Optional[datetime]: | |
if type(given_time) == datetime: | |
given_time = given_time.strftime("%Y-%m-%dT%H:%M:%S%z") | |
try: | |
naive_datetime = datetime.strptime(given_time, "%Y-%m-%dT%H:%M:%S%z") | |
time_utc = naive_datetime.astimezone(pytz.utc) | |
return time_utc | |
except Exception as msg: | |
print(f'Date time is not in expected format : {msg}') | |
return None | |
def get_current_time_utc(): | |
local_time = pytz.timezone("America/New_York") | |
now = datetime.now().astimezone(local_time).strftime("%Y-%m-%dT%H:%M:%S%z") | |
return get_time_utc(now) | |
def get_time_string(given_time): | |
train_time_utc = get_time_utc(given_time) | |
if train_time_utc: | |
current_time_utc = get_current_time_utc() | |
diff = train_time_utc - current_time_utc | |
minutes = get_minute_from_timedelta(diff) | |
if train_time_utc < current_time_utc: | |
string = f'Departed' | |
else: | |
string = f'Departing in {minutes} minutes' | |
return string | |
else: | |
return None | |
def check_not_departed(train_predicted_time): | |
if train_predicted_time: | |
time_utc = get_time_utc(train_predicted_time) | |
if time_utc: | |
current_time = get_current_time_utc() | |
return True if time_utc > get_time_utc(current_time) else False | |
return False | |
class TestDepartureTimeByStation(unittest.TestCase): | |
def test_prediction_departure_api(self): | |
""" Test format of response of predicting departure time API | |
Test if all required information are present in response or not. | |
Number of API hits : 1 per test """ | |
response = call_api(predict_by_schedule_api) | |
self.assertIn('data', response) | |
for res in response['data']: | |
self.assertIn('departure_time', res.get('attributes')) | |
self.assertIn('stop', res.get('relationships')) | |
self.assertIn('data', res.get('relationships').get('stop')) | |
self.assertIn('id', res.get('relationships').get('stop').get('data')) | |
def test_stop_data_api(self): | |
""" Test response format of Stop Detail API, given stop id. | |
Number of API hits : 1 per test """ | |
id = 70070 | |
response = call_api(f'{stop_detail_api}/{id}') | |
self.assertIn('data', response) | |
self.assertIn('platform_name', response['data'].get('attributes')) | |
class DepartureTimeByStation: | |
def __init__(self): | |
self.predict_schedule_dict = defaultdict(list) | |
self.stop_detail_dict = dict() | |
def print_schedule(self) -> None: | |
for key, value in self.predict_schedule_dict.items(): | |
print(f'\n------{key}------') | |
for val in value: | |
print(f'{val}') | |
def stop_detail_by_id(self, stop_id: int) -> str: | |
if stop_id in self.stop_detail_dict: | |
return self.stop_detail_dict[stop_id] | |
else: | |
stop_data = call_api(f'{stop_detail_api}/{stop_id}') | |
stop_name = stop_data['data'].get('attributes').get('platform_name') | |
self.stop_detail_dict[stop_id] = stop_name | |
return stop_name | |
def predict(self): | |
count = 0 | |
while count < NUMBER_OF_RESULTS: | |
prediction_data = call_api(predict_by_schedule_api) | |
for prediction in prediction_data['data']: | |
if check_not_departed(prediction.get('attributes')['departure_time']) and count < NUMBER_OF_RESULTS: | |
count += 1 | |
stop_id = prediction.get('relationships').get('stop').get('data').get('id') | |
stop_name = self.stop_detail_by_id(stop_id) | |
time_string = get_time_string(prediction.get('attributes')['departure_time']) | |
self.predict_schedule_dict[ | |
prediction.get('relationships').get('route').get('data').get('id')].append( | |
{stop_name: time_string}) | |
elif count == NUMBER_OF_RESULTS: | |
break | |
if __name__ == '__main__': | |
test = input('Do You Want to Run Test, True or False : ') | |
if test.lower() == 'false': | |
predictor = DepartureTimeByStation() | |
predictor.predict() | |
predictor.print_schedule() | |
else: | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment