Skip to content

Instantly share code, notes, and snippets.

@gvaughn
Forked from toranb/fizzbuzz.ex
Created December 31, 2022 18:13
Show Gist options
  • Save gvaughn/912c9e5177787e180f00dc916254d0ff to your computer and use it in GitHub Desktop.
Save gvaughn/912c9e5177787e180f00dc916254d0ff to your computer and use it in GitHub Desktop.
fizzbuzz with Axon (collaboration with Ian Warshak)
defmodule Mlearning do
@moduledoc false
def mods(x) do
[rem(x, 3), rem(x, 5), rem(x, 15)]
end
def fizzbuzz(n) do
cond do
rem(n, 15) == 0 -> [0, 0, 1, 0]
rem(n, 3) == 0 -> [1, 0, 0, 0]
rem(n, 5) == 0 -> [0, 1, 0, 0]
true -> [0, 0, 0, 1]
end
end
def hello() do
data =
1..1000
|> Stream.map(fn n ->
tensor = Nx.tensor([mods(n)])
label = Nx.tensor([fizzbuzz(n)])
{tensor, label}
end)
model =
Axon.input("input", shape: {nil, 3})
|> Axon.dense(10, activation: :relu)
|> Axon.dense(4, activation: :softmax)
params =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005))
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(data, %{}, epochs: 5, compiler: EXLA)
{_init_fn, predict_fn} = Axon.build(model)
guess = fn x ->
mod = Nx.tensor([mods(x)])
case predict_fn.(params, mod) |> Nx.argmax() |> Nx.to_flat_list() do
[0] -> "fizz"
[1] -> "buzz"
[2] -> "fizzbuzz"
[3] -> "womp"
end
end
guess.(3) |> IO.inspect(label: "3")
guess.(5) |> IO.inspect(label: "5")
guess.(15) |> IO.inspect(label: "15")
guess.(16) |> IO.inspect(label: "16")
guess.(15_432_115) |> IO.inspect(label: "15,432,115")
:ok
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment