Last active
November 15, 2023 14:45
-
-
Save kalmarek/e7ba0413eabc78a9f3360b248f457089 to your computer and use it in GitHub Desktop.
MA.add_mul!! with Arbs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Revise | |
using Arblib | |
import MutableArithmetics | |
const MA = MutableArithmetics | |
using LinearAlgebra | |
MA.mutability(::Type{Arb}) = MA.IsMutable() | |
MA.mutable_copy(x::Arb) = set!(Arb(prec=precision(x)), x) | |
MA.promote_operation(::typeof(zero), ::Type{Arb}) = Arb | |
MA.promote_operation(::typeof(one), ::Type{Arb}) = Arb | |
MA.operate!(::typeof(zero), x::Arb) = Arblib.set!(x, 0) | |
MA.operate!(::typeof(one), x::Arb) = Arblib.set!(x, 1) | |
MA.promote_operation(::typeof(+), ::Type{Arb}, ::Type{Arb}) = Arb | |
MA.promote_operation(::typeof(-), ::Vararg{Type{Arb},N}) where {N} = Arb | |
MA.operate_to!(res::Arb, ::typeof(+), a::Arb, b::Arb) = Arblib.add!(res, a, b) | |
MA.operate_to!(res::Arb, ::typeof(-), a::Arb, b::Arb) = Arblib.sub!(res, a, b) | |
MA.operate_to!(res::Arb, ::typeof(-), a::Arb) = Arblib.neg!(res, a) | |
MA.operate!(::typeof(Base.abs), x::Arb) = Arblib.abs!(x, x) | |
MA.promote_operation(::typeof(*), ::Type{Arb}, ::Type{Arb}) = Arb | |
MA.operate_to!(res::Arb, ::typeof(*), a::Arb, b::Arb) = Arblib.mul!(res, a, b) | |
MA.promote_operation(::typeof(Base.fma), ::Type{Arb}, ::Type{Arb}, ::Type{Arb}) = Arb | |
MA.operate_to!(res::Arb, ::typeof(Base.fma), x::Arb, y::Arb, z::Arb) = return Arblib.fma!(res, x, y, z) | |
MA.operate!(::typeof(Base.fma), x::Arb, y::Arb, z::Arb) = MA.operate_to!(x, Base.fma, x, y, z) | |
# Base.muladd | |
MA.promote_operation( | |
::typeof(Base.muladd), | |
::Type{Arb}, | |
::Type{Arb}, | |
::Type{Arb}, | |
) = Arb | |
MA.operate_to!( | |
output::Arb, | |
::typeof(Base.muladd), | |
x::Arb, | |
y::Arb, | |
z::Arb, | |
) = operate_to!(output, Base.fma, x, y, z) | |
MA.operate!(::typeof(Base.muladd), x::Arb, y::Arb, z::Arb) = MA.operate!(Base.fma, x, y, z) | |
function MA.operate_to!( | |
output::Arb, | |
op::Union{typeof(+),typeof(-),typeof(*)}, | |
a::Arb, | |
b::Arb, | |
c::Vararg{Arb,N}, | |
) where {N} | |
MA.operate_to!(output, op, a, b) | |
return MA.operate!(op, output, c...) | |
end | |
function MA.operate!(op::Function, x::Arb, args::Vararg{Any,N}) where {N} | |
return MA.operate_to!(x, op, x, args...) | |
end | |
# add_mul and sub_mul | |
# Buffer to hold the product | |
MA.buffer_for(::MA.AddSubMul, args::Vararg{Type{Arb},N}) where {N} = Arb() | |
function MA.operate_to!( | |
output::Arb, | |
op::MA.AddSubMul, | |
x::Arb, | |
y::Arb, | |
z::Arb, | |
args::Vararg{Arb,N}, | |
) where {N} | |
return buffered_operate_to!(Arb(prec=_precision(x, y, z, args...)), output, op, x, y, z, args...) | |
end | |
function MA.buffered_operate_to!( | |
buffer::Arb, | |
output::Arb, | |
op::MA.AddSubMul, | |
a::Arb, | |
x::Arb, | |
y::Arb, | |
args::Vararg{Arb,N}, | |
) where {N} | |
MA.operate_to!(buffer, *, x, y, args...) | |
return MA.operate_to!(output, MA.add_sub_op(op), a, buffer) | |
end | |
function MA.buffered_operate!( | |
buffer::Arb, | |
op::MA.AddSubMul, | |
x::Arb, | |
args::Vararg{Any,N}, | |
) where {N} | |
return MA.buffered_operate_to!(buffer, x, op, x, args...) | |
end | |
MA._scaling_to_bigfloat(x) = _scaling_to(BigFloat, x) | |
function MA.operate_to!( | |
output::Arb, | |
op::Union{typeof(+),typeof(-),typeof(*)}, | |
args::Vararg{MA.Scaling,N}, | |
) where {N} | |
return MA.operate_to!(output, op, MA._scaling_to.(Ref(Arb), args)...) | |
end | |
function MA.operate_to!( | |
output::Arb, | |
op::MA.AddSubMul, | |
x::MA.Scaling, | |
y::MA.Scaling, | |
z::MA.Scaling, | |
args::Vararg{MA.Scaling,N}, | |
) where {N} | |
return MA.operate_to!( | |
output, | |
op, | |
MA._scaling_to(Arb, x), | |
MA._scaling_to(Arb, y), | |
MA._scaling_to(Arb, z), | |
MA._scaling_to.(Ref(Arb), args)..., | |
) | |
end | |
# Called for instance if `args` is `(v', v)` for a vector `v`. | |
function MA.operate_to!( | |
output::Arb, | |
op::MA.AddSubMul, | |
x, | |
y, | |
z, | |
args::Vararg{Any,N}, | |
) where {N} | |
return MA.operate_to!(output, MA.add_sub_op(op), x, *(y, z, args...)) | |
end | |
function MA.buffer_for(::typeof(LinearAlgebra.dot), ::Type{V}, ::Type{V}) where {V<:AbstractVector{Arb}} | |
Arb() | |
end | |
using BenchmarkTools | |
n = 256 | |
A = rand(n, n) | |
b = rand(n) | |
c = rand(n) | |
# MA.mul works for arbitrary types | |
MA.mul(A, b) | |
big_trials = let A = BigFloat.(A), b = BigFloat.(b), c = BigFloat.(c) | |
trial1 = @benchmark LinearAlgebra.mul!($c, $A, $b) | |
display(trial1) | |
trial2 = @benchmark MA.add_mul!!($c, $A, $b) | |
display(trial2) | |
buffer = MA.buffer_for(MA.add_mul, typeof(c), typeof(A), typeof(b)) | |
trial3 = @benchmark MA.buffered_operate!!($buffer, MA.add_mul, $c, $A, $b) | |
display(trial3) | |
trial1, trial2, trial3 | |
end | |
arb_trials = let A = Arb.(A), b = Arb.(b), c = Arb.(c) | |
trial1 = @benchmark LinearAlgebra.mul!($c, $A, $b) | |
display(trial1) | |
trial2 = @benchmark MA.add_mul!!($c, $A, $b) | |
display(trial2) | |
buffer = MA.buffer_for(MA.add_mul, typeof(c), typeof(A), typeof(b)) | |
trial3 = @benchmark MA.buffered_operate!!($buffer, MA.add_mul, $c, $A, $b) | |
display(trial3) | |
trial1, trial2, trial3 | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment