Skip to content

Instantly share code, notes, and snippets.

@patrick-kidger
Last active April 15, 2022 02:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save patrick-kidger/68bf7b99ba02c246b20eaa38f2ad3d38 to your computer and use it in GitHub Desktop.
Save patrick-kidger/68bf7b99ba02c246b20eaa38f2ad3d38 to your computer and use it in GitHub Desktop.
# JAX script adapted from https://news.ycombinator.com/item?id=31029699
# The main two modifications were to switch out `jax.value_and_grad` -> `jax.grad`,
# and to include the model update inside the JIT'd region.
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import optax
import time
class MatrixExponentEstimator(eqx.Module):
d0: eqx.nn.Linear
d1: eqx.nn.Linear
d2: eqx.nn.Linear
def __init__(self, key):
key1, key2, key3 = jr.split(key, 3)
self.d0 = eqx.nn.Linear(4, 32, key=key1)
self.d1 = eqx.nn.Linear(32, 16, key=key2)
self.d2 = eqx.nn.Linear(16, 4, key=key3)
def __call__(self, x):
x = jax.numpy.tanh(self.d0(x))
x = jax.numpy.tanh(self.d1(x))
return self.d2(x)
def f(x):
return jax.scipy.linalg.expm(x.reshape((2,2))).reshape((4,))
def apply_matrix_exponential(x):
return jax.numpy.apply_along_axis(f, 1, x)
def train():
epochs = 10000
key = jr.PRNGKey(1337)
trainkey, testkey, modelkey = jr.split(key, 3)
trainx = jr.normal(trainkey, shape=(10000, 2*2))
trainy = apply_matrix_exponential(trainx)
testx = jr.normal(testkey, shape=(10000, 2*2))
testy = apply_matrix_exponential(testx)
model = MatrixExponentEstimator(modelkey)
adam = optax.adam(1e-3)
opt_state = adam.init(model)
def loss_fn(model, X, y):
err = jax.vmap(model)(X) - y
return jnp.mean(jnp.square(err)) # mse
@jax.jit
def make_step(model, X, y, opt_state):
grads = jax.grad(loss_fn)(model, X, y)
updates, opt_state = adam.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return model, opt_state
print('Initial Train Loss: {:.4f}'.format(loss_fn(model, trainx, trainy).item()))
print('Initial Test Loss: {:.4f}'.format(loss_fn(model, testx, testy).item()))
for _ in range(3):
t_start = time.time()
for _ in range(epochs):
model, opt_state = make_step(model, trainx, trainy, opt_state)
print('Took: {:.2f} seconds'.format(time.time() - t_start))
print('Train Loss: {:.4f}'.format(loss_fn(model, trainx, trainy).item()))
print('Test Loss: {:.4f}'.format(loss_fn(model, testx, testy).item()))
if __name__ == '__main__':
train()
# Output:
#
# Initial Train Loss: 6.4230
# Initial Test Loss: 6.1087
# Took: 15.61 seconds
# Train Loss: 0.0310
# Test Loss: 0.0273
# Took: 16.21 seconds
# Train Loss: 0.0035
# Test Loss: 0.0156
# Took: 16.80 seconds
# Train Loss: 0.0018
# Test Loss: 0.0111
# In comparison when running
# https://julialang.org/blog/2022/04/simple-chains/#simplechainsjl_in_action_30x-ing_pytorch_in_tiny_example
# (And switching out `G = SimpleChains.alloc_threaded_grad(mlpd)` for `G = similar(p)` to avoid
# `UndefVarError: alloc_threaded_grad not defined`.)
# I get:
#
# julia> report(p)
#
# ┌ Info: Loss:
# │ train = 113290.805f0
# └ test = 109008.77f0
#
# julia> for _ in 1:3
# @time SimpleChains.train_unbatched!(
# G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
# );
# report(p)
# end
# 19.304949 seconds (10.21 M allocations: 536.257 MiB, 0.85% gc time, 20.75% compilation time)
# ┌ Info: Loss:
# │ train = 143.67769f0
# └ test = 1110.0271f0
# 15.198186 seconds
# ┌ Info: Loss:
# │ train = 40.200905f0
# └ test = 733.7304f0
# 15.536973 seconds
# ┌ Info: Loss:
# │ train = 32.163437f0
# └ test = 639.05725f0
#
# Which produces timings that are (a) very similar to those obtained for JAX, but
# (b) noticeably very different those results reported by Chris in the HN thread.
#
# In addition, the loss results for the Julia script are really really bad. I don't know what's going
# on with that.
@rejuvyesh
Copy link

rejuvyesh commented Apr 14, 2022

It's possible you are not using the latest version of SimpleChains.jl? I think alloc_threaded_grad is necessary for correct computation and likely why the losses are so much worse. Also need to start julia with appropriate number of threads.

On my machine for your jax version on a cpu (6-cores) I'm seeing:

Initial Train Loss: 6.4232
Initial Test Loss: 6.1088
Took: 19.39 seconds
Train Loss: 0.0307
Test Loss: 0.0270
Took: 18.91 seconds
Train Loss: 0.0037
Test Loss: 0.0157
Took: 20.09 seconds
Train Loss: 0.0018
Test Loss: 0.0111

While for the Julia version:

13.428804 seconds (17.76 M allocations: 949.815 MiB, 2.89% gc time, 100.00% compilation time)
┌ Info: Loss:
│   train = 12.414271f0
└   test = 12.085746f0
 17.685621 seconds (14.99 M allocations: 808.462 MiB, 4.02% gc time, 48.56% compilation time)
┌ Info: Loss:
│   train = 0.034923762f0
└   test = 0.052024134f0
  9.208631 seconds (19 allocations: 608 bytes)
┌ Info: Loss:
│   train = 0.0045825513f0
└   test = 0.03521506f0
  9.258355 seconds (30 allocations: 960 bytes)
┌ Info: Loss:
│   train = 0.0026099205f0
└   test = 0.023117168f0

@ChrisRackauckas
Copy link

ChrisRackauckas commented Apr 15, 2022

(And switching out G = SimpleChains.alloc_threaded_grad(mlpd) for G = similar(p) to avoid UndefVarError: alloc_threaded_grad not defined.)

Yeah don't do that, that's not correct. Just get release version.

With AMD Ryzen 9 5950X 16-Core Processor and I get:

Took: 14.52 seconds
Train Loss: 0.0304
Test Loss: 0.0268
Took: 14.00 seconds
Train Loss: 0.0033
Test Loss: 0.0154
Took: 13.85 seconds
Train Loss: 0.0018
Test Loss: 0.0112

vs https://gist.github.com/rejuvyesh/1948888e05a4c74d4203760003df6ff2 getting:

julia> Threads.nthreads()
16

  5.097569 seconds (14.81 M allocations: 798.000 MiB, 3.94% gc time, 73.62% compilation time)
┌ Info: Loss:
│   train = 0.022585187f0
└   test = 0.32509857f0
  1.310997 seconds
┌ Info: Loss:
│   train = 0.0038023277f0
└   test = 0.23108596f0
  1.295088 seconds
┌ Info: Loss:
│   train = 0.0023415526f0
└   test = 0.20991518f0

while

julia> Threads.nthreads()
1

 13.182836 seconds (12.07 M allocations: 647.639 MiB, 1.18% gc time, 23.74% compilation time)
┌ Info: Loss:
│   train = 0.011142263f0
└   test = 0.03988739f0
 10.150889 seconds
┌ Info: Loss:
│   train = 0.0038442607f0
└   test = 0.023254404f0
 10.199495 seconds
┌ Info: Loss:
│   train = 0.002967517f0
└   test = 0.019188924f0

so slightly larger than 10x, but essentially lost when threading is disabled. So I'm also wondering if multithreading was disabled there since that's the only way I could recreate it.

(And note that SimpleChains is going to be a bit CPU-dependent since it tries to force AVX512 when it can, so that's where you get the 0.5 seconds with the Intel i9 10980XE. I want that chip now...)

I invite others to submit timings and share CPU and Threads.nthreads().

@extradosages
Copy link

extradosages commented Apr 15, 2022

From https://gist.github.com/rejuvyesh/1948888e05a4c74d4203760003df6ff2.

 29.738605 seconds (10.81 M allocations: 571.659 MiB, 0.75% gc time, 19.20% compilation time)
┌ Info: Loss:
│   train = 0.012048299f0
└   test = 0.1272674f0
 23.849172 seconds
┌ Info: Loss:
│   train = 0.0028957298f0
└   test = 0.088817276f0
 24.116243 seconds
┌ Info: Loss:
│   train = 0.0019454482f0
└   test = 0.078549586f0
julia> Threads.nthreads()
1

julia> versioninfo()
Julia Version 1.8.0-beta3
Commit 3e092a2521 (2022-03-29 15:42 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: 8 × AMD Ryzen 7 PRO 3700U w/ Radeon Vega Mobile Gfx
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, znver1)
  Threads: 1 on 8 virtual cores
Environment:
  JULIA_PATH = /usr/local/julia
  JULIA_GPG = 3673DF529D9049477F76B37566E3C7DC03D6E495
  JULIA_VERSION = 1.8.0-beta3

julia> 

@ChrisRackauckas
Copy link

Set your threads and run it again too (and run the Jax code). See https://docs.julialang.org/en/v1/manual/multi-threading/#Starting-Julia-with-multiple-threads for how to start Julia with multiple threads.

@chriselrod
Copy link

chriselrod commented Apr 15, 2022

In a fresh Julia session, so all first calls will be compiling (note that @time doesn't actually show the full compile time though, but I think it's reasonably close here):

julia> using SimpleChains

julia> function f(x)
         N = Base.isqrt(length(x))
         A = reshape(view(x, 1:N*N), (N,N))
         expA = exp(A)
         vec(expA)
       end
f (generic function with 1 method)

julia> T = Float32;

julia> D = 2 # 2x2 matrices
2

julia> X = randn(T, D*D, 10_000); # random input matrices

julia> Y = reduce(hcat, map(f, eachcol(X))); # `mapreduce` is not optimized for `hcat`, but `reduce` is

julia> Xtest = randn(T, D*D, 10_000);

julia> Ytest = reduce(hcat, map(f, eachcol(Xtest)));

julia> mlpd = SimpleChain(
         static(4),
         TurboDense(tanh, 32),
         TurboDense(tanh, 16),
         TurboDense(identity, 4)
       )
SimpleChain with the following layers:
TurboDense static(32) with bias.
Activation layer applying: tanh
TurboDense static(16) with bias.
Activation layer applying: tanh
TurboDense static(4) with bias.

julia> @time p = SimpleChains.init_params(mlpd);
  7.185457 seconds (4.30 M allocations: 226.262 MiB, 0.52% gc time, 100.00% compilation time)

julia> G = SimpleChains.alloc_threaded_grad(mlpd);

julia> mlpdloss = SimpleChains.add_loss(mlpd, SquaredLoss(Y));

julia> mlpdtest = SimpleChains.add_loss(mlpd, SquaredLoss(Ytest));

julia> report = let mtrain = mlpdloss, X=X, Xtest=Xtest, mtest = mlpdtest
         p -> begin
           let train = mlpdloss(X, p), test = mlpdtest(Xtest, p)
             @info "Loss:" train test
           end
         end
       end
#1 (generic function with 1 method)

julia> report(p)
┌ Info: Loss:
│   train = 13.415737f0
└   test = 12.306785f0

julia> for _ in 1:3
         @time SimpleChains.train_unbatched!(
           G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
         );
         report(p)
       end
  4.810973 seconds (13.03 M allocations: 686.357 MiB, 8.25% gc time, 89.76% compilation time)
┌ Info: Loss:
│   train = 0.011851382f0
└   test = 0.017254675f0
  0.410168 seconds
┌ Info: Loss:
│   train = 0.0037487738f0
└   test = 0.009099905f0
  0.410368 seconds
┌ Info: Loss:
│   train = 0.002041543f0
└   test = 0.0065089874f0

(simplechainsrelease) pkg> st
Status `~/Documents/progwork/julia/env/simplechainsrelease/Project.toml`
  [eb30cadb] MLDatasets v0.5.16
  [de6bee2f] SimpleChains v0.2.1

julia> versioninfo()
Julia Version 1.8.0-beta3
Commit 3e092a2521 (2022-03-29 15:42 UTC)
Platform Info:
  OS: Linux (x86_64-redhat-linux)
  CPU: 36 × Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, cascadelake)
  Threads: 18 on 36 virtual cores

On the same CPU, running this Jax script:

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Initial Train Loss: 6.4232
Initial Test Loss: 6.1088
Took: 9.26 seconds
Train Loss: 0.0304
Test Loss: 0.0268
Took: 8.98 seconds
Train Loss: 0.0036
Test Loss: 0.0156
Took: 9.01 seconds
Train Loss: 0.0018
Test Loss: 0.0111

So after compilation, I get >20x speedup.

Anyway, I do really appreciate you taking the time to translate a Jax example.
It takes a lot of time and effort for me to try and set up examples and do enough due diligence so that I can be reasonably confident that I am not treating the other libraries too unfairly, aside from the unfairness of not benchmarking them for their intended use cases.

SimpleChains was open sources and released today, but there is actually an old version that already was open source and registered, early on in its development.
There were a few breaking changes since then, and some additions, like alloc_threaded_gradient.
The fact you're getting wrong answers, and that alloc_threaded_gradient weren't defined make it seem like you're using the old version by mistake.

You can check with ] st or ] st SimpleChains, like I showed above.
I created a new project, which you can do via

julia --project=/path/to/project

and then installed SimpleChains (and MLDatasets, which isn't needed for this test) in this. I like to keep lots of small projects, as this reduces the risks of version conflicts. It could be you have other Julia packages that cause a conflict, causing you to get the old version of SimpleChains.

However, there were other breakages, like how TurboDense layers were defined, so you'd have had to have gotten other errors if you are on this old version. Which should've been enough to tip you off that maybe something odd is going on; if this is the case, I have to commend you for your patience getting the example to run at all!!!

Maybe the blog post should have emphasized version 0.2.1.

so slightly larger than 10x, but essentially lost when threading is disabled. So I'm also wondering if multithreading was disabled there since that's the only way I could recreate it.

The description said:

In comparison when running
https://julialang.org/blog/2022/04/simple-chains/#simplechainsjl_in_action_30x-ing_pytorch_in_tiny_example
(And switching out G = SimpleChains.alloc_threaded_grad(mlpd) for G = similar(p) to avoid
UndefVarError: alloc_threaded_grad not defined.)

While I'm still a little confused as to what version Patrick was using, if it was in fact the latest release, using G = similar(p) forces single threaded, as there is only enough gradient buffer for a single thread.
G is normally a matrix with 1 column per thread used (although not a Matrix, as it needs certain alignment to avoid false sharing).

(And note that SimpleChains is going to be a bit CPU-dependent since it tries to force AVX512 when it can, so that's where you get the 0.5 seconds with the Intel i9 10980XE. I want that chip now...)

It's amazing for benchmarks that can use AVX512 (which is most of them I write...), but there are rumors than Zen4 will feature AVX512, and I'm also looking forward to Sapphire Rapids-X from Intel. Both should be coming out this year.
I'll probably buy one of the two, depending on which is better at SIMD.

Also, with a single thread, I get:

julia> for _ in 1:3
         @time SimpleChains.train_unbatched!(
           G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
         );
         report(p)
       end
 12.635821 seconds (10.67 M allocations: 561.711 MiB, 1.42% gc time, 26.70% compilation time)
┌ Info: Loss:
│   train = 0.020518048f0
└   test = 0.1018267f0
  9.236013 seconds
┌ Info: Loss:
│   train = 0.0042917724f0
└   test = 0.06968568f0
  9.249139 seconds
┌ Info: Loss:
│   train = 0.002604088f0
└   test = 0.06139791f0

Which really isn't much better than your 5950X. Seems like the main advantage might be that it scales better with multithreading, perhaps because it is monolithic?

Looking at perf, I see that I have a ton of L1 data cache misses:

julia> @pstats "(cpu-cycles,task-clock),(instructions,branch-instructions,branch-misses), (L1-dcache-load-misses, L1-dcache-loads, cache-misses, cache-references)" begin
           SimpleChains.train_unbatched!(
                  G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
                );
       end
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
┌ cpu-cycles               3.93e+10   33.3%  #  4.2 cycles per ns
└ task-clock               9.43e+09   33.3%  #  9.4 s
┌ instructions             4.18e+10   66.7%  #  1.1 insns per cycle
│ branch-instructions      1.50e+09   66.7%  #  3.6% of instructions
└ branch-misses            5.33e+05   66.7%  #  0.0% of branch instructions
┌ L1-dcache-load-misses    3.61e+09   33.3%  # 30.2% of dcache loads
│ L1-dcache-loads          1.20e+10   33.3%
│ cache-misses             2.19e+04   33.3%  #  0.0% of cache references
└ cache-references         3.03e+09   33.3%
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

While multithreaded, I get

julia> @pstats "(cpu-cycles,task-clock),(instructions,branch-instructions,branch-misses), (L1-dcache-load-misses, L1-dcache-loads, cache-misses, cache-references)" begin
                 SimpleChains.train_unbatched!(
                        G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
                      );
             end
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
┌ cpu-cycles               3.30e+10   33.4%  #  4.0 cycles per ns
└ task-clock               8.21e+09   33.4%  #  8.2 s
┌ instructions             4.29e+10   66.7%  #  1.3 insns per cycle
│ branch-instructions      1.76e+09   66.7%  #  4.1% of instructions
└ branch-misses            1.36e+07   66.7%  #  0.8% of branch instructions
┌ L1-dcache-load-misses    3.59e+09   33.3%  # 29.6% of dcache loads
│ L1-dcache-loads          1.21e+10   33.3%
│ cache-misses             6.94e+02   33.3%  #  0.0% of cache references
└ cache-references         4.63e+07   33.3%
                 aggregated from 18 threads
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

I.e., I actually have higher IPC when multithreaded, and less overall CPU time. That is, I have better than linear scaling here.
The L1d cache misses aren't much better. I'm not sure, but I don't think that can explain the difference. Branch miss rate is more significant.
When I see this, I normally assume it's related to the large L2 cache (1 MiB); each core has a private L2, so by using multiple cores, we get a lot more cache overall.

@chriselrod
Copy link

Oh, really important to note when comparing losses: SquaredError in SimpleChains.jl multiplies the error by 1/2.
I could change it if that's too confusing, but it does mean a mean squared error of 0.05 in SimpleChains is the same as an error of 0.1 in Jax/PyTorch.

So the multithreaded errors from SimpleChains above:

  0.410368 seconds
┌ Info: Loss:
│   train = 0.002041543f0
└   test = 0.0065089874f0

Are about comparable to what I see from Jax:

Took: 9.01 seconds
Train Loss: 0.0018
Test Loss: 0.0111

But, I need to look into why the initial errors are so much higher than what I got from PyTorch or am getting from Jax here.

I also haven't tested single threaded fitting much; something seems to be going very wrong here:

  9.249139 seconds
┌ Info: Loss:
│   train = 0.002604088f0
└   test = 0.06139791f0

Why is it 10x higher?
I'll take a look...

@chriselrod
Copy link

chriselrod commented Apr 15, 2022

Although, this is the opposite of what Chris R observed.
For me, the pattern is fairly consistent; multithreaded:

julia> SimpleChains.init_params!(mlpdloss, p);

julia> for _ in 1:3
         @time SimpleChains.train_unbatched!(
           G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
         );
         report(p)
       end
  0.463921 seconds
┌ Info: Loss:
│   train = 0.0129161235f0
└   test = 0.018706292f0
  0.406318 seconds
┌ Info: Loss:
│   train = 0.0032260346f0
└   test = 0.009573642f0
  0.409833 seconds
┌ Info: Loss:
│   train = 0.0021066444f0
└   test = 0.0070530926f0

julia> SimpleChains.init_params!(mlpdloss, p);

julia> for _ in 1:3
         @time SimpleChains.train_unbatched!(
           G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
         );
         report(p)
       end
  0.412915 seconds
┌ Info: Loss:
│   train = 0.012180713f0
└   test = 0.019388271f0
  0.409664 seconds
┌ Info: Loss:
│   train = 0.0027349133f0
└   test = 0.0103735635f0
  0.409908 seconds
┌ Info: Loss:
│   train = 0.001797435f0
└   test = 0.007536795f0

julia> SimpleChains.init_params!(mlpdloss, p);

julia> for _ in 1:3
         @time SimpleChains.train_unbatched!(
           G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
         );
         report(p)
       end
  0.412446 seconds
┌ Info: Loss:
│   train = 0.012522352f0
└   test = 0.017825317f0
  0.413578 seconds
┌ Info: Loss:
│   train = 0.0029457442f0
└   test = 0.008456751f0
  0.410407 seconds
┌ Info: Loss:
│   train = 0.0019602645f0
└   test = 0.00613266f0

Single threaded:

julia> SimpleChains.init_params!(mlpdloss, p);

julia> for _ in 1:3
         @time SimpleChains.train_unbatched!(
           G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
         );
         report(p)
       end
  9.394842 seconds
┌ Info: Loss:
│   train = 0.014558737f0
└   test = 0.091713026f0
  9.425004 seconds
┌ Info: Loss:
│   train = 0.0032704105f0
└   test = 0.06821577f0
  9.434632 seconds
┌ Info: Loss:
│   train = 0.0021691609f0
└   test = 0.060845226f0

julia> SimpleChains.init_params!(mlpdloss, p);

julia> for _ in 1:3
         @time SimpleChains.train_unbatched!(
           G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
         );
         report(p)
       end
  9.368543 seconds
┌ Info: Loss:
│   train = 0.016641248f0
└   test = 0.08612959f0
  9.386725 seconds
┌ Info: Loss:
│   train = 0.0030917034f0
└   test = 0.052406352f0
  9.405264 seconds
┌ Info: Loss:
│   train = 0.0021836497f0
└   test = 0.04452359f0

julia> SimpleChains.init_params!(mlpdloss, p);

julia> for _ in 1:3
         @time SimpleChains.train_unbatched!(
           G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
         );
         report(p)
       end
  9.381625 seconds
┌ Info: Loss:
│   train = 0.01729222f0
└   test = 0.09130784f0
  9.396374 seconds
┌ Info: Loss:
│   train = 0.004226131f0
└   test = 0.07051856f0
  9.376561 seconds
┌ Info: Loss:
│   train = 0.0027545807f0
└   test = 0.062378857f0

@ChrisRackauckas
Copy link

For me it's just due to initializations:

  5.365769 seconds (14.81 M allocations: 917.160 MiB, 6.05% gc time, 75.77% compilation time)
┌ Info: Loss:
│   train = 0.012597987f0
└   test = 0.023731364f0
  1.284829 seconds
┌ Info: Loss:
│   train = 0.0025600272f0
└   test = 0.010612518f0
  1.307484 seconds
┌ Info: Loss:
│   train = 0.0015771794f0
└   test = 0.008084581f0

  1.301913 seconds
┌ Info: Loss:
│   train = 0.0015198826f0
└   test = 0.007395477f0
  1.303582 seconds
┌ Info: Loss:
│   train = 0.0010119737f0
└   test = 0.0063088294f0
  1.310925 seconds
┌ Info: Loss:
│   train = 0.00095060834f0
└   test = 0.0058895415f0

julia> SimpleChains.init_params!(mlpdloss, p);

julia> report(p)
┌ Info: Loss:
│   train = 14.027461f0
└   test = 14.123837f0

julia> for _ in 1:3
         @time SimpleChains.train_unbatched!(
           G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
         );
         report(p)
       end
  1.313561 seconds
┌ Info: Loss:
│   train = 0.0147712445f0
└   test = 0.028131848f0
  1.315667 seconds
┌ Info: Loss:
│   train = 0.0036588982f0
└   test = 0.014324798f0
  1.315291 seconds
┌ Info: Loss:
│   train = 0.0024623652f0
└   test = 0.011305514f0

both threaded of course.

@chriselrod
Copy link

With this PR, the number of threads will no longer influence training behavior: PumasAI/SimpleChains.jl#52
That PR also removes the 0.5x multiplier.

But it defines mean as w/ respect to batch size instead of total data length, meaning we're dividing by 10k instead of 40k, hence equivalent accuracy means the error SimpleChains reports should about 4x higher, which is now what I observe:

julia> SimpleChains.init_params!(mlpdloss, p);

julia> report(p)
┌ Info: Loss:
│   train = 24.167797f0
└   test = 23.112621f0

julia> for _ in 1:3
         @time SimpleChains.train_unbatched!(
           G, p, mlpdloss, X, SimpleChains.ADAM(), 10_000
         );
         report(p)
       end
  0.413229 seconds
┌ Info: Loss:
│   train = 0.024573077f0
└   test = 0.110939406f0
  0.409671 seconds
┌ Info: Loss:
│   train = 0.0071228216f0
└   test = 0.06868424f0
  0.409380 seconds
┌ Info: Loss:
│   train = 0.0048818965f0
└   test = 0.0587865f0

So accuracy is similar on my computer, but SimpleChains takes just over 0.4 seconds per 10k epochs instead of 9 seconds.

Anyway, I should probably add more options to let users control the error function's behavior, e.g. what it takes the mean with respect to (if it takes the mean at all).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment