Skip to content

Instantly share code, notes, and snippets.

@bluepost59
Last active September 11, 2021 15:57
Show Gist options
  • Save bluepost59/ba8cde090fdd29241dd16cb4ddf022c0 to your computer and use it in GitHub Desktop.
Save bluepost59/ba8cde090fdd29241dd16cb4ddf022c0 to your computer and use it in GitHub Desktop.
economy_simulation.py
import os
import time as t
import numpy as np
import matplotlib.pyplot as plt
N_POPULATION = 100
N_INITIAL_BALANCE = 100
MAX_TRADE = 10
N_TICKS = 10000
OUTPUT_INTERVAL = 1000
GRAPH_DIR = "./output/{}".format(int(t.time()))
"""
main
"""
def main():
balances = N_INITIAL_BALANCE * np.ones((N_TICKS, N_POPULATION))
os.makedirs(GRAPH_DIR, exist_ok=True)
draw_graph(0, balances[0, :])
for i_tick in range(N_TICKS):
if i_tick < N_TICKS-1:
balances[i_tick+1, :] = tick(balances[i_tick, :])
if (i_tick+1) % OUTPUT_INTERVAL == 0:
print(i_tick+1)
draw_graph(i_tick+1, np.sort(balances[i_tick+1, :])[::-1])
# 集計
summary(balances)
# np.savetxt("result_{}".format(int(t.time())), balances, delimiter=",")
"""
統計
"""
def summary(balances):
mean = np.mean(balances, axis=1)
plt.figure()
plt.title("mean")
plt.plot(mean)
plt.savefig(os.path.join(GRAPH_DIR, "mean.png"))
plt.clf()
median = np.median(balances, axis=1)
plt.figure()
plt.title("median")
plt.plot(median)
plt.savefig(os.path.join(GRAPH_DIR, "median.png"))
plt.clf()
std = np.std(balances, axis=1)
plt.title("std")
plt.plot(std)
plt.savefig(os.path.join(GRAPH_DIR, "std.png"))
plt.clf()
value_ranges = np.max(balances, axis=1)-np.min(balances, axis=1)
plt.title("range")
plt.plot(value_ranges)
plt.savefig(os.path.join(GRAPH_DIR, "ranges.png"))
plt.clf()
"""
1回の取引
"""
def tick(balances):
new_balances = balances.copy()
# 取引
pairs = np.random.randint(N_POPULATION, size=(N_POPULATION, 2))
amount = np.random.randint(MAX_TRADE, size=N_POPULATION)
for i, trade in enumerate(pairs):
new_balances[trade[1]] += amount[i]
new_balances[trade[0]] -= amount[i]
# 課税
tax_sum = 0
for i in range(new_balances.shape[0]):
tax_amount = tax(new_balances[i])
new_balances[i] -= tax_amount
tax_sum += tax_amount
# 再配分
for i in range(new_balances.shape[0]):
new_balances[i] += tax_sum/N_POPULATION
return new_balances
"""
税率の関数
"""
def tax(value):
# -----------------------------------------------
# 例1:最初より儲けた人だけ40%持っていかれる
# return value*0.4*float(value >= N_INITIAL_BALANCE)
# -----------------------------------------------
# 例2:最初より儲けた分にだけ課税
return 0.1*max(value-N_INITIAL_BALANCE, 0)
# -----------------------------------------------
# 例3:全員一律に課税
# return value*0.8
# -----------------------------------------------
# 例4:累進課税
# if value >= N_INITIAL_BALANCE:
# tax_rate = 0.4
# elif value >= 0.5*N_INITIAL_BALANCE:
# tax_rate = 0.2
# elif value >= 0.25*N_INITIAL_BALANCE:
# tax_rate = 0.1
# else:
# tax_rate = 0
# return tax_rate * value
# -----------------------------------------------
# 例4:sigmoid関数的に増税
# return (value)/(1+np.exp(-(value/N_INITIAL_BALANCE)))
"""
グラフ
"""
def draw_graph(n_ticks, balances):
zero_line = np.zeros(balances.shape)
myfig = plt.figure()
myax = myfig.add_subplot(111)
myax.set_title("ticks {}".format(n_ticks))
# myax.set_ylim(-3*N_INITIAL_BALANCE, 3*N_INITIAL_BALANCE)
myax.plot(balances)
myax.plot(zero_line, color="black", linestyle="--")
# myax.fill_between(balances, zero_line)
myfig.savefig(
os.path.join(GRAPH_DIR, "{0:06d}.png".format(n_ticks))
)
myfig.clf()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment