Skip to content

Instantly share code, notes, and snippets.

@adam-r-kowalski
Last active January 14, 2019 23:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adam-r-kowalski/e0ee08fcc9dbd9143456c887c2f7c658 to your computer and use it in GitHub Desktop.
Save adam-r-kowalski/e0ee08fcc9dbd9143456c887c2f7c658 to your computer and use it in GitHub Desktop.
Named Array
includet("NamedArray.jl")
ims = randn(6, 96, 96, 3)
named_ims = Demo.NamedArray(ims, (:batch, :height, :width, :channels))
# Demo.NamedArray{Float64, (batch = 6, height = 96, width = 96, channels = 3), 4, Array{Float64, 4}}
size(named_ims)
# (batch = 6, height = 96, width = 96, channels = 3)
ex = randn((height=96, width=96, channels=3))
# Demo.NamedArray{Float64, (height = 96, width = 96, channels = 3), 3, Array{Float64, 3}}
ex2 = randn(Float32, (height=96, width=96, channels=3))
# Demo.NamedArray{Float32, (height = 96, width = 96, channels = 3), 3, Array{Float32, 3}}
ex3 = permutedims(ex, (:height, :channels, :width))
# Demo.NamedArray{Float64, (height = 96, channels = 3, width = 96), 3, Array{Float64, 3}}
module Demo
export NamedArray
import Base: size, getindex, setindex!, IndexStyle,
randn, length, permutedims
struct NamedArray{T, S, N, A <: AbstractArray{T, N}} # <: AbstractArray{T, N}
data::A
end
NamedArray(data::AbstractArray{T, N}, names::NTuple{N, Symbol}) where {T, N} =
NamedArray{T, NamedTuple{names}(size(data)), ndims(data), typeof(data)}(data)
IndexStyle(::NamedArray{T, S, N, A}) where {T, S, N, A} = IndexStyle(A)
size(array::NamedArray{T, S}) where {T, S} = S
getindex(array::NamedArray, i::Int) = array.data[i]
getindex(array::NamedArray, I::Vararg{Int}) = array.data[I...]
setindex!(array::NamedArray, v, i::Int) = array.data[i] = v
setindex!(array::NamedArray, v, I::Vararg{Int}) = array.data[I...] = v
length(array::NamedArray) = length(array.data)
randn(dims::NamedTuple) = NamedArray(randn(values(dims)...), keys(dims))
randn(T::Type, dims::NamedTuple) =
NamedArray(randn(T, values(dims)...), keys(dims))
function permutedims(array::NamedArray{T, S, N, A},
dims::NTuple{N, Symbol}) where {T, S, N, A}
p = tuple(indexin(collect(dims), collect(keys(S)))...)
NamedArray(permutedims(array.data, p), dims)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment