Skip to content

Instantly share code, notes, and snippets.

@wch
Last active November 2, 2023 14:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save wch/358b95c0f978f957f462a0cb75b2c9db to your computer and use it in GitHub Desktop.
Save wch/358b95c0f978f957f462a0cb75b2c9db to your computer and use it in GitHub Desktop.
Retirement simulation Quarto Shiny app
---
title: "Retirement: simulating wealth with random returns, inflation and withdrawals"
format: dashboard
logo: retirement-logo.png
server: shiny
execute:
daemon: false
---
```{python}
#| context: setup
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from shiny import render, reactive, ui
```
## Row {.flow}
```{python}
#| expandable: false
#| title: Scenario A
ui.input_slider("start_capital", "Initial investment", 1e5, 1e7, value=2e6, pre="$")
ui.input_slider("return_mean", "Average annual investment return", 0, 30, value=5, step=0.5, post="%")
# ui.input_slider("return_stdev", "Annual investment volatility", 0, 25, value=7, step=0.5, post="%")
ui.input_slider("inflation_mean", "Average annual inflation", 0, 20, value=2.5, step=0.5, post="%")
# ui.input_slider("inflation_stdev", "Annual inflation volatility", 0, 5, value=1.5, step=0.5, post="%")
ui.input_slider("monthly_withdrawal", "Monthly withdrawals", 0, 50000, value=10000, pre="$")
```
```{python}
#| expandable: false
#| title: Scenario B
ui.input_slider("start_capital2", "Initial investment", 1e5, 1e7, value=2e6, pre="$")
ui.input_slider("return_mean2", "Average annual investment return", 0, 30, value=5, step=0.5, post="%")
# ui.input_slider("return_stdev2", "Annual investment volatility", 0, 25, value=7, step=0.5, post="%")
ui.input_slider("inflation_mean2", "Average annual inflation", 0, 20, value=2.5, step=0.5, post="%")
# ui.input_slider("inflation_stdev2", "Annual inflation volatility", 0, 5, value=1.5, step=0.5, post="%")
ui.input_slider("monthly_withdrawal2", "Monthly withdrawals", 0, 50000, value=8000, step=500, pre="$")
```
## Row
```{python}
@render.plot()
def nav_1():
nav_df = run_simulation(
input.start_capital(),
input.return_mean() / 100,
# input.return_stdev() / 100,
.07,
input.inflation_mean() / 100,
# input.inflation_stdev() / 100,
.015,
input.monthly_withdrawal(),
30,
100
)
return make_plot(nav_df)
```
```{python}
@render.plot()
def nav_2():
nav_df = run_simulation(
input.start_capital2(),
input.return_mean2() / 100,
# input.return_stdev2() / 100,
.07,
input.inflation_mean2() / 100,
# input.inflation_stdev2() / 100,
.015,
input.monthly_withdrawal2(),
30,
100
)
return make_plot(nav_df)
```
```{python}
def create_matrix(rows, cols, mean, stdev):
x = np.random.randn(rows, cols)
x = mean + x * stdev
return x
def run_simulation(
start_capital,
return_mean,
return_stdev,
inflation_mean,
inflation_stdev,
monthly_withdrawal,
n_years,
n_simulations
):
# Convert annual values to monthly
n_months = 12 * n_years
monthly_return_mean = return_mean / 12
monthly_return_stdev = return_stdev / math.sqrt(12)
monthly_inflation_mean = inflation_mean / 12
monthly_inflation_stdev = inflation_stdev / math.sqrt(12)
# Simulate returns and inflation
monthly_returns = create_matrix(
n_months, n_simulations, monthly_return_mean, monthly_return_stdev
)
monthly_inflation = create_matrix(
n_months, n_simulations, monthly_inflation_mean, monthly_inflation_stdev
)
# Simulate withdrawals
nav = np.full((n_months + 1, n_simulations), float(start_capital))
for j in range(n_months):
nav[j + 1, :] = (
nav[j, :] *
(1 + monthly_returns[j, :] - monthly_inflation[j, :]) -
monthly_withdrawal
)
# Set nav values below 0 to NaN (Not a Number, which is equivalent to NA in R)
nav[nav < 0] = np.nan
# convert to millions
nav = nav / 1000000
return pd.DataFrame(nav)
def make_plot(nav_df):
# # For the histogram, we will fill NaNs with -1
nav_df_zeros = nav_df.ffill().fillna(0).iloc[-1, :]
# Define the figure and axes
fig = plt.figure()
# Create the top plot for time series on the first row that spans all columns
ax1 = plt.subplot2grid((2, 2), (0, 0), colspan=2)
# Create the bottom left plot for the percentage above zero
ax2 = plt.subplot2grid((2, 2), (1, 0), colspan=2)
for column in nav_df.columns:
ax1.plot(nav_df.index, nav_df[column], alpha=0.3)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.title.set_text("Projected value of capital over 30 years")
ax1.set_xlabel("Months")
ax1.set_ylabel("Millions")
ax1.grid(True)
# Calculate the percentage of columns that are above zero for each date and plot (bottom left plot)
percent_above_zero = (nav_df > 0).sum(axis=1) / nav_df.shape[1] * 100
ax2.plot(nav_df.index, percent_above_zero, color='purple')
ax2.set_xlim(nav_df.index.min(), nav_df.index.max())
ax2.set_ylim(0, 100) # Percentage goes from 0 to 100
ax2.title.set_text("Percent of scenarios still paying")
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.set_xlabel("Months")
ax2.grid(True)
plt.tight_layout()
return fig
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment