Created
October 5, 2022 20:40
-
-
Save george-polevoy/e7c9098a74b4ef813c4308f56dbd7871 to your computer and use it in GitHub Desktop.
Simulated Annealing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
namespace SimulatedAnnealing; | |
public class Tests | |
{ | |
[TestCase(new[] { 6, 1, 9, 14, 19, 11, 0, 12, 8, 3, 5, 2, 4, 7, 10, 16, 15, 17, 18, 13, }, | |
new[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 }, 300000)] | |
public void CanSort(int[] source, int[] expected, int kMax) | |
{ | |
var random = new Random(123); | |
ISolver sa = new SimulatedAnnealing(kMax, random); | |
var problem = new SortingProblemFamily(random); | |
var actual = sa.Solve(problem, source.ToList()); | |
TestContext.WriteLine(actual.quality); | |
TestContext.WriteLine(string.Join(", ", actual.solution)); | |
CollectionAssert.AreEqual(expected, actual.solution); | |
} | |
[TestCase("1432219", 3, "1219", 32)] | |
[TestCase("10200", 1, "0200", 8)] | |
[TestCase("10", 2, "0", 3)] | |
public void CanRemoveKDigits(string num, int k, string expectedOutput, int kMax) | |
{ | |
var random = new Random(123); | |
ISolver sa = new SimulatedAnnealing(kMax, random); | |
var problem = new RemoveKDigitsProblemFamily(random, num, k); | |
var actual = sa.Solve(problem, | |
new BitSet(Enumerable.Repeat(false, k).Concat(Enumerable.Repeat(true, num.Length - k)).ToList())); | |
var actualSolutionPhenotype = problem.DerivePhenotype(actual.solution); | |
Assert.That(actualSolutionPhenotype, Is.EqualTo(expectedOutput)); | |
} | |
public interface IProblemFamily<TGenotype, TPhenotype> | |
{ | |
TPhenotype DerivePhenotype(TGenotype genotype); | |
TGenotype Mutate(TGenotype genotype); | |
double EstimateError(TPhenotype phenotype); | |
} | |
public class SortingProblemFamily : IProblemFamily<List<int>, List<int>> | |
{ | |
private Random _random; | |
public SortingProblemFamily(Random random) | |
{ | |
_random = random; | |
} | |
public List<int> DerivePhenotype(List<int> genotype) | |
{ | |
// don't need anything more than phenotype itself to estimate. | |
return genotype; | |
} | |
public List<int> Mutate(List<int> genotype) | |
{ | |
var start = _random.Next(genotype.Count); | |
var len = _random.Next(2, Math.Max(2, genotype.Count / 10)); | |
var result = new List<int>(genotype.Count); | |
result.AddRange(genotype); | |
Reverse(result, (l, a, b) => (l[a % l.Count], l[b % l.Count]) = (l[b % l.Count], l[a % l.Count]), start, | |
len); | |
return result; | |
} | |
public double EstimateError(List<int> phenotype) | |
{ | |
return UnsortedCost(phenotype); | |
} | |
static void Reverse<TArg>(TArg arg, Action<TArg, int, int> swap, int start, int len) | |
{ | |
for (var i = 0; i < len / 2; i++) | |
{ | |
swap(arg, start + i, start + len - i - 1); | |
} | |
} | |
static double UnsortedCost(List<int> source) | |
{ | |
double s = 0; | |
var len = source.Count; | |
var lenSquare = (double)len * len; | |
for (var i = 0; i < len; i++) | |
{ | |
var x = source[i] - i; | |
s += x * x; | |
} | |
return s / lenSquare; | |
} | |
} | |
public interface ISolver | |
{ | |
public (TGenotype solution, double quality) Solve<TGenotype, TPhenotype>( | |
IProblemFamily<TGenotype, TPhenotype> problemFamily, | |
TGenotype start); | |
} | |
public class SimulatedAnnealing : ISolver | |
{ | |
private readonly int _kMax; | |
private readonly Random _random; | |
public SimulatedAnnealing(int kMax, Random random) | |
{ | |
_kMax = kMax; | |
_random = random; | |
} | |
public (TGenotype solution, double quality) Solve<TGenotype, TPhenotype>( | |
IProblemFamily<TGenotype, TPhenotype> problemFamily, | |
TGenotype start) | |
{ | |
var startPhenotype = problemFamily.DerivePhenotype(start); | |
var current = (solition: start, cost: problemFamily.EstimateError(startPhenotype)); | |
var best = current; | |
for (var k = 0; k < _kMax; k++) | |
{ | |
if (current.cost == 0) | |
{ | |
TestContext.WriteLine($"Solution found at step: {k}"); | |
break; | |
} | |
var candidateSolution = problemFamily.Mutate(current.solition); | |
var candidatePhenotype = problemFamily.DerivePhenotype(candidateSolution); | |
var candidate = (solution: candidateSolution, cost: problemFamily.EstimateError(candidatePhenotype)); | |
double t = 1 - (k + 1.0) / _kMax; | |
var costDiff = candidate.cost - current.cost; | |
var p = Math.Exp(-costDiff / t); | |
var r = _random.NextDouble(); | |
if (candidate.cost <= current.cost | |
|| r < p) | |
{ | |
current = candidate; | |
} | |
if (candidate.cost < best.cost) | |
{ | |
best = candidate; | |
} | |
} | |
return best; | |
} | |
} | |
/// <summary> | |
/// RemoveKDigitsProblemFamily. This problem is intentionally formulated as stupid as possible. | |
/// </summary> | |
public class RemoveKDigitsProblemFamily : IProblemFamily<BitSet, string> | |
{ | |
private readonly Random _random; | |
private readonly string _num; | |
private readonly int _k; | |
private double _actualNum; | |
public RemoveKDigitsProblemFamily(Random random, string num, int k) | |
{ | |
_random = random; | |
_num = num; | |
_k = k; | |
_actualNum = double.Parse(num); | |
} | |
public string DerivePhenotype(BitSet genotype) | |
{ | |
// Span<char> buffer = new char[_num.Length - _k]; | |
// var p = 0; | |
// for (var i = 0; i < _num.Length; i++) | |
// { | |
// if (!genotype[i]) continue; | |
// buffer[p++] = _num[i]; | |
// } | |
// | |
// var phenotype = new string(buffer); | |
var phenotypeChars = _num | |
.Select((ch, index) => (ch, index)) | |
.Where(i => genotype[i.index]) | |
.Select(i => i.ch) | |
.ToArray(); | |
var phenotype = new string(phenotypeChars); | |
if (phenotype == "") | |
{ | |
return "0"; | |
} | |
return phenotype; | |
} | |
public BitSet Mutate(BitSet genotype) | |
{ | |
var (a, b) = (_random.Next(_num.Length), _random.Next(_num.Length)); | |
var r = new BitSet(genotype); | |
(r[a], r[b]) = (r[b], r[a]); // swap bits | |
return r; | |
} | |
/// <summary> | |
/// Cost function should return error amount. | |
/// Values should be in range [0..1) | |
/// </summary> | |
public double EstimateError(string phenotype) | |
{ | |
if (phenotype == "") | |
{ | |
return 0; | |
} | |
return double.Parse(phenotype) / _actualNum; | |
} | |
} | |
public struct BitSet | |
{ | |
private long _bits; | |
public bool this[int x] | |
{ | |
get => (_bits & (1L << x)) != 0; | |
set | |
{ | |
if (value) | |
{ | |
_bits |= 1L << x; | |
} | |
else | |
{ | |
_bits &= ~ (1L << x); | |
} | |
} | |
} | |
public BitSet(IEnumerable<bool> bits) | |
{ | |
_bits = 0; | |
var count = 0; | |
using var ie = bits.GetEnumerator(); | |
while (ie.MoveNext()) | |
{ | |
if (ie.Current) | |
{ | |
_bits |= 1L << count; | |
} | |
count++; | |
} | |
} | |
public BitSet(BitSet source) | |
{ | |
_bits = source._bits; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment