Skip to content

Instantly share code, notes, and snippets.

@Tokazama
Created January 13, 2023 17:20
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Tokazama/eba2474f13754cbb3b9fadf5151ebf48 to your computer and use it in GitHub Desktop.
Save Tokazama/eba2474f13754cbb3b9fadf5151ebf48 to your computer and use it in GitHub Desktop.
Reference types from A(tomic) to S(tatic)
module RefTypes
import .Base.Sys: ARCH, WORD_SIZE
export
AtomicRef,
ImmutableRef,
MutableRef,
StaticRef,
add!,
and!,
dec,
dec!,
inc,
inc!,
max!,
min!,
nand!,
or!,
sub!,
swap!,
xor!
_bool(x::Bool) = x
_rettype(::Type{T}, value::T) where {T} = value
struct StaticRef{T,value} <: Ref{T}
StaticRef{T,value}() where {T,value} = new{T,value::T}()
StaticRef{T}(value::T) where {T} = new{T,value}()
StaticRef{T}(value) where {T} = StaticRef{T}(convert(T, value))
StaticRef{T}() where {T} = StaticRef{T}(T())
StaticRef(value::T) where {T} = StaticRef{T}(value)
end
struct ImmutableRef{T} <: Ref{T}
value::T
ImmutableRef{T}(v) where {T} = new{T}(v)
ImmutableRef{T}() where {T} = ImmutableRef{T}(T())
ImmutableRef(v::T) where {T} = ImmutableRef{T}(v)
end
mutable struct MutableRef{T} <: Ref{T}
value::T
MutableRef{T}(v) where {T} = new{T}(v)
MutableRef{T}() where {T} = MutableRef{T}(T())
MutableRef(v::T) where {T} = MutableRef{T}(v)
end
# Filter out unsupported atomic types on platforms
# - 128-bit atomics do not exist on AArch32.
# - Omitting 128-bit types on 32bit x86 and ppc64
# - LLVM doesn't currently support atomics on floats for ppc64
# C++20 is adding limited support for atomics on float, but as of
# now Clang does not support that yet.
if Base.Sys.ARCH === :i686 || startswith(string(Base.Sys.ARCH), "arm") ||
Base.Sys.ARCH === :powerpc64le || Base.Sys.ARCH === :ppc64le
const inttypes = (Int8, Int16, Int32, Int64,
UInt8, UInt16, UInt32, UInt64)
else
const inttypes = (Int8, Int16, Int32, Int64, Int128,
UInt8, UInt16, UInt32, UInt64, UInt128)
end
const floattypes = (Float16, Float32, Float64)
const arithmetictypes = (inttypes..., floattypes...)
# TODO: Support Ptr
if Base.Sys.ARCH === :powerpc64le || Base.Sys.ARCH === :ppc64le
const atomictypes = (inttypes...,)
else
const atomictypes = (arithmetictypes...,)
end
const IntTypes = Union{inttypes...}
const FloatTypes = Union{floattypes...}
const ArithmeticTypes = Union{arithmetictypes...}
const AtomicRefTypes = Union{atomictypes...}
mutable struct AtomicRef{T} <: Ref{T}
@atomic value::T
AtomicRef{T}(value::T) where {T} = new{T}(value)
AtomicRef{T}() where {T} = AtomicRef{T}(T())
AtomicRef(value::T) where {T} = AtomicRef{T}(value)
end
function Base.unsafe_convert(::Type{Ptr{T}}, x::AtomicRef{T}) where {T}
convert(Ptr{T}, pointer_from_objref(x))
end
Base.setindex!(x::AtomicRef{T}, v) where {T} = setindex!(x, convert(T, v))
const llvmtypes = IdDict{Any,String}(
Int8 => "i8", UInt8 => "i8",
Int16 => "i16", UInt16 => "i16",
Int32 => "i32", UInt32 => "i32",
Int64 => "i64", UInt64 => "i64",
Int128 => "i128", UInt128 => "i128",
Float16 => "half",
Float32 => "float",
Float64 => "double",
)
inttype(::Type{T}) where {T<:Integer} = T
inttype(::Type{Float16}) = Int16
inttype(::Type{Float32}) = Int32
inttype(::Type{Float64}) = Int64
import ..Base.gc_alignment
# All atomic operations have acquire and/or release semantics, depending on
# whether the load or store values. Most of the time, this is what one wants
# anyway, and it's only moderately expensive on most hardware.
for typ in atomictypes
lt = llvmtypes[typ]
ilt = llvmtypes[inttype(typ)]
rt = "$lt, $lt*"
irt = "$ilt, $ilt*"
@eval Base.getindex(x::AtomicRef{$typ}) =
GC.@preserve x Base.llvmcall($"""
%ptr = inttoptr i$(WORD_SIZE) %0 to $lt*
%rv = load atomic $rt %ptr acquire, align $(gc_alignment(typ))
ret $lt %rv
""", $typ, Tuple{Ptr{$typ}}, Base.unsafe_convert(Ptr{$typ}, x))
@eval Base.setindex!(x::AtomicRef{$typ}, v::$typ) =
GC.@preserve x Base.llvmcall($"""
%ptr = inttoptr i$(WORD_SIZE) %0 to $lt*
store atomic $lt %1, $lt* %ptr release, align $(gc_alignment(typ))
ret void
""", Cvoid, Tuple{Ptr{$typ}, $typ}, Base.unsafe_convert(Ptr{$typ}, x), v)
# Note: atomic_cas! succeeded (i.e. it stored "new") if and only if the result is "cmp"
if typ <: Integer
@eval cas!(x::AtomicRef{$typ}, cmp::$typ, new::$typ) =
GC.@preserve x Base.llvmcall($"""
%ptr = inttoptr i$(WORD_SIZE) %0 to $lt*
%rs = cmpxchg $lt* %ptr, $lt %1, $lt %2 acq_rel acquire
%rv = extractvalue { $lt, i1 } %rs, 0
ret $lt %rv
""", $typ, Tuple{Ptr{$typ},$typ,$typ},
Base.unsafe_convert(Ptr{$typ}, x), cmp, new)
else
@eval cas!(x::AtomicRef{$typ}, cmp::$typ, new::$typ) =
GC.@preserve x Base.llvmcall($"""
%iptr = inttoptr i$WORD_SIZE %0 to $ilt*
%icmp = bitcast $lt %1 to $ilt
%inew = bitcast $lt %2 to $ilt
%irs = cmpxchg $ilt* %iptr, $ilt %icmp, $ilt %inew acq_rel acquire
%irv = extractvalue { $ilt, i1 } %irs, 0
%rv = bitcast $ilt %irv to $lt
ret $lt %rv
""", $typ, Tuple{Ptr{$typ},$typ,$typ},
Base.unsafe_convert(Ptr{$typ}, x), cmp, new)
end
arithmetic_ops = [:add, :sub]
for rmwop in [arithmetic_ops..., :xchg, :and, :nand, :or, :xor, :max, :min]
rmw = string(rmwop)
fn = Symbol(rmw, "!")
if (rmw == "max" || rmw == "min") && typ <: Unsigned
# LLVM distinguishes signedness in the operation, not the integer type.
rmw = "u" * rmw
end
if rmwop in arithmetic_ops && !(typ <: ArithmeticTypes) continue end
if typ <: Integer
@eval $fn(x::AtomicRef{$typ}, v::$typ) =
GC.@preserve x Base.llvmcall($"""
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
%rv = atomicrmw $rmw $lt* %ptr, $lt %1 acq_rel
ret $lt %rv
""", $typ, Tuple{Ptr{$typ}, $typ}, Base.unsafe_convert(Ptr{$typ}, x), v)
else
rmwop === :xchg || continue
@eval $fn(x::AtomicRef{$typ}, v::$typ) =
GC.@preserve x Base.llvmcall($"""
%iptr = inttoptr i$WORD_SIZE %0 to $ilt*
%ival = bitcast $lt %1 to $ilt
%irv = atomicrmw $rmw $ilt* %iptr, $ilt %ival acq_rel
%rv = bitcast $ilt %irv to $lt
ret $lt %rv
""", $typ, Tuple{Ptr{$typ}, $typ}, Base.unsafe_convert(Ptr{$typ}, x), v)
end
end
end
# Provide atomic floating-point operations via atomic_cas!
const opnames = Dict{Symbol, Symbol}(:+ => :add, :- => :sub)
for op in [:+, :-, :max, :min]
opname = get(opnames, op, op)
@eval function $(Symbol(opname, "!"))(var::AtomicRef{T}, val::T) where T<:FloatTypes
IT = inttype(T)
old = var[]
while true
new = $op(old, val)
cmp = old
old = cas!(var, cmp, new)
reinterpret(IT, old) == reinterpret(IT, cmp) && return old
# Temporary solution before we have gc transition support in codegen.
ccall(:jl_gc_safepoint, Cvoid, ())
end
end
end
const RefType{T} = Union{StaticRef{T},ImmutableRef{T},MutableRef{T},AtomicRef{T}}
# ImmutableRef
function Base.promote_rule(::Type{ImmutableRef{X}}, ::Type{<:StaticRef{Y}}) where {X,Y}
ImmutableRef{promote_type(X, Y)}
end
function Base.promote_rule(::Type{ImmutableRef{X}}, ::Type{AtomicRef{Y}}) where {X,Y}
ImmutableRef{promote_type(X, Y)}
end
function Base.promote_rule(::Type{ImmutableRef{X}}, ::Type{MutableRef{Y}}) where {X,Y}
ImmutableRef{promote_type(X, Y)}
end
# MutableRef
function Base.promote_rule(::Type{MutableRef{X}}, ::Type{<:ImmutableRef{Y}}) where {X,Y}
MutableRef{promote_type(X, Y)}
end
function Base.promote_rule(::Type{MutableRef{X}}, ::Type{AtomicRef{Y}}) where {X,Y}
AtomicRef{promote_type(X, Y)}
end
# AtomicRef
function Base.promote_rule(::Type{AtomicRef{X}}, ::Type{<:StaticRef{Y}}) where {X,Y}
AtomicRef{promote_type(X, Y)}
end
Base.convert(::Type{ImmutableRef{T}}, v::ImmutableRef{T}) where {T} = v
Base.convert(::Type{ImmutableRef{T}}, v::Ref) where {T} = ImmutableRef{T}(v[])
Base.convert(::Type{MutableRef{T}}, v::MutableRef{T}) where {T} = v
Base.convert(::Type{MutableRef{T}}, v::Ref) where {T} = MutableRef{T}(v[])
Base.convert(::Type{AtomicRef{T}}, v::AtomicRef{T}) where {T} = v
Base.convert(::Type{AtomicRef{T}}, v::Ref) where {T} = AtomicRef{T}(v[])
Base.convert(::Type{<:StaticRef{T}}, v::StaticRef{T}) where {T} = v
Base.convert(::Type{<:StaticRef{T}}, v::Ref) where {T} = StaticRef{T}(v[])
Base.eltype(::Type{<:RefType{T}}) where {T} = T
for f in (:<, :<=, :>, :>=, :(==), :isequal, :isless)
eval(:(Base.$(f)(x::RefType, y::RefType) = _bool(Base.$(f)(x[], y[]))))
end
# replace!
function replace!(x::Union{AtomicRef{T},MutableRef{T}}, expected, desired) where {T}
replace!(x, convert(T, expected), convert(T, desired))
end
function replace!(x::MutableRef{T}, expected::T, desired::T) where {T}
getfield(replacefield!(x, 1, expected, desired), 1, false)
end
function replace!(x::AtomicRef{T}, expected::T, desired::T) where {T}
getfield(replacefield!(x, 1, expected, desired, :sequentially_consistent), 1, false)
end
# TODO doc swap!
function swap!(x::Union{AtomicRef{T},MutableRef{T}}, newval) where {T}
swap!(x, convert(T, newval))
end
function swap!(x::MutableRef{T}, newval::T) where {T}
swapfield!(x, 1, newval)
end
function swap!(x::AtomicRef{T}, newval::T) where {T<:IntTypes}
xchg!(x, newval)
end
function swap!(x::AtomicRef{T}, newval::T) where {T}
swapfield!(x, 1, newval, :sequentially_consistent)
end
# TODO modify!
function modify!(r::MutableRef{T}, op, x) where {T}
modifyfield!(r, 1, op, x)
return r
end
function modify!(r::AtomicRef{T}, op, x) where {T}
modifyfield!(r, 1, op, x, :sequentially_consistent)
return r
end
# TODO doc sub!
sub!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = sub!(x, convert(T, y))
function sub!(x::MutableRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, -, y, :sequentially_consistent), 1, false)
end
function sub!(x::AtomicRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, -, y, :sequentially_consistent), 1, false)
end
# TODO doc add!
add!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = add!(x, convert(T, y))
function add!(x::MutableRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, +, y), 1, false)
end
function add!(x::AtomicRef{T}, y::T) where {T}
getfield!(modifyfield!(x, 1, +, y, :sequentially_consistent), 1, false)
end
# TODO doc or!
or!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = or!(x, convert(T, y))
function or!(x::MutableRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, |, y), 1, false)
end
function or!(x::AtomicRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, |, y), 1, false)
end
# TODO doc xor!
xor!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = xor!(x, convert(T, y))
function xor!(x::MutableRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, ⊻, y), 1, false)
end
function xor!(x::AtomicRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, ⊻, y, :sequentially_consistent), 1, false)
end
# TODO doc and!
and!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = and!(x, convert(T, y))
function and!(x::MutableRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, &, y), 1, false)
end
function and!(x::AtomicRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, &, y, :sequentially_consistent), 1, false)
end
# TODO doc nand!
nand!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = nand!(x, convert(T, y))
function nand!(x::MutableRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, ⊼, y), 1, false)
end
function nand!(x::AtomicRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, ⊼, y, :sequentially_consistent), 1, false)
end
# TODO doc max!
max!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = max!(x, convert(T, y))
function max!(x::MutableRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, max, y), 1, false)
end
function max!(x::AtomicRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, max, y, :sequentially_consistent), 1, false)
end
# TODO doc min!
min!(x::Union{MutableRef{T},AtomicRef{T}}, y) where {T} = min!(x, convert(T, y))
function min!(x::MutableRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, min, y), 1, false)
end
function min!(x::AtomicRef{T}, y::T) where {T}
getfield(modifyfield!(x, 1, min, y, :sequentially_consistent), 1, false)
end
# TODO document inc
inc(x::T) where {T} = x + one(T)
# TODO document inc!
inc!(x::Union{AtomicRef{T},MutableRef{T}}) where {T} = add!(x, one(T))
# TODO document dec
dec(x::T) where {T} = x - one(T)
# TODO document dec!
dec!(x::Union{AtomicRef{T},MutableRef{T}}) where {T} = sub!(x, one(T))
Base.getindex(::StaticRef{T,value}) where {T,value} = _rettype(T, value)
Base.getindex(x::ImmutableRef{T}) where {T} = getfield(x, 1)
Base.getindex(x::MutableRef{T}) where {T} = getfield(x, 1)
Base.getindex(x::AtomicRef{T}) where {T} = getfield(x, 1, :sequentially_consistent)
function Base.setindex!(x::Union{MutableRef{T},AtomicRef{T}}, newval) where {T}
setindex!(x, convert(T, newval))
end
Base.setindex!(x::MutableRef{T}, newval::T) where {T} = setfield!(x, 1, newval)
function Base.setindex!(x::AtomicRef{T}, newval::T) where {T}
setfield!(x, 1, newval, :sequentially_consistent)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment