Skip to content

Instantly share code, notes, and snippets.

defmodule MNIST do
import Nx.Defn
@default_defn_compiler {EXLA, run_options: [keep_on_device: true]}
defn init_random_params do
w1 = Nx.random_normal({32, 1, 7, 7}, 0.0, 0.1)
b1 = Nx.random_normal({1, 32, 1, 1}, 0.0, 0.1)
w2 = Nx.random_normal({64, 32, 4, 4}, 0.0, 0.1)
b2 = Nx.random_normal({1, 64, 1, 1}, 0.0, 0.1)