Skip to content

Instantly share code, notes, and snippets.

@margusmartsepp
Created April 4, 2016 05:30
Show Gist options
  • Save margusmartsepp/987121fdec97ca5253bef0706ffc8c02 to your computer and use it in GitHub Desktop.
Save margusmartsepp/987121fdec97ca5253bef0706ffc8c02 to your computer and use it in GitHub Desktop.
A* on IntervalHeap's (eucleidian distance heuristics)
using System;
using System.Collections.Generic;
using System.Drawing;
using System.Drawing.Imaging;
using System.Linq;
using System.Runtime.InteropServices;
using C5;
namespace Ai.framework2
{
public class SpiderStates
{
public Position Initial { get; }
public Position Goal { get; }
public List<Position> Walls { get; }
public SpiderStates(string filename)
{
Walls = new List<Position>();
var bmp = new Bitmap(filename);
var rect = new Rectangle(0, 0, bmp.Width, bmp.Height);
var bmpData = bmp.LockBits(rect, ImageLockMode.ReadWrite, PixelFormat.Format24bppRgb);
var ptr = bmpData.Scan0;
var bytes = bmpData.Stride * bmp.Height;
var rgbValues = new byte[bytes];
Marshal.Copy(ptr, rgbValues, 0, bytes);
var stride = bmpData.Stride;
for (var column = 0; column < bmpData.Height; column++)
{
for (var row = 0; row < bmpData.Width; row++)
{
var b = rgbValues[column * stride + row * 3];
var g = rgbValues[column * stride + row * 3 + 1];
var r = rgbValues[column * stride + row * 3 + 2];
if (b == 0) Walls.Add(Position.Create(row, column));
else if (r < 127) Goal = Position.Create(row, column);
else if (g < 127) Initial = Position.Create(row, column);
}
}
Grapher.MaxX = bmp.Width;
Grapher.MaxY = bmp.Height;
}
}
public static class GraphicsExtensions
{
public static void DrawCircle(this Graphics g, Pen pen,
double centerX, double centerY, float radius)
{
g.DrawEllipse(pen, (float)centerX * 10 + 3 - radius, (float)centerY * 10 + 3 - radius,
radius + radius, radius + radius);
}
public static void FillCircle(this Graphics g, Brush brush,
double centerX, double centerY, float radius)
{
g.FillEllipse(brush, (float)centerX * 10 + 3 - radius, (float)centerY * 10 + 3 - radius,
radius + radius, radius + radius);
}
}
public class Grapher
{
public static void Plot<T>(IProblem<T> problem, IPriorityQueue<INode<T>> fringe)
{
Plot(problem as SpiderProblem, fringe as IntervalHeap<INode<Position>>);
}
private static int _counter;
public static double MaxX = 10;
public static double MaxY = 10;
public static void Plot(SpiderProblem problem, IntervalHeap<INode<Position>> fringe)
{
var data = new INode<Position>[fringe.Count];
fringe.CopyTo(data, 0);
var dataStates = new C5.HashSet<Position>();
dataStates.AddAll(data.Select(o => o.State));
var findMin = fringe.FindMin();
var dimX = (int)(MaxX * 10 + 3);
var dimY = (int)(MaxY * 10 + 3);
var initial = new Pen(Color.Red);
var goal = new Pen(Color.Green);
var known = new Pen(Color.Blue);
var link = new Pen(Color.LightGray);
var initialb = new SolidBrush(Color.Red);
var goalb = new SolidBrush(Color.Green);
var knownb = new SolidBrush(Color.Blue);
var wall = new SolidBrush(Color.Black);
var deadlink = new SolidBrush(Color.LightGray);
using (var bmp = new Bitmap(dimX, dimY, PixelFormat.Format32bppPArgb))
using (var gr = Graphics.FromImage(bmp))
{
gr.Clear(Color.WhiteSmoke);
foreach (var item in data)
{
if (item.Parent != null)
gr.DrawLine(
link,
new Point((int)item.State.X * 10 + 3, (int)item.State.Y * 10 + 3),
new Point((int)item.Parent.State.X * 10 + 3, (int)item.Parent.State.Y * 10 + 3));
}
foreach (var item in problem._memory)
{
gr.DrawCircle(known, item.X, item.Y, 3);
}
foreach (var item in problem._memory)
{
if (dataStates.Contains(item))
{
gr.DrawCircle(known, item.X, item.Y, 3);
}
else
{
gr.DrawCircle(link, item.X, item.Y, 3);
gr.FillCircle(deadlink, item.X, item.Y, 3);
}
}
foreach (var item in Util.GetPathFromRoot(findMin))
{
if (item.Parent != null)
gr.DrawLine(
known,
new Point((int)item.State.X * 10 + 3, (int)item.State.Y * 10 + 3),
new Point((int)item.Parent.State.X * 10 + 3, (int)item.Parent.State.Y * 10 + 3));
gr.FillCircle(knownb, item.State.X, item.State.Y, 3);
}
foreach (var item in problem._walls)
{
gr.FillRectangle(wall, new RectangleF((float)item.X * 10 - 1, (float)item.Y * 10 - 1, 9, 9));
//gr.FillCircle(wall, item.X, item.Y, 3);
}
gr.FillCircle(initialb, problem._root.State.X, problem._root.State.Y, 3);
gr.FillCircle(goalb, problem._goal.X, problem._goal.Y, 3);
gr.DrawCircle(initial, problem._root.State.X, problem._root.State.Y, 3);
gr.DrawCircle(goal, problem._goal.X, problem._goal.Y, 3);
var filename = $"C:\\img\\result_{_counter++}.png";
bmp.Save(filename, ImageFormat.Png);
}
Console.WriteLine(findMin.State);
}
}
public class SpiderProblem : IProblem<Position>
{
public readonly Memory<Position> _memory;
public readonly Node<Position> _root;
public readonly Position _goal;
public readonly Memory<Position> _walls;
public SpiderProblem(Position initial, Position goal, IEnumerable<Position> walls = null)
{
_goal = goal;
_root = new Node<Position>(initial);
_memory = new Memory<Position> { initial };
if (walls == null) return;
var items = walls.ToList();
_memory.AddAll(items);
_walls = new Memory<Position>();
_walls.AddAll(items);
}
public INode<Position> GetRoot()
{
return _root;
}
public bool IsGoal(INode<Position> current)
{
return Equals(current.State, _goal);
}
public double GetStepCost(INode<Position> current, Position next)
{
return Util.EuclideanDistance(current.State, next);
}
public IEnumerable<INode<Position>> GetSuccessors(INode<Position> parent)
{
return new INode<Position>[]
{
new Move.NorthWest(parent),
new Move.North(parent),
new Move.NorthEast(parent),
new Move.West(parent),
new Move.East(parent),
new Move.SouthWest(parent),
new Move.South(parent),
new Move.SouthEast(parent)
}
.Where(o => !_memory.Contains(o.State));
}
public double GetHeuristicValue(Position next)
{
return Util.EuclideanDistance(next, _goal);
}
}
public static class Util
{
public static double EuclideanDistance(Tuple<double, double> first, Tuple<double, double> second)
{
var difItem1 = first.Item1 - second.Item1;
var difItem2 = first.Item2 - second.Item2;
return Math.Sqrt(difItem1 * difItem1 + difItem2 * difItem2);
}
public static IEnumerable<INode<T>> GetPathFromRoot<T>(INode<T> current)
{
Stack<INode<T>> queue;
for (queue = new Stack<INode<T>>(); !current.IsRootNode; current = current.Parent)
queue.Push(current);
return queue;
}
public static Position CreateRelative(Position position, double x, double y)
{
return Position.Create(position.X + x, position.Y + y);
}
}
public class Position : Tuple<double, double>
{
public double X => Item1;
public double Y => Item2;
private Position(double x, double y) : base(x, y)
{
}
public static Position Create(double x, double y)
{
return new Position(x, y);
}
}
public class NodeComparator<T> : IComparer<INode<T>>
{
private readonly IProblem<T> _problem;
private readonly double _tolerance;
public NodeComparator(IProblem<T> problem, double tolerance = 0.00000001d)
{
_tolerance = tolerance;
_problem = problem;
}
public int Compare(INode<T> one, INode<T> two)
{
var s1 = one.PathCost + _problem.GetHeuristicValue(one.State);
var s2 = two.PathCost + _problem.GetHeuristicValue(two.State);
if (Math.Abs(s1 - s2) < _tolerance)
return 0;
return s1 < s2 ? -1 : 1;
}
}
public class Memory<T> : C5.HashSet<T>
{
public override bool Contains(T item)
{
if (base.Contains(item))
return true;
Add(item);
return false;
}
}
public class AStarSearch<T> : PrioritySearch<T>
{
public override IComparer<INode<T>> GetComparator(IProblem<T> p)
{
return new NodeComparator<T>(p);
}
}
public abstract class PrioritySearch<T> : ISearch<T>
{
public abstract IComparer<INode<T>> GetComparator(IProblem<T> p);
public IEnumerable<INode<T>> Search(IProblem<T> problem, IPriorityQueue<INode<T>> fringe)
{
INode<T> node;
for (fringe.Add(problem.GetRoot()); !fringe.IsEmpty; fringe.AddAll(problem.GetSuccessors(node)))
{
Grapher.Plot(problem, fringe);
node = fringe.DeleteMin();
if (problem.IsGoal(node))
return Util.GetPathFromRoot(node);
}
return null;
}
}
public abstract class Move : Node<Position>
{
protected const double Diagonal = 1.414213562373095048801688724209698078569671875376948073176d;
protected const double Direct = 1;
private Move(INode<Position> parent, double x, double y, string action = "root", double stepCost = Direct)
: base(Util.CreateRelative(parent.State, x, y), parent, stepCost, parent.PathCost + parent.StepCost, action + " " + Util.CreateRelative(parent.State, x, y))
{ }
public class NorthWest : Move
{
public NorthWest(INode<Position> parent) : base(parent, -1, -1, "nw", Diagonal) { }
}
public class North : Move
{
public North(INode<Position> parent) : base(parent, 0, -1, "n") { }
}
public class NorthEast : Move
{
public NorthEast(INode<Position> parent) : base(parent, 1, -1, "ne", Diagonal) { }
}
public class West : Move
{
public West(INode<Position> parent) : base(parent, -1, 0, "w") { }
}
public class East : Move
{
public East(INode<Position> parent) : base(parent, 1, 0, "e") { }
}
public class SouthWest : Move
{
public SouthWest(INode<Position> parent) : base(parent, -1, 1, "sw", Diagonal) { }
}
public class South : Move
{
public South(INode<Position> parent) : base(parent, 0, 1, "s") { }
}
public class SouthEast : Move
{
public SouthEast(INode<Position> parent) : base(parent, 1, 1, "se", Diagonal) { }
}
}
public class Node<T> : INode<T>
{
public T State { get; }
public INode<T> Parent { get; }
public bool IsRootNode => Parent == null;
public double StepCost { get; }
public double PathCost { get; }
public string Action { get; }
public Node(T state, INode<T> parent = null, double stepCost = 0, double pathCost = 0, string action = "root")
{
Parent = parent;
State = state;
StepCost += pathCost;
PathCost += stepCost;
Action = action;
}
public override string ToString()
{
return State.ToString();
}
}
public interface INode<out T>
{
T State { get; }
INode<T> Parent { get; }
bool IsRootNode { get; }
double StepCost { get; }
double PathCost { get; }
string Action { get; }
}
public interface IProblem<T>
{
INode<T> GetRoot();
bool IsGoal(INode<T> current);
double GetStepCost(INode<T> current, T next);
IEnumerable<INode<T>> GetSuccessors(INode<T> current);
double GetHeuristicValue(T next);
}
public interface ISearch<T>
{
IEnumerable<INode<T>> Search(IProblem<T> problem, IPriorityQueue<INode<T>> fringe);
}
}
using System;
using System.Linq;
using Ai.DataStructures;
using Ai.framework2;
using C5;
using FluentAssertions;
using NUnit.Framework;
namespace AiTest.DataStructures
{
[TestFixture]
public class SolveSpiderTests
{
[Test]
public void SolveSpider()
{
// Act
var algorithm = new AStarSearch<Position>();
var problem = new SpiderProblem(
initial: Position.Create(3, 4),
goal: Position.Create(9, 4),
walls: new[]
{
Position.Create(6, 3),
Position.Create(6, 4),
Position.Create(6, 5)
});
var fringe = new IntervalHeap<INode<Position>>(algorithm.GetComparator(problem));
var result = algorithm.Search(problem, fringe);
Console.WriteLine(string.Join(", ", result.Select(o => o.Action)));
}
[Test]
[TestCase(@"C:\img\input\input1.png")]
[TestCase(@"C:\img\input\input3.png")]
public void SolveSpiderFile(string filename)
{
// Act
var algorithm = new AStarSearch<Position>();
var states = new SpiderStates(filename);
var problem = new SpiderProblem(
initial: states.Initial,
goal: states.Goal,
walls: states.Walls);
var fringe = new IntervalHeap<INode<Position>>(algorithm.GetComparator(problem));
var result = algorithm.Search(problem, fringe);
Console.WriteLine(string.Join(", ", result.Select(o => o.Action)));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment