Skip to content

Instantly share code, notes, and snippets.

@MatsuuraKentaro
Created November 24, 2023 09:21
Show Gist options
  • Save MatsuuraKentaro/c7457d8a207c4e631e0377b8d9efa8e6 to your computer and use it in GitHub Desktop.
Save MatsuuraKentaro/c7457d8a207c4e631e0377b8d9efa8e6 to your computer and use it in GitHub Desktop.
『Pythonではじめる数理最適化』の7章「商品推薦のための興味のスコアリング」をStanで解く
data {
int I; // number of data
int R; // number of Rcen
int F; // number of Freq
array[I] int Rcen; // value of Rcen
array[I] int Freq; // value of Freq
array[I] int N;
array[I] int PV;
}
parameters {
matrix<lower=0>[R, F] dx;
}
transformed parameters {
matrix[R, F] x;
matrix[R, F] q;
x[1,1] = 5 - dx[1,1];
for (r in 2:R) {
x[r,1] = x[r-1,1] - dx[r,1];
}
for (f in 2:F) {
x[1,f] = x[1,f-1] - dx[1,f];
}
for (r in 2:R) {
for (f in 2:F) {
x[r,f] = min([x[r-1,f], x[r,f-1]]) - dx[r,f];
}
}
q[1:R,1:F] = inv_logit(x[1:R,1:F]);
}
model {
for (i in 1:I) {
PV[i] ~ binomial(N[i], q[Rcen[i], Freq[i]]);
}
}
import pandas
import cmdstanpy
import numpy as np
import matplotlib.pyplot as plt
from plotnine import ggplot, aes, theme, element_text, geom_ribbon, geom_line, geom_point, labs
rf_df = pandas.read_csv('input/rf_df.csv')
fit = cmdstanpy.from_csv('output/result')
q_ms = fit.stan_variable(var='q')
Freq = rf_df.freq.unique().tolist()
Rcen = rf_df.rcen.unique().tolist()
Z = [np.median(q_ms[:, rcen-1, 7-freq]) for freq in Freq for rcen in Rcen]
Z = np.array(Z).reshape((len(Freq), len(Rcen)))
X, Y = np.meshgrid(Rcen, Freq)
fig = plt.figure(dpi=250)
ax = fig.add_subplot(111, projection='3d', xlabel='rcen', ylabel='freq', zlabel='pred_prob')
ax.plot_wireframe(X, Y, Z)
plt.savefig('output/q_median.png')
plot_f = 5
qua = np.quantile(q_ms[:, :, 7-plot_f], [0.025, 0.25, 0.50, 0.75, 0.975], axis=0)
d_est = pandas.DataFrame(np.column_stack([np.arange(1,8), qua.T]), \
columns=['rcen', '2.5%', '25%', '50%', '75%', '97.5%'])
rf_df_at_plot_f = rf_df[rf_df.freq == plot_f]
p = (ggplot()
+ theme(text=element_text(size=18))
+ geom_ribbon(d_est, aes(x='rcen', ymin='2.5%', ymax='97.5%'), fill='blue', alpha=1/6)
+ geom_ribbon(d_est, aes(x='rcen', ymin='25%', ymax='75%'), fill='blue', alpha=2/6)
+ geom_line(d_est, aes(x='rcen', y='50%'), color='blue', size=1)
+ geom_point(rf_df_at_plot_f, aes(x='rcen', y='prob'), size=1)
+ labs(y='prob')
)
p.save(filename=f'output/q_at_f_{plot_f}.png', dpi=300, width=5, height=4)
rcen freq N pv prob
1 1 19602 245 0.012498724619936742
1 2 3323 132 0.039723141739392114
1 3 1120 81 0.07232142857142858
1 4 539 36 0.06679035250463822
1 5 285 36 0.12631578947368421
1 6 177 20 0.11299435028248588
1 7 120 21 0.175
2 1 19126 112 0.0058559029593223885
2 2 3162 67 0.02118912080961417
2 3 1001 27 0.026973026973026972
2 4 459 26 0.05664488017429194
2 5 302 20 0.06622516556291391
2 6 162 16 0.09876543209876543
2 7 94 6 0.06382978723404255
3 1 22596 138 0.006107275624004248
3 2 3616 84 0.023230088495575223
3 3 1161 46 0.03962101636520241
3 4 582 31 0.05326460481099656
3 5 279 11 0.03942652329749104
3 6 185 10 0.05405405405405406
3 7 119 6 0.05042016806722689
4 1 24385 133 0.005454172647119131
4 2 4035 62 0.015365551425030979
4 3 1305 32 0.024521072796934867
4 4 597 28 0.04690117252931323
4 5 300 11 0.03666666666666667
4 6 185 7 0.03783783783783784
4 7 109 2 0.01834862385321101
5 1 25363 111 0.0043764538895241095
5 2 3999 62 0.015503875968992248
5 3 1225 29 0.0236734693877551
5 4 603 9 0.014925373134328358
5 5 274 6 0.021897810218978103
5 6 173 5 0.028901734104046242
5 7 98 3 0.030612244897959183
6 1 26034 116 0.004455711761542598
6 2 3757 37 0.009848283204684588
6 3 1183 29 0.024513947590870666
6 4 511 10 0.019569471624266144
6 5 235 2 0.00851063829787234
6 6 121 3 0.024793388429752067
6 7 79 2 0.02531645569620253
7 1 25611 109 0.004255983756979423
7 2 3522 32 0.009085746734809767
7 3 996 14 0.014056224899598393
7 4 385 9 0.023376623376623377
7 5 220 2 0.00909090909090909
7 6 98 2 0.02040816326530612
7 7 43 0 0.0
import pandas
import cmdstanpy
d = pandas.read_csv('input/rf_df.csv')
data = {'I':len(d), 'R':7, 'F':7,
'Rcen':d.rcen, 'Freq':8 - d.freq, 'N':d.N, 'PV':d.pv}
model = cmdstanpy.CmdStanModel(stan_file='model/model.stan')
fit = model.sample(data=data, seed=123)
fit.save_csvfiles('output/result')
fit.summary().to_csv('output/fit-summary.csv')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment