Skip to content

Instantly share code, notes, and snippets.

@mikigom
Created July 24, 2017 10:41
Show Gist options
  • Save mikigom/bad72795c5e87e3caa9464e64952b524 to your computer and use it in GitHub Desktop.
Save mikigom/bad72795c5e87e3caa9464e64952b524 to your computer and use it in GitHub Desktop.
Tensorflow Implementation of Bilinear Additive Upsampling
import tensorflow as tf
"""
Author : @MikiBear_
Tensorflow Implementation of Bilinear Additive Upsampling.
Reference : https://arxiv.org/abs/1707.05847
"""
def bilinear_additive_upsampling(x, to_channel_num, name):
from_channel_num = x.get_shape().as_list()[3]
assert from_channel_num % to_channel_num == 0
channel_split = from_channel_num / to_channel_num
new_shape = x.get_shape().as_list()
new_shape[1] *= 2
new_shape[2] *= 2
new_shape[3] = to_channel_num
upsampled_x = tf.image.resize_images(x, new_shape[1:3])
output_list = []
for i in range(to_channel_num):
splited_upsampled_x = upsampled_x[:,:,:,i*channel_split:(i+1)*channel_split]
output_list.append(tf.reduce_sum(splited_upsampled_x, axis = -1))
output = tf.stack(output_list, axis = -1)
return output
if __name__ == '__main__':
x = tf.ones([20, 100, 100, 20])
y = bilinear_additive_upsampling(x, 5, '0')
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
new_x = sess.run(y)
print(new_x)
@dhaneshr
Copy link

Hi,
Thanks for sharing this implementation. But when I run this on TF1.9.0, I'm getting the following error:
Traceback (most recent call last): File "tf_bilinear_additive_upsampling.py", line 31, in <module> y = bilinear_additive_upsampling(x, 10, '0') File "tf_bilinear_additive_upsampling.py", line 22, in bilinear_additive_upsampling splited_upsampled_x = upsampled_x[:,:,:,i*channel_split:(i+1)*channel_split] File "/home/dhaneshr/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 523, in _slice_helper name=name) File "/home/dhaneshr/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 689, in strided_slice shrink_axis_mask=shrink_axis_mask) File "/home/dhaneshr/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 8232, in strided_slice name=name) File "/home/dhaneshr/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 609, in _apply_op_helper param_name=input_name) File "/home/dhaneshr/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) TypeError: Value passed to parameter 'begin' has DataType float32 not in list of allowed values: int32, int64

@mrgloom
Copy link

mrgloom commented Sep 6, 2019

In python3 it should be
channel_split = from_channel_num // to_channel_num

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment