Skip to content

Instantly share code, notes, and snippets.

Created September 16, 2019 15:57
Show Gist options
  • Save medewitt/209512a8c042c3cb28b960a39cd28b4c to your computer and use it in GitHub Desktop.
Save medewitt/209512a8c042c3cb28b960a39cd28b4c to your computer and use it in GitHub Desktop.
Exploring Loss Functions and Integrating Over the Loss
# Parameters
n <- 100L
#true model
x <- rnorm(n, 5, 1)
treat <- rep(c(0,1), n/2)
y <- rnorm(n, 2 * treat + 1*x)
# Package data for stan
stan_dat <- list(
N = n,
x = x,
y = y,
status = treat
# compile and run model
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)
# Run Model
linear_regression <- stan_model("loss-function.stan")
fit1 <- sampling(linear_regression, data = stan_dat,
chains = 2, iter = 1000, refresh = 0)
# Look at output
results <- data.frame(
treat <-fit1 %>%
extract("treatment") %>%
loss <- fit1 %>%
extract("loss") %>%
my_loss <- function(x){
if(x > 1){
} else if (x >0 ) {
estimated_loss <- purrr::map_dbl(seq(-3, 3, .1), my_loss)
par(mfrow = c(1,4))
hist(results$treatment, main = "Histogram of Treatment Effect", col = "grey", breaks = 30)
plot(seq(-3, 3, .1), estimated_loss, main = "Loss Function")
plot(x = results$treatment, y = results$loss, main = "Treatment vs Loss")
hist(results$loss, breaks = 30, main = "Probable Loss")
// The input data is a vector 'y' of length 'N'.
* loss_function
* @param x a vector of outputed values
real loss_function(real x){
//Build output vector
real output;
output = 1/log(x);
else if (x > 0 )
output = x;
output = 20;
return output;
data {
int<lower=0> N;
vector[N] x;
vector[N] status;
vector[N] y;
// The parameters accepted by the model. Our model
// accepts two parameters 'mu' and 'sigma'.
parameters {
real alpha;
real beta;
real treatment;
real<lower=0> sigma;
// The model to be estimated. We model the output
// 'y' to be normally distributed with mean 'mu'
// and standard deviation 'sigma'.
model {
y ~ normal(alpha + beta * x + treatment * status, sigma);
generated quantities{
real loss = loss_function(treatment);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment