Skip to content

Instantly share code, notes, and snippets.

@oxinabox
Created August 9, 2023 11:19
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oxinabox/878c16d1f6011f15138dd2cbf690e750 to your computer and use it in GitHub Desktop.
Save oxinabox/878c16d1f6011f15138dd2cbf690e750 to your computer and use it in GitHub Desktop.
VectorPrisms -- a way of viewing any simple nested struct as a vector
module VectorPrisms
function check_compatible(::Type{T}) where T
isconcretetype(T) || error("Type is not fully concrete")
for ft in fieldtypes(T)
check_compatible(ft)
end
end
check_compatible(::Type{<:Array}) = error("Type contains an array")
check_compatible(::Type{Nothing}) = error("Nothing is not supporting.")
"can subtype this to automatically be an AbstractVector"
abstract type AbstractVectorViewable{T} <: AbstractVector{T} end
"Can wuse this to wrap anything else to make it a AbstractVector"
struct PrismView{T, B} <: AbstractVectorViewable{T}
backing::B
end
function PrismView(x::B) where B
check_compatible(B)
T = determine_eltype(B)
return PrismView{T, B}(x)
end
function determine_eltype(::Type{B}) where B
fieldcount(B) == 0 && return B
eltype = Union{}
for ft in fieldtypes(B)
eltype = Union{eltype, determine_eltype(ft)}
end
return eltype
end
function Base.getindex(x::AbstractVectorViewable{T}, ii:Int) where T
ind_remaining = ii
function getsome(v::T)
ind_remaining-=1
if iszero(ind_remaining)
return v
else
return nothing
end
end
function getsome(v)
for field_ind in 1:fieldcount(typeof(v))
r = getsome(getfield(v, field_ind))
isnothing(r) || return r
end
throw(BoundsError())
end
return getsome(x)
end
function Base.size(x::AbstractVectorViewable{T}) where T
size_from(::Type{<:T}) = 1
function size_from(::Type{V}) where V
sum(fieldtypes(V)) do fieldtype
size_from(fieldtype)
end
end
return size_from(x)
end
# Note: this will error if the particular index does not line up with a mutable struct position
# This could be made more powerful using Accessors.jl
function Base.setindex!(x::AbstractVectorViewable{T}, value, ii:Int) where T
ind_remaining = ii
function setsome(v::T)
ind_remaining-=1
if iszero(ind_remaining)
return true
else
return nothing
end
end
function setsome(v)
for field_ind in 1:fieldcount(typeof(v))
set_this = setsome(getfield(v, field_ind))
set_this && setfield(v, field_ind, s)
end
throw(BoundsError())
end
return setsome(x)
end
end # module
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment