Skip to content

Instantly share code, notes, and snippets.

@Moelf
Last active June 10, 2023 18:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Moelf/43f181bbfc90cc7b2a2c51f4d98c3050 to your computer and use it in GitHub Desktop.
Save Moelf/43f181bbfc90cc7b2a2c51f4d98c3050 to your computer and use it in GitHub Desktop.
How to efficiently cache I/O in a user-transparent way (`@threads` compatible)

Comparison

julia> includet("./tmp.jl")

julia> @time user_code(LRU_getindex);
  6.539568 seconds (840.07 k allocations: 69.739 MiB, 0.18% gc time, 1.18% compilation time: 33% of which was recompilation)

julia> @time user_code(illegal_getindex);
  3.374077 seconds (252.80 k allocations: 53.082 MiB, 2.87% compilation time: 25% of which was recompilation)

Speculation of the root problem

The illegal approach avoids expensive "hash table look up", since it first check if the cluster range is cached.

One observations is that, the slowdown ratio widens as ClusterSize increases, which is counter intuitive. This can only be explained if the majority of time is spent on checking if a cluster is cached, which happens every getindex(). The illegal approach avoids it.

Based on this, I have also tried to use ConcurrentDict, which seems to be a linear probe dict (no hash table), but the performance is not much better

using LRUCache
const Nevents = 10^5
const ClusterSize = 10000
## ====================shared utility and mock function================================
function _mock_io(cluster)
# can be cached by using `cluster` or its `start` as key
start = first(cluster)
res = [collect(i:i+20)./start for i in 1:ClusterSize]
for _ in 1:400, i in res
i .= sin.(i)
i .= exp.(i)
i .= cos.(i)
end
return res
end
const all_indicies = 1:Nevents
# this does not have to be evenly distributed
const all_ranges = collect(Base.Iterators.partition(all_indicies, ClusterSize))
# everytime user index into the column, they get back 1 element in a cluster
# we need to find the cluster range that contains this index
function _findrange(cluster_ranges, idx)
# this in reality needs more search, thus this small CPU-burning loop
for i in 1:20000
sin(i)
end
for cluster in cluster_ranges
first_entry = first(cluster)
n_entries = length(cluster) # the real structure record this instead of last()
if first_entry + n_entries - 1 >= idx
return cluster
end
end
end
## ====================different getindex implementations=========================
function no_cache_getindex(idx)
cluster = _findrange(all_ranges, idx)
localidx = idx - first(cluster) + 1
data = _mock_io(cluster)
return data[localidx]
end
const lru = LRU{Int64, Vector{Vector{Float64}}}(; maxsize = Threads.nthreads())
function LRU_getindex(idx)
cluster = _findrange(all_ranges, idx)
# start is enough to be the key
start = first(cluster)
data = get!(lru, start) do
res = _mock_io(cluster)
end
localidx = idx - first(cluster) + 1
return data[localidx]
end
# this is considered a bug due to task migration
const illegal_cache = [Vector{Vector{Float64}}() for _ in 1:Threads.nthreads()]
const illegal_cache_ranges = [0:-1 for _ in 1:Threads.nthreads()]
function illegal_getindex(idx)
tid = Threads.threadid()
cluster = illegal_cache_ranges[tid]
local data
if idx ∉ cluster
# if idx is outside of current cache, fetch new data
cluster = _findrange(all_ranges, idx)
data = illegal_cache[tid] = _mock_io(cluster)
illegal_cache_ranges[tid] = cluster
else
data = illegal_cache[tid]
end
localidx = idx - first(cluster) + 1
return data[localidx]
end
## ================================================================
##################### user code ############################
function user_code(MyGet::F) where F
Threads.@threads for i in 1:Nevents
sum(MyGet(i))
end
empty!(lru)
illegal_cache_ranges .= Ref(0:-1)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment