Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save matteo-peltarion/47928f1b85a9e7f07969d9fd7cf8d173 to your computer and use it in GitHub Desktop.
Save matteo-peltarion/47928f1b85a9e7f07969d9fd7cf8d173 to your computer and use it in GitHub Desktop.
COVID-19 analysis code 05 - plot predictions
# Imports for performing ML analysis
from sklearn.linear_model import LinearRegression
from scipy.optimize import curve_fit
from datetime import timedelta
# Set range of data to build model
# It might make sense to skip part of the initial points, when the exponential trend was still not evident
START_DATE = datetime(2020, 2, 23).date()
# Prepare value vectors
# The acual values
y = country_df[country_df['Date'] >= START_DATE]['Confirmed']
# The log values
y_log = np.log(y)
# Independent variable
x = np.arange(len(y))
# The two versions, weighted and unweighted
reg_unweighted = LinearRegression()
reg_unweighted.fit(x[:,np.newaxis], y_log)
reg_weighted = LinearRegression()
reg_weighted.fit(x[:,np.newaxis], y_log, sample_weight=y)
# Create a dataframe with predicted values
PREDICT_UNTIL = (datetime.today() + timedelta(days=1)).date().strftime("%m/%d/%Y")
# Prepare range of dates
estimate_dates = pd.date_range(start=START_DATE.strftime("%m/%d/%Y"), end=PREDICT_UNTIL)
# Make predictions
estimate_cases_ols_unweighted = np.exp(reg_unweighted.predict(np.arange(len(estimate_dates))[:, np.newaxis]))
estimate_cases_ols_weighted = np.exp(reg_weighted.predict(np.arange(len(estimate_dates))[:, np.newaxis]))
# df_estimates = pd.DataFrame({"Date": estimate_dates, "Predictions": estimate_cases})
df_estimates = pd.DataFrame({
"Date": estimate_dates,
"Predictions (unweighted)": estimate_cases_ols_unweighted,
"Predictions (weighted)": estimate_cases_ols_weighted})
# Plot values
ax = plt.gca()
country_df.plot(
x='Date',
y=["Confirmed"],
figsize=(20,10), ax=ax, marker='o')
# Uncomment "Predictions (unweighted)" line to show also the plot relative to
# (since its values are much higher it woudld compress the other curves)
df_estimates.plot(
x='Date',
y=[
# "Predictions (unweighted)",
"Predictions (weighted)"
],
figsize=(20,10), ax=ax, marker='o', alpha=0.4, color=['green', 'orange'])
ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
ax.set_ylabel("# of confirmed cases");
# # Zoom in
# ax.set_xlim(["2020-02-20", "2020-03-12"]);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment