Skip to content

Instantly share code, notes, and snippets.

@ftes
Created December 12, 2022 19:42
Show Gist options
  • Save ftes/afa1aadeb4fab4b4dccc5d202b5f4166 to your computer and use it in GitHub Desktop.
Save ftes/afa1aadeb4fab4b4dccc5d202b5f4166 to your computer and use it in GitHub Desktop.
Elixir Dijkstra's Shortest Path Algorithm

dijkstra

Mix.install(
  [
    {:kino, "~> 0.7.0"},
    {:vega_lite, "~> 0.1.6"},
    {:kino_vega_lite, "~> 0.1.7"}
  ],
  consolidate_protocols: false
)

alias VegaLite, as: Vl

Dijkstra's Algorithm

defmodule Dijkstra do
  def dijkstra(start, get_neighbours, get_distance, on_step \\ fn _ -> nil end) do
    # 2      dist[source] ← 0                           // Initialization
    # 8              dist[v] ← INFINITY                 // Unknown distance from source to v
    dist = %{start => 0}
    # 4      create vertex priority queue Q
    # Instead of filling the priority queue with all nodes in the initialization phase,
    # it is also possible to initialize it to contain only source;
    q = PrioQ.new([{0, start}])
    # 9              prev[v] ← UNDEFINED                // Predecessor of v
    prev = %{}

    loop(q, dist, prev, get_neighbours, get_distance, on_step)
  end

  defp loop(q, dist, prev, get_neighbours, get_distance, on_step) do
    # 14     while Q is not empty:                      // The main loop
    case q |> PrioQ.extract_min() |> check_outdated(dist) do
      :empty ->
        on_step.({:done, dist, prev})
        #  23     return dist, prev
        {dist, prev}

      # Yet another alternative is to add nodes unconditionally to the priority queue
      # and to instead check after extraction that no shorter connection was found yet.
      # This can be done by additionally extracting the associated priority p
      # from the queue and only processing further if p == dist[u] inside the
      # while Q is not empty loop.
      :outdated ->
        loop(q, dist, prev, get_neighbours, get_distance, on_step)

      # 15         u ← Q.extract_min()                    // Remove and return best vertex
      {u, q} ->
        # 16         for each neighbor v of u:              // Go through all v neighbors of u
        {dist, prev, q} =
          for v <- get_neighbours.(u),
              reduce: {dist, prev, q} do
            {dist, prev, q} ->
              # 17             alt ← dist[u] + Graph.Edges(u, v)
              alt = dist[u] + get_distance.(u, v)

              # 18             if alt < dist[v]:
              # 19                 dist[v] ← alt
              # 20                 prev[v] ← u
              # 21                 Q.decrease_priority(v, alt)
              if alt < dist[v] do
                dist = Map.put(dist, v, alt)
                prev = Map.put(prev, v, u)
                q = PrioQ.add_with_priority(q, v, alt)
                on_step.({v, dist, prev})
                {dist, prev, q}
              else
                {dist, prev, q}
              end
          end

        loop(q, dist, prev, get_neighbours, get_distance, on_step)
    end
  end

  defp check_outdated({{prio, u}, q}, dist) do
    if prio == dist[u], do: {u, q}, else: :outdated
  end

  defp check_outdated(other, _), do: other
end

defmodule PrioQ do
  defstruct [:set]

  def new(), do: %__MODULE__{set: :gb_sets.empty()}
  def new([]), do: new()
  def new([{_prio, _elem} | _] = list), do: %__MODULE__{set: :gb_sets.from_list(list)}

  def add_with_priority(%__MODULE__{} = q, elem, prio) do
    %{q | set: :gb_sets.add({prio, elem}, q.set)}
  end

  def size(%__MODULE__{} = q) do
    :gb_sets.size(q.set)
  end

  def extract_min(%__MODULE__{} = q) do
    case :gb_sets.size(q.set) do
      0 ->
        :empty

      _else ->
        {{prio, elem}, set} = :gb_sets.take_smallest(q.set)
        {{prio, elem}, %{q | set: set}}
    end
  end

  defimpl Inspect do
    import Inspect.Algebra

    def inspect(%PrioQ{} = q, opts) do
      concat(["#PrioQ.new(", to_doc(:gb_sets.to_list(q.set), opts), ")"])
    end
  end
end

Visualization

defmodule DijkstraChart do
  @block_1 for x <- 6..16, y <- 14..16, into: MapSet.new(), do: {x, y}
  @block_2 for x <- 14..16, y <- 9..13, into: MapSet.new(), do: {x, y}
  @block MapSet.union(@block_1, @block_2)
  @start {1, 1}
  @goal {18, 18}
  @interval 25

  def chart() do
    chart =
      Vl.new(width: 400, height: 400)
      |> Vl.layers([
        Vl.new()
        |> Vl.mark(:point, filled: true, size: 400, shape: :square, color: :black)
        |> Vl.data_from_values(Enum.map(@block, fn {x, y} -> %{x: x, y: y} end))
        |> Vl.encode_field(:x, "x", type: :quantitative, scale: [domain: [1, 20]])
        |> Vl.encode_field(:y, "y", type: :quantitative, scale: [domain: [1, 20]]),
        Vl.new(data: [name: :distances])
        |> Vl.mark(:point, filled: true, size: 200, shape: :circle)
        |> Vl.encode_field(:x, "x", type: :quantitative)
        |> Vl.encode_field(:y, "y", type: :quantitative)
        |> Vl.encode_field(:color, "color",
          type: :quantitative,
          legend: false,
          scale: [scheme: :redyellowgreen, domain: [0, 30]]
        ),
        Vl.new(data: [name: :path])
        |> Vl.mark(:line, stroke_width: 10)
        |> Vl.encode_field(:x, "x", type: :quantitative)
        |> Vl.encode_field(:y, "y", type: :quantitative),
        Vl.new()
        |> Vl.mark(:point, filled: true, size: 600, shape: "triangle-up", color: :blue)
        |> Vl.data_from_values(Enum.map([@start, @goal], fn {x, y} -> %{x: x, y: y} end))
        |> Vl.encode_field(:x, "x", type: :quantitative)
        |> Vl.encode_field(:y, "y", type: :quantitative)
      ])
      |> Kino.VegaLite.new()
      |> Kino.render()

    Process.sleep(1000)
    Dijkstra.dijkstra(@start, &get_neighbours/1, &get_distance/2, &on_step(chart, &1))
  end

  def on_step(chart, {{x, y}, dist, _prev}) do
    Kino.VegaLite.push(chart, %{x: x, y: y, color: dist[{x, y}]}, dataset: :distances)
    Process.sleep(@interval)
  end

  def on_step(chart, {:done, _, prev}) do
    for {x, y} <- get_path(prev, @goal, @start) do
      Kino.VegaLite.push(chart, %{x: x, y: y}, dataset: :path)
      Process.sleep(50)
    end
  end

  def get_neighbours({x, y}) do
    [{1, 0}, {0, 1}, {1, 1}]
    |> Enum.map(fn {x2, y2} -> {x + x2, y + y2} end)
    |> Enum.filter(&(not MapSet.member?(@block, &1)))
    |> Enum.filter(fn {x, y} -> x <= 20 && y <= 20 end)
  end

  def get_distance(_, _), do: 1

  def get_path(prev, current, goal, path \\ [])

  def get_path(_prev, current, current, path), do: [current | path]

  def get_path(prev, current, goal, path) do
    get_path(prev, prev[current], goal, [current | path])
  end
end

DijkstraChart.chart()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment