Last active
April 11, 2019 21:01
-
-
Save YingboMa/2304f8561551095b8ad18781d763a94c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Base.Broadcast: _broadcast_getindex, preprocess, preprocess_args, Broadcasted, broadcast_unalias, combine_axes, broadcast_shape, check_broadcast_axes, check_broadcast_shape | |
import Base: copyto!, tail, axes | |
struct DiffEqBC{T} | |
x::T | |
end | |
@inline axes(b::DiffEqBC) = axes(b.x) | |
Base.@propagate_inbounds _broadcast_getindex(b::DiffEqBC, i) = _broadcast_getindex(b.x, i) | |
Base.@propagate_inbounds _broadcast_getindex(b::DiffEqBC{<:AbstractArray{<:Any,0}}, i) = b.x[] | |
Base.@propagate_inbounds _broadcast_getindex(b::DiffEqBC{<:AbstractVector}, i) = b.x[i[1]] | |
Base.@propagate_inbounds _broadcast_getindex(b::DiffEqBC{<:AbstractArray}, i) = b.x[i] | |
diffeqbc(x::Array) = DiffEqBC(x) | |
diffeqbc(x) = x | |
@inline combine_axes(A, B, C...) = broadcast_shape(axes(A), combine_axes(B, C...)) | |
@inline combine_axes(A, B) = broadcast_shape(axes(A), axes(B)) | |
@inline check_broadcast_axes(shp, A) = check_broadcast_shape(shp, axes(A)) | |
@inline preprocess(f, dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(f, dest, bc.args), bc.axes) | |
preprocess(f, dest, x) = f(broadcast_unalias(dest, x)) | |
@inline preprocess_args(f, dest, args::Tuple) = (preprocess(f, dest, args[1]), preprocess_args(f, dest, tail(args))...) | |
@inline preprocess_args(f, dest, args::Tuple{Any}) = (preprocess(f, dest, args[1]),) | |
preprocess_args(f, dest, args::Tuple{}) = () | |
@static if VERSION >= v"1.2.0" | |
@eval Base.getindex(A::DiffEqBC, i1::Int) = | |
(Base.@_inline_meta; Core.const_arrayref($(Expr(:boundscheck)), A.x, i1)) | |
@eval Base.getindex(A::DiffEqBC, i1::Int, i2::Int, I::Int...) = | |
(Base.@_inline_meta; Core.const_arrayref($(Expr(:boundscheck)), A.x, i1, i2, I...)) | |
macro aliasscope(body) | |
sym = gensym() | |
quote | |
$(Expr(:aliasscope)) | |
$sym = $(esc(body)) | |
$(Expr(:popaliasscope)) | |
$sym | |
end | |
end | |
end | |
@inline function copyto!(dest::DiffEqBC, bc::Broadcasted) | |
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) | |
bcs′ = preprocess(diffeqbc, dest, bc) | |
dest′ = dest.x | |
@static if VERSION >= v"1.2.0" | |
@aliasscope @simd for I in eachindex(bcs′) | |
@inbounds dest′[I] = bcs′[I] | |
end | |
else | |
@simd for I in eachindex(bcs′) | |
@inbounds dest′[I] = bcs′[I] | |
end | |
end | |
return dest | |
end | |
import Base.Broadcast: broadcasted, broadcastable, combine_styles | |
map_nostop(f, t::Tuple{}) = () | |
map_nostop(f, t::Tuple{Any,}) = (f(t[1]),) | |
map_nostop(f, t::Tuple{Any, Any}) = (f(t[1]), f(t[2])) | |
map_nostop(f, t::Tuple{Any, Any, Any}) = (f(t[1]), f(t[2]), f(t[3])) | |
map_nostop(f, t::Tuple) = (Base.@_inline_meta; (f(t[1]), map_nostop(f,tail(t))...)) | |
@inline function broadcasted(f::Union{typeof(*), typeof(+)}, arg1, arg2, args...) | |
arg1′ = broadcastable(arg1) | |
arg2′ = broadcastable(arg2) | |
args′ = map_nostop(broadcastable, args) | |
broadcasted(combine_styles(arg1′, arg2′, args′...), f, arg1′, arg2′, args′...) | |
end | |
#= | |
using BenchmarkTools | |
function foo(a, b, c, d, e, f) | |
a = diffeqbc(a) | |
@. a = b + 0.1 * (0.2c + 0.3d + 0.4e + 0.5f) | |
nothing | |
end | |
function goo(a, b, c, d, e, f) | |
@assert length(a) == length(b) == length(c) == length(d) == length(e) == length(f) | |
@inbounds for i in eachindex(a) | |
a[i] = b[i] + 0.1 * (0.2c[i] + 0.3d[i] + 0.4e[i] + 0.5f[i]) | |
end | |
nothing | |
end | |
function foo2(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z) | |
a = diffeqbc(a) | |
@. a = b + 0.1 * (0.2c + 0.3d + 0.4e + 0.5f + 0.6g + 0.6h + 0.6i + 0.6j + 0.6k + 0.6l + 0.6m + 0.6n + 0.6o + 0.6p + 0.6q + 0.6r + 0.6s + 0.6t + 0.6u + 0.6v + 0.6w + 0.6x + 0.6y + 0.6z) | |
nothing | |
end | |
function goo2(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z) | |
@assert length(a) == length(b) == length(c) == length(d) == length(e) == length(f) | |
@inbounds for i in eachindex(a) | |
a[i] = b[i] + 0.1 * (0.2c[i] + 0.3d[i] + 0.4e[i] + 0.5f[i] + 0.6g[i] + 0.6h[i] + 0.6i[i] + 0.6j[i] + 0.6k[i] + 0.6l[i] + 0.6m[i] + 0.6n[i] + 0.6o[i] + 0.6p[i] + 0.6q[i] + 0.6r[i] + 0.6s[i] + 0.6t[i] + 0.6u[i] + 0.6v[i] + 0.6w[i] + 0.6x[i] + 0.6y[i] + 0.6z[i]) | |
end | |
nothing | |
end | |
a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z=[rand(1000) for i in 1:26]; | |
@btime foo($a,$b,$c,$d,$e,$f) | |
@btime goo($a,$b,$c,$d,$e,$f) | |
@btime foo2($a, $b, $c, $d, $e, $f, $g, $h, $i, $j, $k, $l, $m, $n, $o, $p, $q, $r, $s, $t, $u, $v, $w, $x, $y, $z) | |
@btime goo2($a, $b, $c, $d, $e, $f, $g, $h, $i, $j, $k, $l, $m, $n, $o, $p, $q, $r, $s, $t, $u, $v, $w, $x, $y, $z) | |
julia> @btime foo($a,$b,$c,$d,$e,$f) | |
431.934 ns (0 allocations: 0 bytes) | |
julia> @btime goo($a,$b,$c,$d,$e,$f) | |
430.662 ns (0 allocations: 0 bytes) | |
julia> @btime foo2($a, $b, $c, $d, $e, $f, $g, $h, $i, $j, $k, $l, $m, $n, $o, $p, $q, $r, $s, $t, $u, $v, $w, $x, $y, $z) | |
12.816 μs (0 allocations: 0 bytes) | |
julia> @btime goo2($a, $b, $c, $d, $e, $f, $g, $h, $i, $j, $k, $l, $m, $n, $o, $p, $q, $r, $s, $t, $u, $v, $w, $x, $y, $z) | |
12.729 μs (0 allocations: 0 bytes) | |
=# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment