Skip to content

Instantly share code, notes, and snippets.

@toranb
Last active January 30, 2024 07:40
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save toranb/e5c48565e83e4baaaf2c5850531a8a58 to your computer and use it in GitHub Desktop.
Save toranb/e5c48565e83e4baaaf2c5850531a8a58 to your computer and use it in GitHub Desktop.
fizzbuzz with Axon (collaboration with Ian Warshak)
defmodule Mlearning do
@moduledoc false
def mods(x) do
[rem(x, 3), rem(x, 5)]
end
def fizzbuzz(n) do
cond do
rem(n, 15) == 0 -> [0, 0, 1, 0]
rem(n, 3) == 0 -> [1, 0, 0, 0]
rem(n, 5) == 0 -> [0, 1, 0, 0]
true -> [0, 0, 0, 1]
end
end
def hello() do
data =
1..1000
|> Stream.map(fn n ->
tensor = Nx.tensor([mods(n)])
label = Nx.tensor([fizzbuzz(n)])
{tensor, label}
end)
model =
Axon.input("input", shape: {nil, 2})
|> Axon.dense(3, activation: :relu)
|> Axon.dense(4, activation: :softmax)
params =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adamw(learning_rate: 0.01))
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(data, %{}, epochs: 3, compiler: EXLA)
{_init_fn, predict_fn} = Axon.build(model)
guess = fn x ->
mod = Nx.tensor([mods(x)])
case predict_fn.(params, mod) |> Nx.argmax() |> Nx.to_flat_list() do
[0] -> "fizz"
[1] -> "buzz"
[2] -> "fizzbuzz"
[3] -> "womp"
end
end
guess.(3) |> IO.inspect(label: "3")
guess.(5) |> IO.inspect(label: "5")
guess.(15) |> IO.inspect(label: "15")
guess.(16) |> IO.inspect(label: "16")
guess.(15_432_115) |> IO.inspect(label: "15,432,115")
:ok
end
end
@toranb
Copy link
Author

toranb commented Dec 31, 2022

  defp deps do
    [
      {:axon, "~> 0.6"},
      {:exla, "~> 0.6"},
      {:nx, "~> 0.6"}
    ]
  end

@toranb
Copy link
Author

toranb commented Dec 31, 2022

the inspiration for this came from Bruce and the Programmer Passport series on Nx

https://www.youtube.com/watch?v=NcsqGS6SVXg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment