Skip to content

Instantly share code, notes, and snippets.

@hyunjimoon
Last active June 12, 2024 23:00
Show Gist options
  • Save hyunjimoon/233df6ff719719a6fb9dc39bccd9e2ed to your computer and use it in GitHub Desktop.
Save hyunjimoon/233df6ff719719a6fb9dc39bccd9e2ed to your computer and use it in GitHub Desktop.
meditech.py
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from collections import namedtuple
Decision = namedtuple('Decision', ['cable', 'assembly', 'pack'])
EquityRate = namedtuple('EquityRate', ['CEO', 'VC', 'Emp', 'Supp'])
INITIAL_MONEY = 760_000
PROTOTYPE_COST = {
Decision(cable=True, assembly=True, pack=True): 1_990_165,
Decision(cable=True, assembly=True, pack=False): 2_142_000,
Decision(cable=True, assembly=False, pack=True): 2_282_156,
Decision(cable=True, assembly=False, pack=False): 2_433_991,
Decision(cable=False, assembly=True, pack=True): 2_457_350,
Decision(cable=False, assembly=True, pack=False): 2_609_185,
Decision(cable=False, assembly=False, pack=True): 2_749_341,
Decision(cable=False, assembly=False, pack=False): 2_901_176
}
ITERATION_TIME_RATE = {
Decision(cable=True, assembly=True, pack=True): 1/(7+2+0),
Decision(cable=True, assembly=True, pack=False): 1/(7+2+3),
Decision(cable=True, assembly=False, pack=True): 2_282_156,
Decision(cable=True, assembly=False, pack=False): 2_433_991,
Decision(cable=False, assembly=True, pack=True): 2_457_350,
Decision(cable=False, assembly=True, pack=False): 2_609_185,
Decision(cable=False, assembly=False, pack=True): 2_749_341,
Decision(cable=False, assembly=False, pack=False): 2_901_176
}
EQUITY_RATE = {
Decision(cable=True, assembly=True, pack=True): EquityRate(CEO=0.226, VC=0.539, Emp=0.100, Supp=0.136),
Decision(cable=True, assembly=True, pack=False): EquityRate(CEO=0.226, VC=0.539, Emp=0.100, Supp=0.136),
Decision(cable=True, assembly=False, pack=True): EquityRate(CEO=0.226, VC=0.539, Emp=0.100, Supp=0.136),
Decision(cable=True, assembly=False, pack=False): EquityRate(CEO=0.226, VC=0.539, Emp=0.100, Supp=0.136),
Decision(cable=False, assembly=True, pack=True): EquityRate(CEO=0.272, VC=0.608, Emp=0.120, Supp=0.000),
Decision(cable=False, assembly=True, pack=False): EquityRate(CEO=0.272, VC=0.608, Emp=0.120, Supp=0.000),
Decision(cable=False, assembly=False, pack=True): EquityRate(CEO=0.272, VC=0.608, Emp=0.120, Supp=0.000),
Decision(cable=False, assembly=False, pack=False): EquityRate(CEO=0.272, VC=0.608, Emp=0.120, Supp=0.000)
}
def iteration_time(key, decision: Decision):
return ITERATION_TIME_RATE[decision] * jax.random.exponential(key)
def fda_approved(iter_number):
if iter_number <= 4:
return 0.1
if iter_number <= 5:
return 0.2
return 0.35
def model(key, decision: Decision):
equity = EQUITY_RATE[decision]
evaluation = 20_000_000 + EVALUATION_CHANGE[decision]
cash = INITIAL_MONEY
iter_number, time = 1, 0
while cash > 0:
cash = cash - PROTOTYPE_COST[decision]
time = time + iteration_time(key, decision)
iter_number = iter_number + 1
key = jax.random.split(key)[0]
approval_rate = fda_approved(iter_number)
# XXX TODO
return approval_rate, time, equity
#### optimization
N = 1000
for decision in PROTOTYPE_COST.keys():
ceo_costs = []
for _ in range(N):
approval_rate, time, equity = model(decision)
ceo_costs.append(ceo_cost_fn(approval_rate, equity.CEO, time))
print('average cost for ceo', sum(ceo_costs) / N)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment