Created
July 12, 2022 22:35
-
-
Save p-geon/a80c07d5496a99bfe1c3ffbdb48b46dc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print('tf.__version__', tf.__version__) | |
z_1 = [_*2 for _ in range(10)] | |
z_2 = [_*2+1 for _ in range(10)] | |
z_1 = tf.convert_to_tensor(z_1, dtype=tf.float32) | |
z_2 = tf.convert_to_tensor(z_2, dtype=tf.float32) | |
@tf.function | |
def aabb2abab(z_1: tf.Tensor, z_2: tf.Tensor) -> tf.Tensor: | |
"""pixel shuffler 1D-like operation (without tf.nn.depth_to_space) | |
[a, c, e], [b, d, f] -> [a, b, c, d, e, f] | |
""" | |
base_length = tf.size(z_1) | |
tf.Assert(tf.size(z_1)==tf.size(z_2), [z_1, z_2]) | |
z = tf.concat([z_1, z_2], axis=0) | |
tf.print(z) | |
z = tf.reshape(z, [2, base_length]) | |
z = tf.experimental.numpy.swapaxes(z, 0, 1) | |
z = tf.reshape(z, [2*base_length, 1]) | |
z = tf.experimental.numpy.swapaxes(z, 1, 0) | |
return z | |
print("Inputs: ", z_1, z_2) | |
z = aabb2abab(z_1, z_2) | |
print("outputs: ", z) |
Author
p-geon
commented
Jul 12, 2022
@tf.function
def tensor_altanatery_on_batch_2d(z_1: tf.Tensor, z_2: tf.Tensor) -> tf.Tensor:
"""pixel shuffler 1D-like operation (without tf.nn.depth_to_space)
[a, c, e], [b, d, f] -> [a, b, c, d, e, f]
"""
tf.Assert(z_1.shape==z_2.shape, [z_1, z_2])
z_1 = tf.expand_dims(z_1, -1, name=None) # (B, C) -> (B, C, 1)
z_2 = tf.expand_dims(z_2, -1, name=None) # (B, C) -> (B, C, 1)
z = tf.concat([z_1, z_2], axis=-1) # -> (B, C, 2)
z = tf.experimental.numpy.swapaxes(z, 0, 1) # -> (C, B, 2)
z = tf.reshape(z, [z_1.shape[1], 2*z_1.shape[0]]) # -> (C, 2B)
z = tf.experimental.numpy.swapaxes(z, 1, 0) # -> (2B, C)
return z
def run():
n = 3
dim = 10
z_1 = tf.convert_to_tensor([[2*(i+j*n) for j in range(dim)] for i in range(n)])
z_2 = tf.convert_to_tensor([[2*(i+j*n)+1 for j in range(dim)] for i in range(n)])
print("z_1: \n", z_1)
print("z_2: \n", z_2)
z = mix_tensor(z_1, z_2)
tf.print("result: \n", z)
# tensor check
tf.Assert(z.shape[0]==(z_1.shape[0]+z_2.shape[0]), [z_1, z_2, z])
tf.Assert(z.shape[1]==z_1.shape[1], [z_1, z_2, z])
tf.Assert(z.shape[1]==z_2.shape[1], [z_1, z_2, z])
run()
result
z_1:
tf.Tensor(
[[ 0 6 12 18 24 30 36 42 48 54]
[ 2 8 14 20 26 32 38 44 50 56]
[ 4 10 16 22 28 34 40 46 52 58]], shape=(3, 10), dtype=int32)
z_2:
tf.Tensor(
[[ 1 7 13 19 25 31 37 43 49 55]
[ 3 9 15 21 27 33 39 45 51 57]
[ 5 11 17 23 29 35 41 47 53 59]], shape=(3, 10), dtype=int32)
TensorShape([6, 10])
result:
[[0 6 12 ... 42 48 54]
[1 7 13 ... 43 49 55]
[2 8 14 ... 44 50 56]
[3 9 15 ... 45 51 57]
[4 10 16 ... 46 52 58]
[5 11 17 ... 47 53 59]]
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment