Skip to content

Instantly share code, notes, and snippets.

@andyferris
Last active February 21, 2018 12:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save andyferris/44ce332118d582e52a554739e1b1286b to your computer and use it in GitHub Desktop.
Save andyferris/44ce332118d582e52a554739e1b1286b to your computer and use it in GitHub Desktop.
Implementation of `broadcast` for `AbstractDict`s and `NamedTuple`s
function Base.Dict{K,V}(::Uninitialized, inds) where {K, V}
set = Set{K}(inds)
d = set.dict
n = length(d.keys)
Dict{K,V}(d.slots, d.keys, Vector{V}(uninitialized, n), d.ndel, d.count, d.age, d.idxfloor, d.maxprobe)
end
function Base.Dict{K,V}(::Uninitialized, inds::Set{K}) where {K, V}
d = inds.dict
n = length(d.keys)
Dict{K,V}(copy(d.slots), copy(d.keys), Vector{V}(uninitialized, n), d.ndel, d.count, d.age, d.idxfloor, d.maxprobe)
end
function Base.Dict{K,V}(::Uninitialized, inds::Base.KeySet{<:Dict{K}}) where {K, V}
d = inds.dict
n = length(d.keys)
Dict{K,V}(copy(d.slots), copy(d.keys), Vector{V}(uninitialized, n), d.ndel, d.count, d.age, d.idxfloor, d.maxprobe)
end
## redefine with inlines
@inline Broadcast.combine_indices(A, B...) = Broadcast.broadcast_shape(Broadcast.broadcast_indices(A), Broadcast.combine_indices(B...))
@inline Broadcast.combine_indices(A) = Broadcast.broadcast_indices(A)
##
Broadcast.BroadcastStyle(::Type{<:NamedTuple}) = Broadcast.Style{NamedTuple}()
Broadcast.BroadcastStyle(::Broadcast.Style{NamedTuple}, ::Broadcast.Scalar) = Broadcast.Style{NamedTuple}()
Broadcast.BroadcastStyle(::Broadcast.Style{NamedTuple}, ::Broadcast.Style{Tuple}) = Broadcast.Style{NamedTuple}()
Broadcast.BroadcastStyle(::Broadcast.Style{NamedTuple}, ::Broadcast.AbstractArrayStyle) = Broadcast.Style{NamedTuple}()
struct DictStyle <: Broadcast.BroadcastStyle end
Broadcast.BroadcastStyle(::Type{<:AbstractDict}) = DictStyle()
Broadcast.BroadcastStyle(::DictStyle, ::Broadcast.Scalar) = DictStyle()
Broadcast.BroadcastStyle(::DictStyle, ::Broadcast.Style{Tuple}) = DictStyle()
Broadcast.BroadcastStyle(::DictStyle, ::Broadcast.Style{NamedTuple}) = DictStyle()
Broadcast.BroadcastStyle(::DictStyle, ::Broadcast.AbstractArrayStyle) = DictStyle()
Broadcast.broadcast_similar(f, ::DictStyle, ::Type{ElType}, inds::Tuple{Any}, As...) where {ElType} = Dict{eltype(inds[1]), ElType}(uninitialized, inds[1])
@inline Broadcast.broadcast_indices(::Broadcast.Style{NamedTuple}, ::NamedTuple{names}) where {names} = (names,)
Broadcast.broadcast_indices(::DictStyle, d) = (Base.keys(d),)
Broadcast._bcs1(a::AbstractSet, b::AbstractSet) = Broadcast._bcsm(a, b) ? a : throw(DimensionMismatch("containers could not be broadcast to a common size"))
Broadcast._bcs1(a, b::AbstractSet) = Broadcast._bcsm(a, b) ? b : throw(DimensionMismatch("containers could not be broadcast to a common size"))
Broadcast._bcs1(a::AbstractSet, b) = Broadcast._bcsm(a, b) ? a : throw(DimensionMismatch("containers could not be broadcast to a common size"))
@inline Broadcast._bcs1(a::Tuple{Vararg{Symbol}}, b::Tuple{Vararg{Symbol}}) = Broadcast._bcsm(a, b) ? a : throw(DimensionMismatch("containers could not be broadcast to a common size"))
Broadcast._bcs1(a::Tuple{Vararg{Symbol}}, b::AbstractSet) = Broadcast._bcsm(a, b) ? b : throw(DimensionMismatch("containers could not be broadcast to a common size"))
Broadcast._bcs1(a::AbstractSet, b::Tuple{Vararg{Symbol}}) = Broadcast._bcsm(a, b) ? a : throw(DimensionMismatch("containers could not be broadcast to a common size"))
Broadcast._bcs1(a::Tuple{Vararg{Symbol}}, b) = Broadcast._bcsm(a, b) ? a : throw(DimensionMismatch("containers could not be broadcast to a common size"))
Broadcast._bcs1(a, b::Tuple{Vararg{Symbol}}) = Broadcast._bcsm(a, b) ? b : throw(DimensionMismatch("containers could not be broadcast to a common size"))
Broadcast._bcsm(a::AbstractSet, b::AbstractUnitRange) = length(b) == 1 || issetequal(a, b)
Broadcast._bcsm(a::AbstractUnitRange, b::AbstractSet) = length(a) == 1 || issetequal(a, b)
Broadcast._bcsm(a::AbstractSet, b::AbstractSet) = a == b
Broadcast._bcsm(a::AbstractSet, b::Tuple{Vararg{Symbol}}) = issetequal(a, b)
Broadcast._bcsm(a::Tuple{Vararg{Symbol}}, b::AbstractSet) = issetequal(a, b)
@inline Broadcast._bcsm(a::Tuple{Vararg{Symbol}}, b::Tuple{Vararg{Symbol}}) = _issetequal(a, b)
Base.@pure _issetequal(a::Tuple{Vararg{Symbol}}, b::Tuple{Vararg{Symbol}}) = issetequal(a, b)
Broadcast._broadcast_getindex_eltype(::DictStyle, d) = valtype(d)
@inline function Broadcast.broadcast(f, s::Broadcast.BroadcastStyle, ::Type{ElType}, inds::Tuple{AbstractSet}, As...) where ElType
dest = Broadcast.broadcast_similar(f, s, ElType, inds, As...)
@inbounds broadcast!(f, dest, As...)
end
@inline function Broadcast.broadcast(f, s::Broadcast.Style{NamedTuple}, ::Type{ElType}, inds::Tuple{Tuple{Vararg{Symbol}}}, As...) where ElType
NamedTuple{inds[1]}(_broadcast(f, s, inds[1], As...))
end
@inline function _broadcast(f, s::Broadcast.Style{NamedTuple}, inds::Tuple{Vararg{Symbol}}, As...)
i1 = inds[1]
i_tail = Base.tail(inds)
(f(map(a -> _getindex(a, i1), As)...), _broadcast(f, s, i_tail, As...)...)
end
@inline _broadcast(f, s::Broadcast.Style{NamedTuple}, inds::Tuple{}, As...) = ()
function Broadcast.check_broadcast_indices(out_inds::AbstractSet, d::AbstractDict)
if !issetequal(out_inds, keys(d))
throw(ErrorException("Output broadcast indices $out_inds do not match input $(keys(d))"))
end
end
function Broadcast.check_broadcast_indices(out_inds::AbstractSet, d::AbstractVector)
if !(length(d) === 1 || issetequal(out_inds, keys(d)))
throw(ErrorException("Output broadcast indices $out_inds do not match input $(keys(d))"))
end
end
function Broadcast.check_broadcast_indices(out_inds::AbstractSet, t::Tuple)
if !(length(t) === 1 || issetequal(out_inds, keys(t)))
throw(ErrorException("Output broadcast indices $out_inds do not match input $(keys(t))"))
end
end
function Broadcast.check_broadcast_indices(out_inds::AbstractSet, d::NamedTuple)
if !issetequal(out_inds, keys(d))
throw(ErrorException("Output broadcast indices $out_inds do not match input $(keys(d))"))
end
end
Broadcast.check_broadcast_indices(::AbstractSet, ::Any) = nothing
Broadcast.check_broadcast_indices(::AbstractSet, ::AbstractArray{<:Any,0}) = nothing
Broadcast.check_broadcast_indices(::AbstractSet, ::AbstractArray) = throw(ErrorException("Broadcasting between dictionaries and multidimensional arrays is not supported"))
@inline function Broadcast.broadcast!(f, dest::AbstractDict, ::Broadcast.Scalar, As::Vararg{Any, N}) where N
@inbounds for i in keys(dest)
dest[i] = f(As...)
end
return dest
end
@inline function Broadcast.broadcast!(f, dest::AbstractDict, ::Broadcast.BroadcastStyle, As::Vararg{Any, N}) where N
inds = keys(dest)
@boundscheck map(a -> Broadcast.check_broadcast_indices(inds, a), As)
@inbounds for i in inds
dest[i] = f(map(a -> _getindex(a, i), As)...)
end
return dest
end
@inline function Broadcast.broadcast!(f, dest::AbstractVector, ::DictStyle, As::Vararg{Any, N}) where N
@inbounds for i in keys(dest)
dest[i] = f(map(a -> _getindex(a, i), As)...)
end
return dest
end
@inline _getindex(a, i) = a
@inline _getindex(a::AbstractDict, i) = @inbounds a[i]
@inline function _getindex(a::AbstractVector, i)
if length(a) === 1
return @inbounds first(a)
else
return @inbounds a[i]
end
end
@inline _getindex(a::AbstractArray{<:Any, 0}, i) = @inbounds a[]
@inline _getindex(a::Tuple, i) = @inbounds a[i]
@inline _getindex(a::Tuple{Any}, i) = @inbounds a[1]
@inline _getindex(a::NamedTuple, i) = @inbounds a[i]
using Test
@testset "Broadcast dictionaries" begin
d = Dict(1 => 10, 2 => 20)
# Single argument `broadcast`
@test (d .* 2)::Dict{Int, Int} == Dict(1 => 20, 2 => 40)
@test (d .* 2.0)::Dict{Int, Float64} == Dict(1 => 20.0, 2 => 40.0)
# Two argument `broacast`
x = 2
@test (d .* x)::Dict{Int, Int} == Dict(1 => 20, 2 => 40)
@test (d .* d)::Dict{Int, Int} == Dict(1 => 100, 2 => 400)
@test (d .+ [1, 2])::Dict{Int, Int} == Dict(1 => 11, 2 => 22)
@test (d .+ (1, 2))::Dict{Int, Int} == Dict(1 => 11, 2 => 22)
@test (d .+ [1])::Dict{Int, Int} == Dict(1 => 11, 2 => 21)
@test (d .+ (1,))::Dict{Int, Int} == Dict(1 => 11, 2 => 21)
@test (d .+ fill(1))::Dict{Int, Int} == Dict(1 => 11, 2 => 21) # zero-dimensional array
@test Dict(:a=>1, :b=>2) .+ (a=1, b=2) == Dict(:a=>2, :b=>4)
# Mutating `broadcast!`
d2 = copy(d)
d2 .= 0
@test d2 == Dict(1 => 0, 2 => 0)
d2 .= [1]
@test d2 == Dict(1 => 1, 2 => 1)
d2 .= (2,)
@test d2 == Dict(1 => 2, 2 => 2)
d2 .= Dict(1 => 3, 2 => 4)
@test d2 == Dict(1 => 3, 2 => 4)
d2 .= [5, 6]
@test d2 == Dict(1 => 5, 2 => 6)
d2 .= (7, 8)
@test d2 == Dict(1 => 7, 2 => 8)
a = [0, 0]
a .= d
@test a == [10, 20]
d3 = Dict(:a=>0, :b=>0)
d3 .= (a=1, b=2)
@test d3 == Dict(:a=>1, :b=>2)
end
@testset "Broadcast named tuples" begin
@test (a=1, b=2) .+ 1 === (a=2, b=3)
@test (a=1, b=2) .+ [1] === (a=2, b=3)
@test (a=1, b=2) .+ fill(1) === (a=2, b=3)
@test (a=1, b=2) .+ (1,) === (a=2, b=3)
@test (a=1, b=2) .+ (a=1, b=2) === (a=2, b=4)
@test (a=1, b=2) .+ (b=2, a=1) === (a=2, b=4)
end
@andyferris
Copy link
Author

The biggest remaining issue that I'm aware of is the speed of named tuple broadcasting (constant propagation of the names).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment