Created
October 30, 2019 20:02
-
-
Save FlorianDe/2346055908ba8744e53260ad4d375d6b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
typealias WeightedElement<T> = (element: T, weight: Float) | |
enum Efficiency : CaseIterable { | |
case low | |
case normal | |
case high | |
case perfect | |
} | |
let testWeightElements:[WeightedElement<Efficiency>] = [ | |
(Efficiency.low, 0.0), | |
(Efficiency.normal, 1.0), | |
(Efficiency.high, 0.0), | |
(Efficiency.perfect, 1.0) | |
] | |
class WeightedElementCalculator { | |
static func getRndmElement<ElementType>(values: [WeightedElement<ElementType>]) -> ElementType? { | |
var selectedElement: ElementType? | |
let weightSum = values.reduce(0) { (result, element) in return result + element.weight } | |
if !weightSum.isZero { | |
let rand = Float.random(in: 0..<weightSum) | |
var sum = Float(0) | |
for (element, weight) in values { | |
sum += weight | |
if rand < sum { | |
selectedElement = element | |
break | |
} | |
} | |
} | |
return selectedElement | |
} | |
} | |
// Probability test to show whether the WeightedElementCalculator class is working correctly | |
let rounds = 1000000 | |
var effMap : [Efficiency : Int] = Dictionary(uniqueKeysWithValues: Efficiency.allCases.map {($0,0)}) | |
let start = DispatchTime.now() | |
for i in 1...rounds { | |
let optRndmElement = WeightedElementCalculator.getRndmElement(values: testWeightElements) | |
if let rndmElement = optRndmElement as? Efficiency { | |
effMap[rndmElement]! += 1 | |
} | |
} | |
let weightSum = testWeightElements.reduce(0) { (result, element) in return result + element.weight } | |
testWeightElements.forEach { (elem, weight) in | |
print("\(elem) simulated prob: \(Float(effMap[elem]!)/Float(rounds)) -> expected prob: \(weight/weightSum)") | |
} | |
let timeInterval = Double(DispatchTime.now().uptimeNanoseconds - start.uptimeNanoseconds) / 1_000_000_000 // Technically could overflow for long running tests | |
print("Calulated \(rounds) random elements in \(timeInterval) seconds") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment