Created
December 2, 2016 11:56
-
-
Save LukasDrude/26d143a87ce6fb8b1c5c1c1297fef75f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import unittest | |
import numpy as np | |
import nt.testing as tc | |
from nt.utils.numpy_utils import reshape | |
def _normalize(op): | |
op = op.replace(',', '') | |
op = op.replace(' ', '') | |
op = ' '.join(c for c in op) | |
op = op.replace(' * ', '*') | |
op = op.replace('- >', '->') | |
return op | |
def _only_reshape(array, source, target): | |
source, target = source.split(), target.replace(' * ', '*').split() | |
input_shape = {key: array.shape[index] for index, key in enumerate(source)} | |
output_shape = [] | |
for t in target: | |
product = 1 | |
if not t == '1': | |
t = t.split('*') | |
for t_ in t: | |
product *= input_shape[t_] | |
output_shape.append(product) | |
return array.reshape(output_shape) | |
def reshape(array, operation): | |
""" This is an experimental version of a generalized reshape. | |
See test cases for examples. | |
""" | |
operation = _normalize(operation) | |
if '*' in operation.split('->')[0]: | |
raise NotImplementedError( | |
'Unflatten operation not supported by design. ' | |
'Actual values for dimensions are not available to this function.' | |
) | |
# Initial squeeze | |
squeeze_operation = operation.split('->')[0].split() | |
for axis, op in reversed(list(enumerate(squeeze_operation))): | |
if op == '1': | |
array = np.squeeze(array, axis=axis) | |
# Transpose | |
transposition_operation = operation.replace('1', ' ').replace('*', ' ') | |
try: | |
array = np.einsum(transposition_operation, array) | |
except ValueError as e: | |
msg = 'op: {}, shape: {}'.format(transposition_operation, | |
np.shape(array)) | |
if len(e.args) == 1: | |
e.args = (e.args[0]+'\n\n'+msg,) | |
else: | |
print(msg) | |
raise | |
# Final reshape | |
source = transposition_operation.split('->')[-1] | |
target = operation.split('->')[-1] | |
return _only_reshape(array, source, target) | |
T, B, F = 40, 6, 51 | |
A = np.random.uniform(size=(T, B, F)) | |
A2 = np.random.uniform(size=(T, 1, B, F)) | |
A3 = np.random.uniform(size=(T*B*F,)) | |
A4 = np.random.uniform(size=(T, 1, 1, B, 1, F)) | |
class TestReshape(unittest.TestCase): | |
def test_noop_comma(self): | |
result = reshape(A, 'T,B,F->T,B,F') | |
tc.assert_equal(result.shape, (T, B, F)) | |
tc.assert_equal(result, A) | |
def test_noop_space(self): | |
result = reshape(A, 'T B F->T B F') | |
tc.assert_equal(result.shape, (T, B, F)) | |
tc.assert_equal(result, A) | |
def test_noop_mixed(self): | |
result = reshape(A, 'tbf->t, b f') | |
tc.assert_equal(result.shape, (T, B, F)) | |
tc.assert_equal(result, A) | |
def test_transpose_comma(self): | |
result = reshape(A, 'T,B,F->F,T,B') | |
tc.assert_equal(result.shape, (F, T, B)) | |
tc.assert_equal(result, A.transpose(2, 0, 1)) | |
def test_transpose_mixed(self): | |
result = reshape(A, 't, b, f -> f t b') | |
tc.assert_equal(result.shape, (F, T, B)) | |
tc.assert_equal(result, A.transpose(2, 0, 1)) | |
def test_broadcast_axis_0(self): | |
result = reshape(A, 'T,B,F->1,T,B,F') | |
tc.assert_equal(result.shape, (1, T, B, F)) | |
tc.assert_equal(result, A[None, ...]) | |
def test_broadcast_axis_2(self): | |
result = reshape(A, 'T,B,F->T,B,1,F') | |
tc.assert_equal(result.shape, (T, B, 1, F)) | |
tc.assert_equal(result, A[..., None, :]) | |
def test_broadcast_axis_3(self): | |
result = reshape(A, 'T,B,F->T,B,F,1') | |
tc.assert_equal(result.shape, (T, B, F, 1)) | |
tc.assert_equal(result, A[..., None]) | |
def test_reshape_comma(self): | |
result = reshape(A, 'T,B,F->T,B*F') | |
tc.assert_equal(result.shape, (T, B*F)) | |
tc.assert_equal(result, A.reshape(T, B*F)) | |
def test_reshape_comma_unflatten(self): | |
with tc.assert_raises(NotImplementedError): | |
reshape(A3, 't*b*f->t, b, f') | |
def test_reshape_comma_unflatten_and_transpose_and_flatten(self): | |
with tc.assert_raises(NotImplementedError): | |
reshape(A3, 't*b*f->f, t*b') | |
def test_reshape_comma_flat(self): | |
result = reshape(A, 'T,B,F->T*B*F') | |
tc.assert_equal(result.shape, (T*B*F,)) | |
tc.assert_equal(result, A.ravel()) | |
def test_reshape_comma_with_singleton_input(self): | |
result = reshape(A2, 'T, 1, B, F -> T*B*F') | |
tc.assert_equal(result.shape, (T*B*F,)) | |
tc.assert_equal(result, A2.ravel()) | |
def test_reshape_comma_with_a_lot_of_singleton_inputs(self): | |
result = reshape(A4, 'T, 1, 1, B, 1, F -> T*B*F') | |
tc.assert_equal(result.shape, (T*B*F,)) | |
tc.assert_equal(result, A4.ravel()) | |
def test_reshape_and_broadcast(self): | |
tc.assert_equal(reshape(A, 'T,B,F->T,1,B*F').shape, (T, 1, B*F)) | |
tc.assert_equal(reshape(A, 'T,B,F->T,1,B*F').ravel(), A.ravel()) | |
def test_reshape_and_broadcast_many(self): | |
result = reshape(A, 'T,B,F->1,T,1,B*F,1') | |
tc.assert_equal(result.shape, (1, T, 1, B*F, 1)) | |
def test_swap_and_reshape(self): | |
result = reshape(A, 'T,B,F->T,F*B') | |
tc.assert_equal(result.shape, (T, F * B)) | |
tc.assert_equal(result, A.swapaxes(-1, -2).reshape(T, F * B)) | |
def test_transpose_and_reshape(self): | |
result = reshape(A, 'T,B,F->F,B*T') | |
tc.assert_equal(result.shape, (F, B*T)) | |
tc.assert_equal(result, A.transpose(2, 1, 0).reshape(F, B*T)) | |
def test_all_comma(self): | |
tc.assert_equal(reshape(A, 'T,B,F->F,1,B*T').shape, (F, 1, B*T)) | |
def test_all_space(self): | |
tc.assert_equal(reshape(A, 't b f -> f1b*t').shape, (F, 1, B*T)) | |
suite = unittest.TestLoader().loadTestsFromTestCase(TestReshape) | |
_ = unittest.TextTestRunner(verbosity=1,stream=sys.stderr).run(suite) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment