Skip to content

Instantly share code, notes, and snippets.

@Oblynx
Last active August 20, 2020 19:49
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Oblynx/3ccacabf5f873c51eecb49fb6cc136bf to your computer and use it in GitHub Desktop.
Save Oblynx/3ccacabf5f873c51eecb49fb6cc136bf to your computer and use it in GitHub Desktop.
Flatten multidimensional nested array in Julia
## SUMMARY ##
# Let `a` be the nested array and `aref` the multidim array we want to recreate:
julia> flatten(x::Array{<:Array,1})= Iterators.flatten(x)|> collect|> flatten
flatten (generic function with 1 method)
julia> flatten(x::Array{<:Number,1})= x
flatten (generic function with 2 methods)
julia> reshape(flatten(a), (4,4,3)) == aref
true
## ANALYSIS ##
# Someone wants to create a 3-dimensional array with an array comprehension:
julia> aref = [i * j * k for i in 1:4, j in 2:5, k in 1:3]
4×4×3 Array{Int64,3}:
[:, :, 1] =
2 3 4 5
4 6 8 10
6 9 12 15
8 12 16 20
[:, :, 2] =
4 6 8 10
8 12 16 20
12 18 24 30
16 24 32 40
[:, :, 3] =
6 9 12 15
12 18 24 30
18 27 36 45
24 36 48 60
# But let's say that, instead, they end up with this:
julia> a = [[[i * j * k for i in 1:4] for j in 2:5] for k in 1:3]
3-element Array{Array{Array{Int64,1},1},1}:
[[2, 4, 6, 8], [3, 6, 9, 12], [4, 8, 12, 16], [5, 10, 15, 20]]
[[4, 8, 12, 16], [6, 12, 18, 24], [8, 16, 24, 32], [10, 20, 30, 40]]
[[6, 12, 18, 24], [9, 18, 27, 36], [12, 24, 36, 48], [15, 30, 45, 60]]
# Maybe they made a syntactic mistake. Or maybe, more crucially,
# *this is the format of input data from elsewhere and they have to deal with it*.
# So how can we start from `a` and end up with `aref`, using just Julia Base?
# Notice that multidimensional arrays in Julia are just reshaped vectors. So if we can
# simply put the values of a in order, provided we know the dimensions we want to end up
# with, we're done:
julia> vec(aref)
48-element Array{Int64,1}:
2
4
6
8
3
6
9
12
4
8
12
16
5
10
15
20
4
8
12
16
6
12
18
24
8
16
24
32
10
20
30
40
6
12
18
24
9
18
27
36
12
24
36
48
15
30
45
60
julia> reshape(vec(aref), (4,4,3))
4×4×3 Array{Int64,3}:
[:, :, 1] =
2 3 4 5
4 6 8 10
6 9 12 15
8 12 16 20
[:, :, 2] =
4 6 8 10
8 12 16 20
12 18 24 30
16 24 32 40
[:, :, 3] =
6 9 12 15
12 18 24 30
18 27 36 45
24 36 48 60
# OK, so now how do we take `a` and put its values in order to get `vec(aref)`?
# This is a nested array, so recursion to the rescue. Let's set up a recursive scheme
# using Julia's multiple dispatch.
# This is the general case:
julia> flatten(x::Array{<:Array,1})= Iterators.flatten(x)|> collect|> flatten
flatten (generic function with 1 method)
# This is the end case:
julia> flatten(x::Array{<:Number,1})= x
flatten (generic function with 2 methods)
# So what will happen now if we call our `flatten` function with `a`?
# (note: `flatten` and `Iterators.flatten` are completely separate functions)
julia> flatten(a)
48-element Array{Int64,1}:
2
4
6
8
3
6
9
12
4
8
12
16
5
10
15
20
4
8
12
16
6
12
18
24
8
16
24
32
10
20
30
40
6
12
18
24
9
18
27
36
12
24
36
48
15
30
45
60
# We created `vec(aref)`. All that's left is to reshape:
julia> reshape(flatten(a), (4,4,3))
4×4×3 Array{Int64,3}:
[:, :, 1] =
2 3 4 5
4 6 8 10
6 9 12 15
8 12 16 20
[:, :, 2] =
4 6 8 10
8 12 16 20
12 18 24 30
16 24 32 40
[:, :, 3] =
6 9 12 15
12 18 24 30
18 27 36 45
24 36 48 60
# And for completeness, let's compare our result with `aref`:
julia> reshape(flatten(a), (4,4,3)) == aref
true
@jaantollander
Copy link

Just what I was searching and very elegantly written. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment