Created
August 24, 2017 22:05
-
-
Save krishnanraman/a2f60afd4b162204b62c38e187881246 to your computer and use it in GitHub Desktop.
ucb1 algo.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object ucb1 extends App { | |
printf("Min Bid:") | |
val min = readDouble | |
printf("Max Bid:") | |
val max = readDouble | |
printf("Number of arms:") | |
val arms = readInt | |
printf("Trials (say 1000):") | |
val trials = readInt | |
// associate a bid with each arm | |
val incr = (max-min).toDouble/(arms-1) | |
val bids = (min to max by incr).toList // arm x has bid equal to bids(x) | |
val eps = Eps(arms) | |
def randomArm:Int = (math.random*arms).toInt | |
// This is my fake static database. The real databases reside on AWS, they will get updated over time etc. | |
val bidDB = List.tabulate[Int](bids.size){ x=> (math.random * 1000).toInt } // how many times we bid a particular amount | |
val clickDB = List.tabulate[Int](bids.size){ x=> (math.random * bidDB(x)).toInt } // how many clicks we got | |
val allCTR = clickDB.zip(bidDB).map{ x=> x._1.toDouble/x._2} | |
val bestArm:Int = allCTR.zipWithIndex.maxBy{ x=> x._1}._2 | |
val bestCTR:Double = allCTR.max | |
def reward(arm:Int):Int = { | |
// grab the bid for this arm | |
val myBid = bids(arm) | |
// hit the DB and find out two things | |
// 1. N = how many times we bid this amount | |
// 2. C = how many clicks we got in total | |
val N = bidDB(arm) // THE ACTUAL NUMBER OF TIMES WE BID myBid GOES HERE | |
val C = clickDB(arm) // THE ACTUAL NUMBER OF CLICKS WE GOT for myBid GOES HERE | |
val ctr = C.toDouble/N | |
if (math.random < ctr) 1 else 0 | |
} | |
val pulls = List.tabulate[Int](trials){ x=> eps.reward } | |
printf("Our wins in %d trials: %d\n", trials, pulls.sum) | |
printf("Max wins possible: %d\n", (bestCTR * trials).toInt) | |
printf("The best arm, as per ucb1 algo: %d\n", eps.best) | |
printf("The best CTR, as per ucb1 algo: %.2f\n", allCTR(eps.best)) | |
printf("The TRUE best arm: %d\n", bestArm) | |
printf("The TRUE best CTR: %.2f\n", bestCTR) | |
println("Bid DB: " + bidDB) | |
println("Click DB: " + clickDB) | |
} | |
case class Eps(arms:Int) { | |
// initiaklize reward & arm by 1 instead of 0, to avoid a 0/0 problem when finding bestArm | |
val egRewards = Array.fill[Int](arms)(1) // reward yielded by each arm so far | |
val egArms = Array.fill[Int](arms)(1) // the number of times each arm was pulled so far | |
var best:Int = 0 | |
def update(arm:Int, reward:Int):Unit = { | |
egArms(arm) = egArms(arm) + 1 // update the arm pull | |
egRewards(arm) = egRewards(arm) + reward // update the reward for the arm | |
best = arm | |
} | |
// ucb1 | |
def reward:Int = { | |
val mujl = (0 until arms).toList.map{ j=> | |
val muj = egRewards(j).toDouble/egArms(j) | |
val Nj = egArms(j) | |
val N = egArms.sum | |
muj + math.sqrt(2*math.log(N)/Nj) | |
} | |
//println(mujl) | |
val arm = mujl.zipWithIndex.maxBy{ x=> x._1}._2 | |
//printf("Picking arm %d\n", arm) | |
val r = ucb1.reward(arm) | |
update(arm, r) | |
r | |
} | |
} | |
/* | |
$ scala ucb1 | |
Min Bid:0.1 | |
Max Bid:0.5 | |
Number of arms:7 | |
Trials (say 1000):10000 | |
Our wins in 10000 trials: 9377 | |
Max wins possible: 9605 | |
The best arm, as per ucb1 algo: 5 | |
The best CTR, as per ucb1 algo: 0.96 | |
The TRUE best arm: 5 | |
The TRUE best CTR: 0.96 | |
Bid DB: List(599, 248, 291, 677, 493, 634, 389) | |
Click DB: List(52, 8, 88, 633, 307, 609, 283) | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment