Last active
February 24, 2021 04:04
-
-
Save GiggleLiu/a6d2bed21731fa344f4d7c1660f35952 to your computer and use it in GitHub Desktop.
Tropical BLAS
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using TropicalNumbers, VectorizationBase | |
using Test | |
using LoopVectorization, Octavian | |
using VectorizationBase: OffsetPrecalc, StaticBool, Bit, static, NativeTypes, Index, gep_quote, VectorIndex | |
function distance(a::AbstractArray{<:Tropical}, b::AbstractArray{<:Tropical}) | |
sum(abs.(content.(a) .- content.(b))) | |
end | |
LoopVectorization.check_args(::Type{T}, ::Type{T}) where T<:Tropical = true | |
@inline VectorizationBase.vstore!(ptr::VectorizationBase.StridedPointer{T}, v::T) where {T<:Tropical} = vstore!(ptr, content(v)) | |
@inline function VectorizationBase.vstore!( | |
ptr::Ptr{Tropical{T}}, v::Tropical{Vec{N,T}}, i::VectorIndex{W}, m::VectorizationBase.AbstractSIMDVector{W}, a::A, s::S, nt::NT, si::StaticInt{RS}) where {T,W,S<:StaticBool,A<:StaticBool,NT<:StaticBool,RS,N} | |
vstore!(convert(Ptr{T}, ptr), content(v), i, m, a, s, nt, si) | |
end | |
@inline function VectorizationBase.vstore!( | |
ptr::Ptr{Tropical{T}}, v::Tropical{Vec{N,T}}, m::VectorizationBase.AbstractSIMDVector{W}, a::A, s::S, nt::NT, si::StaticInt{RS}) where {T,W,S<:StaticBool,A<:StaticBool,NT<:StaticBool,RS,N} | |
vstore!(convert(Ptr{T}, ptr), content(v), m, a, s, nt, si) | |
end | |
@inline function VectorizationBase.vload(ptr::Ptr{Tropical{T}}, i::I, m::Mask, a::A, si::StaticInt{RS}) where {A <: StaticBool, T <: NativeTypes, I <: Index, RS} | |
Tropical(vload(Ptr{T}(ptr), i, m, a, si)) | |
end | |
@inline function VectorizationBase.vload(ptr::Ptr{Tropical{T}}, i::I, a::A, si::StaticInt{RS}) where {A <: StaticBool, T <: NativeTypes, I <: Index, RS} | |
Tropical(vload(Ptr{T}(ptr), i, a, si)) | |
end | |
@inline function VectorizationBase.vbroadcast(a::Union{Val{W},StaticInt{W}}, s::Tropical{T}) where {W,T} | |
Tropical(VectorizationBase.vbroadcast(a, content(s))) | |
end | |
@inline function VectorizationBase.stridedpointer(A::AbstractArray{T}) where {T <: Tropical} | |
stridedpointer(VectorizationBase.memory_reference(A), VectorizationBase.contiguous_axis(A), | |
VectorizationBase.contiguous_batch_size(A), VectorizationBase.val_stride_rank(A), | |
VectorizationBase.bytestrides(A), VectorizationBase.offsets(A)) | |
end | |
using Base.Cartesian: @nexprs | |
@generated function VectorizationBase.fma(x::Tropical{Vec{N,T}}, y::Tropical{Vec{N,T}}, z::Tropical{Vec{N,T}}) where {N,T} | |
Expr(:call, :Tropical, Expr(:call, :Vec, [:(max(content(z).data[$i].value, content(x).data[$i].value+content(y).data[$i].value)) for i=1:N]...)) | |
end | |
@inline function VectorizationBase.stridedpointer( | |
ptr::Ptr{T}, ::StaticInt{C}, ::StaticInt{B}, ::Val{R}, strd::X, offsets::O | |
) where {T<:Tropical,C,B,R,N,X<:Tuple{Vararg{Integer,N}},O<:Tuple{Vararg{Integer,N}}} | |
VectorizationBase.StridedPointer{T,N,C,B,R,X,O}(ptr, strd, offsets) | |
end | |
function VectorizationBase._vzero(in1::StaticInt{W}, ::Type{T}, in3::StaticInt{RS}) where {W,T<:Tropical{FT},RS} where FT | |
Tropical(VectorizationBase._vzero(in1, FT, in3)) | |
end | |
@inline function VectorizationBase.similar_no_offset(sptr::OffsetPrecalc{T}, ptr::Ptr{Tropical{T}}) where {T} | |
OffsetPrecalc(VectorizationBase.similar_no_offset(getfield(sptr, :ptr), ptr), getfield(sptr, :precalc)) | |
end | |
LoopVectorization.check_type(::Type{Tropical{T}}) where {T} = LoopVectorization.check_type(T) | |
@inline VectorizationBase.gep(ptr::Ptr{Tropical{T}}, i) where T = Ptr{Tropical{T}}(VectorizationBase.gep(Ptr{T}(ptr), i)) | |
@testset "mydot" begin | |
function mydot(a,b) | |
s = zero(promote_type(eltype(a),eltype(b))) | |
@avx for i in eachindex(a,b) | |
s += a[i]*b[i] | |
end | |
s | |
end | |
a = Tropical.(randn(10)) | |
b = Tropical.(randn(10)) | |
@test mydot(a, b) ≈ transpose(a) * b | |
@test LoopVectorization.check_args(TropicalF64, TropicalF64) | |
end | |
# TODO: FIX!!!!!! | |
function Base.promote(a::Int, b::Tropical{Vec{4,Float64}}) | |
elem = a == 0 ? -Inf : 0.0 | |
Tropical(Vec(elem, elem, elem, elem)), b | |
end | |
function Base.promote(a::Int, b::Tropical{Vec{4,Float64}}, c::Tropical{Vec{4,Float64}}) | |
elem = a == 0 ? -Inf : 0.0 | |
Tropical(Vec(elem, elem, elem, elem)), b, c | |
end | |
@testset "matmul" begin | |
for n in [4, 20] | |
a = Tropical.(randn(n, n)) | |
b = Tropical.(randn(n, n)) | |
@show distance(Octavian.matmul_serial(a, b), a*b) | |
@show distance(Octavian.matmul_serial(a, a), a*a) | |
@show distance(Octavian.matmul(a, b), a*b) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment