Skip to content

Instantly share code, notes, and snippets.

@Mainvooid
Created December 5, 2018 08:46
Show Gist options
  • Select an option

  • Save Mainvooid/84d117f6d24ceaa171b0a76acd9827a7 to your computer and use it in GitHub Desktop.

Select an option

Save Mainvooid/84d117f6d24ceaa171b0a76acd9827a7 to your computer and use it in GitHub Desktop.
keras实现deeplabv3+模型 #Keras #Python
from keras.layers import Activation,Conv2D,MaxPooling2D,BatchNormalization,Input,DepthwiseConv2D,add,Dropout,AveragePooling2D,Concatenate
from keras.models import Model
import keras.backend as K
from keras.engine import Layer,InputSpec
from keras.utils import conv_utils
class BilinearUpsampling(Layer):
def __init__(self, upsampling=(2, 2), data_format=None, **kwargs):
super(BilinearUpsampling, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.upsampling = conv_utils.normalize_tuple(upsampling, 2, 'size')
self.input_spec = InputSpec(ndim=4)
def compute_output_shape(self, input_shape):
height = self.upsampling[0] * \
input_shape[1] if input_shape[1] is not None else None
width = self.upsampling[1] * \
input_shape[2] if input_shape[2] is not None else None
return (input_shape[0],
height,
width,
input_shape[3])
def call(self, inputs):
return K.tf.image.resize_bilinear(inputs, (int(inputs.shape[1]*self.upsampling[0]),
int(inputs.shape[2]*self.upsampling[1])))
def get_config(self):
config = {'size': self.upsampling,
'data_format': self.data_format}
base_config = super(BilinearUpsampling, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def xception_downsample_block(x,channels,top_relu=False):
##separable conv1
if top_relu:
x=Activation("relu")(x)
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(channels,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
##separable conv2
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(channels,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
##separable conv3
x=DepthwiseConv2D((3,3),strides=(2,2),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(channels,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
return x
def res_xception_downsample_block(x,channels):
res=Conv2D(channels,(1,1),strides=(2,2),padding="same",use_bias=False)(x)
res=BatchNormalization()(res)
x=xception_downsample_block(x,channels)
x=add([x,res])
return x
def xception_block(x,channels):
##separable conv1
x=Activation("relu")(x)
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(channels,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
##separable conv2
x=Activation("relu")(x)
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(channels,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
##separable conv3
x=Activation("relu")(x)
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(channels,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
return x
def res_xception_block(x,channels):
res=x
x=xception_block(x,channels)
x=add([x,res])
return x
def aspp(x,input_shape,out_stride):
b0=Conv2D(256,(1,1),padding="same",use_bias=False)(x)
b0=BatchNormalization()(b0)
b0=Activation("relu")(b0)
b1=DepthwiseConv2D((3,3),dilation_rate=(6,6),padding="same",use_bias=False)(x)
b1=BatchNormalization()(b1)
b1=Activation("relu")(b1)
b1=Conv2D(256,(1,1),padding="same",use_bias=False)(b1)
b1=BatchNormalization()(b1)
b1=Activation("relu")(b1)
b2=DepthwiseConv2D((3,3),dilation_rate=(12,12),padding="same",use_bias=False)(x)
b2=BatchNormalization()(b2)
b2=Activation("relu")(b2)
b2=Conv2D(256,(1,1),padding="same",use_bias=False)(b2)
b2=BatchNormalization()(b2)
b2=Activation("relu")(b2)
b3=DepthwiseConv2D((3,3),dilation_rate=(12,12),padding="same",use_bias=False)(x)
b3=BatchNormalization()(b3)
b3=Activation("relu")(b3)
b3=Conv2D(256,(1,1),padding="same",use_bias=False)(b3)
b3=BatchNormalization()(b3)
b3=Activation("relu")(b3)
out_shape=int(input_shape[0]/out_stride)
b4=AveragePooling2D(pool_size=(out_shape,out_shape))(x)
b4=Conv2D(256,(1,1),padding="same",use_bias=False)(b4)
b4=BatchNormalization()(b4)
b4=Activation("relu")(b4)
b4=BilinearUpsampling((out_shape,out_shape))(b4)
x=Concatenate()([b4,b0,b1,b2,b3])
return x
def deeplabv3_plus(input_shape=(512,512,3),out_stride=16,num_classes=21):
img_input=Input(shape=input_shape)
x=Conv2D(32,(3,3),strides=(2,2),padding="same",use_bias=False)(img_input)
x=BatchNormalization()(x)
x=Activation("relu")(x)
x=Conv2D(64,(3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
x=res_xception_downsample_block(x,128)
res=Conv2D(256,(1,1),strides=(2,2),padding="same",use_bias=False)(x)
res=BatchNormalization()(res)
x=Activation("relu")(x)
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(256,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(256,(1,1),padding="same",use_bias=False)(x)
skip=BatchNormalization()(x)
x=Activation("relu")(skip)
x=DepthwiseConv2D((3,3),strides=(2,2),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(256,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=add([x,res])
x=xception_downsample_block(x,728,top_relu=True)
for i in range(16):
x=res_xception_block(x,728)
res=Conv2D(1024,(1,1),padding="same",use_bias=False)(x)
res=BatchNormalization()(res)
x=Activation("relu")(x)
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(728,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(1024,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(1024,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=add([x,res])
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(1536,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(1536,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Conv2D(2048,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
#aspp
x=aspp(x,input_shape,out_stride)
x=Conv2D(256,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
x=Dropout(0.9)(x)
##decoder
x=BilinearUpsampling((4,4))(x)
dec_skip=Conv2D(48,(1,1),padding="same",use_bias=False)(skip)
dec_skip=BatchNormalization()(dec_skip)
dec_skip=Activation("relu")(dec_skip)
x=Concatenate()([x,dec_skip])
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
x=Conv2D(256,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
x=Conv2D(256,(1,1),padding="same",use_bias=False)(x)
x=BatchNormalization()(x)
x=Activation("relu")(x)
x=Conv2D(num_classes,(1,1),padding="same")(x)
x=BilinearUpsampling((4,4))(x)
model=Model(img_input,x)
return model
if __name__=="__main__":
model=deeplabv3_plus(num_classes=1)
model.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment