Skip to content

Instantly share code, notes, and snippets.

@rejuvyesh
Created April 14, 2022 22:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rejuvyesh/1948888e05a4c74d4203760003df6ff2 to your computer and use it in GitHub Desktop.
Save rejuvyesh/1948888e05a4c74d4203760003df6ff2 to your computer and use it in GitHub Desktop.
using SimpleChains
function f(x)
N = Base.isqrt(length(x))
A = reshape(view(x, 1:N*N), (N,N))
expA = exp(A)
vec(expA)
end
T = Float32;
D = 2 # 2x2 matrices
X = randn(T, D*D, 10_000); # random input matrices
Y = reduce(hcat, map(f, eachcol(X))); # `mapreduce` is not optimized for `hcat`, but `reduce` is
Xtest = randn(T, D*D, 10_000);
Ytest = reduce(hcat, map(f, eachcol(Xtest)));
mlpd = SimpleChain(
static(4),
TurboDense(tanh, 32),
TurboDense(tanh, 16),
TurboDense(identity, 4)
)
@time p = SimpleChains.init_params(mlpd);
G = SimpleChains.alloc_threaded_grad(mlpd);
mlpdloss = SimpleChains.add_loss(mlpd, SquaredLoss(Y));
mlpdtest = SimpleChains.add_loss(mlpd, SquaredLoss(Ytest));
report = let mtrain = mlpdloss, X=X, Xtest=Xtest, mtest = mlpdtest
p -> begin
let train = mlpdloss(X, p), test = mlpdtest(Xtest, p)
@info "Loss:" train test
end
end
end
report(p)
for _ in 1:3
@time SimpleChains.train_unbatched!(
G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
);
report(p)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment