Skip to content

Instantly share code, notes, and snippets.

@veelenga
Last active November 2, 2021 19:02
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save veelenga/6057bdef7227bb4a23fcdd2394e0abec to your computer and use it in GitHub Desktop.
Save veelenga/6057bdef7227bb4a23fcdd2394e0abec to your computer and use it in GitHub Desktop.
Flattening array in elixir
def flatten(list), do: flatten(list, []) |> Enum.reverse
def flatten([h | t], acc) when h == [], do: flatten(t, acc)
def flatten([h | t], acc) when is_list(h), do: flatten(t, flatten(h, acc))
def flatten([h | t], acc), do: flatten(t, [h | acc])
def flatten([], acc), do: acc
@mareksuscak
Copy link

It's not necessary to reverse an array. You can use list concatenation to append head on line 5, instead of prepending.

def flatten(list), do: flatten(list, [])
def flatten([h | t], acc) when h == [], do: flatten(t, acc)
def flatten([h | t], acc) when is_list(h), do: flatten(t, flatten(h, acc))
def flatten([h | t], acc), do: flatten(t, acc ++ [h])
def flatten([], acc), do: acc

@veelenga
Copy link
Author

veelenga commented Sep 1, 2017

@mareksuscak ++ is very expensive operation. Check out this benchmark:

defmodule FlattenReverse do
  def flatten(list), do: flatten(list, []) |> Enum.reverse
  def flatten([h | t], acc) when h == [], do: flatten(t, acc)
  def flatten([h | t], acc) when is_list(h), do: flatten(t, flatten(h, acc))
  def flatten([h | t], acc), do: flatten(t, [h | acc])
  def flatten([], acc), do: acc
end

defmodule FlattenAppend do
  def flatten(list), do: flatten(list, [])
  def flatten([h | t], acc) when h == [], do: flatten(t, acc)
  def flatten([h | t], acc) when is_list(h), do: flatten(t, flatten(h, acc))
  def flatten([h | t], acc), do: flatten(t, acc ++ [h])
  def flatten([], acc), do: acc
end

list = List.duplicate(0, 200) |> List.duplicate(200)

{time, _} = :timer.tc fn -> FlattenReverse.flatten(list) end
IO.puts "Flatten reverse took #{time}"

{time, _} = :timer.tc fn -> FlattenAppend.flatten(list) end
IO.puts "Flatten append took #{time}"
Flatten reverse took 2637
Flatten append took 3240488

@chx
Copy link

chx commented Sep 12, 2017

What if you wanted to flatten only one level? I can't figure that out without ++.

  def flatten(list), do: flatten(list, [])
  def flatten([h | t], acc), do: flatten(t, h ++ acc)
  def flatten([], acc), do: acc

Or List.foldl(list, [], &(&1 ++ &2))

@veelenga
Copy link
Author

veelenga commented Sep 13, 2017

@chx ++ works only on lists, so your code (both samples) will fail on this input: [1, [2]].

Here is an adjusted version that allows flattening until we reach the required level (without ++):

defmodule List do
  def flatten(list, depth \\ -2), do: flatten(list, depth + 1, []) |> Enum.reverse
  def flatten(list, 0, acc), do: [list | acc]
  def flatten([h | t], depth, acc) when h == [], do: flatten(t, depth, acc)
  def flatten([h | t], depth, acc) when is_list(h), do: flatten(t, depth, flatten(h, depth - 1, acc))
  def flatten([h | t], depth, acc), do: flatten(t, depth, [h | acc])
  def flatten([], _, acc), do: acc
end
list = [[1], 2, [[3, 4], 5], [[[]]], [[[6]]], 7, 8, []]

List.flatten(list, 0)   # [[1], 2, [[3, 4], 5], [[[]]], [[[6]]], 7, 8]
List.flatten(list, 1)   # [1, 2, [3, 4], 5, [[]], [[6]], 7, 8]
List.flatten(list, 2)   # [1, 2, 3, 4, 5, [6], 7, 8]
List.flatten(list, 3)   # [1, 2, 3, 4, 5, 6, 7, 8]
List.flatten(list)      # [1, 2, 3, 4, 5, 6, 7, 8]

There is one exceptional case:

List.flatten(list, -1)   # [[[1], 2, [[3, 4], 5], [[[]]], [[[6]]], 7, 8, []]]

so you may need to update it a bit :)

@mrinterweb
Copy link

Very late to the party, but this is a shorter answer. My benchmarks put it at about .1x to 2x slower than the FlattenReverse method.

defmodule List do
  def flatten([head | tail]), do: flatten(head) ++ flatten(tail)
  def flatten([]), do: []
  def flatten(head), do: [head]
end

@anphung
Copy link

anphung commented Apr 25, 2019

An improvement to 0.5 speed of FlattenReverse,

def flatten(list), do: flatten(list, [])
def flatten([h | t], acc) when h == [], do: flatten(t, acc)
def flatten([h | t], acc) when is_list(h), do: flatten(h, flatten(t, acc))
def flatten([h | t], acc), do: [h | flatten(t, acc)]
def flatten([], acc), do: acc

@stevenferrer
Copy link

@chx, solved my problem, thanks very much!

@davidsulc
Copy link

For anyone ending up here looking for a way to flatten only the first level in a list without concatenation or multiple list traversals (i.e. @chx's request):

list = [[1], 2, [[3, 4], 5], [[[]]], [[[6]]], 7, 8, []]
Enum.flat_map(list, fn x when is_list(x) -> x; x -> [x] end)
# [1, 2, [3, 4], 5, [[]], [[6]], 7, 8]

@heri16
Copy link

heri16 commented Aug 26, 2021

The above may overflow the stack.
See tail recursion version that I wrote: https://gist.github.com/heri16/e726ee7f335d2ca61bbbb016e6b884e1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment