Created
July 24, 2017 10:41
-
-
Save mikigom/bad72795c5e87e3caa9464e64952b524 to your computer and use it in GitHub Desktop.
Tensorflow Implementation of Bilinear Additive Upsampling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| """ | |
| Author : @MikiBear_ | |
| Tensorflow Implementation of Bilinear Additive Upsampling. | |
| Reference : https://arxiv.org/abs/1707.05847 | |
| """ | |
| def bilinear_additive_upsampling(x, to_channel_num, name): | |
| from_channel_num = x.get_shape().as_list()[3] | |
| assert from_channel_num % to_channel_num == 0 | |
| channel_split = from_channel_num / to_channel_num | |
| new_shape = x.get_shape().as_list() | |
| new_shape[1] *= 2 | |
| new_shape[2] *= 2 | |
| new_shape[3] = to_channel_num | |
| upsampled_x = tf.image.resize_images(x, new_shape[1:3]) | |
| output_list = [] | |
| for i in range(to_channel_num): | |
| splited_upsampled_x = upsampled_x[:,:,:,i*channel_split:(i+1)*channel_split] | |
| output_list.append(tf.reduce_sum(splited_upsampled_x, axis = -1)) | |
| output = tf.stack(output_list, axis = -1) | |
| return output | |
| if __name__ == '__main__': | |
| x = tf.ones([20, 100, 100, 20]) | |
| y = bilinear_additive_upsampling(x, 5, '0') | |
| with tf.Session() as sess: | |
| init_op = tf.global_variables_initializer() | |
| sess.run(init_op) | |
| new_x = sess.run(y) | |
| print(new_x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In python3 it should be
channel_split = from_channel_num // to_channel_num