Skip to content

Instantly share code, notes, and snippets.

@JohnJocoo
Created November 1, 2023 20:35
Show Gist options
  • Save JohnJocoo/af581c71db5d9acde81f10e148ddeef7 to your computer and use it in GitHub Desktop.
Save JohnJocoo/af581c71db5d9acde81f10e148ddeef7 to your computer and use it in GitHub Desktop.
Testing how Nx tensors caching work
Mix.install([{:nx, "~> 0.5"}, {:exla, "~> 0.5"}, {:benchee, "~> 1.0"}])
Nx.global_default_backend(EXLA.Backend)
Application.put_env(:exla, :default_client, :host)
Nx.Defn.default_options(compiler: EXLA, client: :host)
defmodule Test do
import Nx.Defn
@nutrients Enum.map(1..100, &String.to_atom("name#{&1}"))
@data Enum.map(1..35, fn _ -> Map.new(@nutrients, &{&1, :random.uniform()}) end)
defn test_defn(tensor) do
average =
tensor
|> Nx.sum(axes: [0])
|> Nx.divide(35)
tensor
|> Nx.multiply(2.5)
|> Nx.multiply(2.5)
|> Nx.sum(axes: [0])
|> Nx.divide(35)
|> Nx.add(average)
end
defn test_defn2(tensors) do
Nx.stack(tensors, axis: 0)
|> test_defn()
end
def call_nx() do
@data
|> Enum.map(fn d ->
@nutrients
|> Enum.map(&Map.fetch!(d, &1))
end)
|> Nx.tensor(type: {:f, 32})
|> Test.test_defn()
|> Nx.to_list()
end
def call_nx_preloaded() do
Agent.get(:cache, fn tensors -> tensors end)
|> Nx.stack(axis: 0)
|> test_defn()
|> Nx.to_list()
end
def call_nx_preloaded2() do
Agent.get(:cache, fn tensors -> tensors end)
|> List.to_tuple()
|> test_defn2()
|> Nx.to_list()
end
def preload() do
Agent.start(fn ->
@data
|> Enum.map(fn d ->
@nutrients
|> Enum.map(&Map.fetch!(d, &1))
|> Nx.tensor(type: {:f, 32})
end)
end, name: :cache)
end
end
Test.preload()
Benchee.run(
%{
"nx" => fn -> Test.call_nx() end,
"nx_preloaded" => fn -> Test.call_nx_preloaded() end,
"nx_preloaded2" => fn -> Test.call_nx_preloaded2() end
}
)
#Operating System: macOS
#CPU Information: Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
#Number of Available Cores: 16
#Available memory: 32 GB
#Elixir 1.14.5
#Erlang 25.3.2.5
#
#Benchmark suite executing with the following configuration:
#warmup: 2 s
#time: 5 s
#memory time: 0 ns
#reduction time: 0 ns
#parallel: 1
#inputs: none specified
#Estimated total run time: 21 s
#
#Benchmarking nx ...
#Benchmarking nx_preloaded ...
#Benchmarking nx_preloaded2 ...
#
#Name ips average deviation median 99th %
#nx 1505.91 0.66 ms ±21.92% 0.61 ms 1.21 ms
#nx_preloaded 885.25 1.13 ms ±18.85% 1.10 ms 1.88 ms
#nx_preloaded2 782.22 1.28 ms ±23.16% 1.23 ms 2.06 ms
#
#Comparison:
#nx 1505.91
#nx_preloaded 885.25 - 1.70x slower +0.47 ms
#nx_preloaded2 782.22 - 1.93x slower +0.61 ms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment