Skip to content

Instantly share code, notes, and snippets.

@GiggleLiu
Last active February 24, 2021 04:04
Show Gist options
  • Save GiggleLiu/a6d2bed21731fa344f4d7c1660f35952 to your computer and use it in GitHub Desktop.
Save GiggleLiu/a6d2bed21731fa344f4d7c1660f35952 to your computer and use it in GitHub Desktop.
Tropical BLAS
using TropicalNumbers, VectorizationBase
using Test
using LoopVectorization, Octavian
using VectorizationBase: OffsetPrecalc, StaticBool, Bit, static, NativeTypes, Index, gep_quote, VectorIndex
function distance(a::AbstractArray{<:Tropical}, b::AbstractArray{<:Tropical})
sum(abs.(content.(a) .- content.(b)))
end
LoopVectorization.check_args(::Type{T}, ::Type{T}) where T<:Tropical = true
@inline VectorizationBase.vstore!(ptr::VectorizationBase.StridedPointer{T}, v::T) where {T<:Tropical} = vstore!(ptr, content(v))
@inline function VectorizationBase.vstore!(
ptr::Ptr{Tropical{T}}, v::Tropical{Vec{N,T}}, i::VectorIndex{W}, m::VectorizationBase.AbstractSIMDVector{W}, a::A, s::S, nt::NT, si::StaticInt{RS}) where {T,W,S<:StaticBool,A<:StaticBool,NT<:StaticBool,RS,N}
vstore!(convert(Ptr{T}, ptr), content(v), i, m, a, s, nt, si)
end
@inline function VectorizationBase.vstore!(
ptr::Ptr{Tropical{T}}, v::Tropical{Vec{N,T}}, m::VectorizationBase.AbstractSIMDVector{W}, a::A, s::S, nt::NT, si::StaticInt{RS}) where {T,W,S<:StaticBool,A<:StaticBool,NT<:StaticBool,RS,N}
vstore!(convert(Ptr{T}, ptr), content(v), m, a, s, nt, si)
end
@inline function VectorizationBase.vload(ptr::Ptr{Tropical{T}}, i::I, m::Mask, a::A, si::StaticInt{RS}) where {A <: StaticBool, T <: NativeTypes, I <: Index, RS}
Tropical(vload(Ptr{T}(ptr), i, m, a, si))
end
@inline function VectorizationBase.vload(ptr::Ptr{Tropical{T}}, i::I, a::A, si::StaticInt{RS}) where {A <: StaticBool, T <: NativeTypes, I <: Index, RS}
Tropical(vload(Ptr{T}(ptr), i, a, si))
end
@inline function VectorizationBase.vbroadcast(a::Union{Val{W},StaticInt{W}}, s::Tropical{T}) where {W,T}
Tropical(VectorizationBase.vbroadcast(a, content(s)))
end
@inline function VectorizationBase.stridedpointer(A::AbstractArray{T}) where {T <: Tropical}
stridedpointer(VectorizationBase.memory_reference(A), VectorizationBase.contiguous_axis(A),
VectorizationBase.contiguous_batch_size(A), VectorizationBase.val_stride_rank(A),
VectorizationBase.bytestrides(A), VectorizationBase.offsets(A))
end
using Base.Cartesian: @nexprs
@generated function VectorizationBase.fma(x::Tropical{Vec{N,T}}, y::Tropical{Vec{N,T}}, z::Tropical{Vec{N,T}}) where {N,T}
Expr(:call, :Tropical, Expr(:call, :Vec, [:(max(content(z).data[$i].value, content(x).data[$i].value+content(y).data[$i].value)) for i=1:N]...))
end
@inline function VectorizationBase.stridedpointer(
ptr::Ptr{T}, ::StaticInt{C}, ::StaticInt{B}, ::Val{R}, strd::X, offsets::O
) where {T<:Tropical,C,B,R,N,X<:Tuple{Vararg{Integer,N}},O<:Tuple{Vararg{Integer,N}}}
VectorizationBase.StridedPointer{T,N,C,B,R,X,O}(ptr, strd, offsets)
end
function VectorizationBase._vzero(in1::StaticInt{W}, ::Type{T}, in3::StaticInt{RS}) where {W,T<:Tropical{FT},RS} where FT
Tropical(VectorizationBase._vzero(in1, FT, in3))
end
@inline function VectorizationBase.similar_no_offset(sptr::OffsetPrecalc{T}, ptr::Ptr{Tropical{T}}) where {T}
OffsetPrecalc(VectorizationBase.similar_no_offset(getfield(sptr, :ptr), ptr), getfield(sptr, :precalc))
end
LoopVectorization.check_type(::Type{Tropical{T}}) where {T} = LoopVectorization.check_type(T)
@inline VectorizationBase.gep(ptr::Ptr{Tropical{T}}, i) where T = Ptr{Tropical{T}}(VectorizationBase.gep(Ptr{T}(ptr), i))
@testset "mydot" begin
function mydot(a,b)
s = zero(promote_type(eltype(a),eltype(b)))
@avx for i in eachindex(a,b)
s += a[i]*b[i]
end
s
end
a = Tropical.(randn(10))
b = Tropical.(randn(10))
@test mydot(a, b) ≈ transpose(a) * b
@test LoopVectorization.check_args(TropicalF64, TropicalF64)
end
# TODO: FIX!!!!!!
function Base.promote(a::Int, b::Tropical{Vec{4,Float64}})
elem = a == 0 ? -Inf : 0.0
Tropical(Vec(elem, elem, elem, elem)), b
end
function Base.promote(a::Int, b::Tropical{Vec{4,Float64}}, c::Tropical{Vec{4,Float64}})
elem = a == 0 ? -Inf : 0.0
Tropical(Vec(elem, elem, elem, elem)), b, c
end
@testset "matmul" begin
for n in [4, 20]
a = Tropical.(randn(n, n))
b = Tropical.(randn(n, n))
@show distance(Octavian.matmul_serial(a, b), a*b)
@show distance(Octavian.matmul_serial(a, a), a*a)
@show distance(Octavian.matmul(a, b), a*b)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment