Skip to content

Instantly share code, notes, and snippets.

@YingboMa
Last active April 11, 2019 21:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YingboMa/2304f8561551095b8ad18781d763a94c to your computer and use it in GitHub Desktop.
Save YingboMa/2304f8561551095b8ad18781d763a94c to your computer and use it in GitHub Desktop.
import Base.Broadcast: _broadcast_getindex, preprocess, preprocess_args, Broadcasted, broadcast_unalias, combine_axes, broadcast_shape, check_broadcast_axes, check_broadcast_shape
import Base: copyto!, tail, axes
struct DiffEqBC{T}
x::T
end
@inline axes(b::DiffEqBC) = axes(b.x)
Base.@propagate_inbounds _broadcast_getindex(b::DiffEqBC, i) = _broadcast_getindex(b.x, i)
Base.@propagate_inbounds _broadcast_getindex(b::DiffEqBC{<:AbstractArray{<:Any,0}}, i) = b.x[]
Base.@propagate_inbounds _broadcast_getindex(b::DiffEqBC{<:AbstractVector}, i) = b.x[i[1]]
Base.@propagate_inbounds _broadcast_getindex(b::DiffEqBC{<:AbstractArray}, i) = b.x[i]
diffeqbc(x::Array) = DiffEqBC(x)
diffeqbc(x) = x
@inline combine_axes(A, B, C...) = broadcast_shape(axes(A), combine_axes(B, C...))
@inline combine_axes(A, B) = broadcast_shape(axes(A), axes(B))
@inline check_broadcast_axes(shp, A) = check_broadcast_shape(shp, axes(A))
@inline preprocess(f, dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(f, dest, bc.args), bc.axes)
preprocess(f, dest, x) = f(broadcast_unalias(dest, x))
@inline preprocess_args(f, dest, args::Tuple) = (preprocess(f, dest, args[1]), preprocess_args(f, dest, tail(args))...)
@inline preprocess_args(f, dest, args::Tuple{Any}) = (preprocess(f, dest, args[1]),)
preprocess_args(f, dest, args::Tuple{}) = ()
@static if VERSION >= v"1.2.0"
@eval Base.getindex(A::DiffEqBC, i1::Int) =
(Base.@_inline_meta; Core.const_arrayref($(Expr(:boundscheck)), A.x, i1))
@eval Base.getindex(A::DiffEqBC, i1::Int, i2::Int, I::Int...) =
(Base.@_inline_meta; Core.const_arrayref($(Expr(:boundscheck)), A.x, i1, i2, I...))
macro aliasscope(body)
sym = gensym()
quote
$(Expr(:aliasscope))
$sym = $(esc(body))
$(Expr(:popaliasscope))
$sym
end
end
end
@inline function copyto!(dest::DiffEqBC, bc::Broadcasted)
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
bcs′ = preprocess(diffeqbc, dest, bc)
dest′ = dest.x
@static if VERSION >= v"1.2.0"
@aliasscope @simd for I in eachindex(bcs′)
@inbounds dest′[I] = bcs′[I]
end
else
@simd for I in eachindex(bcs′)
@inbounds dest′[I] = bcs′[I]
end
end
return dest
end
import Base.Broadcast: broadcasted, broadcastable, combine_styles
map_nostop(f, t::Tuple{}) = ()
map_nostop(f, t::Tuple{Any,}) = (f(t[1]),)
map_nostop(f, t::Tuple{Any, Any}) = (f(t[1]), f(t[2]))
map_nostop(f, t::Tuple{Any, Any, Any}) = (f(t[1]), f(t[2]), f(t[3]))
map_nostop(f, t::Tuple) = (Base.@_inline_meta; (f(t[1]), map_nostop(f,tail(t))...))
@inline function broadcasted(f::Union{typeof(*), typeof(+)}, arg1, arg2, args...)
arg1′ = broadcastable(arg1)
arg2′ = broadcastable(arg2)
args′ = map_nostop(broadcastable, args)
broadcasted(combine_styles(arg1′, arg2′, args′...), f, arg1′, arg2′, args′...)
end
#=
using BenchmarkTools
function foo(a, b, c, d, e, f)
a = diffeqbc(a)
@. a = b + 0.1 * (0.2c + 0.3d + 0.4e + 0.5f)
nothing
end
function goo(a, b, c, d, e, f)
@assert length(a) == length(b) == length(c) == length(d) == length(e) == length(f)
@inbounds for i in eachindex(a)
a[i] = b[i] + 0.1 * (0.2c[i] + 0.3d[i] + 0.4e[i] + 0.5f[i])
end
nothing
end
function foo2(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z)
a = diffeqbc(a)
@. a = b + 0.1 * (0.2c + 0.3d + 0.4e + 0.5f + 0.6g + 0.6h + 0.6i + 0.6j + 0.6k + 0.6l + 0.6m + 0.6n + 0.6o + 0.6p + 0.6q + 0.6r + 0.6s + 0.6t + 0.6u + 0.6v + 0.6w + 0.6x + 0.6y + 0.6z)
nothing
end
function goo2(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z)
@assert length(a) == length(b) == length(c) == length(d) == length(e) == length(f)
@inbounds for i in eachindex(a)
a[i] = b[i] + 0.1 * (0.2c[i] + 0.3d[i] + 0.4e[i] + 0.5f[i] + 0.6g[i] + 0.6h[i] + 0.6i[i] + 0.6j[i] + 0.6k[i] + 0.6l[i] + 0.6m[i] + 0.6n[i] + 0.6o[i] + 0.6p[i] + 0.6q[i] + 0.6r[i] + 0.6s[i] + 0.6t[i] + 0.6u[i] + 0.6v[i] + 0.6w[i] + 0.6x[i] + 0.6y[i] + 0.6z[i])
end
nothing
end
a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z=[rand(1000) for i in 1:26];
@btime foo($a,$b,$c,$d,$e,$f)
@btime goo($a,$b,$c,$d,$e,$f)
@btime foo2($a, $b, $c, $d, $e, $f, $g, $h, $i, $j, $k, $l, $m, $n, $o, $p, $q, $r, $s, $t, $u, $v, $w, $x, $y, $z)
@btime goo2($a, $b, $c, $d, $e, $f, $g, $h, $i, $j, $k, $l, $m, $n, $o, $p, $q, $r, $s, $t, $u, $v, $w, $x, $y, $z)
julia> @btime foo($a,$b,$c,$d,$e,$f)
431.934 ns (0 allocations: 0 bytes)
julia> @btime goo($a,$b,$c,$d,$e,$f)
430.662 ns (0 allocations: 0 bytes)
julia> @btime foo2($a, $b, $c, $d, $e, $f, $g, $h, $i, $j, $k, $l, $m, $n, $o, $p, $q, $r, $s, $t, $u, $v, $w, $x, $y, $z)
12.816 μs (0 allocations: 0 bytes)
julia> @btime goo2($a, $b, $c, $d, $e, $f, $g, $h, $i, $j, $k, $l, $m, $n, $o, $p, $q, $r, $s, $t, $u, $v, $w, $x, $y, $z)
12.729 μs (0 allocations: 0 bytes)
=#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment