Skip to content

Instantly share code, notes, and snippets.

@krishnanraman
Created August 24, 2017 22:05
Show Gist options
  • Save krishnanraman/a2f60afd4b162204b62c38e187881246 to your computer and use it in GitHub Desktop.
Save krishnanraman/a2f60afd4b162204b62c38e187881246 to your computer and use it in GitHub Desktop.
ucb1 algo.
object ucb1 extends App {
printf("Min Bid:")
val min = readDouble
printf("Max Bid:")
val max = readDouble
printf("Number of arms:")
val arms = readInt
printf("Trials (say 1000):")
val trials = readInt
// associate a bid with each arm
val incr = (max-min).toDouble/(arms-1)
val bids = (min to max by incr).toList // arm x has bid equal to bids(x)
val eps = Eps(arms)
def randomArm:Int = (math.random*arms).toInt
// This is my fake static database. The real databases reside on AWS, they will get updated over time etc.
val bidDB = List.tabulate[Int](bids.size){ x=> (math.random * 1000).toInt } // how many times we bid a particular amount
val clickDB = List.tabulate[Int](bids.size){ x=> (math.random * bidDB(x)).toInt } // how many clicks we got
val allCTR = clickDB.zip(bidDB).map{ x=> x._1.toDouble/x._2}
val bestArm:Int = allCTR.zipWithIndex.maxBy{ x=> x._1}._2
val bestCTR:Double = allCTR.max
def reward(arm:Int):Int = {
// grab the bid for this arm
val myBid = bids(arm)
// hit the DB and find out two things
// 1. N = how many times we bid this amount
// 2. C = how many clicks we got in total
val N = bidDB(arm) // THE ACTUAL NUMBER OF TIMES WE BID myBid GOES HERE
val C = clickDB(arm) // THE ACTUAL NUMBER OF CLICKS WE GOT for myBid GOES HERE
val ctr = C.toDouble/N
if (math.random < ctr) 1 else 0
}
val pulls = List.tabulate[Int](trials){ x=> eps.reward }
printf("Our wins in %d trials: %d\n", trials, pulls.sum)
printf("Max wins possible: %d\n", (bestCTR * trials).toInt)
printf("The best arm, as per ucb1 algo: %d\n", eps.best)
printf("The best CTR, as per ucb1 algo: %.2f\n", allCTR(eps.best))
printf("The TRUE best arm: %d\n", bestArm)
printf("The TRUE best CTR: %.2f\n", bestCTR)
println("Bid DB: " + bidDB)
println("Click DB: " + clickDB)
}
case class Eps(arms:Int) {
// initiaklize reward & arm by 1 instead of 0, to avoid a 0/0 problem when finding bestArm
val egRewards = Array.fill[Int](arms)(1) // reward yielded by each arm so far
val egArms = Array.fill[Int](arms)(1) // the number of times each arm was pulled so far
var best:Int = 0
def update(arm:Int, reward:Int):Unit = {
egArms(arm) = egArms(arm) + 1 // update the arm pull
egRewards(arm) = egRewards(arm) + reward // update the reward for the arm
best = arm
}
// ucb1
def reward:Int = {
val mujl = (0 until arms).toList.map{ j=>
val muj = egRewards(j).toDouble/egArms(j)
val Nj = egArms(j)
val N = egArms.sum
muj + math.sqrt(2*math.log(N)/Nj)
}
//println(mujl)
val arm = mujl.zipWithIndex.maxBy{ x=> x._1}._2
//printf("Picking arm %d\n", arm)
val r = ucb1.reward(arm)
update(arm, r)
r
}
}
/*
$ scala ucb1
Min Bid:0.1
Max Bid:0.5
Number of arms:7
Trials (say 1000):10000
Our wins in 10000 trials: 9377
Max wins possible: 9605
The best arm, as per ucb1 algo: 5
The best CTR, as per ucb1 algo: 0.96
The TRUE best arm: 5
The TRUE best CTR: 0.96
Bid DB: List(599, 248, 291, 677, 493, 634, 389)
Click DB: List(52, 8, 88, 633, 307, 609, 283)
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment