Skip to content

Instantly share code, notes, and snippets.

@kenichi
Last active June 27, 2023 23:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kenichi/d941ab3965d6a49e612f1664eccab337 to your computer and use it in GitHub Desktop.
Save kenichi/d941ab3965d6a49e612f1664eccab337 to your computer and use it in GitHub Desktop.
defmodule C3 do
import Nx.Defn
defn predict(x, w, b) do
x
|> Nx.multiply(w)
|> Nx.add(b)
end
defn loss(x, y, w, b) do
x
|> predict(w, b)
|> Nx.subtract(y)
|> Nx.pow(2)
|> Nx.mean()
end
defn weight_gradient(x, y, w, b) do
x
|> predict(w, b)
|> Nx.subtract(y)
|> Nx.multiply(x)
|> Nx.mean()
|> Nx.multiply(2)
end
defn bias_gradient(x, y, w, b) do
x
|> predict(w, b)
|> Nx.subtract(y)
|> Nx.mean()
|> Nx.multiply(2)
end
def gradients(%Nx.Tensor{} = tx, %Nx.Tensor{} = ty, w, b \\ 0) do
{
weight_gradient(tx, ty, w, b),
bias_gradient(tx, ty, w, b)
}
end
def train(%Nx.Tensor{} = tx, %Nx.Tensor{} = ty, i, lr) do
Enum.reduce(1..i, {0, 0}, fn iteration, {weight, bias} ->
current_loss = loss(tx, ty, weight, bias)
IO.puts("Iteration #{iteration} => Loss: #{Nx.to_number(current_loss)}")
{wg, bg} = gradients(tx, ty, weight, bias)
{
adjust(weight, wg, lr),
adjust(bias, bg, lr)
}
end)
end
defn adjust(value, gradient, lr) do
gradient
|> Nx.multiply(lr)
|> then(&Nx.subtract(value, &1))
end
end
Iteration 20000 => Loss: 28.272876739501953
Prediction: x=20 y=35.5132942199707
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment