Skip to content

Instantly share code, notes, and snippets.

@yukoba
Last active January 21, 2021 19:45
Show Gist options
  • Save yukoba/76cacfee5fb373d944f29a5693859e83 to your computer and use it in GitHub Desktop.
Save yukoba/76cacfee5fb373d944f29a5693859e83 to your computer and use it in GitHub Desktop.
新型コロナウイルスの感染者数の予測
"""
Copyright (C) 2021 by Yu Kobayashi
Permission to use, copy, modify, and/or distribute this software for any purpose
with or without fee is hereby granted.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND
FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER
TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF
THIS SOFTWARE.
"""
from datetime import date, timedelta
import matplotlib.pyplot as plt
import numpy as np
from dateutil.relativedelta import relativedelta
from matplotlib.ticker import ScalarFormatter
from scipy import stats
# noinspection PyProtectedMember
from scipy.stats.mstats_basic import LinregressResult
# 2020年10月19日~2021年1月21日の東京都の新規感染者数 https://stopcovid19.metro.tokyo.lg.jp/
initial_date = date(2020, 10, 19)
holidays = [
date(2020, 11, 3),
date(2020, 11, 23),
date(2021, 1, 11),
]
data = np.array([
78,
139,
145,
185,
186,
201,
124,
102,
158,
171,
220,
203,
215,
116,
87,
209,
122,
269,
242,
294,
189,
157,
293,
316,
392,
374,
352,
255,
180,
298,
485,
533,
522,
539,
391,
314,
186,
401,
481,
570,
561,
418,
311,
372,
500,
533,
449,
584,
327,
299,
352,
572,
602,
595,
621,
480,
305,
460,
660,
821,
664,
736,
556,
392,
563,
748,
888,
884,
949,
708,
481,
856,
944,
1337,
783,
814,
816,
884,
1278,
1591,
2447,
2392,
2268,
1494,
1219,
970,
1433,
1502,
2001,
1809,
1592,
1204,
1240,
1274,
1471,
])
x = np.arange(len(data))
y = np.log10(data)
exclude_indices = [int((holiday - initial_date) / timedelta(days=1)) + 1 for holiday in holidays]
def search_param() -> np.ndarray:
# 探索開始の初期値として回帰直線を使用します
result: LinregressResult = stats.linregress(x, y)
param_init = np.array([result.slope, result.intercept] + [0] * 6)
def evaluate(param: np.ndarray) -> float:
slope = param[0]
intercept = param[1]
weekly_difference = param[2:8]
weekly_difference = np.append(weekly_difference, [-weekly_difference.sum()]) # 総和は0
diff = y - (slope * x + intercept + weekly_difference[x % 7])
# 東京都は祝日の翌日に大きく減少します。祝日の翌日は評価値から除外します。
diff[exclude_indices] = 0
return np.abs(diff).sum() # 絶対誤差の和を最小化します
# 進化戦略 https://qiita.com/yukoba/items/ed40e0c4f4a27b73c6b8
iter_count = 30000
pop_size = 300
children_size = 30
param_len = len(param_init)
tau = 1.0 / np.sqrt(2.0 * param_len)
individuals = [(param_init, np.full([param_len], 0.1), evaluate(param_init))]
best = individuals[0]
best_ev = best[2]
for i in range(iter_count):
for _ in range(children_size):
# 交叉
if np.random.rand() < 0.8:
ind0 = individuals[np.random.randint(len(individuals))]
ind1 = individuals[np.random.randint(len(individuals))]
r = np.random.randint(0, 2, [param_len])
parent = (ind0[0] * r + ind1[0] * (1 - r), ind0[1] * r + ind1[1] * (1 - r))
else:
parent = individuals[np.random.randint(len(individuals))]
# 突然変異
strategies2 = parent[1] * np.exp(tau * np.random.randn(param_len))
params2 = parent[0] + strategies2 * np.random.randn(param_len)
individuals.append((params2, strategies2, evaluate(params2)))
# 上位を選択
individuals = sorted(individuals, key=lambda ind: ind[2])[:pop_size]
best = individuals[0]
# 最善が更新されたら出力します
if best[2] < best_ev:
best_ev = best[2]
print(i, best_ev, np.array_repr(best[0], max_line_width=1000))
return best[0]
def plot_result(param: np.ndarray):
slope = param[0]
intercept = param[1]
weekly_difference = param[2:8]
weekly_difference = np.append(weekly_difference, [-weekly_difference.sum()]) # 総和は0
# データから30日延長してグラフに表示
x2 = np.arange(len(data) + 30)
x_date = [initial_date + relativedelta(days=int(i)) for i in x]
x2_date = [initial_date + relativedelta(days=int(i)) for i in x2]
# 週単位の差分込みのグラフ
plt.plot(x_date, 10 ** y, label='real')
plt.plot(x2_date, 10 ** (slope * x2 + intercept + weekly_difference[x2 % 7]), label='model')
plt.yscale("log")
plt.gca().yaxis.set_major_formatter(ScalarFormatter())
plt.legend()
plt.grid()
plt.show()
# 週単位の差分を除去したグラフ
plt.plot(x_date, 10 ** (y - weekly_difference[x % 7]), label="real - weekly difference")
plt.plot(x2_date, 10 ** (slope * x2 + intercept), label="model - weekly difference")
plt.yscale("log")
plt.gca().yaxis.set_major_formatter(ScalarFormatter())
plt.legend()
plt.grid()
plt.show()
param_best = search_param()
# param_best = np.array([0.01137184, 2.13545116, -0.20030776, -0.01955204, 0.04198086, 0.09855929, 0.0490876, 0.10989335])
plot_result(param_best)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment