Created
July 24, 2017 10:41
-
-
Save mikigom/bad72795c5e87e3caa9464e64952b524 to your computer and use it in GitHub Desktop.
Tensorflow Implementation of Bilinear Additive Upsampling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
""" | |
Author : @MikiBear_ | |
Tensorflow Implementation of Bilinear Additive Upsampling. | |
Reference : https://arxiv.org/abs/1707.05847 | |
""" | |
def bilinear_additive_upsampling(x, to_channel_num, name): | |
from_channel_num = x.get_shape().as_list()[3] | |
assert from_channel_num % to_channel_num == 0 | |
channel_split = from_channel_num / to_channel_num | |
new_shape = x.get_shape().as_list() | |
new_shape[1] *= 2 | |
new_shape[2] *= 2 | |
new_shape[3] = to_channel_num | |
upsampled_x = tf.image.resize_images(x, new_shape[1:3]) | |
output_list = [] | |
for i in range(to_channel_num): | |
splited_upsampled_x = upsampled_x[:,:,:,i*channel_split:(i+1)*channel_split] | |
output_list.append(tf.reduce_sum(splited_upsampled_x, axis = -1)) | |
output = tf.stack(output_list, axis = -1) | |
return output | |
if __name__ == '__main__': | |
x = tf.ones([20, 100, 100, 20]) | |
y = bilinear_additive_upsampling(x, 5, '0') | |
with tf.Session() as sess: | |
init_op = tf.global_variables_initializer() | |
sess.run(init_op) | |
new_x = sess.run(y) | |
print(new_x) |
In python3 it should be
channel_split = from_channel_num // to_channel_num
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
Thanks for sharing this implementation. But when I run this on TF1.9.0, I'm getting the following error:
Traceback (most recent call last): File "tf_bilinear_additive_upsampling.py", line 31, in <module> y = bilinear_additive_upsampling(x, 10, '0') File "tf_bilinear_additive_upsampling.py", line 22, in bilinear_additive_upsampling splited_upsampled_x = upsampled_x[:,:,:,i*channel_split:(i+1)*channel_split] File "/home/dhaneshr/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 523, in _slice_helper name=name) File "/home/dhaneshr/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 689, in strided_slice shrink_axis_mask=shrink_axis_mask) File "/home/dhaneshr/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 8232, in strided_slice name=name) File "/home/dhaneshr/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 609, in _apply_op_helper param_name=input_name) File "/home/dhaneshr/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) TypeError: Value passed to parameter 'begin' has DataType float32 not in list of allowed values: int32, int64