Created
November 1, 2023 20:35
-
-
Save JohnJocoo/af581c71db5d9acde81f10e148ddeef7 to your computer and use it in GitHub Desktop.
Testing how Nx tensors caching work
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Mix.install([{:nx, "~> 0.5"}, {:exla, "~> 0.5"}, {:benchee, "~> 1.0"}]) | |
Nx.global_default_backend(EXLA.Backend) | |
Application.put_env(:exla, :default_client, :host) | |
Nx.Defn.default_options(compiler: EXLA, client: :host) | |
defmodule Test do | |
import Nx.Defn | |
@nutrients Enum.map(1..100, &String.to_atom("name#{&1}")) | |
@data Enum.map(1..35, fn _ -> Map.new(@nutrients, &{&1, :random.uniform()}) end) | |
defn test_defn(tensor) do | |
average = | |
tensor | |
|> Nx.sum(axes: [0]) | |
|> Nx.divide(35) | |
tensor | |
|> Nx.multiply(2.5) | |
|> Nx.multiply(2.5) | |
|> Nx.sum(axes: [0]) | |
|> Nx.divide(35) | |
|> Nx.add(average) | |
end | |
defn test_defn2(tensors) do | |
Nx.stack(tensors, axis: 0) | |
|> test_defn() | |
end | |
def call_nx() do | |
@data | |
|> Enum.map(fn d -> | |
@nutrients | |
|> Enum.map(&Map.fetch!(d, &1)) | |
end) | |
|> Nx.tensor(type: {:f, 32}) | |
|> Test.test_defn() | |
|> Nx.to_list() | |
end | |
def call_nx_preloaded() do | |
Agent.get(:cache, fn tensors -> tensors end) | |
|> Nx.stack(axis: 0) | |
|> test_defn() | |
|> Nx.to_list() | |
end | |
def call_nx_preloaded2() do | |
Agent.get(:cache, fn tensors -> tensors end) | |
|> List.to_tuple() | |
|> test_defn2() | |
|> Nx.to_list() | |
end | |
def preload() do | |
Agent.start(fn -> | |
@data | |
|> Enum.map(fn d -> | |
@nutrients | |
|> Enum.map(&Map.fetch!(d, &1)) | |
|> Nx.tensor(type: {:f, 32}) | |
end) | |
end, name: :cache) | |
end | |
end | |
Test.preload() | |
Benchee.run( | |
%{ | |
"nx" => fn -> Test.call_nx() end, | |
"nx_preloaded" => fn -> Test.call_nx_preloaded() end, | |
"nx_preloaded2" => fn -> Test.call_nx_preloaded2() end | |
} | |
) | |
#Operating System: macOS | |
#CPU Information: Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz | |
#Number of Available Cores: 16 | |
#Available memory: 32 GB | |
#Elixir 1.14.5 | |
#Erlang 25.3.2.5 | |
# | |
#Benchmark suite executing with the following configuration: | |
#warmup: 2 s | |
#time: 5 s | |
#memory time: 0 ns | |
#reduction time: 0 ns | |
#parallel: 1 | |
#inputs: none specified | |
#Estimated total run time: 21 s | |
# | |
#Benchmarking nx ... | |
#Benchmarking nx_preloaded ... | |
#Benchmarking nx_preloaded2 ... | |
# | |
#Name ips average deviation median 99th % | |
#nx 1505.91 0.66 ms ±21.92% 0.61 ms 1.21 ms | |
#nx_preloaded 885.25 1.13 ms ±18.85% 1.10 ms 1.88 ms | |
#nx_preloaded2 782.22 1.28 ms ±23.16% 1.23 ms 2.06 ms | |
# | |
#Comparison: | |
#nx 1505.91 | |
#nx_preloaded 885.25 - 1.70x slower +0.47 ms | |
#nx_preloaded2 782.22 - 1.93x slower +0.61 ms | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment