Created
May 28, 2017 23:27
-
-
Save mikigom/757bd0132ace8bf914a82b8d7c200da4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| slim = tf.contrib.slim | |
| class Unet(object): | |
| def __init__(self, input, class_num, reuse = False): | |
| self.input = input | |
| self.reuse = reuse | |
| self.class_num = class_num | |
| self.build_model() | |
| def build_model(self): | |
| # (30, 572, 572, 1) | |
| self.input = tf.expand_dims(self.input, -1) | |
| if self.reuse: | |
| tf.get_variable_scope().reuse_variables() | |
| with tf.name_scope('encoder'): | |
| with tf.name_scope('e0'): | |
| # (30, 570, 570, 64) | |
| self.e0_0 = lrelu(slim.conv2d(self.input, 64, [3, 3], scope = 'e0_0', padding = 'VALID', activation_fn = None)) | |
| # (30, 568, 568, 64) | |
| self.e0_1 = lrelu(slim.conv2d(self.e0_0, 64, [3, 3], scope = 'e0_1', padding = 'VALID', activation_fn = None)) | |
| with tf.name_scope('max_pool1'): | |
| # (30, 284, 284, 64) | |
| self.max_pool1 = slim.max_pool2d(self.e0_1, [2, 2], scope = 'max_pool1') | |
| with tf.name_scope('e1'): | |
| # (30, 282, 282, 128) | |
| self.e1_0 = BL(slim.conv2d(self.max_pool1, 128, [3, 3], scope = 'e1_0', padding = 'VALID', activation_fn = None)) | |
| # (30, 280, 280, 128) | |
| self.e1_1 = BL(slim.conv2d(self.e1_0, 128, [3, 3], scope = 'e1_1', padding = 'VALID', activation_fn = None)) | |
| with tf.name_scope('max_pool2'): | |
| # (30, 140, 140, 128) | |
| self.max_pool2 = slim.max_pool2d(self.e1_1, [2, 2], scope = 'max_pool2') | |
| with tf.name_scope('e2'): | |
| # (30, 138, 138, 256) | |
| self.e2_0 = BL(slim.conv2d(self.max_pool2, 256, [3, 3], scope = 'e2_0', padding = 'VALID', activation_fn = None)) | |
| # (30, 136, 136, 256) | |
| self.e2_1 = BL(slim.conv2d(self.e2_0, 256, [3, 3], scope = 'e2_1', padding = 'VALID', activation_fn = None)) | |
| with tf.name_scope('max_pool3'): | |
| # (30, 68, 68, 256) | |
| self.max_pool3 = slim.max_pool2d(self.e2_1, [2, 2], scope = 'max_pool3') | |
| with tf.name_scope('e3'): | |
| # (30, 66, 66, 512) | |
| self.e3_0 = BL(slim.conv2d(self.max_pool3, 512, [3, 3], scope = 'e3_0', padding = 'VALID', activation_fn = None)) | |
| # (30, 64, 64, 512) | |
| self.e3_1 = BL(slim.conv2d(self.e3_0, 512, [3, 3], scope = 'e3_1', padding = 'VALID', activation_fn = None)) | |
| with tf.name_scope('middle'): | |
| with tf.name_scope('max_pool4'): | |
| # (30, 32, 32, 512) | |
| self.max_pool4 = slim.max_pool2d(self.e3_1, [2, 2], scope = 'max_pool4') | |
| with tf.name_scope('middle_conv'): | |
| self.middle_0 = BL(slim.conv2d(self.max_pool4, 1024, [3, 3], scope = 'middle_0', padding = 'SAME', activation_fn = None)) | |
| self.middle_1 = BL(slim.conv2d(self.middle_0, 1024, [3, 3], scope = 'middle_1', padding = 'VALID', activation_fn = None)) | |
| self.middle_2 = BDR(slim.conv2d(self.middle_1, 1024, [3, 3], scope = 'middle_2', padding = 'SAME', activation_fn = None)) | |
| self.middle_3 = BDR(slim.conv2d(self.middle_2, 1024, [3, 3], scope = 'middle_3', padding = 'VALID', activation_fn = None)) | |
| with tf.name_scope('up_conv1'): | |
| # (30, 56, 56, 512) | |
| self.up_conv1 = slim.batch_norm(slim.conv2d_transpose(self.middle_3, 512, [2, 2], [2, 2], scope = 'up_conv1', padding = 'VALID', activation_fn = None), activation_fn = None) | |
| with tf.name_scope('decoder'): | |
| with tf.name_scope('d0'): | |
| _, self.e3_1_w, self.e3_1_h, __ = self.e3_1.get_shape().as_list() | |
| _, self.up_conv1_w, self.up_conv1_h, __ = self.up_conv1.get_shape().as_list() | |
| self.crop_0 = self.e3_1[:, self.e3_1_w/2 - self.up_conv1_w/2 : self.e3_1_w/2 + self.up_conv1_w/2,\ | |
| self.e3_1_h/2 - self.up_conv1_h/2 : self.e3_1_h/2 + self.up_conv1_w/2, :] | |
| # (30, 54, 54, 512) | |
| self.d0_0 = BR(slim.conv2d(tf.concat([self.up_conv1, self.crop_0], axis = 3), 512, [3, 3], scope = 'd0_0', padding = 'VALID', activation_fn = None)) | |
| # (30, 52, 52, 512) | |
| self.d0_1 = BR(slim.conv2d(self.d0_0, 512, [3, 3], scope = 'd0_1', padding = 'VALID', activation_fn = None)) | |
| with tf.name_scope('up_conv2'): | |
| # (30, 104, 104, 256) | |
| self.up_conv2 = slim.batch_norm(slim.conv2d_transpose(self.d0_1, 256, [2, 2], [2, 2], scope = 'up_conv2', padding = 'VALID', activation_fn = None), activation_fn = None) | |
| with tf.name_scope('d1'): | |
| _, self.e2_1_w, self.e2_1_h, __ = self.e2_1.get_shape().as_list() | |
| _, self.up_conv2_w, self.up_conv2_h, __ = self.up_conv2.get_shape().as_list() | |
| self.crop_1 = self.e2_1[:, self.e2_1_w/2 - self.up_conv2_w/2 : self.e2_1_w/2 + self.up_conv2_w/2,\ | |
| self.e2_1_h/2 - self.up_conv2_h/2 : self.e2_1_h/2 + self.up_conv2_w/2, :] | |
| # (30, 102, 102, 256) | |
| self.d1_0 = BR(slim.conv2d(tf.concat([self.up_conv2, self.crop_1], axis = 3), 256, [3, 3], scope = 'd1_0', padding = 'VALID', activation_fn = None)) | |
| # (30, 100, 100, 256) | |
| self.d1_1 = BR(slim.conv2d(self.d1_0, 256, [3, 3], scope = 'd1_1', padding = 'VALID', activation_fn = None)) | |
| with tf.name_scope('up_conv3'): | |
| # (30, 200, 200, 128) | |
| self.up_conv3 = slim.batch_norm(slim.conv2d_transpose(self.d1_1, 128, [2, 2], [2, 2], scope = 'up_conv3', padding = 'VALID', activation_fn = None), activation_fn = None) | |
| with tf.name_scope('d2'): | |
| _, self.e1_1_w, self.e1_1_h, __ = self.e1_1.get_shape().as_list() | |
| _, self.up_conv3_w, self.up_conv3_h, __ = self.up_conv3.get_shape().as_list() | |
| self.crop_2 = self.e1_1[:, self.e1_1_w/2 - self.up_conv3_w/2 : self.e1_1_w/2 + self.up_conv3_w/2,\ | |
| self.e1_1_h/2 - self.up_conv3_h/2 : self.e1_1_h/2 + self.up_conv3_w/2, :] | |
| # (30, 198, 198, 128) | |
| self.d2_0 = BR(slim.conv2d(tf.concat([self.up_conv3, self.crop_2], axis = 3), 128, [3, 3], scope = 'd2_0', padding = 'VALID', activation_fn = None)) | |
| # (30, 196, 196, 128) | |
| self.d2_1 = BR(slim.conv2d(self.d2_0, 128, [3, 3], scope = 'd2_1', padding = 'VALID', activation_fn = None)) | |
| with tf.name_scope('up_conv4'): | |
| # (30, 392, 392, 64) | |
| self.up_conv4 = slim.batch_norm(slim.conv2d_transpose(self.d2_1, 64, [2, 2], [2, 2], scope = 'up_conv4', padding = 'VALID', activation_fn = None), activation_fn = None) | |
| with tf.name_scope('d3'): | |
| _, self.e0_1_w, self.e0_1_h, __ = self.e0_1.get_shape().as_list() | |
| _, self.up_conv4_w, self.up_conv4_h, __ = self.up_conv4.get_shape().as_list() | |
| self.crop_3 = self.e0_1[:, self.e0_1_w/2 - self.up_conv4_w/2 : self.e0_1_w/2 + self.up_conv4_w/2,\ | |
| self.e0_1_h/2 - self.up_conv4_h/2 : self.e0_1_h/2 + self.up_conv4_w/2, :] | |
| # (30, 390, 390, 64) | |
| self.d3_0 = BR(slim.conv2d(tf.concat([self.up_conv4, self.crop_3], axis = 3), 64, [3, 3], scope = 'd3_0', padding = 'VALID', activation_fn = None)) | |
| # (30, 388, 388, 64) | |
| self.d3_1 = BR(slim.conv2d(self.d3_0, 64, [3, 3], scope = 'd3_1', padding = 'VALID', activation_fn = None)) | |
| with tf.name_scope('conv1x1'): | |
| # (30, 388, 388, 2) | |
| self.output = slim.conv2d(self.d3_1, self.class_num, [1, 1], scope = 'output', padding = 'VALID', activation_fn = tf.sigmoid) | |
| def lrelu(x, leak=0.2, name="lrelu"): | |
| return tf.maximum(x, leak*x) | |
| def BL(x, leak = 0.2): | |
| return lrelu(slim.batch_norm(x, activation_fn = None), leak) | |
| def BDR(x): | |
| return tf.nn.relu(slim.dropout(slim.batch_norm(x, activation_fn = None))) | |
| def BR(x): | |
| return tf.nn.relu(slim.batch_norm(x, activation_fn = None)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment