Skip to content

Instantly share code, notes, and snippets.

@jjallaire
Forked from wch/retirement-logo.png
Last active November 2, 2023 14:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jjallaire/bee437084aa8e75a2e7c8eabca529beb to your computer and use it in GitHub Desktop.
Save jjallaire/bee437084aa8e75a2e7c8eabca529beb to your computer and use it in GitHub Desktop.
Retirement simulation Quarto Shiny app
env/
__pycache__/
*-app.py
*_files/
*.html
anyio==4.0.0
appdirs==1.4.4
appnope==0.1.3
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asgiref==3.7.2
asttokens==2.4.1
async-lru==2.0.4
attrs==23.1.0
Babel==2.13.1
beautifulsoup4==4.12.2
bleach==6.1.0
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
comm==0.1.4
contourpy==1.1.1
cycler==0.12.1
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
exceptiongroup==1.1.3
executing==2.0.1
fastjsonschema==2.18.1
fonttools==4.43.1
fqdn==1.5.1
h11==0.14.0
htmltools==0.4.1
idna==3.4
ipykernel==6.26.0
ipython==8.17.2
ipython-genutils==0.2.0
ipywidgets==8.1.1
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.19.2
jsonschema-specifications==2023.7.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.8.0
jupyter-lsp==2.2.0
jupyter_client==8.5.0
jupyter_core==5.5.0
jupyter_server==2.9.1
jupyter_server_terminals==0.4.4
jupyterlab==4.0.7
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.9
jupyterlab_server==2.25.0
kiwisolver==1.4.5
linkify-it-py==2.0.2
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.8.1
matplotlib-inline==0.1.6
mdit-py-plugins==0.4.0
mdurl==0.1.2
mistune==3.0.2
nbclient==0.8.0
nbconvert==7.10.0
nbformat==5.9.2
nest-asyncio==1.5.8
notebook==7.0.6
notebook_shim==0.2.3
numpy==1.26.1
overrides==7.4.0
packaging==23.2
pandas==2.1.2
pandocfilters==1.5.0
parso==0.8.3
pexpect==4.8.0
Pillow==10.1.0
platformdirs==3.11.0
prometheus-client==0.18.0
prompt-toolkit==3.0.39
psutil==5.9.6
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser==2.21
Pygments==2.16.1
pyparsing==3.1.1
python-dateutil==2.8.2
python-json-logger==2.0.7
python-multipart==0.0.6
pytz==2023.3.post1
PyYAML==6.0.1
pyzmq==25.1.1
qtconsole==5.4.4
QtPy==2.4.1
referencing==0.30.2
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.10.6
Send2Trash==1.8.2
shiny==0.6.0
six==1.16.0
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.3
starlette==0.31.1
terminado==0.17.1
tinycss2==1.2.1
tomli==2.0.1
tornado==6.3.3
traitlets==5.13.0
types-python-dateutil==2.8.19.14
typing_extensions==4.8.0
tzdata==2023.3
uc-micro-py==1.0.2
uri-template==1.3.0
urllib3==2.0.7
uvicorn==0.23.2
watchfiles==0.21.0
wcwidth==0.2.9
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.4
websockets==12.0
widgetsnbextension==4.0.9
---
title: "Retirement: simulating wealth with random returns, inflation and withdrawals"
format: dashboard
logo: retirement-logo.png
server: shiny
---
```{python}
#| context: setup
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from shiny import render, reactive, ui
```
## Row {.flow}
```{python}
#| title: Scenario A
ui.input_slider("start_capital", "Initial investment", 1e5, 1e7, value=2e6, pre="$")
ui.input_slider("return_mean", "Average annual investment return", 0, 30, value=5, step=0.5, post="%")
ui.input_slider("inflation_mean", "Average annual inflation", 0, 20, value=2.5, step=0.5, post="%")
ui.input_slider("monthly_withdrawal", "Monthly withdrawals", 0, 50000, value=10000, pre="$")
```
```{python}
#| title: Scenario B
ui.input_slider("start_capital2", "Initial investment", 1e5, 1e7, value=2e6, pre="$")
ui.input_slider("return_mean2", "Average annual investment return", 0, 30, value=5, step=0.5, post="%")
ui.input_slider("inflation_mean2", "Average annual inflation", 0, 20, value=2.5, step=0.5, post="%")
ui.input_slider("monthly_withdrawal2", "Monthly withdrawals", 0, 50000, value=8000, step=500, pre="$")
```
## Row
```{python}
@render.plot()
def nav_1():
nav_df = run_simulation(
input.start_capital(),
input.return_mean() / 100,
# input.return_stdev() / 100,
.07,
input.inflation_mean() / 100,
# input.inflation_stdev() / 100,
.015,
input.monthly_withdrawal(),
30,
100
)
return make_plot(nav_df)
```
```{python}
@render.plot()
def nav_2():
nav_df = run_simulation(
input.start_capital2(),
input.return_mean2() / 100,
# input.return_stdev2() / 100,
.07,
input.inflation_mean2() / 100,
# input.inflation_stdev2() / 100,
.015,
input.monthly_withdrawal2(),
30,
100
)
return make_plot(nav_df)
```
```{python}
def create_matrix(rows, cols, mean, stdev):
x = np.random.randn(rows, cols)
x = mean + x * stdev
return x
def run_simulation(
start_capital,
return_mean,
return_stdev,
inflation_mean,
inflation_stdev,
monthly_withdrawal,
n_years,
n_simulations
):
# Convert annual values to monthly
n_months = 12 * n_years
monthly_return_mean = return_mean / 12
monthly_return_stdev = return_stdev / math.sqrt(12)
monthly_inflation_mean = inflation_mean / 12
monthly_inflation_stdev = inflation_stdev / math.sqrt(12)
# Simulate returns and inflation
monthly_returns = create_matrix(
n_months, n_simulations, monthly_return_mean, monthly_return_stdev
)
monthly_inflation = create_matrix(
n_months, n_simulations, monthly_inflation_mean, monthly_inflation_stdev
)
# Simulate withdrawals
nav = np.full((n_months + 1, n_simulations), float(start_capital))
for j in range(n_months):
nav[j + 1, :] = (
nav[j, :] *
(1 + monthly_returns[j, :] - monthly_inflation[j, :]) -
monthly_withdrawal
)
# Set nav values below 0 to NaN (Not a Number, which is equivalent to NA in R)
nav[nav < 0] = np.nan
# convert to millions
nav = nav / 1000000
return pd.DataFrame(nav)
def make_plot(nav_df):
# # For the histogram, we will fill NaNs with -1
nav_df_zeros = nav_df.ffill().fillna(0).iloc[-1, :]
# Define the figure and axes
fig = plt.figure()
# Create the top plot for time series on the first row that spans all columns
ax1 = plt.subplot2grid((2, 2), (0, 0), colspan=2)
# Create the bottom left plot for the percentage above zero
ax2 = plt.subplot2grid((2, 2), (1, 0), colspan=2)
for column in nav_df.columns:
ax1.plot(nav_df.index, nav_df[column], alpha=0.3)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.title.set_text("Projected value of capital over 30 years")
ax1.set_xlabel("Months")
ax1.set_ylabel("Millions")
ax1.grid(True)
# Calculate the percentage of columns that are above zero for each date and plot (bottom left plot)
percent_above_zero = (nav_df > 0).sum(axis=1) / nav_df.shape[1] * 100
ax2.plot(nav_df.index, percent_above_zero, color='purple')
ax2.set_xlim(nav_df.index.min(), nav_df.index.max())
ax2.set_ylim(0, 100) # Percentage goes from 0 to 100
ax2.title.set_text("Percent of scenarios still paying")
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.set_xlabel("Months")
ax2.grid(True)
plt.tight_layout()
return fig
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment