Skip to content

Instantly share code, notes, and snippets.

Created November 13, 2019 22:05
Show Gist options
  • Save samcarlos/e8282e585251e60948c931ddcf8f3084 to your computer and use it in GitHub Desktop.
Save samcarlos/e8282e585251e60948c931ddcf8f3084 to your computer and use it in GitHub Desktop.
get_data = function(num_obs){
x1 = runif(num_obs,0,3)
x2 = runif(num_obs,0,3)
t = rbinom(num_obs,1,.5)
noise_fee = rnorm(num_obs,0,1)
noise_costs = rnorm(num_obs,0,1)
#fee = log(x1*1.718+1)*(t) + noise_fee
#cost = .5*x2*(t) + noise_costs
fee = x1*(t) + noise_fee
cost = x2*(t) + noise_costs
cost = (t) + noise_costs
data = data.frame(fee,cost,x1,x2,t)
training_data = get_data(100000)
testing_data = get_data(100000)
reg_fee = randomForest(fee ~ x1+x2+t, data = training_data)
reg_cost = randomForest(cost ~ x1+x2+t, data = training_data)
get_counterfactuals = function(model, data){
data_1 = data
data_1$t = 1
data_0 = data
data_0$t = 0
preds_1 = predict(model, data_1)
preds_0 = predict(model, data_0)
counterfactuals = cbind(preds_0, preds_1)
counters_fee = get_counterfactuals(reg_fee,testing_data)
counters_cost = get_counterfactuals(reg_cost,testing_data)
erupt = function(given_tmts, assigned_tmts, response){
print(sum(given_tmts == assigned_tmts))
return ( colMeans(response[which(given_tmts == assigned_tmts),]) )
weighted_preds = mapply(function(x,y){x*y},x=list(counters_fee, counters_cost), y=c(1,-1), SIMPLIFY = FALSE)
get_erupts = function(weights, preds, tmts, response){
weighted_preds1 = mapply(function(x,y){x*y} ,x=preds , y=weights, SIMPLIFY = FALSE)
best_tmt = apply(weighted_preds1[[1]]+weighted_preds1[[2]], 1, which.max)-1
return( erupt(tmts, best_tmt, response ))
get_erupts_1 = function(weights, preds){
weighted_preds1 = mapply(function(x,y){x*y} ,x=preds , y=weights, SIMPLIFY = FALSE)
best_tmt = apply(weighted_preds1[[1]]+weighted_preds1[[2]], 1, which.max)-1
fees = (sum(preds[[1]][which(best_tmt == 0),1])+sum(preds[[1]][which(best_tmt == 1),2]))/nrow(preds[[1]])
costs = (sum(preds[[2]][which(best_tmt == 0),1])+sum(preds[[2]][which(best_tmt == 1),2]))/nrow(preds[[1]])
get_true = function(weights, preds, true_values){
weighted_preds1 = mapply(function(x,y){x*y} ,x=preds , y=weights, SIMPLIFY = FALSE)
best_tmt = apply(weighted_preds1[[1]]+weighted_preds1[[2]], 1, which.max)-1
fees = (sum(true_values[[1]][which(best_tmt == 0),1])+sum(true_values[[1]][which(best_tmt == 1),2]))/nrow(preds[[1]])
costs = (sum(true_values[[2]][which(best_tmt == 0),1])+sum(true_values[[2]][which(best_tmt == 1),2]))/nrow(preds[[1]])
true_values = list(cbind(rep(0,100000),testing_data[,'x1']), cbind(rep(0,100000),rep(1,100000)))
weights = rbind(c(0,0),expand.grid(c(seq(0,20,.25),10000), -1))
estimated_erupts = apply(weights,1,function(x) get_erupts(x, list(counters_fee, counters_cost), testing_data[,'t'], testing_data[,c('fee','cost')]))
estimated_erupts = t(estimated_erupts)
estimated_erupts = data.frame(estimated_erupts)
estimated_erupts[,'estimate'] = 'erupt'
over_estimated_erupts = apply(weights,1,function(x)( get_true(x, list(counters_fee, counters_cost),list(counters_fee, counters_cost)) ) )
over_estimated_erupts = t(over_estimated_erupts)
over_estimated_erupts = data.frame(over_estimated_erupts)
over_estimated_erupts[,'estimate'] = 'model'
colnames(over_estimated_erupts) = c('fee','cost','estimate')
true_estimated_erupts = apply(weights,1,function(x)( get_true(x, list(counters_fee, counters_cost), true_values ) ))
true_estimated_erupts = t(true_estimated_erupts)
colnames(true_estimated_erupts) = c('fee','cost')
true_estimated_erupts = data.frame(true_estimated_erupts)
true_estimated_erupts[,'estimate'] = 'truth'
graph_data = rbind(over_estimated_erupts,estimated_erupts,true_estimated_erupts)
ggplot(graph_data, aes(x = cost, y = fee, group = as.factor(estimate), colour = as.factor(estimate)))+geom_point() + geom_line()
Copy link

What would this look like in a programming language? Maybe Python?

probably pretty similar considering I still use the parenthesis after return statements in python :)

Copy link

It might cor-erupt the code

Copy link

Well thanks for the sudo code, neat stuff!

Copy link

haha nice Matt, that took me too long

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment