Skip to content

Instantly share code, notes, and snippets.

@Hiroshiba
Created September 22, 2017 08:22
Show Gist options
  • Save Hiroshiba/cc7bbd84f52004b36c1a440b51ed1588 to your computer and use it in GitHub Desktop.
Save Hiroshiba/cc7bbd84f52004b36c1a440b51ed1588 to your computer and use it in GitHub Desktop.
concatenate recursive numpy arary
import numpy
def concat_recursive(batch, newaxis=False):
"""
>>> from pprint import pprint
>>> onedata = numpy.arange(3).reshape(1, -1)
>>> batch = [onedata] * 4
>>> pprint(concat_recursive(batch))
array([[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]])
>>> batch = [{'key1': onedata, 'key2': onedata+3}] * 4
>>> pprint(concat_recursive(batch))
{'key1': array([[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]]),
'key2': array([[3, 4, 5],
[3, 4, 5],
[3, 4, 5],
[3, 4, 5]])}
>>> batch = [{'nest1': {'nest2': onedata}}] * 4
>>> pprint(concat_recursive(batch))
{'nest1': {'nest2': array([[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]])}}
>>> onedata = numpy.arange(3)
>>> batch = [onedata] * 4
>>> pprint(concat_recursive(batch, newaxis=True))
array([[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]])
>>> pprint(concat_recursive(batch, newaxis=False))
array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2])
"""
first = batch[0]
if isinstance(first, tuple) or isinstance(first, list):
return tuple([concat_recursive([example[i] for example in batch]) for i in range(len(first))])
elif isinstance(first, dict):
return {key: concat_recursive([example[key] for example in batch]) for key in first}
else:
return numpy.concatenate([
array if not newaxis else array[numpy.newaxis]
for array in batch
])
if __name__ == "__main__":
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment