Created
April 4, 2016 05:30
-
-
Save margusmartsepp/987121fdec97ca5253bef0706ffc8c02 to your computer and use it in GitHub Desktop.
A* on IntervalHeap's (eucleidian distance heuristics)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Drawing; | |
using System.Drawing.Imaging; | |
using System.Linq; | |
using System.Runtime.InteropServices; | |
using C5; | |
namespace Ai.framework2 | |
{ | |
public class SpiderStates | |
{ | |
public Position Initial { get; } | |
public Position Goal { get; } | |
public List<Position> Walls { get; } | |
public SpiderStates(string filename) | |
{ | |
Walls = new List<Position>(); | |
var bmp = new Bitmap(filename); | |
var rect = new Rectangle(0, 0, bmp.Width, bmp.Height); | |
var bmpData = bmp.LockBits(rect, ImageLockMode.ReadWrite, PixelFormat.Format24bppRgb); | |
var ptr = bmpData.Scan0; | |
var bytes = bmpData.Stride * bmp.Height; | |
var rgbValues = new byte[bytes]; | |
Marshal.Copy(ptr, rgbValues, 0, bytes); | |
var stride = bmpData.Stride; | |
for (var column = 0; column < bmpData.Height; column++) | |
{ | |
for (var row = 0; row < bmpData.Width; row++) | |
{ | |
var b = rgbValues[column * stride + row * 3]; | |
var g = rgbValues[column * stride + row * 3 + 1]; | |
var r = rgbValues[column * stride + row * 3 + 2]; | |
if (b == 0) Walls.Add(Position.Create(row, column)); | |
else if (r < 127) Goal = Position.Create(row, column); | |
else if (g < 127) Initial = Position.Create(row, column); | |
} | |
} | |
Grapher.MaxX = bmp.Width; | |
Grapher.MaxY = bmp.Height; | |
} | |
} | |
public static class GraphicsExtensions | |
{ | |
public static void DrawCircle(this Graphics g, Pen pen, | |
double centerX, double centerY, float radius) | |
{ | |
g.DrawEllipse(pen, (float)centerX * 10 + 3 - radius, (float)centerY * 10 + 3 - radius, | |
radius + radius, radius + radius); | |
} | |
public static void FillCircle(this Graphics g, Brush brush, | |
double centerX, double centerY, float radius) | |
{ | |
g.FillEllipse(brush, (float)centerX * 10 + 3 - radius, (float)centerY * 10 + 3 - radius, | |
radius + radius, radius + radius); | |
} | |
} | |
public class Grapher | |
{ | |
public static void Plot<T>(IProblem<T> problem, IPriorityQueue<INode<T>> fringe) | |
{ | |
Plot(problem as SpiderProblem, fringe as IntervalHeap<INode<Position>>); | |
} | |
private static int _counter; | |
public static double MaxX = 10; | |
public static double MaxY = 10; | |
public static void Plot(SpiderProblem problem, IntervalHeap<INode<Position>> fringe) | |
{ | |
var data = new INode<Position>[fringe.Count]; | |
fringe.CopyTo(data, 0); | |
var dataStates = new C5.HashSet<Position>(); | |
dataStates.AddAll(data.Select(o => o.State)); | |
var findMin = fringe.FindMin(); | |
var dimX = (int)(MaxX * 10 + 3); | |
var dimY = (int)(MaxY * 10 + 3); | |
var initial = new Pen(Color.Red); | |
var goal = new Pen(Color.Green); | |
var known = new Pen(Color.Blue); | |
var link = new Pen(Color.LightGray); | |
var initialb = new SolidBrush(Color.Red); | |
var goalb = new SolidBrush(Color.Green); | |
var knownb = new SolidBrush(Color.Blue); | |
var wall = new SolidBrush(Color.Black); | |
var deadlink = new SolidBrush(Color.LightGray); | |
using (var bmp = new Bitmap(dimX, dimY, PixelFormat.Format32bppPArgb)) | |
using (var gr = Graphics.FromImage(bmp)) | |
{ | |
gr.Clear(Color.WhiteSmoke); | |
foreach (var item in data) | |
{ | |
if (item.Parent != null) | |
gr.DrawLine( | |
link, | |
new Point((int)item.State.X * 10 + 3, (int)item.State.Y * 10 + 3), | |
new Point((int)item.Parent.State.X * 10 + 3, (int)item.Parent.State.Y * 10 + 3)); | |
} | |
foreach (var item in problem._memory) | |
{ | |
gr.DrawCircle(known, item.X, item.Y, 3); | |
} | |
foreach (var item in problem._memory) | |
{ | |
if (dataStates.Contains(item)) | |
{ | |
gr.DrawCircle(known, item.X, item.Y, 3); | |
} | |
else | |
{ | |
gr.DrawCircle(link, item.X, item.Y, 3); | |
gr.FillCircle(deadlink, item.X, item.Y, 3); | |
} | |
} | |
foreach (var item in Util.GetPathFromRoot(findMin)) | |
{ | |
if (item.Parent != null) | |
gr.DrawLine( | |
known, | |
new Point((int)item.State.X * 10 + 3, (int)item.State.Y * 10 + 3), | |
new Point((int)item.Parent.State.X * 10 + 3, (int)item.Parent.State.Y * 10 + 3)); | |
gr.FillCircle(knownb, item.State.X, item.State.Y, 3); | |
} | |
foreach (var item in problem._walls) | |
{ | |
gr.FillRectangle(wall, new RectangleF((float)item.X * 10 - 1, (float)item.Y * 10 - 1, 9, 9)); | |
//gr.FillCircle(wall, item.X, item.Y, 3); | |
} | |
gr.FillCircle(initialb, problem._root.State.X, problem._root.State.Y, 3); | |
gr.FillCircle(goalb, problem._goal.X, problem._goal.Y, 3); | |
gr.DrawCircle(initial, problem._root.State.X, problem._root.State.Y, 3); | |
gr.DrawCircle(goal, problem._goal.X, problem._goal.Y, 3); | |
var filename = $"C:\\img\\result_{_counter++}.png"; | |
bmp.Save(filename, ImageFormat.Png); | |
} | |
Console.WriteLine(findMin.State); | |
} | |
} | |
public class SpiderProblem : IProblem<Position> | |
{ | |
public readonly Memory<Position> _memory; | |
public readonly Node<Position> _root; | |
public readonly Position _goal; | |
public readonly Memory<Position> _walls; | |
public SpiderProblem(Position initial, Position goal, IEnumerable<Position> walls = null) | |
{ | |
_goal = goal; | |
_root = new Node<Position>(initial); | |
_memory = new Memory<Position> { initial }; | |
if (walls == null) return; | |
var items = walls.ToList(); | |
_memory.AddAll(items); | |
_walls = new Memory<Position>(); | |
_walls.AddAll(items); | |
} | |
public INode<Position> GetRoot() | |
{ | |
return _root; | |
} | |
public bool IsGoal(INode<Position> current) | |
{ | |
return Equals(current.State, _goal); | |
} | |
public double GetStepCost(INode<Position> current, Position next) | |
{ | |
return Util.EuclideanDistance(current.State, next); | |
} | |
public IEnumerable<INode<Position>> GetSuccessors(INode<Position> parent) | |
{ | |
return new INode<Position>[] | |
{ | |
new Move.NorthWest(parent), | |
new Move.North(parent), | |
new Move.NorthEast(parent), | |
new Move.West(parent), | |
new Move.East(parent), | |
new Move.SouthWest(parent), | |
new Move.South(parent), | |
new Move.SouthEast(parent) | |
} | |
.Where(o => !_memory.Contains(o.State)); | |
} | |
public double GetHeuristicValue(Position next) | |
{ | |
return Util.EuclideanDistance(next, _goal); | |
} | |
} | |
public static class Util | |
{ | |
public static double EuclideanDistance(Tuple<double, double> first, Tuple<double, double> second) | |
{ | |
var difItem1 = first.Item1 - second.Item1; | |
var difItem2 = first.Item2 - second.Item2; | |
return Math.Sqrt(difItem1 * difItem1 + difItem2 * difItem2); | |
} | |
public static IEnumerable<INode<T>> GetPathFromRoot<T>(INode<T> current) | |
{ | |
Stack<INode<T>> queue; | |
for (queue = new Stack<INode<T>>(); !current.IsRootNode; current = current.Parent) | |
queue.Push(current); | |
return queue; | |
} | |
public static Position CreateRelative(Position position, double x, double y) | |
{ | |
return Position.Create(position.X + x, position.Y + y); | |
} | |
} | |
public class Position : Tuple<double, double> | |
{ | |
public double X => Item1; | |
public double Y => Item2; | |
private Position(double x, double y) : base(x, y) | |
{ | |
} | |
public static Position Create(double x, double y) | |
{ | |
return new Position(x, y); | |
} | |
} | |
public class NodeComparator<T> : IComparer<INode<T>> | |
{ | |
private readonly IProblem<T> _problem; | |
private readonly double _tolerance; | |
public NodeComparator(IProblem<T> problem, double tolerance = 0.00000001d) | |
{ | |
_tolerance = tolerance; | |
_problem = problem; | |
} | |
public int Compare(INode<T> one, INode<T> two) | |
{ | |
var s1 = one.PathCost + _problem.GetHeuristicValue(one.State); | |
var s2 = two.PathCost + _problem.GetHeuristicValue(two.State); | |
if (Math.Abs(s1 - s2) < _tolerance) | |
return 0; | |
return s1 < s2 ? -1 : 1; | |
} | |
} | |
public class Memory<T> : C5.HashSet<T> | |
{ | |
public override bool Contains(T item) | |
{ | |
if (base.Contains(item)) | |
return true; | |
Add(item); | |
return false; | |
} | |
} | |
public class AStarSearch<T> : PrioritySearch<T> | |
{ | |
public override IComparer<INode<T>> GetComparator(IProblem<T> p) | |
{ | |
return new NodeComparator<T>(p); | |
} | |
} | |
public abstract class PrioritySearch<T> : ISearch<T> | |
{ | |
public abstract IComparer<INode<T>> GetComparator(IProblem<T> p); | |
public IEnumerable<INode<T>> Search(IProblem<T> problem, IPriorityQueue<INode<T>> fringe) | |
{ | |
INode<T> node; | |
for (fringe.Add(problem.GetRoot()); !fringe.IsEmpty; fringe.AddAll(problem.GetSuccessors(node))) | |
{ | |
Grapher.Plot(problem, fringe); | |
node = fringe.DeleteMin(); | |
if (problem.IsGoal(node)) | |
return Util.GetPathFromRoot(node); | |
} | |
return null; | |
} | |
} | |
public abstract class Move : Node<Position> | |
{ | |
protected const double Diagonal = 1.414213562373095048801688724209698078569671875376948073176d; | |
protected const double Direct = 1; | |
private Move(INode<Position> parent, double x, double y, string action = "root", double stepCost = Direct) | |
: base(Util.CreateRelative(parent.State, x, y), parent, stepCost, parent.PathCost + parent.StepCost, action + " " + Util.CreateRelative(parent.State, x, y)) | |
{ } | |
public class NorthWest : Move | |
{ | |
public NorthWest(INode<Position> parent) : base(parent, -1, -1, "nw", Diagonal) { } | |
} | |
public class North : Move | |
{ | |
public North(INode<Position> parent) : base(parent, 0, -1, "n") { } | |
} | |
public class NorthEast : Move | |
{ | |
public NorthEast(INode<Position> parent) : base(parent, 1, -1, "ne", Diagonal) { } | |
} | |
public class West : Move | |
{ | |
public West(INode<Position> parent) : base(parent, -1, 0, "w") { } | |
} | |
public class East : Move | |
{ | |
public East(INode<Position> parent) : base(parent, 1, 0, "e") { } | |
} | |
public class SouthWest : Move | |
{ | |
public SouthWest(INode<Position> parent) : base(parent, -1, 1, "sw", Diagonal) { } | |
} | |
public class South : Move | |
{ | |
public South(INode<Position> parent) : base(parent, 0, 1, "s") { } | |
} | |
public class SouthEast : Move | |
{ | |
public SouthEast(INode<Position> parent) : base(parent, 1, 1, "se", Diagonal) { } | |
} | |
} | |
public class Node<T> : INode<T> | |
{ | |
public T State { get; } | |
public INode<T> Parent { get; } | |
public bool IsRootNode => Parent == null; | |
public double StepCost { get; } | |
public double PathCost { get; } | |
public string Action { get; } | |
public Node(T state, INode<T> parent = null, double stepCost = 0, double pathCost = 0, string action = "root") | |
{ | |
Parent = parent; | |
State = state; | |
StepCost += pathCost; | |
PathCost += stepCost; | |
Action = action; | |
} | |
public override string ToString() | |
{ | |
return State.ToString(); | |
} | |
} | |
public interface INode<out T> | |
{ | |
T State { get; } | |
INode<T> Parent { get; } | |
bool IsRootNode { get; } | |
double StepCost { get; } | |
double PathCost { get; } | |
string Action { get; } | |
} | |
public interface IProblem<T> | |
{ | |
INode<T> GetRoot(); | |
bool IsGoal(INode<T> current); | |
double GetStepCost(INode<T> current, T next); | |
IEnumerable<INode<T>> GetSuccessors(INode<T> current); | |
double GetHeuristicValue(T next); | |
} | |
public interface ISearch<T> | |
{ | |
IEnumerable<INode<T>> Search(IProblem<T> problem, IPriorityQueue<INode<T>> fringe); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Linq; | |
using Ai.DataStructures; | |
using Ai.framework2; | |
using C5; | |
using FluentAssertions; | |
using NUnit.Framework; | |
namespace AiTest.DataStructures | |
{ | |
[TestFixture] | |
public class SolveSpiderTests | |
{ | |
[Test] | |
public void SolveSpider() | |
{ | |
// Act | |
var algorithm = new AStarSearch<Position>(); | |
var problem = new SpiderProblem( | |
initial: Position.Create(3, 4), | |
goal: Position.Create(9, 4), | |
walls: new[] | |
{ | |
Position.Create(6, 3), | |
Position.Create(6, 4), | |
Position.Create(6, 5) | |
}); | |
var fringe = new IntervalHeap<INode<Position>>(algorithm.GetComparator(problem)); | |
var result = algorithm.Search(problem, fringe); | |
Console.WriteLine(string.Join(", ", result.Select(o => o.Action))); | |
} | |
[Test] | |
[TestCase(@"C:\img\input\input1.png")] | |
[TestCase(@"C:\img\input\input3.png")] | |
public void SolveSpiderFile(string filename) | |
{ | |
// Act | |
var algorithm = new AStarSearch<Position>(); | |
var states = new SpiderStates(filename); | |
var problem = new SpiderProblem( | |
initial: states.Initial, | |
goal: states.Goal, | |
walls: states.Walls); | |
var fringe = new IntervalHeap<INode<Position>>(algorithm.GetComparator(problem)); | |
var result = algorithm.Search(problem, fringe); | |
Console.WriteLine(string.Join(", ", result.Select(o => o.Action))); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment