Skip to content

Instantly share code, notes, and snippets.

@samcarlos
Created November 13, 2019 22:05
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samcarlos/e8282e585251e60948c931ddcf8f3084 to your computer and use it in GitHub Desktop.
Save samcarlos/e8282e585251e60948c931ddcf8f3084 to your computer and use it in GitHub Desktop.
get_data = function(num_obs){
x1 = runif(num_obs,0,3)
x2 = runif(num_obs,0,3)
t = rbinom(num_obs,1,.5)
noise_fee = rnorm(num_obs,0,1)
noise_costs = rnorm(num_obs,0,1)
#fee = log(x1*1.718+1)*(t) + noise_fee
#cost = .5*x2*(t) + noise_costs
fee = x1*(t) + noise_fee
cost = x2*(t) + noise_costs
cost = (t) + noise_costs
data = data.frame(fee,cost,x1,x2,t)
return(data)
}
training_data = get_data(100000)
testing_data = get_data(100000)
library(randomForest)
reg_fee = randomForest(fee ~ x1+x2+t, data = training_data)
reg_cost = randomForest(cost ~ x1+x2+t, data = training_data)
get_counterfactuals = function(model, data){
data_1 = data
data_1$t = 1
data_0 = data
data_0$t = 0
preds_1 = predict(model, data_1)
preds_0 = predict(model, data_0)
counterfactuals = cbind(preds_0, preds_1)
return(counterfactuals)
}
counters_fee = get_counterfactuals(reg_fee,testing_data)
counters_cost = get_counterfactuals(reg_cost,testing_data)
erupt = function(given_tmts, assigned_tmts, response){
print(sum(given_tmts == assigned_tmts))
return ( colMeans(response[which(given_tmts == assigned_tmts),]) )
}
weighted_preds = mapply(function(x,y){x*y},x=list(counters_fee, counters_cost), y=c(1,-1), SIMPLIFY = FALSE)
get_erupts = function(weights, preds, tmts, response){
print(weights)
weighted_preds1 = mapply(function(x,y){x*y} ,x=preds , y=weights, SIMPLIFY = FALSE)
best_tmt = apply(weighted_preds1[[1]]+weighted_preds1[[2]], 1, which.max)-1
print(table(best_tmt))
return( erupt(tmts, best_tmt, response ))
}
get_erupts_1 = function(weights, preds){
print(weights)
weighted_preds1 = mapply(function(x,y){x*y} ,x=preds , y=weights, SIMPLIFY = FALSE)
best_tmt = apply(weighted_preds1[[1]]+weighted_preds1[[2]], 1, which.max)-1
fees = (sum(preds[[1]][which(best_tmt == 0),1])+sum(preds[[1]][which(best_tmt == 1),2]))/nrow(preds[[1]])
costs = (sum(preds[[2]][which(best_tmt == 0),1])+sum(preds[[2]][which(best_tmt == 1),2]))/nrow(preds[[1]])
return(c(fees,costs))
}
get_true = function(weights, preds, true_values){
print(weights)
weighted_preds1 = mapply(function(x,y){x*y} ,x=preds , y=weights, SIMPLIFY = FALSE)
best_tmt = apply(weighted_preds1[[1]]+weighted_preds1[[2]], 1, which.max)-1
fees = (sum(true_values[[1]][which(best_tmt == 0),1])+sum(true_values[[1]][which(best_tmt == 1),2]))/nrow(preds[[1]])
costs = (sum(true_values[[2]][which(best_tmt == 0),1])+sum(true_values[[2]][which(best_tmt == 1),2]))/nrow(preds[[1]])
return(c(fees,costs))
}
true_values = list(cbind(rep(0,100000),testing_data[,'x1']), cbind(rep(0,100000),rep(1,100000)))
weights = rbind(c(0,0),expand.grid(c(seq(0,20,.25),10000), -1))
estimated_erupts = apply(weights,1,function(x) get_erupts(x, list(counters_fee, counters_cost), testing_data[,'t'], testing_data[,c('fee','cost')]))
estimated_erupts = t(estimated_erupts)
estimated_erupts = data.frame(estimated_erupts)
estimated_erupts[,'estimate'] = 'erupt'
over_estimated_erupts = apply(weights,1,function(x)( get_true(x, list(counters_fee, counters_cost),list(counters_fee, counters_cost)) ) )
over_estimated_erupts = t(over_estimated_erupts)
over_estimated_erupts = data.frame(over_estimated_erupts)
over_estimated_erupts[,'estimate'] = 'model'
colnames(over_estimated_erupts) = c('fee','cost','estimate')
true_estimated_erupts = apply(weights,1,function(x)( get_true(x, list(counters_fee, counters_cost), true_values ) ))
true_estimated_erupts = t(true_estimated_erupts)
colnames(true_estimated_erupts) = c('fee','cost')
true_estimated_erupts = data.frame(true_estimated_erupts)
true_estimated_erupts[,'estimate'] = 'truth'
library(ggplot2)
graph_data = rbind(over_estimated_erupts,estimated_erupts,true_estimated_erupts)
ggplot(graph_data, aes(x = cost, y = fee, group = as.factor(estimate), colour = as.factor(estimate)))+geom_point() + geom_line()
@Mikelew88
Copy link

What would this look like in a programming language? Maybe Python?

@samcarlos
Copy link
Author

What would this look like in a programming language? Maybe Python?

probably pretty similar considering I still use the parenthesis after return statements in python :)

@mattsgithub
Copy link

It might cor-erupt the code

@Mikelew88
Copy link

Well thanks for the sudo code, neat stuff!

@Mikelew88
Copy link

haha nice Matt, that took me too long

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment