Skip to content

Instantly share code, notes, and snippets.

@chethega
Last active February 28, 2019 22:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chethega/b5268b29e55b44fc96c7726eae5968b2 to your computer and use it in GitHub Desktop.
Save chethega/b5268b29e55b44fc96c7726eae5968b2 to your computer and use it in GitHub Desktop.
canonical equipartitions of graphs via diffusion
#fallback aes via mbedtls.
#officially, mbedtls is a binary dependency of julia, so this should always work?
#also, results are (same as other versions) dependent on system endianness.
#factor 5 slower than native on my system, and causes no allocations. Probably fast enough.
module AESmbed
using Libdl
export Aes_wrap
libmbed = Libdl.dlopen("libmbedcrypto")
const cipher_init = Libdl.dlsym(libmbed, "mbedtls_cipher_init")
const cipher_free = Libdl.dlsym(libmbed, "mbedtls_cipher_free")
const cipher_info_from_values = Libdl.dlsym(libmbed, "mbedtls_cipher_info_from_values")
const cipher_setup = Libdl.dlsym(libmbed, "mbedtls_cipher_setup")
const cipher_setkey = Libdl.dlsym(libmbed, "mbedtls_cipher_setkey")
#const cipher_update = Libdl.dlsym(libmbed, "mbedtls_cipher_update")
const cipher_crypt = Libdl.dlsym(libmbed, "mbedtls_cipher_crypt")
mutable struct Aes_wrap
#200 bytes, opaque structure.
data::NTuple{25, UInt64}
@noinline function Aes_wrap(key::UInt128)
ctx = new()
kr = Ref(key)
GC.@preserve ctx kr begin
ccall(cipher_init, Cvoid,
(Ptr{Cvoid},), pointer_from_objref(ctx))
# MBEDTLS_CIPHER_ID_AES: 2, keylen:128, MBEDTLS_MODE_ECB:1
ci = ccall(cipher_info_from_values, Ptr{Cvoid},
(Cint, Cint, Cint), 2, 128, 1)
ccall(cipher_setup, Cint,
(Ptr{Cvoid}, Ptr{Cvoid}), pointer_from_objref(ctx), ci)
#keylen:128, ENCRYPT:1
ccall(cipher_setkey, Cint,
(Ptr{Cvoid}, Ptr{Cvoid}, Cint, Cint),
pointer_from_objref(ctx), kr, 128, 1)
end
finalizer(ctx) do ctx_
GC.@preserve ctx_ ccall(cipher_free, Cvoid,
(Ptr{Cvoid},), pointer_from_objref(ctx_))
end
return ctx
end
end
Base.show(io::IO, a::Aes_wrap) = print(io, "Aes_wrap($(pointer_from_objref(a)))")
(ks::Aes_wrap)(u) = enc(ks, u)
#=
slower variant
function enc(ks::Aes_wrap, v::UInt128)
res = Ref(UInt128(0))
#rv = Ref(v)
#rs = Ref(UInt(16))
GC.@preserve ks res begin ccall(cipher_update, Cint,
(Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Ptr{Cvoid}),
pointer_from_objref(ks), Ref(v), 16, res, Ref(16))
end
return res[]
end
=#
function enc(ks::Aes_wrap, v::UInt128)
res = Ref(UInt128(0))
GC.@preserve ks res ccall(cipher_crypt, Cint,
(Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Ptr{Csize_t}),
pointer_from_objref(ks), Ref(UInt128(0)), 16, Ref(v), 16,
res, Ref{Csize_t}(16))
return res[]
end
end
module AESni
#todo: sometimes julia fails at vectorization for some operations. Should be 30% faster.
#todo: ugly, refactor
#todo: Either memalign keyschedules to cache-line or cut cruft. Current layout means that with alignment:
# -encryption and CTR touch 3 lines
# -decryption touches 3 different lines
using Core.Intrinsics: llvmcall
using Random
const _u128 = NTuple{2, Base.VecElement{UInt64}}
Base.convert(::Type{_u128}, x::UInt128) = (Base.VecElement{UInt64}(x<<64>>64), Base.VecElement{UInt64}(x>>64))
Base.convert(::Type{UInt128}, x::_u128) = (UInt128(x[2].value)<<64 | UInt128(x[1].value) )
u128(x,y)=(convert(Base.VecElement{UInt64},x),convert(Base.VecElement{UInt64},y))
_xor(a,b)=u128(xor(a[1].value, b[1].value), xor(a[2].value, b[2].value))
mutable struct AES_ks #6 cache-lines. we don't get alignment to cache-line, so could be better to cut d0, d10, pad (such that we go down to 5.25 lines).
k0::_u128
k1::_u128
k2::_u128
k3::_u128
k4::_u128
k5::_u128
k6::_u128
k7::_u128
k8::_u128
k9::_u128
k10::_u128
ctr::_u128
d10::_u128 # k10
d9::_u128
d8::_u128
d7::_u128
d6::_u128
d5::_u128
d4::_u128
d3::_u128
d2::_u128
d1::_u128
d0::_u128# k0
pad::_u128 #can fill up last part of cache_line, why not.
function AES_ks(key = nothing)
n=new()
finalizer(n) do x
GC.@preserve x ccall(:memset, Nothing, (Ptr{Nothing}, Cint, Csize_t), pointer_from_objref(x), 0, sizeof(x))
end
if key == nothing
load_key!(n, convert(_u128, rand(Random.RandomDevice(), UInt128) ))
else
load_key!(n, convert(_u128, key))
end
return n
end
end
get_key(ks::AES_ks) = convert(UInt128, ks.k0)
get_counter(ks::AES_ks) = convert(UInt128,ks.ctr)
(ks::AES_ks)(val) = enc(ks, convert(UInt128,val))
function Base.show(io::IO, ks::AES_ks)
print(io, "AES_ks[...]")
end
#for use in CTR-mode / stream / random number generation.
@inline function gen_nxt!(gen::AES_ks)
res= enc(gen, gen.ctr)
gen.ctr = convert(_u128, convert(UInt128,gen.ctr) + 1)
res
end
@inline function enc(key::AES_ks, m_)
m = convert(_u128, m_)
m = _xor(m, key.k0)
m = _aesni_enc(m, key.k1)
m = _aesni_enc(m, key.k2)
m = _aesni_enc(m, key.k3)
m = _aesni_enc(m, key.k4)
m = _aesni_enc(m, key.k5)
m = _aesni_enc(m, key.k6)
m = _aesni_enc(m, key.k7)
m = _aesni_enc(m, key.k8)
m = _aesni_enc(m, key.k9)
m = _aesni_enclast(m, key.k10)
return convert(typeof(m_), m)
end
@inline function dec(key::AES_ks, m_)
m = convert(_u128, m_)
m = _xor(m, key.d10)
m = _aesni_dec(m, key.d9)
m = _aesni_dec(m, key.d8)
m = _aesni_dec(m, key.d7)
m = _aesni_dec(m, key.d6)
m = _aesni_dec(m, key.d5)
m = _aesni_dec(m, key.d4)
m = _aesni_dec(m, key.d3)
m = _aesni_dec(m, key.d2)
m = _aesni_dec(m, key.d1)
m = _aesni_declast(m, key.d0)
return convert(typeof(m_), m)
end
@inline function _aesni_enc(a::_u128, b::_u128)
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aesenc(<2 x i64>, <2 x i64>) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aesenc(<2 x i64> %0, <2 x i64> %1)
ret <2 x i64> %res"), _u128, Tuple{_u128,_u128}, a, b)
end
@inline function _aesni_enclast(a::_u128, b::_u128)
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aesenclast(<2 x i64>, <2 x i64>) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aesenclast(<2 x i64> %0, <2 x i64> %1)
ret <2 x i64> %res"), _u128, Tuple{_u128,_u128}, a, b)
end
@inline function _aesni_dec(a::_u128, b::_u128)
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aesdec(<2 x i64>, <2 x i64>) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aesdec(<2 x i64> %0, <2 x i64> %1)
ret <2 x i64> %res"), _u128, Tuple{_u128,_u128}, a, b)
end
@inline function _aesni_declast(a::_u128, b::_u128)
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aesdeclast(<2 x i64>, <2 x i64>) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aesdeclast(<2 x i64> %0, <2 x i64> %1)
ret <2 x i64> %res"), _u128, Tuple{_u128,_u128}, a, b)
end
#key loading.
function load_key!(ks::AES_ks, key = rand(Random.RandomDevice(), UInt128))
ks.k0 = convert(_u128, key)
tw = _aesni_keygenassist_1(key)
k = _keyexpand(key, tw)
ks.k1 = k
tw = _aesni_keygenassist_2(k)
k = _keyexpand(k, tw)
ks.k2 = k
tw = _aesni_keygenassist_3(k)
k = _keyexpand(k, tw)
ks.k3 = k
tw = _aesni_keygenassist_4(k)
k = _keyexpand(k, tw)
ks.k4 = k
tw = _aesni_keygenassist_5(k)
k = _keyexpand(k, tw)
ks.k5 = k
tw = _aesni_keygenassist_6(k)
k = _keyexpand(k, tw)
ks.k6 = k
tw = _aesni_keygenassist_7(k)
k = _keyexpand(k, tw)
ks.k7 = k
tw = _aesni_keygenassist_8(k)
k = _keyexpand(k, tw)
ks.k8 = k
tw = _aesni_keygenassist_9(k)
k = _keyexpand(k, tw)
ks.k9 = k
tw = _aesni_keygenassist_10(k)
k = _keyexpand(k, tw)
ks.k10 = k
ks.d9 = _aesni_imc(ks.k9)
ks.d8 = _aesni_imc(ks.k8)
ks.d7 = _aesni_imc(ks.k7)
ks.d6 = _aesni_imc(ks.k6)
ks.d5 = _aesni_imc(ks.k5)
ks.d4 = _aesni_imc(ks.k4)
ks.d3 = _aesni_imc(ks.k3)
ks.d2 = _aesni_imc(ks.k2)
ks.d1 = _aesni_imc(ks.k1)
ks.d0 = ks.k0
ks.ctr = UInt128(1)
ks.pad = convert(_u128, UInt128(0))
ks
end
function _keyexpand(key, keyh)
a,b,c,d = _unpack(key)
_,_,_,kk = _unpack(keyh)
a_ = xor(a,kk)
b_ = xor(a,b,kk)
c_ = xor(a,b,c,kk)
d_= xor(a,b,c,d,kk)
return _repack((a_,b_,c_,d_))
end
@inline function _unpack(a::_u128)
L=a[1].value
R=a[2].value
map(UInt32,( (L<<32)>>32, L>>32, (R<<32)>>32, R>>32))
end
@inline function _repack(k)
a,b,c,d = k
return u128(UInt64(a) | (UInt64(b)<<32) , UInt64(c) | (UInt64(d)<<32) )
end
@inline _aesni_imc(a::_u128) =
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aesimc(<2 x i64>) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aesimc(<2 x i64> %0)
ret <2 x i64> %res"), _u128, Tuple{_u128}, a)
@inline _aesni_keygenassist_1(a::_u128) =
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64>, i8) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64> %0, i8 1)
ret <2 x i64> %res"), _u128, Tuple{_u128}, a)
@inline _aesni_keygenassist_2(a::_u128) =
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64>, i8) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64> %0, i8 2)
ret <2 x i64> %res"), _u128, Tuple{_u128}, a)
@inline _aesni_keygenassist_3(a::_u128) =
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64>, i8) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64> %0, i8 4)
ret <2 x i64> %res"), _u128, Tuple{_u128}, a)
@inline _aesni_keygenassist_4(a::_u128) =
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64>, i8) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64> %0, i8 8)
ret <2 x i64> %res"), _u128, Tuple{_u128}, a)
@inline _aesni_keygenassist_5(a::_u128) =
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64>, i8) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64> %0, i8 16)
ret <2 x i64> %res"), _u128, Tuple{_u128}, a)
@inline _aesni_keygenassist_6(a::_u128) =
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64>, i8) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64> %0, i8 32)
ret <2 x i64> %res"), _u128, Tuple{_u128}, a)
@inline _aesni_keygenassist_7(a::_u128) =
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64>, i8) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64> %0, i8 64)
ret <2 x i64> %res"), _u128, Tuple{_u128}, a)
@inline _aesni_keygenassist_8(a::_u128) =
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64>, i8) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64> %0, i8 128)
ret <2 x i64> %res"), _u128, Tuple{_u128}, a)
@inline _aesni_keygenassist_9(a::_u128) =
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64>, i8) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64> %0, i8 27)
ret <2 x i64> %res"), _u128, Tuple{_u128}, a)
@inline _aesni_keygenassist_10(a::_u128) =
llvmcall(("declare <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64>, i8) nounwind readnone",
" %res = call <2 x i64> @llvm.x86.aesni.aeskeygenassist(<2 x i64> %0, i8 54)
ret <2 x i64> %res"), _u128, Tuple{_u128}, a)
end
#needs include("./aesni.jl")
#use diffusion algorithm to compute a partial canonical labeling.
#I.e.: output of `colorize` depends only on the isomorphism-type of `g`. Simple use:
#=
julia> for (N,k) in [(100, 10), (1000, 10), (1_000, 50), (10_000, 100)]
g1 = LightGraphs.SimpleGraphs.random_regular_graph(N, k);
perm = shuffle(1:nv(g1)); iperm = collect(1:nv(g1)); for i=1:nv(g1) iperm[perm[i]] = i end;
g2 = relabel(g1, perm)
@show N,k
@time c1 = Equipartitions.colorize(g1; debug=false, use_triangles=true);
@time c2 = Equipartitions.colorize(g2; debug=false, use_triangles=true);
@assert c1.global_color == c2.global_color
@assert c1.color_map == c2.color_map[perm]
end;
=#
#After convergence, the total graph and each vertex have an UInt128-color.
#In many cases without nontrivial automorphism group, each vertex has a distinct color.
#Then, we have a full canonical labeling.
#Based on a random oracle. For reproducibility, this is AES with hard-coded default key.
#The diffusion is fast. Use of triangles is very slow.
#Todo: Figure out a faster kernel that can disambiguate regular graphs.
#Todo: Write a variant for directed graphs (should only need a different kernel).
module Equipartitions
using LightGraphs
struct Color_class
contents::UInt128
end
@inline Base.:+(h1::Color_class, h2::Color_class) = Color_class(h1.contents + h2.contents)
@inline Base.hash(color::Color_class) = color.contents % UInt64
#
struct _Dict{DT} <:Function
parent::DT
end
@inline (d::_Dict)(key) = d.parent[key]
mutable struct Colorize_state{HF, G<:AbstractGraph}
done_vertices::BitSet
color_map::Vector{Color_class}
color_count::Dict{Color_class, Int}
update_cols::Dict{Int64, Color_class}
update_counter::Dict{Color_class, Int}
mod_last::Vector{Int}
mod_persistent::BitSet
global_color::Color_class
use_triangles::Bool
graph::G
_hash_fun::HF
end
function increment!(d::AbstractDict, k, v)
if haskey(d, k)
d[k] += v
else
d[k] = v
end
end
function increment!(d::AbstractDict, default, k, v)
if haskey(d, k)
d[k] += v
else
d[k] = default[k] + v
end
end
#hf needs to be an approximately random function UInt128->UInt128. AES encryption fits the bill.
function Colorize_state(g::AbstractGraph, hf = AESni.AES_ks(0xfc969d39fd4ba27a97ae525b1c8d982f); init = i->0, use_triangles = false)
done_vertices = BitSet()
color_map = Vector{Color_class}(undef, nv(g))
color_count = Dict{Color_class, Int}()
done_vertices = BitSet()
color_map = Vector{Color_class}(undef, nv(g))
color_count = Dict{Color_class, Int}()
update_cols = Dict{Int64, Color_class}()
update_counter = Dict{Color_class, Int}()
mod_last = collect(1:nv(g))
mod_persistent = BitSet(1:nv(g))
global_color = Color_class(hf( UInt128(nv(g))<<64 + ne(g)) )
for v = 1:nv(g)
col_ = (degree(g, v)<<1 | has_edge(g, v, v)) | UInt128(UInt64(init(v))) << 64
col = Color_class(hf(col_))
increment!(color_count, col, 1)
color_map[v] = col
global_color += col
end
for v=1:nv(g)
if color_count[color_map[v]] == 1
push!(done_vertices, v)
end
end
return Colorize_state(done_vertices, color_map, color_count, update_cols,
update_counter, mod_last, mod_persistent, global_color, use_triangles, g, hf)
end
function _update_colors(state::Colorize_state, kernel::kernelT) where kernelT
g = state.graph
done_vertices = state.done_vertices
color_map = state.color_map
color_count = state.color_count
update_cols = state.update_cols
update_counter = state.update_counter
mod_last = state.mod_last
mod_persistent = state.mod_persistent
empty!(update_counter)
empty!(update_cols)
for v in mod_last
kernel(state, v)
end
for (v, col_up) in update_cols
if color_count[color_map[v]] == 1
push!(done_vertices, v)
end
increment!(update_counter, col_up, 1)
end
for (v,col) in update_cols
if update_counter[col] == color_count[color_map[v]]
delete!(update_cols, v)
end
end
empty!(mod_last)
for (vertex, col_up) in update_cols
#generate new color.
new_col = Color_class(state._hash_fun(col_up.contents))
#update counters
col_old = color_map[vertex]
r = increment!(color_count, col_old, -1)
r == 0 && delete!(color_count, col_old)
if update_counter[col_up] == 1
push!(done_vertices, vertex)
end
increment!(color_count, new_col, 1)
state.global_color += new_col
color_map[vertex] = new_col
push!(mod_persistent, vertex)
push!(mod_last, vertex)
end
return state
end
function color_show(s::Colorize_state)
d = Dict{Color_class, Vector{Int}}()
for i in 1:nv(s.graph)
k = s.color_map[i]
if haskey(d, k)
push!(d[k], i)
else
d[k]=[i]
end
end
println()
for (k,v) in d
length(v)>1 && @show k, length(v), v
end
@show length(s.done_vertices), nv(s.graph)-length(s.done_vertices), length(s.mod_last), length(s.mod_persistent)
println()
end
function edge_kernel(s::Colorize_state, v)
g = s.graph
for u in neighbors(g, v)
((u in s.done_vertices) || u == v) && continue
increment!(s.update_cols, s.color_map, u, s.color_map[v])
end
nothing
end
function triangle_kernel(s::Colorize_state, v)
g = s.graph
v in s.done_vertices && return nothing
for u in neighbors(g, v)
((u in s.done_vertices) || u == v) && continue
for x in common_neighbors(g, u, v)
((x in s.done_vertices) || u == v || x == u || x == v) && continue
new_col = Color_class(s._hash_fun(s.color_map[u].contents + s.color_map[v].contents + s.color_map[x].contents))
increment!(s.update_cols, s.color_map, u, new_col)
increment!(s.update_cols, s.color_map, v, new_col)
increment!(s.update_cols, s.color_map, x, new_col)
end
end
nothing
end
function clear_done(s::Colorize_state)
for i=1:length(s.color_map)
if s.color_count[s.color_map[i]] == 1
push!(s.done_vertices, i)
end
end
end
function colorize(g::AbstractGraph, hf = AESni.AES_ks(0xfc969d39fd4ba27a97ae525b1c8d982f); use_triangles = false, init = i->0, debug = false)
s = Colorize_state(g, hf; use_triangles=use_triangles, init=init)
debug && color_show(s)
while true
while length(s.mod_last)>0
_update_colors(s, edge_kernel)
debug && color_show(s)
end
if use_triangles && length(s.mod_persistent) > 0 && length(s.done_vertices) < nv(s.graph)
clear_done(s)
for i in s.mod_persistent
push!(s.mod_last, i)
end
empty!(s.mod_persistent)
_update_colors(s, triangle_kernel)
debug && color_show(s)
else
break
end
end
return s
end
function Base.iterate(c::Colorize_state)
return (c.global_color, nothing)
end
function Base.iterate(s::Colorize_state, iter_state)
if length(s.mod_last)>0
_update_colors(s, edge_kernel)
if length(s.mod_last)>0
return (s.global_color, nothing)
end
end
if s.use_triangles && length(s.mod_persistent)>0
clear_done(s)
for i in s.mod_persistent
push!(s.mod_last, i)
end
empty!(s.mod_persistent)
_update_colors(s, triangle_kernel)
if length(s.mod_last)>0
return (s.global_color, nothing)
else
return nothing
end
end
return nothing
end
function relabel(g, perm)
res = LightGraphs.SimpleGraph(nv(g))
res.ne = g.ne
for i=1:nv(g)
res.fadjlist[perm[i]] = sort(perm[g.fadjlist[i]])
end
res
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment