Skip to content

Instantly share code, notes, and snippets.

@rkawajiri
Last active January 18, 2017 23:21
Show Gist options
  • Select an option

  • Save rkawajiri/6df312f9fa65afed50448bf00d782c0d to your computer and use it in GitHub Desktop.

Select an option

Save rkawajiri/6df312f9fa65afed50448bf00d782c0d to your computer and use it in GitHub Desktop.
Deconvolution with tensorflow
import numpy as np
import tensorflow as tf
def unpool(input_images, argmax, output_shape, name='unpooling'):
os = output_shape.as_list()
output_sz = np.prod(os)
b = os[0]
output_hwc = np.prod(os[1:])
input_hwc = np.prod(argmax.get_shape().as_list()[1:])
offset = tf.tile(tf.reshape(tf.range(b, dtype=tf.int64), [b, 1]), [1, input_hwc]) * output_hwc
reshaped_argmax = tf.reshape(argmax, [b, input_hwc])
indices = tf.reshape(reshaped_argmax + offset, [-1, 1])
updates = tf.reshape(input_images, [-1])
scatter = tf.scatter_nd(indices=indices,
updates=updates,
shape=tf.constant([output_sz], dtype=tf.int64))
return tf.reshape(scatter, output_shape, name=name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment