Last active
August 4, 2021 11:57
-
-
Save ArtemisDicoTiar/067e9095d2eb1db7213abfcada5b51dd to your computer and use it in GitHub Desktop.
CSSE_Arima.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
import multiprocessing | |
from datetime import timedelta | |
from functools import partial | |
from multiprocessing import Manager | |
import pandas as pd | |
from LostChapter.secrets import controller | |
from LostChapter.utils.ARIMAPredictor import get_arima_predictions | |
def get_target_regions(target_date: str): | |
if controller.is_table_exist(target_db='covid_info', target_table='COVID_Cases_prediction'): | |
return controller.get_df_from_sql( | |
target_query="select distinct CountryCode from {destination_db}.{table_name} " | |
"where " | |
"not exists(select distinct CountryCode " | |
"from {destination_db}.{table_name} " | |
"where predicted='{target_date}')" | |
.format(destination_db='covid_info', table_name='COVID_Cases_prediction', target_date=target_date) | |
)['CountryCode'].to_list() | |
return controller.get_df_from_sql( | |
target_query="select distinct CountryCode from {destination_db}.{table_name} " | |
.format(destination_db='covid_info', table_name='COVID_Cases') | |
)['CountryCode'].to_list() | |
def ARIMA_pred_process(mrg_list: list, target_col: str, | |
base_table: pd.DataFrame, prediction_dates, area): | |
print('current_area:', area) | |
cur_area_df = base_table.loc[base_table[target_col] == area][['date', 'confirmed', 'deaths']] \ | |
.copy() | |
prediction_df = pd.DataFrame( | |
data={ | |
'confirmed_prediction': | |
get_arima_predictions(cur_area_df, | |
target_case='confirmed', | |
pred_periods=prediction_dates) | |
} | |
) | |
prediction_df['deaths_prediction'] = get_arima_predictions(cur_area_df, | |
target_case='deaths', | |
pred_periods=prediction_dates) | |
last_update = cur_area_df['date'].to_list()[-1] | |
prediction_df['date'] = [last_update + timedelta(days=i) for i in range(1, 1 + prediction_dates)] | |
prediction_df['CountryCode'] = area | |
prediction_df['ContinentName'] = base_table.loc[base_table[target_col] == area]['ContinentName'].unique()[0] | |
mrg_list.append(prediction_df) | |
def get_ARIMA_prediction(process_date: str) -> [pd.DataFrame, int]: | |
if controller.is_table_exist(target_db='covid_info', target_table='COVID_Cases_prediction') and \ | |
len(controller.get_df_from_sql( | |
target_query="select distinct CountryCode from {destination_db}.{table_name} " | |
"where " | |
"not exists(select distinct CountryCode " | |
"from {destination_db}.{table_name} " | |
"where predicted='{target_date}')" | |
.format(destination_db='covid_info', table_name='COVID_Cases_prediction', | |
target_date=process_date) | |
)['CountryCode'].to_list()) == 0: | |
print('** ARIMA Prediction already done **') | |
return 0 | |
print('** ARIMA PREDICTION DATA LOAD **') | |
base_df = controller.get_df_from_sql( | |
target_query="select * from {destination_db}.{table_name} where date <= '{date}' and SubdivisionCode is null" | |
.format(destination_db='covid_info', table_name='COVID_Cases', date=process_date) | |
).set_index(keys='index') | |
print('** ARIMA PREDICTION PROCESSING **') | |
country_list = get_target_regions(target_date=process_date) | |
pred_df = pd.DataFrame() | |
prediction_dates = 7 | |
cores = 8 | |
manager = Manager() | |
mgr_list = manager.list() | |
pool = multiprocessing.Pool(processes=cores) | |
res = pool.map_async( | |
partial(ARIMA_pred_process, mgr_list, 'CountryCode', base_df, prediction_dates), country_list) | |
res.wait() | |
if mgr_list: | |
pred_df = pd.concat(mgr_list, ignore_index=False) | |
pred_df.sort_values(by=['CountryCode', 'date'], inplace=True) | |
pred_df['predicted'] = process_date | |
return pred_df[['predicted', | |
'ContinentName', | |
'CountryCode', | |
'date', | |
'confirmed_prediction', | |
'deaths_prediction']] | |
else: | |
raise ValueError("prediction incomplete") | |
def COVID_case_info_processor(process_date: str): | |
case_df = get_ARIMA_prediction(process_date) | |
if type(case_df) == pd.DataFrame and not case_df.empty: | |
controller.save_df_to_sql(case_df, | |
target_table_name='COVID_Cases_prediction', | |
if_exists='append', | |
index=False) | |
if __name__ == '__main__': | |
COVID_case_info_processor('2021-05-11') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment