Skip to content

Instantly share code, notes, and snippets.

@ericphanson
Created June 26, 2021 23:55
Show Gist options
  • Save ericphanson/03c9905a24b5b40f0108edb381ae2da7 to your computer and use it in GitHub Desktop.
Save ericphanson/03c9905a24b5b40f0108edb381ae2da7 to your computer and use it in GitHub Desktop.
TransformDagger

TransformDagger

A simple restricted/expanded implementation of DataFrames.transform, using Dagger, for generic Tables.

# This file is machine-generated - editing it directly is not advised
[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.11.0"
[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "Reexport"]
git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.12.8"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.31.0"
[[Crayons]]
git-tree-sha1 = "3f71217b538d7aaee0b69ab47d9b7724ca8afa0d"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.4"
[[Dagger]]
deps = ["Colors", "Distributed", "LinearAlgebra", "MemPool", "Profile", "Random", "Requires", "Serialization", "SharedArrays", "SparseArrays", "Statistics", "StatsBase"]
git-tree-sha1 = "8d59bf882d9c8a1e5eb64207ee830e4efdfdc940"
uuid = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
version = "0.11.3"
[[DataAPI]]
git-tree-sha1 = "ee400abb2298bd13bfc3df1c412ed228061a2385"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.7.0"
[[DataFrames]]
deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
git-tree-sha1 = "66ee4fe515a9294a8836ef18eea7239c6ac3db5e"
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
version = "1.1.1"
[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.9"
[[DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
version = "1.0.0"
[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
[[ExprTools]]
git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e"
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
version = "0.1.3"
[[FixedPointNumbers]]
deps = ["Statistics"]
git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.8.4"
[[Formatting]]
deps = ["Printf"]
git-tree-sha1 = "8339d61043228fdd3eb658d86c926cb282ae72a8"
uuid = "59287772-0a20-5a39-b81b-1366585eb4c0"
version = "0.4.2"
[[Future]]
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[InvertedIndices]]
deps = ["Test"]
git-tree-sha1 = "15732c475062348b0165684ffe28e85ea8396afc"
uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
version = "1.0.0"
[[IteratorInterfaceExtensions]]
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
uuid = "82899510-4779-5014-852e-03e436cf321d"
version = "1.0.0"
[[LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
[[LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
[[LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
[[MemPool]]
deps = ["DataStructures", "Distributed", "Mmap", "Random", "Serialization", "Sockets"]
git-tree-sha1 = "cb17c1dff8d9c89065c55ac4b0222b93d147e983"
uuid = "f9f48841-c794-520a-933b-121f7ba6ed94"
version = "0.3.4"
[[Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "1.0.0"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
[[OrderedCollections]]
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.1"
[[Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[PooledArrays]]
deps = ["DataAPI", "Future"]
git-tree-sha1 = "cde4ce9d6f33219465b55162811d8de8139c0414"
uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
version = "1.2.1"
[[PrettyTables]]
deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"]
git-tree-sha1 = "0d1245a357cc61c8cd61934c07447aa569ff22e6"
uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
version = "1.1.0"
[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[Profile]]
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Reexport]]
git-tree-sha1 = "5f6c21241f0f655da3952fd60aa18477cf96c220"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.1.0"
[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.3"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[SortingAlgorithms]]
deps = ["DataStructures"]
git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "1.0.0"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsAPI]]
git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.0.0"
[[StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.8"
[[TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
[[TableTraits]]
deps = ["IteratorInterfaceExtensions"]
git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39"
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
version = "1.0.1"
[[Tables]]
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"]
git-tree-sha1 = "8ed4a3ea724dac32670b062be3ef1c1de6773ae8"
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "1.4.4"
[[Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
[[Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TimerOutputs]]
deps = ["ExprTools", "Printf"]
git-tree-sha1 = "9f494bc54b4c31404a9eff449235836615929de1"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.10"
[[TrackingTimers]]
deps = ["Distributed", "PrettyTables", "Printf", "Tables"]
git-tree-sha1 = "16e1d1b40436284f6a0f3b965101da3f8c807564"
uuid = "88ba133c-8695-4d62-9a5c-bcf16b6d2e1a"
version = "0.1.2"
[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
[[p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
[deps]
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
TrackingTimers = "88ba133c-8695-4d62-9a5c-bcf16b6d2e1a"
using Distributed
nprocs() >= 2 || addprocs(2)
@everywhere using LinearAlgebra
@everywhere using TrackingTimers
@everywhere begin
# borrowed from DataFrames.jl
struct ByRow{F}
f::F
end
function (b::ByRow)(args...)
return [b.f((arg[i] for arg in args)...) for i in eachindex(args...)]
end
function hi(a, b)
return (; x=a * b, y=a + b)
end
svdfn(x, y) = svdvals(x * y')
function garbo(a)
@time begin
m = rand(length(a), length(a), length(a))
m = rand(length(a), length(a), length(a))
m = rand(length(a), length(a), length(a))
m = rand(length(a), length(a), length(a))
m = dropdims(sum(m; dims=3); dims=3)
r = diag(m) .+ a
end
return r
end
BLAS.set_num_threads(1)
end
stringify(f::ByRow) = string("ByRow(", stringify(f.f), ")")
include("transform.jl")
test_table = (; a=randn(1000), b=rand(1000))
result, log, t = @time transform(test_table, :a => (x -> 2x) => :c,
(:a, :c) => svdfn => :svd, (:svd, :a) => (+) => :sum,
(:sum, :svd) => ByRow(hi) => [:x, :y],
[:y, :a] => (+) => :z, (:a, :b) => svdfn => :svd_ab,
:b => garbo => :g);
open("logs.gv"; write=true) do io
return Dagger.show_plan(io, Dagger.get_logs!(log))
end
using Dagger, Tables, OrderedCollections
using PrettyTables, TrackingTimers
# unpack syntax `input_cols => f => output_cols`
function decompose_pairs(p::Pair{<:Any,<:Pair})
input = first(p)
f = first(last(p))
output = last(last(p))
return input, f, output
end
stringify(f) = repr(f; context=:compact => true)
function instrument(t::TrackingTimer, p::Pair{<:Any,<:Pair})
input, f, output = decompose_pairs(p)
name = string(input, " ↦ ", stringify(f), " ↦ ", output)
return input => t(f, name) => output
end
wrap(input::Symbol) = tuple(input)
wrap(input) = input
columnify(f) = (args...) -> Tables.columns(f(args...))
# input: any Tables.jl table.
# Output: an `OrderedDict` of columns, which is a Tables.jl column table
function transform(table, ps...)
t = TrackingTimer()
ctx = Context()
log = Dagger.LocalEventLog()
ctx.log_sink = log
tab = Tables.columns(table)
delayed_cols = OrderedDict{Symbol,Thunk}()
# Pre-populate with existing columns
for c in Tables.columnnames(tab)
col = delayed(Tables.getcolumn)(tab, c)
delayed_cols[c] = col
end
# Add in new columns from transformations
for p in ps
input, f, output = decompose_pairs(instrument(t, p))
cols = (delayed_cols[i] for i in wrap(input))
if length(wrap(output)) > 1
result = delayed(columnify(f); cache=true)(cols...)
for col in output
# not sure if `getcolumn` should be cached here... it should be very cheap
delayed_cols[col] = delayed(t(Tables.getcolumn); cache=true)(result,
col)
end
else
delayed_cols[output] = delayed(f; cache=true)(cols...)
end
end
# Collect results
result = OrderedDict{Symbol,AbstractVector}()
for (k, v) in delayed_cols
result[k] = collect(ctx, v)
end
@info "Timing information" t
return result, log, t
end
@ericphanson
Copy link
Author

Results:

julia> include("test.jl")
      From worker 3:     30.067755 seconds (187 allocations: 29.810 GiB, 5.14% gc time)
┌ Info: Timing information
│   t =
│    TrackingTimer: 31.38 s since creation (98% measured).
│                    name                   time    gctime  n_allocs    allocs    thread ID  proc ID 
│    ────────────────────────────────────────────────────────────────────────────────────────────────
│     b  garbo  g                        30.10 s      5%       236  29.810 GiB          2        3
│     (:a, :b)  svdfn  svd_ab             0.30 s      2%      1093  15.891 MiB          2        3
│     (:a, :c)  svdfn  svd                0.28 s      0%        44  15.841 MiB          2        3
│     (:sum, :svd)  ByRow(hi)  [:x, :y]   0.01 s      0%       575  45.547 KiB          2        3
│     (:svd, :a)  +  sum                  0.00 s      0%         2   7.969 KiB          2        3                                                                   4 rows omitted

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment