Skip to content

Instantly share code, notes, and snippets.

@LukasDrude
Created December 2, 2016 11:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save LukasDrude/26d143a87ce6fb8b1c5c1c1297fef75f to your computer and use it in GitHub Desktop.
Save LukasDrude/26d143a87ce6fb8b1c5c1c1297fef75f to your computer and use it in GitHub Desktop.
import sys
import unittest
import numpy as np
import nt.testing as tc
from nt.utils.numpy_utils import reshape
def _normalize(op):
op = op.replace(',', '')
op = op.replace(' ', '')
op = ' '.join(c for c in op)
op = op.replace(' * ', '*')
op = op.replace('- >', '->')
return op
def _only_reshape(array, source, target):
source, target = source.split(), target.replace(' * ', '*').split()
input_shape = {key: array.shape[index] for index, key in enumerate(source)}
output_shape = []
for t in target:
product = 1
if not t == '1':
t = t.split('*')
for t_ in t:
product *= input_shape[t_]
output_shape.append(product)
return array.reshape(output_shape)
def reshape(array, operation):
""" This is an experimental version of a generalized reshape.
See test cases for examples.
"""
operation = _normalize(operation)
if '*' in operation.split('->')[0]:
raise NotImplementedError(
'Unflatten operation not supported by design. '
'Actual values for dimensions are not available to this function.'
)
# Initial squeeze
squeeze_operation = operation.split('->')[0].split()
for axis, op in reversed(list(enumerate(squeeze_operation))):
if op == '1':
array = np.squeeze(array, axis=axis)
# Transpose
transposition_operation = operation.replace('1', ' ').replace('*', ' ')
try:
array = np.einsum(transposition_operation, array)
except ValueError as e:
msg = 'op: {}, shape: {}'.format(transposition_operation,
np.shape(array))
if len(e.args) == 1:
e.args = (e.args[0]+'\n\n'+msg,)
else:
print(msg)
raise
# Final reshape
source = transposition_operation.split('->')[-1]
target = operation.split('->')[-1]
return _only_reshape(array, source, target)
T, B, F = 40, 6, 51
A = np.random.uniform(size=(T, B, F))
A2 = np.random.uniform(size=(T, 1, B, F))
A3 = np.random.uniform(size=(T*B*F,))
A4 = np.random.uniform(size=(T, 1, 1, B, 1, F))
class TestReshape(unittest.TestCase):
def test_noop_comma(self):
result = reshape(A, 'T,B,F->T,B,F')
tc.assert_equal(result.shape, (T, B, F))
tc.assert_equal(result, A)
def test_noop_space(self):
result = reshape(A, 'T B F->T B F')
tc.assert_equal(result.shape, (T, B, F))
tc.assert_equal(result, A)
def test_noop_mixed(self):
result = reshape(A, 'tbf->t, b f')
tc.assert_equal(result.shape, (T, B, F))
tc.assert_equal(result, A)
def test_transpose_comma(self):
result = reshape(A, 'T,B,F->F,T,B')
tc.assert_equal(result.shape, (F, T, B))
tc.assert_equal(result, A.transpose(2, 0, 1))
def test_transpose_mixed(self):
result = reshape(A, 't, b, f -> f t b')
tc.assert_equal(result.shape, (F, T, B))
tc.assert_equal(result, A.transpose(2, 0, 1))
def test_broadcast_axis_0(self):
result = reshape(A, 'T,B,F->1,T,B,F')
tc.assert_equal(result.shape, (1, T, B, F))
tc.assert_equal(result, A[None, ...])
def test_broadcast_axis_2(self):
result = reshape(A, 'T,B,F->T,B,1,F')
tc.assert_equal(result.shape, (T, B, 1, F))
tc.assert_equal(result, A[..., None, :])
def test_broadcast_axis_3(self):
result = reshape(A, 'T,B,F->T,B,F,1')
tc.assert_equal(result.shape, (T, B, F, 1))
tc.assert_equal(result, A[..., None])
def test_reshape_comma(self):
result = reshape(A, 'T,B,F->T,B*F')
tc.assert_equal(result.shape, (T, B*F))
tc.assert_equal(result, A.reshape(T, B*F))
def test_reshape_comma_unflatten(self):
with tc.assert_raises(NotImplementedError):
reshape(A3, 't*b*f->t, b, f')
def test_reshape_comma_unflatten_and_transpose_and_flatten(self):
with tc.assert_raises(NotImplementedError):
reshape(A3, 't*b*f->f, t*b')
def test_reshape_comma_flat(self):
result = reshape(A, 'T,B,F->T*B*F')
tc.assert_equal(result.shape, (T*B*F,))
tc.assert_equal(result, A.ravel())
def test_reshape_comma_with_singleton_input(self):
result = reshape(A2, 'T, 1, B, F -> T*B*F')
tc.assert_equal(result.shape, (T*B*F,))
tc.assert_equal(result, A2.ravel())
def test_reshape_comma_with_a_lot_of_singleton_inputs(self):
result = reshape(A4, 'T, 1, 1, B, 1, F -> T*B*F')
tc.assert_equal(result.shape, (T*B*F,))
tc.assert_equal(result, A4.ravel())
def test_reshape_and_broadcast(self):
tc.assert_equal(reshape(A, 'T,B,F->T,1,B*F').shape, (T, 1, B*F))
tc.assert_equal(reshape(A, 'T,B,F->T,1,B*F').ravel(), A.ravel())
def test_reshape_and_broadcast_many(self):
result = reshape(A, 'T,B,F->1,T,1,B*F,1')
tc.assert_equal(result.shape, (1, T, 1, B*F, 1))
def test_swap_and_reshape(self):
result = reshape(A, 'T,B,F->T,F*B')
tc.assert_equal(result.shape, (T, F * B))
tc.assert_equal(result, A.swapaxes(-1, -2).reshape(T, F * B))
def test_transpose_and_reshape(self):
result = reshape(A, 'T,B,F->F,B*T')
tc.assert_equal(result.shape, (F, B*T))
tc.assert_equal(result, A.transpose(2, 1, 0).reshape(F, B*T))
def test_all_comma(self):
tc.assert_equal(reshape(A, 'T,B,F->F,1,B*T').shape, (F, 1, B*T))
def test_all_space(self):
tc.assert_equal(reshape(A, 't b f -> f1b*t').shape, (F, 1, B*T))
suite = unittest.TestLoader().loadTestsFromTestCase(TestReshape)
_ = unittest.TextTestRunner(verbosity=1,stream=sys.stderr).run(suite)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment