Skip to content

Instantly share code, notes, and snippets.

@malte-j
Created December 4, 2023 10:35
Show Gist options
  • Save malte-j/5d846a92159f00f83a1d7db69adaf68a to your computer and use it in GitHub Desktop.
Save malte-j/5d846a92159f00f83a1d7db69adaf68a to your computer and use it in GitHub Desktop.
dart thompson sampling
import 'dart:math';
class ThompsonSampling {
final List<double> means;
final List<double> variances;
ThompsonSampling(List<double> initialMeans, List<double> initialVariances)
: means = List.from(initialMeans),
variances = List.from(initialVariances);
void updateObservations(int armIndex, double newObservation) {
// Update mean and variance based on new observation
final double oldMean = means[armIndex];
final double oldVariance = variances[armIndex];
// Update mean and variance using online update formulas
final double newMean = (oldMean + newObservation) / 2;
final double newVariance =
(oldVariance + pow(newObservation - oldMean, 2)) / 2;
means[armIndex] = newMean;
variances[armIndex] = newVariance;
}
int selectArm() {
// Number of arms (options)
final int numArms = means.length;
// Perform Thompson Sampling for each arm
final List<double> samples = List.generate(numArms, (index) {
// Generate a random sample for each arm using the Normal distribution
final double sample = Random().nextDouble();
// Calculate the sampled value from the Normal distribution
return means[index] + sqrt(variances[index]) * cos(2 * pi * sample);
});
// Choose the arm with the highest sampled value
final int selectedArm = samples.indexOf(samples.reduce(max));
return selectedArm;
}
}
void main() {
// Example usage
final List<double> initialMeans = [
1.0,
1.0,
1.0,
]; // Initial mean for each arm
final List<double> initialVariances = [
2.0,
2.0,
2.0,
]; // Initial variance for each arm
// Create Thompson Sampling instance
final ThompsonSampling thompsonSampling =
ThompsonSampling(initialMeans, initialVariances);
// // Simulate new observations (adjust to new data)
thompsonSampling.updateObservations(0, 11.0);
// thompsonSampling.updateObservations(1, 10.0);
// thompsonSampling.updateObservations(2, 3.0);
// Get the index of the selected arm using Thompson Sampling
final int selectedArm = thompsonSampling.selectArm();
print("Selected Arm: $selectedArm");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment