Created
March 11, 2022 20:06
-
-
Save j2kun/317ab2f37902498d93ecaacb3d278def to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'''An implementation of a media mix model.''' | |
import pymc3 as pm | |
import numpy as np | |
# currently no adstock transformation... | |
def tanh_saturation(x, users_at_saturation, initial_cost_per_user): | |
'''A saturation function based on the hyperbolic tangent. | |
The saturation function represents the diminishing returns of an | |
advertising channel. | |
This is PyMC Labs's reparameterization of the original HelloFresh model, cf. | |
https://web.archive.org/web/20211224060713/https://www.pymc-labs.io/blog-posts/reducing-customer-acquisition-costs-how-we-helped-optimizing-hellofreshs-marketing-budget/ | |
''' | |
return users_at_saturation * pm.math.tanh(x / (users_at_saturation * initial_cost_per_user)) | |
def make_model(channel_data, sales): | |
with pm.Model() as model: | |
channel_models = [] | |
for (channel, weekly_spending) in channel_data.items(): | |
acquisition_rate = pm.Gamma(f'acquisition_rate_{channel}', alpha=2, beta=1) | |
saturation_users = pm.Gamma(f'saturation_users_{channel}', alpha=15, beta=0.5) | |
initial_cost = pm.Gamma(f'initial_cost_{channel}', alpha=5, beta=1) | |
channel_models.append( | |
acquisition_rate * tanh_saturation( | |
weekly_spending, saturation_users, initial_cost | |
) | |
) | |
baseline = pm.HalfNormal('baseline', sigma=1) | |
output_noise = pm.HalfNormal('output_noise', sigma=1) | |
new_sales = baseline + sum(channel_models) | |
_ = pm.Normal( | |
'likelihood', mu=new_sales, sd=output_noise, observed=sales | |
) | |
return model | |
if __name__ == "__main__": | |
import arviz as az | |
import matplotlib.pyplot as plt | |
import csv | |
channel_data = { | |
'tv': [230.1, 44.5, 17.2, 151.5, 180.8, 8.7, 57.5, 120.2, 8.6, 199.8], | |
'radio': [37.8, 39.3, 45.9, 41.3, 10.8, 48.9, 32.8, 19.6, 2.1, 2.6], | |
'newspaper': [69.2, 45.1, 69.3, 58.5, 58.4, 75.0, 23.5, 11.6, 1.0, 21.2], | |
} | |
sales = [22.1, 10.4, 9.3, 18.5, 12.9, 7.2, 11.8, 13.2, 4.8, 10.6] | |
print(channel_data, sales) | |
model = make_model(channel_data, sales) | |
with model: | |
trace = pm.sample(10, return_inferencedata=False) | |
summary = az.summary(trace, round_to=2) | |
print(summary) | |
az.plot_trace(trace) | |
plt.tight_layout() | |
plt.savefig('plot.pdf') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
WARNING (theano.configdefaults): g++ not detected ! Theano will be unable to execute optimized C-implementations (for both CPU and GPU) and will default to Python implementations. Performance will be severely degraded. To remove this warning, set Theano flags cxx to an empty string. | |
WARNING (theano.tensor.blas): Using NumPy C-API based implementation for BLAS functions. | |
{'tv': [230.1, 44.5, 17.2, 151.5, 180.8, 8.7, 57.5, 120.2, 8.6, 199.8], 'radio': [37.8, 39.3, 45.9, 41.3, 10.8, 48.9, 32.8, 19.6, 2.1, 2.6], 'newspaper': [69.2, 45.1, 69.3, 58.5, 58.4, 75.0, 23.5, 11.6, 1.0, 21.2]} [22.1, 10.4, 9.3, 18.5, 12.9, 7.2, 11.8, 13.2, 4.8, 10.6] | |
Only 10 samples in chain. | |
Auto-assigning NUTS sampler... | |
Initializing NUTS using jitter+adapt_diag... | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/tensor/elemwise.py:826: RuntimeWarning: invalid value encountered in log | |
variables = ufunc(*ufunc_args, **ufunc_kwargs) | |
Multiprocess sampling (4 chains in 4 jobs) | |
NUTS: [output_noise, baseline, initial_cost_newspaper, saturation_users_newspaper, acquisition_rate_newspaper, initial_cost_radio, saturation_users_radio, acquisition_rate_radio, initial_cost_tv, saturation_users_tv, acquisition_rate_tv] | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/tensor/elemwise.py:826: RuntimeWarning: overflow encountered in expmpling 4 chains, 0 divergences] | |
variables = ufunc(*ufunc_args, **ufunc_kwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/tensor/elemwise.py:826: RuntimeWarning: overflow encountered in exp | |
variables = ufunc(*ufunc_args, **ufunc_kwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/numpy/core/fromnumeric.py:86: RuntimeWarning: overflow encountered in reduce | |
return ufunc.reduce(obj, axis, dtype, out, **passkwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/tensor/elemwise.py:826: RuntimeWarning: overflow encountered in impl (vectorized) | |
variables = ufunc(*ufunc_args, **ufunc_kwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/scalar/basic.py:2893: RuntimeWarning: divide by zero encountered in log | |
return np.log(x) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/tensor/elemwise.py:826: RuntimeWarning: divide by zero encountered in impl (vectorized) | |
variables = ufunc(*ufunc_args, **ufunc_kwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/scalar/basic.py:2893: RuntimeWarning: divide by zero encountered in log | |
return np.log(x) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/tensor/elemwise.py:826: RuntimeWarning: divide by zero encountered in impl (vectorized) | |
variables = ufunc(*ufunc_args, **ufunc_kwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/numpy/core/fromnumeric.py:86: RuntimeWarning: invalid value encountered in reduce | |
return ufunc.reduce(obj, axis, dtype, out, **passkwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/scalar/basic.py:1813: RuntimeWarning: invalid value encountered in double_scalars | |
return sum(inputs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/scalar/basic.py:1955: RuntimeWarning: invalid value encountered in true_divide | |
return x / y | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/numpy/core/fromnumeric.py:86: RuntimeWarning: invalid value encountered in reduce4 chains, 0 divergences] | |
return ufunc.reduce(obj, axis, dtype, out, **passkwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/scalar/basic.py:1813: RuntimeWarning: invalid value encountered in double_scalars | |
return sum(inputs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/scalar/basic.py:1955: RuntimeWarning: invalid value encountered in true_divide | |
return x / y | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/tensor/elemwise.py:826: RuntimeWarning: invalid value encountered in multiplyhains, 0 divergences] | |
variables = ufunc(*ufunc_args, **ufunc_kwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/tensor/elemwise.py:826: RuntimeWarning: invalid value encountered in impl (vectorized) | |
variables = ufunc(*ufunc_args, **ufunc_kwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/tensor/elemwise.py:826: RuntimeWarning: invalid value encountered in multiply | |
variables = ufunc(*ufunc_args, **ufunc_kwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/theano/tensor/elemwise.py:826: RuntimeWarning: invalid value encountered in impl (vectorized) | |
variables = ufunc(*ufunc_args, **ufunc_kwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/numpy/core/fromnumeric.py:86: RuntimeWarning: overflow encountered in reduce | |
return ufunc.reduce(obj, axis, dtype, out, **passkwargs) | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/numpy/core/fromnumeric.py:86: RuntimeWarning: overflow encountered in reduceling 4 chains, 0 divergences] | |
return ufunc.reduce(obj, axis, dtype, out, **passkwargs) | |
Sampling 4 chains for 1_000 tune and 10 draw iterations (4_000 + 40 draws total) took 79 seconds.00.00% [4040/4040 01:18<00:00 Sampling 4 chains, 0 divergences] | |
/home/j2kun/pmfp-code/venv/lib/python3.9/site-packages/pymc3/sampling.py:643: UserWarning: The number of samples is too small to check convergence reliably. | |
warnings.warn("The number of samples is too small to check convergence reliably.") | |
The acceptance probability does not match the target. It is 0.510787241074177, but should be close to 0.8. Try to increase the number of tuning steps. | |
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat | |
acquisition_rate_tv 0.49 0.09 0.33 0.64 0.02 0.02 14.44 46.58 1.61 | |
saturation_users_tv 31.24 7.19 19.12 40.76 0.93 0.66 56.02 61.26 1.48 | |
initial_cost_tv 6.11 2.11 2.92 9.55 0.41 0.29 31.40 46.58 1.19 | |
acquisition_rate_radio 0.80 0.36 0.17 1.25 0.08 0.06 16.80 61.26 1.42 | |
saturation_users_radio 32.22 6.71 22.67 46.88 0.87 0.63 58.92 61.26 1.25 | |
initial_cost_radio 5.87 1.91 2.45 9.09 0.46 0.33 18.11 64.08 1.37 | |
acquisition_rate_newspaper 0.23 0.22 0.02 0.60 0.05 0.04 15.53 19.70 1.42 | |
saturation_users_newspaper 26.72 7.87 17.86 42.73 1.36 0.97 27.95 61.26 1.17 | |
initial_cost_newspaper 6.67 3.35 0.89 13.29 0.75 0.54 22.14 27.97 1.25 | |
baseline 0.76 0.58 0.03 1.76 0.08 0.06 41.27 19.70 1.06 | |
output_noise 1.92 0.37 1.44 2.52 0.08 0.06 25.66 61.26 1.16 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output replacing the Gammas with HalfNormals