Skip to content

Instantly share code, notes, and snippets.

@nburke-salesforce
Created July 22, 2021 15:45
Show Gist options
  • Save nburke-salesforce/c74e3826919adb0e4c421db60a24b468 to your computer and use it in GitHub Desktop.
Save nburke-salesforce/c74e3826919adb0e4c421db60a24b468 to your computer and use it in GitHub Desktop.
In Person Return To Work Simulation
package example
import java.io._
import scala.util.Random
import scala.annotation.tailrec
import com.github.tototoshi.csv._
trait SalesTeam {
val pop: Int
}
case class InPersonTeam(pop: Int, winProb: Double, inPersonWinProb: Double) extends SalesTeam
case class RemoteTeam(pop: Int, winProb: Double) extends SalesTeam
case class Marketplace(pop: Int, inPersonFrac: Double) {
val rand = new Random()
val inPersons = (0 until pop).map { x => rand.nextDouble < inPersonFrac }
val targets = inPersons.zipWithIndex.map { case (k, v) => (v, k) }
def getsToYes(team: SalesTeam, inPersonTarget: Boolean): Boolean = {
if (inPersonTarget) {
team match {
case InPersonTeam(_, _, inPersonWinProb) => rand.nextDouble < inPersonWinProb
case RemoteTeam(_, winProb) => rand.nextDouble < winProb
}
} else {
team match {
case InPersonTeam(_, winProb, _) => rand.nextDouble < winProb
case RemoteTeam(_, winProb) => rand.nextDouble < winProb
}
}
}
def reduceResults(a: ((Int, Int), Set[Int]), b: ((Int, Int), Set[Int])): ((Int, Int), Set[Int]) = {
((a._1._1 + b._1._1, a._1._2 + b._1._2), a._2.union(b._2))
}
def doRound(
team1: SalesTeam,
team2: SalesTeam,
targets: Seq[(Int, Boolean)]
): ((Int, Int), Seq[(Int, Boolean)], Seq[(Int, Int)]) = {
@tailrec def doRoundRec(
score: (Int, Int),
targets: Seq[(Int, Boolean)],
history: Seq[(Int, Int)]
): ((Int, Int), Seq[(Int, Boolean)], Seq[(Int, Int)]) = {
if(targets.size == 0) {
(score, targets, history)
} else {
val team1Targets = Random.shuffle(targets).take(team1.pop).toSet
val team2Targets = Random.shuffle(targets).take(team2.pop).toSet
val overlaps = team1Targets.intersect(team2Targets)
val team1Only = team1Targets.diff(overlaps)
val team2Only = team2Targets.diff(overlaps)
val overlapResults = overlaps.map { case (id, isInPerson) =>
val team1win = getsToYes(team1, isInPerson)
val team2win = getsToYes(team2, isInPerson)
// TODO: needs to be more generic, assumes team 2 is in person
if (team2win) {
((0, 1), Set(id))
} else if (team1win) {
((1, 0), Set(id))
} else {
((0, 0), Set.empty[Int])
}
}.foldLeft(((0, 0), Set.empty[Int])) { (a, b) => reduceResults(a, b) }
val team1OnlyResults = team1Only.map { case (id, isInPerson) =>
val team1win = getsToYes(team1, isInPerson)
if (team1win) {
((1, 0), Set(id))
} else {
((0, 0), Set.empty[Int])
}
}.foldLeft(((0, 0), Set.empty[Int])) { (a, b) => reduceResults(a, b) }
val team2OnlyResults = team2Only.map { case (id, isInPerson) =>
val team2win = getsToYes(team2, isInPerson)
if (team2win) {
((0, 1), Set(id))
} else {
((0, 0), Set.empty[Int])
}
}.foldLeft(((0, 0), Set.empty[Int])) { (a, b) => reduceResults(a, b) }
val (scoreAddition, toDel) = Seq(
overlapResults,
team1OnlyResults,
team2OnlyResults
).reduce((a, b) => reduceResults(a, b))
val newScore = (scoreAddition._1 + score._1, scoreAddition._2 + score._2)
val newTargets = targets.filter { case (id, _) => !toDel.contains(id) }
val newHistory = newScore +: history
doRoundRec(newScore, newTargets, newHistory)
}
}
doRoundRec(
(0, 0),
targets,
List.empty[(Int, Int)]
)
}
}
object Simulation extends App {
// val strategy = "teamSize"
val strategy = "winRate"
val marketSize = 1000
val remoteProbWin = 0.1
val proportionInPerson = 0.9
val remote = RemoteTeam(100, remoteProbWin)
val numMonteRounds = 99
val marketplace = Marketplace(marketSize, proportionInPerson)
val oot = (1 to numMonteRounds).map { round =>
val teamSize = strategy match {
case "teamSize" => round % 4 match {
case 0 => 20
case 1 => 40
case 2 => 80
case 3 => 160
}
case _ => 20
}
val inPersonWinRate = strategy match {
case "winRate" => round % 4 match {
case 0 => 0.1
case 1 => 0.2
case 2 => 0.4
case 3 => 0.8
}
case _ => 0.1
}
marketplace.doRound(
remote,
InPersonTeam(teamSize, 0.1, inPersonWinRate),
marketplace.targets
)._3.map(_._2 * 0.001d).reverse
}.toList
val f = new File("out.csv")
val writer = CSVWriter.open(f)
writer.writeAll(oot)
writer.close()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment