Skip to content

Instantly share code, notes, and snippets.

@HurricanKai
Last active June 25, 2022 16:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save HurricanKai/0d05a73cd8bdb67c1a1361bfcc79dd81 to your computer and use it in GitHub Desktop.
Save HurricanKai/0d05a73cd8bdb67c1a1361bfcc79dd81 to your computer and use it in GitHub Desktop.
Snake game implemented with tensors

Snake on a Tensor

Mix.install([
  {:kino, "~> 0.6.1"},
  {:kino_vega_lite, "~> 0.1.1"},
  {:axon, "~> 0.1.0"},
  {:exla, "~> 0.2.2"},
  {:nx, "~> 0.2.1"}
])
:ok

The Game

defmodule Game do
  import Nx.Defn

  # TODO: Stack these tensors
  defstruct [:score, :snake_tensor, :apple_tensor]

  def advance(
        %__MODULE__{score: score, snake_tensor: snake_tensor, apple_tensor: apple_tensor},
        direction
      ) do
    new_snake_tensor = advance_snake(snake_tensor, direction)

    {new_score, new_apple_tensor} = process_apple(score, new_snake_tensor, apple_tensor)

    final_snake_tensor = filter_snake_tensor(new_snake_tensor, Nx.tensor(new_score))

    to_spawn = new_score - score

    final_apple_tensor =
      if to_spawn > 0,
        do:
          spawn_new_apples(
            new_apple_tensor,
            to_spawn,
            Nx.add(new_apple_tensor, final_snake_tensor)
          ),
        else: new_apple_tensor

    %__MODULE__{
      score: new_score,
      snake_tensor: final_snake_tensor,
      apple_tensor: final_apple_tensor
    }
  end

  def is_dead(tensor) do
    Nx.reduce_max(tensor) != 1
  end

  def spawn_new_apples(tensor, to_spawn, valid_tensor) do
    random = tensor |> Nx.shape() |> Nx.random_normal()
    # the valid_tensor contains >= 0 on blocked fields
    # we simply weight those fields upward so it is very unlikely that the
    # selection process below would ever hit those
    weighted = Nx.add(Nx.multiply(Nx.ceil(valid_tensor), 1.0e4), random)

    max_selected =
      (weighted
       |> Nx.flatten()
       |> Nx.sort())[to_spawn]

    new_apples = Nx.less(weighted, max_selected)
    Nx.add(tensor, new_apples)
  end

  defnp filter_snake_tensor(snake, score) do
    {a, b} = Nx.shape(snake)
    minimum = 1 - score / (a * b)

    Nx.select(Nx.greater_equal(snake, minimum), snake, 0.0)
  end

  defp process_apple(old_score, snake, old_apple) do
    apple_overlap = Nx.logical_and(snake, old_apple)
    apples_collected = Nx.sum(apple_overlap)
    new_apple = Nx.logical_and(Nx.logical_not(apple_overlap), old_apple)
    {old_score + Nx.to_number(apples_collected), new_apple}
  end

  defp advance_snake(snake_tensor, direction) do
    padded =
      Nx.pad(
        snake_tensor,
        0,
        case direction do
          # pad one top, cut of bottom
          :down -> [{1, -1, 0}, {0, 0, 0}]
          # pad one bottom, cut of top
          :up -> [{-1, 1, 0}, {0, 0, 0}]
          # pad one left, cut of right
          :right -> [{0, 0, 0}, {1, -1, 0}]
          # pad oen right, cut of left
          :left -> [{0, 0, 0}, {-1, 1, 0}]
        end
      )

    advance_snake_core(snake_tensor, padded)
  end

  defnp advance_snake_core(tensor, padded) do
    self = tensor
    {a, b} = Nx.shape(tensor)
    decayed = tensor - 1 / (a * b)
    other = Nx.floor(padded)

    # inverted, because logical_not will normalize to 0 or 1, but inverted, so the params are swapped
    Nx.select(Nx.logical_not(self), other, decayed)
  end
end
{:module, Game, <<70, 79, 82, 49, 0, 0, 25, ...>>, {:advance_snake_core, 2}}
ExUnit.start(autorun: false)

defmodule Game.Tests do
  use ExUnit.Case, async: true
  require Game

  test "advance advances linear correctly up" do
    max_length = 16
    dec1 = 1 - 1 / max_length

    snake_tensor =
      Nx.tensor([
        [0.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0],
        [0.0, dec1, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0]
      ])

    apple_tensor = Nx.tensor(Enum.map(1..4, fn _c1 -> Enum.map(1..4, fn _c2 -> 0 end) end))
    score = 1

    %Game{score: new_score, snake_tensor: snake_final, apple_tensor: apple_final} =
      Game.advance(
        %Game{score: score, snake_tensor: snake_tensor, apple_tensor: apple_tensor},
        :up
      )

    assert 1 = new_score

    final_list = Nx.to_flat_list(snake_final)

    assert [0.0, 1.0, 0.0, 0.0, 0.0, ^dec1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] =
             final_list

    apple_count = apple_final |> Nx.sum() |> Nx.to_number()
    assert 0 = apple_count
  end

  test "advance lengthens correctly up" do
    max_length = 16
    dec1 = 1 - 1 / max_length
    dec2 = 1 - 2 / max_length

    snake_tensor =
      Nx.tensor([
        [0.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0],
        [0.0, dec1, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0]
      ])

    apple_tensor =
      Nx.tensor(
        Enum.map(1..4, fn y ->
          Enum.map(1..4, fn x ->
            case {x, y} do
              {2, 1} -> 1
              _ -> 0
            end
          end)
        end)
      )

    score = 1

    %Game{score: score_final, snake_tensor: snake_final, apple_tensor: apple_final} =
      Game.advance(
        %Game{score: score, snake_tensor: snake_tensor, apple_tensor: apple_tensor},
        :up
      )

    assert 2 = score_final

    final_list = Nx.to_flat_list(snake_final)

    assert [0.0, 1.0, 0.0, 0.0, 0.0, ^dec1, 0.0, 0.0, 0.0, ^dec2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] =
             final_list

    apple_count = apple_final |> Nx.sum() |> Nx.to_number()
    assert 1 = apple_count
  end

  test "advance advances linear correctly left" do
    max_length = 16
    dec1 = 1 - 1 / max_length

    snake_tensor =
      Nx.tensor([
        [0.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, dec1, 0.0],
        [0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0]
      ])

    apple_tensor = Nx.tensor(Enum.map(1..4, fn _c1 -> Enum.map(1..4, fn _c2 -> 0 end) end))
    score = 1

    %Game{score: new_score, snake_tensor: snake_final, apple_tensor: apple_final} =
      Game.advance(
        %Game{score: score, snake_tensor: snake_tensor, apple_tensor: apple_tensor},
        :left
      )

    assert 1 = new_score

    final_list = Nx.to_flat_list(snake_final)

    assert [0.0, 0.0, 0.0, 0.0, 1.0, ^dec1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] =
             final_list

    apple_count = apple_final |> Nx.sum() |> Nx.to_number()
    assert 0 = apple_count
  end

  test "advance lengthens correctly left" do
    max_length = 16
    dec1 = 1 - 1 / max_length
    dec2 = 1 - 2 / max_length

    snake_tensor =
      Nx.tensor([
        [0.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, dec1, 0.0],
        [0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0]
      ])

    apple_tensor =
      Nx.tensor(
        Enum.map(1..4, fn y ->
          Enum.map(1..4, fn x ->
            case {x, y} do
              {1, 2} -> 1
              _ -> 0
            end
          end)
        end)
      )

    score = 1

    %Game{score: score_final, snake_tensor: snake_final, apple_tensor: apple_final} =
      Game.advance(
        %Game{score: score, snake_tensor: snake_tensor, apple_tensor: apple_tensor},
        :left
      )

    assert 2 = score_final

    final_list = Nx.to_flat_list(snake_final)

    assert [0.0, 0.0, 0.0, 0.0, 1.0, ^dec1, ^dec2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] =
             final_list

    apple_count = apple_final |> Nx.sum() |> Nx.to_number()
    assert 1 = apple_count
  end
end

ExUnit.run()
....

Finished in 0.00 seconds (0.00s async, 0.00s sync)
4 tests, 0 failures

Randomized with seed 833031
%{excluded: 0, failures: 0, skipped: 0, total: 4}
ExUnit.start(autorun: false)

defmodule Game.CollisionTests do
  use ExUnit.Case, async: true
  require Game

  test "self collision yields dead" do
    max_length = 16
    dec1 = 1 - 1 / max_length
    dec2 = 1 - 2 / max_length
    dec3 = 1 - 3 / max_length

    snake_tensor =
      Nx.tensor([
        [0.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, dec1, 0.0],
        [0.0, dec3, dec2, 0.0],
        [0.0, 0.0, 0.0, 0.0]
      ])

    apple_tensor = Nx.tensor(Enum.map(1..4, fn _y -> Enum.map(1..4, fn _x -> 0 end) end))
    score = 3

    %Game{snake_tensor: snake_final} =
      Game.advance(
        %Game{score: score, snake_tensor: snake_tensor, apple_tensor: apple_tensor},
        :down
      )

    assert Game.is_dead(snake_final)
  end

  test "wall collision yields dead" do
    max_length = 16
    dec1 = 1 - 1 / max_length
    dec2 = 1 - 2 / max_length

    snake_tensor =
      Nx.tensor([
        [0.0, 0.0, 0.0, 0.0],
        [1.0, dec1, dec2, 0.0],
        [0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0]
      ])

    apple_tensor = Nx.tensor(Enum.map(1..4, fn _y -> Enum.map(1..4, fn _x -> 0 end) end))
    score = 2

    %Game{snake_tensor: snake_final} =
      Game.advance(
        %Game{score: score, snake_tensor: snake_tensor, apple_tensor: apple_tensor},
        :left
      )

    assert Game.is_dead(snake_final)
  end
end

ExUnit.run()
..

Finished in 0.00 seconds (0.00s async, 0.00s sync)
2 tests, 0 failures

Randomized with seed 833031
%{excluded: 0, failures: 0, skipped: 0, total: 2}

Display Game State as SVG

defmodule Game.Svg do
  def to_svg(%Game{apple_tensor: apples, snake_tensor: snake}) do
    {height, width} = Nx.shape(apples)

    """
    <svg version="1.1"
     viewBox="0 0 #{height} #{width}"
     xmlns="http://www.w3.org/2000/svg">
     #{tensor_to_svg_rects(snake, "green")}

     #{tensor_to_svg_rects(apples, "red")}
    </svg>
    """
  end

  defp tensor_to_svg_rects(tensor, color) do
    {height, _width} = Nx.shape(tensor)

    tensor
    |> Nx.to_flat_list()
    |> Enum.chunk_every(height)
    |> Enum.with_index(fn element, index -> {index, element} end)
    |> Enum.flat_map(fn {y, e} -> Enum.with_index(e, fn e2, x -> {x, y, e2} end) end)
    |> Enum.filter(fn {_x, _y, e} -> e > 0 end)
    |> Enum.map(fn {x, y, e} ->
      '<rect x="#{x}" y="#{y}" width="1" height="1" fill="#{color}" fill-opacity="#{:math.floor(e * 100)}%" />'
    end)
    |> Enum.join("\n")
  end
end
{:module, Game.Svg, <<70, 79, 82, 49, 0, 0, 16, ...>>, {:tensor_to_svg_rects, 2}}
defmodule ImgHelper do
  def img_helper(list, per_row) when is_list(list) do
    img_helper(list, length(list), per_row)
  end

  def img_helper(list, list_len, per_row) do
    rows =
      list
      |> Stream.chunk_every(per_row)

    content =
      rows
      |> Stream.with_index()
      |> Stream.flat_map(fn {r, y} ->
        Stream.with_index(r)
        |> Stream.map(fn {e, x} ->
          """
          <g transform="translate(#{x}, #{y})">
            <rect width="1" height="1" stroke="black" stroke-width="0.01" fill="none" />
            #{e |> String.replace_prefix("<svg", "<svg width=\"1\" height=\"1\"")}
          </g>
          end
          """
        end)
      end)
      |> Enum.to_list()
      |> Enum.join("\n")

    rowCount = :math.ceil(list_len / per_row)

    """
    <svg version="1.1"
      viewBox="0 0 #{per_row} #{rowCount}"
      xmlns="http://www.w3.org/2000/svg">
        #{content}
    </svg>
    """
  end
end
{:module, ImgHelper, <<70, 79, 82, 49, 0, 0, 14, ...>>, {:img_helper, 3}}

Visual Tests

dec1 = 1 - 1 / 25
dec2 = 1 - 2 / 25
dec3 = 1 - 3 / 25
dec4 = 1 - 4 / 25
dec5 = 1 - 5 / 25

snake_tensor =
  Nx.tensor([
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, dec1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, dec2, dec3, dec4, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, dec5, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
  ])

rect_size = 5

1..(rect_size * rect_size)
|> Stream.map(fn _ ->
  %Game{
    score: 5,
    snake_tensor: snake_tensor,
    apple_tensor:
      Nx.tensor([
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
      ])
      |> Game.spawn_new_apples(1, snake_tensor)
  }
end)
|> Stream.map(&Game.Svg.to_svg/1)
|> ImgHelper.img_helper(rect_size * rect_size, rect_size)
|> Kino.Image.new(:svg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment